diff options
191 files changed, 14634 insertions, 8136 deletions
@@ -1,4 +1,10 @@ -== Core development team == -Matthew Wild (matthew.wild AT heavy-horse.co.uk) -Waqas Hussain (waqas20 AT gmail.com) +The Prosody project is open to contributions (see HACKERS file), but is +maintained daily by: + + - Matthew Wild (mail: matthew [at] prosody.im) + - Waqas Hussain (mail: waqas [at] prosody.im) + - Kim Alvefur (mail: zash [at] prosody.im) + +You can reach us collectively by email: developers [at] prosody.im +or in realtime in the Prosody chatroom: prosody@conference.prosody.im @@ -1,5 +1,5 @@ -Copyright (c) 2008-2009 Matthew Wild -Copyright (c) 2008-2009 Waqas Hussain +Copyright (c) 2008-2011 Matthew Wild +Copyright (c) 2008-2011 Waqas Hussain Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -11,10 +11,10 @@ furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. @@ -1,8 +1,11 @@ -The easiest way to install dependencies is using the luarocks tool. -Rocks: -luaexpat -luasocket +For full information on our dependencies, version requirements, and +where to find them, see http://prosody.im/doc/depends + +If you have luarocks available on your platform, install the following: + + - luaexpat + - luasocket + - luafilesystem + - luasec -Non-rocks: -LuaSec for SSL connections @@ -1,18 +1,21 @@ -(This file was created from -http://prosody.im/doc/installing_from_source on 2009-05-22) +(This file was created from +http://prosody.im/doc/installing_from_source on 2013-03-31) -===== Building ===== +====== Installing from source ====== ==== Dependencies ==== There are a couple of libraries which Prosody needs installed before you can build it. These are: + * lua5.1: The Lua 5.1 interpreter * liblua5.1: Lua 5.1 library * libssl 0.9.8: OpenSSL * libidn11: GNU libidn library, version 1.1 -Both of these can be installed on Debian/Ubuntu with the packages: +These can be installed on Debian/Ubuntu with the packages: lua5.1 liblua5.1-dev libidn11-dev libssl-dev +On Mandriva try: urpmi lua liblua-devel libidn-devel libopenssl-devel + On other systems... good luck, but please let me know of the best way of getting the dependencies for your system and I can add it here. @@ -30,7 +33,8 @@ accepts. You can load a preset using: ./configure --ostype=PRESET -Where PRESET can currently be one of: debian, macosx +Where PRESET can currently be one of: 'debian', 'macosx' or (in 0.8 +and later) 'freebsd' ==== make ==== Once you have run configure successfully, then you can simply run: @@ -13,6 +13,8 @@ INSTALLEDCONFIG = $(SYSCONFDIR) INSTALLEDMODULES = $(PREFIX)/lib/prosody/modules INSTALLEDDATA = $(DATADIR) +.PHONY: all clean install + all: prosody.install prosodyctl.install prosody.cfg.lua.install prosody.version $(MAKE) -C util-src install @@ -25,20 +27,18 @@ install: prosody.install prosodyctl.install prosody.cfg.lua.install util/encodin install -m755 ./prosody.install $(BIN)/prosody install -m755 ./prosodyctl.install $(BIN)/prosodyctl install -m644 core/* $(SOURCE)/core - install -m644 net/* $(SOURCE)/net + install -m644 net/*.lua $(SOURCE)/net + install -d $(SOURCE)/net/http + install -m644 net/http/*.lua $(SOURCE)/net/http install -m644 util/*.lua $(SOURCE)/util install -m644 util/*.so $(SOURCE)/util install -d $(SOURCE)/util/sasl install -m644 util/sasl/* $(SOURCE)/util/sasl - install -m644 plugins/*.lua $(MODULES) - install -d $(MODULES)/muc - install -m644 plugins/muc/* $(MODULES)/muc + umask 0022 && cp -r plugins/* $(MODULES) install -m644 certs/* $(CONFIG)/certs - install -d $(MODULES)/adhoc - install -m644 plugins/adhoc/*.lua $(MODULES)/adhoc install -m644 man/prosodyctl.man $(MAN)/man1/prosodyctl.1 test -e $(CONFIG)/prosody.cfg.lua || install -m644 prosody.cfg.lua.install $(CONFIG)/prosody.cfg.lua - test -e prosody.version && install prosody.version $(SOURCE)/prosody.version || true + test -e prosody.version && install -m644 prosody.version $(SOURCE)/prosody.version || true $(MAKE) install -C util-src clean: @@ -48,36 +48,21 @@ clean: rm -f prosody.version $(MAKE) clean -C util-src -util/encodings.so: - $(MAKE) install -C util-src - -util/hashes.so: +util/%.so: $(MAKE) install -C util-src -util/pposix.so: - $(MAKE) install -C util-src - -util/signal.so: - $(MAKE) install -C util-src - -prosody.install: prosody - sed "s|^CFG_SOURCEDIR=.*;$$|CFG_SOURCEDIR='$(INSTALLEDSOURCE)';|; \ +%.install: % + sed "1s/\blua\b/$(RUNWITH)/; \ + s|^CFG_SOURCEDIR=.*;$$|CFG_SOURCEDIR='$(INSTALLEDSOURCE)';|; \ s|^CFG_CONFIGDIR=.*;$$|CFG_CONFIGDIR='$(INSTALLEDCONFIG)';|; \ s|^CFG_DATADIR=.*;$$|CFG_DATADIR='$(INSTALLEDDATA)';|; \ - s|^CFG_PLUGINDIR=.*;$$|CFG_PLUGINDIR='$(INSTALLEDMODULES)/';|;" < prosody > prosody.install - -prosodyctl.install: prosodyctl - sed "s|^CFG_SOURCEDIR=.*;$$|CFG_SOURCEDIR='$(INSTALLEDSOURCE)';|; \ - s|^CFG_CONFIGDIR=.*;$$|CFG_CONFIGDIR='$(INSTALLEDCONFIG)';|; \ - s|^CFG_DATADIR=.*;$$|CFG_DATADIR='$(INSTALLEDDATA)';|; \ - s|^CFG_PLUGINDIR=.*;$$|CFG_PLUGINDIR='$(INSTALLEDMODULES)/';|;" < prosodyctl > prosodyctl.install - -prosody.cfg.lua.install: - sed 's|certs/|$(INSTALLEDCONFIG)/certs/|' prosody.cfg.lua.dist > prosody.cfg.lua.install + s|^CFG_PLUGINDIR=.*;$$|CFG_PLUGINDIR='$(INSTALLEDMODULES)/';|;" < $^ > $@ -prosody.release: - test -e .hg/dirstate && hexdump -n6 -e'6/1 "%02x"' .hg/dirstate \ - > prosody.version || true +prosody.cfg.lua.install: prosody.cfg.lua.dist + sed 's|certs/|$(INSTALLEDCONFIG)/certs/|' $^ > $@ -prosody.version: prosody.release - cp prosody.release prosody.version || true +prosody.version: $(wildcard prosody.release .hg/dirstate) + test -e .hg/dirstate && \ + hexdump -n6 -e'6/1 "%02x"' .hg/dirstate > $@ || true + test -f prosody.release && \ + cp prosody.release $@ || true @@ -1,16 +1,5 @@ -== 0.8 == -- Ad-hoc commands: - http://code.google.com/p/prosody-modules/wiki/mod_adhoc - http://code.google.com/p/prosody-modules/wiki/mod_adhoc_cmd_admin - http://code.google.com/p/prosody-modules/wiki/mod_adhoc_cmd_ping - http://code.google.com/p/prosody-modules/wiki/mod_adhoc_cmd_uptime - -- Pubsub -- Data storage backend abstraction - -== 0.9 == -- Clustering - == 1.0 == -- Web interface? +- Roster providers +- Statistics +- Clustering - World domination diff --git a/certs/Makefile b/certs/Makefile index c5e4294c..f3854c5f 100644 --- a/certs/Makefile +++ b/certs/Makefile @@ -1,4 +1,4 @@ -.DEFAULT: localhost.cert +.DEFAULT: localhost.crt keysize=2048 # How to: @@ -8,7 +8,7 @@ keysize=2048 # Then `make yourhost.key` to create your private key, you can # include keysize=number to change the size of the key. # Then you can either `make yourhost.csr` to generate a certificate -# signing request that you can submit to a CA, or `make yourhost.cert` +# signing request that you can submit to a CA, or `make yourhost.crt` # to generate a self signed certificate. .PRECIOUS: %.cnf %.key @@ -18,7 +18,7 @@ keysize=2048 openssl req -new -key $(lastword $^) -out $@ -utf8 -config $(firstword $^) # Self signed -%.cert: %.cnf %.key +%.crt: %.cnf %.key openssl req -new -x509 -nodes -key $(lastword $^) -days 365 \ -sha1 -out $@ -utf8 -config $(firstword $^) diff --git a/certs/localhost.cert b/certs/localhost.crt index 5156d307..5156d307 100644 --- a/certs/localhost.cert +++ b/certs/localhost.crt diff --git a/certs/openssl.cnf b/certs/openssl.cnf index 44fc0424..091409c4 100644 --- a/certs/openssl.cnf +++ b/certs/openssl.cnf @@ -2,7 +2,7 @@ oid_section = new_oids [ new_oids ] -# RFC 3920 section 5.1.1 defines this OID +# RFC 6120 section 13.7.1.4. defines this OID xmppAddr = 1.3.6.1.5.5.7.8.5 # RFC 4985 defines this OID @@ -40,13 +40,13 @@ subjectAltName = @subject_alternative_name [ subject_alternative_name ] -# See http://tools.ietf.org/html/draft-ietf-xmpp-3920bis#section-13.7.1.2 for more info. +# See http://tools.ietf.org/html/rfc6120#section-13.7.1.2 for more info. DNS.0 = example.com -otherName.0 = xmppAddr;UTF8:example.com +otherName.0 = xmppAddr;FORMAT:UTF8,UTF8:example.com otherName.1 = SRVName;IA5STRING:_xmpp-client.example.com otherName.2 = SRVName;IA5STRING:_xmpp-server.example.com DNS.1 = conference.example.com -otherName.3 = xmppAddr;UTF8:conference.example.com +otherName.3 = xmppAddr;FORMAT:UTF8,UTF8:conference.example.com otherName.4 = SRVName;IA5STRING:_xmpp-server.conference.example.com @@ -16,6 +16,7 @@ OPENSSL_LIB=crypto CC=gcc CXX=g++ LD=gcc +RUNWITH=lua CFLAGS="-fPIC -Wall" LDFLAGS="-shared" @@ -40,15 +41,17 @@ Configure Prosody prior to building. Default is "$LUA_SUFFIX" (lua$LUA_SUFFIX...) --with-lua=PREFIX Use Lua from given prefix. Default is $LUA_DIR +--runwith=BINARY What Lua binary to set as runtime environment. + Default is $RUNWITH --with-lua-include=DIR You can also specify Lua's includes dir. Default is \$LUA_DIR/include --with-lua-lib=DIR You can also specify Lua's libraries dir. Default is \$LUA_DIR/lib --with-idn=LIB The name of the IDN library to link with. Default is $IDN_LIB ---idn-library=(idn|icu) Select library to use for IDNA functionality. - idn: use GNU libidn (default) - icu: use ICU from IBM +--idn-library=(idn|icu) Select library to use for IDNA functionality. + idn: use GNU libidn (default) + icu: use ICU from IBM --with-ssl=LIB The name of the SSL to link with. Default is $OPENSSL_LIB --cflags=FLAGS Flags to pass to the compiler @@ -91,29 +94,31 @@ do --ostype=*) OSTYPE="$value" OSTYPE_SET=yes - if [ "$OSTYPE" = "debian" ] - then LUA_SUFFIX="5.1"; - LUA_SUFFIX_SET=yes - LUA_INCDIR=/usr/include/lua5.1; - LUA_INCDIR_SET=yes - fi - if [ "$OSTYPE" = "macosx" ] - then LUA_INCDIR=/usr/local/include; - LUA_INCDIR_SET=yes - LUA_LIBDIR=/usr/local/lib - LUA_LIBDIR_SET=yes - LDFLAGS="-bundle -undefined dynamic_lookup" - fi - if [ "$OSTYPE" = "linux" ] - then LUA_INCDIR=/usr/local/include; + if [ "$OSTYPE" = "debian" ]; then + LUA_SUFFIX="5.1"; + LUA_SUFFIX_SET=yes + RUNWITH="lua5.1" + LUA_INCDIR=/usr/include/lua5.1; + LUA_INCDIR_SET=yes + CFLAGS="$CFLAGS -D_GNU_SOURCE" + fi + if [ "$OSTYPE" = "macosx" ]; then + LUA_INCDIR=/usr/local/include; + LUA_INCDIR_SET=yes + LUA_LIBDIR=/usr/local/lib + LUA_LIBDIR_SET=yes + LDFLAGS="-bundle -undefined dynamic_lookup" + fi + if [ "$OSTYPE" = "linux" ]; then + LUA_INCDIR=/usr/local/include; LUA_INCDIR_SET=yes LUA_LIBDIR=/usr/local/lib LUA_LIBDIR_SET=yes - CFLAGS="-Wall -fPIC" + CFLAGS="-Wall -fPIC -D_GNU_SOURCE" LDFLAGS="-shared" - fi - if [ "$OSTYPE" = "freebsd" ] - then LUA_INCDIR="/usr/local/include/lua51" + fi + if [ "$OSTYPE" = "freebsd" -o "$OSTYPE" = "openbsd" ]; then + LUA_INCDIR="/usr/local/include/lua51" LUA_INCDIR_SET=yes CFLAGS="-Wall -fPIC -I/usr/local/include" LDFLAGS="-I/usr/local/include -L/usr/local/lib -shared" @@ -121,7 +126,10 @@ do LUA_SUFFIX_SET=yes LUA_DIR=/usr/local LUA_DIR_SET=yes - fi + fi + if [ "$OSTYPE" = "openbsd" ]; then + LUA_INCDIR="/usr/local/include"; + fi ;; --datadir=*) DATADIR="$value" @@ -166,6 +174,9 @@ do --linker=*) LD="$value" ;; + --runwith=*) + RUNWITH="$value" + ;; *) echo "Error: Unknown flag: $1" exit 1 @@ -274,7 +285,7 @@ then IDNA_LIBS="$ICU_FLAGS" CFLAGS="$CFLAGS -DUSE_STRINGPREP_ICU" fi -if [ "$IDN_LIBRARY" = "idn" ] +if [ "$IDN_LIBRARY" = "idn" ] then IDNA_LIBS="-l$IDN_LIB" fi @@ -336,6 +347,7 @@ LDFLAGS=$LDFLAGS CC=$CC CXX=$CXX LD=$LD +RUNWITH=$RUNWITH EOF diff --git a/core/certmanager.lua b/core/certmanager.lua index 0dc0bfd4..caa4afce 100644 --- a/core/certmanager.lua +++ b/core/certmanager.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -11,49 +11,109 @@ local log = require "util.logger".init("certmanager"); local ssl = ssl; local ssl_newcontext = ssl and ssl.newcontext; -local setmetatable, tostring = setmetatable, tostring; +local tostring = tostring; +local pairs = pairs; +local type = type; +local io_open = io.open; local prosody = prosody; local resolve_path = configmanager.resolve_relative_path; local config_path = prosody.paths.config; +local luasec_has_noticket, luasec_has_verifyext, luasec_has_no_compression; +if ssl then + local luasec_major, luasec_minor = ssl._VERSION:match("^(%d+)%.(%d+)"); + luasec_has_noticket = tonumber(luasec_major)>0 or tonumber(luasec_minor)>=4; + luasec_has_verifyext = tonumber(luasec_major)>0 or tonumber(luasec_minor)>=5; + luasec_has_no_compression = tonumber(luasec_major)>0 or tonumber(luasec_minor)>=5; +end + module "certmanager" -- Global SSL options if not overridden per-host -local default_ssl_config = configmanager.get("*", "core", "ssl"); -local default_capath = "/etc/ssl/certs"; -local default_verify = (ssl and ssl.x509 and { "peer", "client_once", "continue", "ignore_purpose" }) or "none"; -local default_options = { "no_sslv2" }; +local global_ssl_config = configmanager.get("*", "ssl"); + +local core_defaults = { + capath = "/etc/ssl/certs"; + protocol = "sslv23"; + verify = (ssl and ssl.x509 and { "peer", "client_once", }) or "none"; + options = { "no_sslv2", luasec_has_noticket and "no_ticket" or nil }; + verifyext = { "lsec_continue", "lsec_ignore_purpose" }; + curve = "secp384r1"; + ciphers = "HIGH:!DSS:!aNULL@STRENGTH"; +} +local path_options = { -- These we pass through resolve_path() + key = true, certificate = true, cafile = true, capath = true, dhparam = true +} + +if ssl and not luasec_has_verifyext and ssl.x509 then + -- COMPAT mw/luasec-hg + for i=1,#core_defaults.verifyext do -- Remove lsec_ prefix + core_defaults.verify[#core_defaults.verify+1] = core_defaults.verifyext[i]:sub(6); + end +end + +if luasec_has_no_compression and configmanager.get("*", "ssl_compression") ~= true then + core_defaults.options[#core_defaults.options+1] = "no_compression"; +end function create_context(host, mode, user_ssl_config) - user_ssl_config = user_ssl_config or default_ssl_config; + user_ssl_config = user_ssl_config or {} + user_ssl_config.mode = mode; if not ssl then return nil, "LuaSec (required for encryption) was not found"; end - if not user_ssl_config then return nil, "No SSL/TLS configuration present for "..host; end - - local ssl_config = { - mode = mode; - protocol = user_ssl_config.protocol or "sslv23"; - key = resolve_path(config_path, user_ssl_config.key); - password = user_ssl_config.password; - certificate = resolve_path(config_path, user_ssl_config.certificate); - capath = resolve_path(config_path, user_ssl_config.capath or default_capath); - cafile = resolve_path(config_path, user_ssl_config.cafile); - verify = user_ssl_config.verify or default_verify; - options = user_ssl_config.options or default_options; - ciphers = user_ssl_config.ciphers; - depth = user_ssl_config.depth; - }; - - local ctx, err = ssl_newcontext(ssl_config); + + if global_ssl_config then + for option,default_value in pairs(global_ssl_config) do + if not user_ssl_config[option] then + user_ssl_config[option] = default_value; + end + end + end + for option,default_value in pairs(core_defaults) do + if not user_ssl_config[option] then + user_ssl_config[option] = default_value; + end + end + user_ssl_config.password = user_ssl_config.password or function() log("error", "Encrypted certificate for %s requires 'ssl' 'password' to be set in config", host); end; + for option in pairs(path_options) do + if type(user_ssl_config[option]) == "string" then + user_ssl_config[option] = resolve_path(config_path, user_ssl_config[option]); + end + end + + if not user_ssl_config.key then return nil, "No key present in SSL/TLS configuration for "..host; end + if not user_ssl_config.certificate then return nil, "No certificate present in SSL/TLS configuration for "..host; end + + -- LuaSec expects dhparam to be a callback that takes two arguments. + -- We ignore those because it is mostly used for having a separate + -- set of params for EXPORT ciphers, which we don't have by default. + if type(user_ssl_config.dhparam) == "string" then + local f, err = io_open(user_ssl_config.dhparam); + if not f then return nil, "Could not open DH parameters: "..err end + local dhparam = f:read("*a"); + f:close(); + user_ssl_config.dhparam = function() return dhparam; end + end + + local ctx, err = ssl_newcontext(user_ssl_config); + + -- COMPAT Older LuaSec ignores the cipher list from the config, so we have to take care + -- of it ourselves (W/A for #x) + if ctx and user_ssl_config.ciphers then + local success; + success, err = ssl.context.setcipher(ctx, user_ssl_config.ciphers); + if not success then ctx = nil; end + end + if not ctx then err = err or "invalid ssl config" local file = err:match("^error loading (.-) %("); if file then if file == "private key" then - file = ssl_config.key or "your private key"; + file = user_ssl_config.key or "your private key"; elseif file == "certificate" then - file = ssl_config.certificate or "your certificate file"; + file = user_ssl_config.certificate or "your certificate file"; end local reason = err:match("%((.+)%)$") or "some reason"; if reason == "Permission denied" then @@ -67,16 +127,16 @@ function create_context(host, mode, user_ssl_config) else reason = "Reason: "..tostring(reason):lower(); end - log("error", "SSL/TLS: Failed to load %s: %s", file, reason); + log("error", "SSL/TLS: Failed to load '%s': %s (for %s)", file, reason, host); else - log("error", "SSL/TLS: Error initialising for host %s: %s", host, err ); + log("error", "SSL/TLS: Error initialising for %s: %s", host, err); end end return ctx, err; end function reload_ssl_config() - default_ssl_config = configmanager.get("*", "core", "ssl"); + global_ssl_config = configmanager.get("*", "ssl"); end prosody.events.add_handler("config-reloaded", reload_ssl_config); diff --git a/core/configmanager.lua b/core/configmanager.lua index b703bb8c..d92120d0 100644 --- a/core/configmanager.lua +++ b/core/configmanager.lua @@ -1,18 +1,19 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- local _G = _G; -local setmetatable, loadfile, pcall, rawget, rawset, io, error, dofile, type, pairs, table = - setmetatable, loadfile, pcall, rawget, rawset, io, error, dofile, type, pairs, table; +local setmetatable, rawget, rawset, io, error, dofile, type, pairs, table = + setmetatable, rawget, rawset, io, error, dofile, type, pairs, table; local format, math_max = string.format, math.max; local fire_event = prosody and prosody.events.fire_event or function () end; +local envload = require"util.envload".envload; local lfs = require "lfs"; local path_sep = package.config:sub(1,1); @@ -21,65 +22,62 @@ module "configmanager" local parsers = {}; local config_mt = { __index = function (t, k) return rawget(t, "*"); end}; -local config = setmetatable({ ["*"] = { core = {} } }, config_mt); +local config = setmetatable({ ["*"] = { } }, config_mt); -- When host not found, use global -local host_mt = { }; - --- When key not found in section, check key in global's section -function section_mt(section_name) - return { __index = function (t, k) - local section = rawget(config["*"], section_name); - if not section then return nil; end - return section[k]; - end - }; -end +local host_mt = { __index = function(_, k) return config["*"][k] end } function getconfig() return config; end -function get(host, section, key) - local sec = config[host][section]; - if sec then - return sec[key]; +function get(host, key, _oldkey) + if key == "core" then + key = _oldkey; -- COMPAT with code that still uses "core" + end + return config[host][key]; +end +function _M.rawget(host, key, _oldkey) + if key == "core" then + key = _oldkey; -- COMPAT with code that still uses "core" + end + local hostconfig = rawget(config, host); + if hostconfig then + return rawget(hostconfig, key); end - return nil; end -local function set(config, host, section, key, value) - if host and section and key then +local function set(config, host, key, value) + if host and key then local hostconfig = rawget(config, host); if not hostconfig then hostconfig = rawset(config, host, setmetatable({}, host_mt))[host]; end - if not rawget(hostconfig, section) then - hostconfig[section] = setmetatable({}, section_mt(section)); - end - hostconfig[section][key] = value; + hostconfig[key] = value; return true; end return false; end -function _M.set(host, section, key, value) - return set(config, host, section, key, value); +function _M.set(host, key, value, _oldvalue) + if key == "core" then + key, value = value, _oldvalue; --COMPAT with code that still uses "core" + end + return set(config, host, key, value); end -- Helper function to resolve relative paths (needed by config) do - local rel_path_start = ".."..path_sep; function resolve_relative_path(parent_path, path) if path then -- Some normalization parent_path = parent_path:gsub("%"..path_sep.."+$", ""); path = path:gsub("^%.%"..path_sep.."+", ""); - + local is_relative; if path_sep == "/" and path:sub(1,1) ~= "/" then is_relative = true; - elseif path_sep == "\\" and (path:sub(1,1) ~= "/" and path:sub(2,3) ~= ":\\") then + elseif path_sep == "\\" and (path:sub(1,1) ~= "/" and (path:sub(2,3) ~= ":\\" and path:sub(2,3) ~= ":/")) then is_relative = true; end if is_relative then @@ -87,7 +85,7 @@ do end end return path; - end + end end -- Helper function to convert a glob to a Lua pattern @@ -109,7 +107,7 @@ function load(filename, format) if parsers[format] and parsers[format].load then local f, err = io.open(filename); if f then - local new_config = setmetatable({ ["*"] = { core = {} } }, config_mt); + local new_config = setmetatable({ ["*"] = { } }, config_mt); local ok, err = parsers[format].load(f:read("*a"), filename, new_config); f:close(); if ok then @@ -152,8 +150,8 @@ end -- Built-in Lua parser do - local loadstring, pcall, setmetatable = _G.loadstring, _G.pcall, _G.setmetatable; - local setfenv, rawget, tostring = _G.setfenv, _G.rawget, _G.tostring; + local pcall, setmetatable = _G.pcall, _G.setmetatable; + local rawget = _G.rawget; parsers.lua = {}; function parsers.lua.load(data, config_file, config) local env; @@ -163,61 +161,58 @@ do Component = true, component = true, Include = true, include = true, RunScript = true }, { __index = function (t, k) - return rawget(_G, k) or - function (settings_table) - config[__currenthost or "*"][k] = settings_table; - end; + return rawget(_G, k); end, __newindex = function (t, k, v) - set(config, env.__currenthost or "*", "core", k, v); + set(config, env.__currenthost or "*", k, v); end }); - + rawset(env, "__currenthost", "*") -- Default is global function env.VirtualHost(name) - if rawget(config, name) and rawget(config[name].core, "component_module") then + if rawget(config, name) and rawget(config[name], "component_module") then error(format("Host %q clashes with previously defined %s Component %q, for services use a sub-domain like conference.%s", - name, config[name].core.component_module:gsub("^%a+$", { component = "external", muc = "MUC"}), name, name), 0); + name, config[name].component_module:gsub("^%a+$", { component = "external", muc = "MUC"}), name, name), 0); end rawset(env, "__currenthost", name); -- Needs at least one setting to logically exist :) - set(config, name or "*", "core", "defined", true); + set(config, name or "*", "defined", true); return function (config_options) rawset(env, "__currenthost", "*"); -- Return to global scope for option_name, option_value in pairs(config_options) do - set(config, name or "*", "core", option_name, option_value); + set(config, name or "*", option_name, option_value); end end; end env.Host, env.host = env.VirtualHost, env.VirtualHost; - + function env.Component(name) - if rawget(config, name) and rawget(config[name].core, "defined") and not rawget(config[name].core, "component_module") then + if rawget(config, name) and rawget(config[name], "defined") and not rawget(config[name], "component_module") then error(format("Component %q clashes with previously defined Host %q, for services use a sub-domain like conference.%s", name, name, name), 0); end - set(config, name, "core", "component_module", "component"); + set(config, name, "component_module", "component"); -- Don't load the global modules by default - set(config, name, "core", "load_global_modules", false); + set(config, name, "load_global_modules", false); rawset(env, "__currenthost", name); local function handle_config_options(config_options) rawset(env, "__currenthost", "*"); -- Return to global scope for option_name, option_value in pairs(config_options) do - set(config, name or "*", "core", option_name, option_value); + set(config, name or "*", option_name, option_value); end end - + return function (module) if type(module) == "string" then - set(config, name, "core", "component_module", module); + set(config, name, "component_module", module); return handle_config_options; end return handle_config_options(module); end end env.component = env.Component; - - function env.Include(file, wildcard) + + function env.Include(file) if file:match("[*?]") then local path_pos, glob = file:match("()([^"..path_sep.."]+)$"); local path = file:sub(1, math_max(path_pos-2,0)); @@ -234,11 +229,10 @@ do end end else + local file = resolve_relative_path(config_file:gsub("[^"..path_sep.."]+$", ""), file); local f, err = io.open(file); if f then - local data = f:read("*a"); - local file = resolve_relative_path(config_file:gsub("[^"..path_sep.."]+$", ""), file); - local ret, err = parsers.lua.load(data, file, config); + local ret, err = parsers.lua.load(f:read("*a"), file, config); if not ret then error(err:gsub("%[string.-%]", file), 0); end end if not f then error("Error loading included "..file..": "..err, 0); end @@ -246,28 +240,26 @@ do end end env.include = env.Include; - + function env.RunScript(file) return dofile(resolve_relative_path(config_file:gsub("[^"..path_sep.."]+$", ""), file)); end - - local chunk, err = loadstring(data, "@"..config_file); - + + local chunk, err = envload(data, "@"..config_file, env); + if not chunk then return nil, err; end - - setfenv(chunk, env); - + local ok, err = pcall(chunk); - + if not ok then return nil, err; end - + return true; end - + end return _M; diff --git a/core/hostmanager.lua b/core/hostmanager.lua index 9e74cd6b..91b052d1 100644 --- a/core/hostmanager.lua +++ b/core/hostmanager.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -12,18 +12,20 @@ local events_new = require "util.events".new; local disco_items = require "util.multitable".new(); local NULL = {}; +local jid_split = require "util.jid".split; local uuid_gen = require "util.uuid".generate; local log = require "util.logger".init("hostmanager"); -local hosts = hosts; +local hosts = prosody.hosts; local prosody_events = prosody.events; if not _G.prosody.incoming_s2s then require "core.s2smanager"; end local incoming_s2s = _G.prosody.incoming_s2s; +local core_route_stanza = _G.prosody.core_route_stanza; -local pairs, setmetatable = pairs, setmetatable; +local pairs, select, rawget = pairs, select, rawget; local tostring, type = tostring, type; module "hostmanager" @@ -33,38 +35,50 @@ local hosts_loaded_once; local function load_enabled_hosts(config) local defined_hosts = config or configmanager.getconfig(); local activated_any_host; - + for host, host_config in pairs(defined_hosts) do - if host ~= "*" and host_config.core.enabled ~= false then - if not host_config.core.component_module then + if host ~= "*" and host_config.enabled ~= false then + if not host_config.component_module then activated_any_host = true; end activate(host, host_config); end end - + if not activated_any_host then log("error", "No active VirtualHost entries in the config file. This may cause unexpected behaviour as no modules will be loaded."); end - + prosody_events.fire_event("hosts-activated", defined_hosts); hosts_loaded_once = true; end prosody_events.add_handler("server-starting", load_enabled_hosts); +local function host_send(stanza) + local name, type = stanza.name, stanza.attr.type; + if type == "error" or (name == "iq" and type == "result") then + local dest_host_name = select(2, jid_split(stanza.attr.to)); + local dest_host = hosts[dest_host_name] or { type = "unknown" }; + log("warn", "Unhandled response sent to %s host %s: %s", dest_host.type, dest_host_name, tostring(stanza)); + return; + end + core_route_stanza(nil, stanza); +end + function activate(host, host_config) - if hosts[host] then return nil, "The host "..host.." is already activated"; end + if rawget(hosts, host) then return nil, "The host "..host.." is already activated"; end host_config = host_config or configmanager.getconfig()[host]; if not host_config then return nil, "Couldn't find the host "..tostring(host).." defined in the current config"; end local host_session = { host = host; s2sout = {}; events = events_new(); - dialback_secret = configmanager.get(host, "core", "dialback_secret") or uuid_gen(); - disallow_s2s = configmanager.get(host, "core", "disallow_s2s"); + dialback_secret = configmanager.get(host, "dialback_secret") or uuid_gen(); + send = host_send; + modules = {}; }; - if not host_config.core.component_module then -- host + if not host_config.component_module then -- host host_session.type = "local"; host_session.sessions = {}; else -- component @@ -72,16 +86,16 @@ function activate(host, host_config) end hosts[host] = host_session; if not host:match("[@/]") then - disco_items:set(host:match("%.(.*)") or "*", host, true); + disco_items:set(host:match("%.(.*)") or "*", host, host_config.name or true); end - for option_name in pairs(host_config.core) do + for option_name in pairs(host_config) do if option_name:match("_ports$") or option_name:match("_interface$") then log("warn", "%s: Option '%s' has no effect for virtual hosts - put it in the server-wide section instead", host, option_name); end end - + log((hosts_loaded_once and "info") or "debug", "Activated host: %s", host); - prosody_events.fire_event("host-activated", host, host_config); + prosody_events.fire_event("host-activated", host); return true; end @@ -89,13 +103,14 @@ function deactivate(host, reason) local host_session = hosts[host]; if not host_session then return nil, "The host "..tostring(host).." is not activated"; end log("info", "Deactivating host: %s", host); - prosody_events.fire_event("host-deactivating", host, host_session); - + prosody_events.fire_event("host-deactivating", { host = host, host_session = host_session, reason = reason }); + if type(reason) ~= "table" then reason = { condition = "host-gone", text = tostring(reason or "This server has stopped serving "..host) }; end - + -- Disconnect local users, s2s connections + -- TODO: These should move to mod_c2s and mod_s2s (how do they know they're being unloaded and not reloaded?) if host_session.sessions then for username, user in pairs(host_session.sessions) do for resource, session in pairs(user.sessions) do @@ -120,6 +135,7 @@ function deactivate(host, reason) end end + -- TODO: This should be done in modulemanager if host_session.modules then for module in pairs(host_session.modules) do modulemanager.unload(host, module); diff --git a/core/loggingmanager.lua b/core/loggingmanager.lua index 40b96d52..c6361146 100644 --- a/core/loggingmanager.lua +++ b/core/loggingmanager.lua @@ -1,20 +1,18 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- -local format, rep = string.format, string.rep; -local pcall = pcall; -local debug = debug; -local tostring, setmetatable, rawset, pairs, ipairs, type = - tostring, setmetatable, rawset, pairs, ipairs, type; +local format = string.format; +local setmetatable, rawset, pairs, ipairs, type = + setmetatable, rawset, pairs, ipairs, type; local io_open, io_write = io.open, io.write; local math_max, rep = math.max, string.rep; -local os_date, os_getenv = os.date, os.getenv; +local os_date = os.date; local getstyle, setstyle = require "util.termcolours".getstyle, require "util.termcolours".setstyle; if os.getenv("__FLUSH_LOG") then @@ -27,8 +25,6 @@ local config = require "core.configmanager"; local logger = require "util.logger"; local prosody = prosody; -local debug_mode = config.get("*", "core", "debug"); - _G.log = logger.init("general"); module "loggingmanager" @@ -43,41 +39,19 @@ local logging_config; local apply_sink_rules; local log_sink_types = setmetatable({}, { __newindex = function (t, k, v) rawset(t, k, v); apply_sink_rules(k); end; }); local get_levels; -local logging_levels = { "debug", "info", "warn", "error", "critical" } +local logging_levels = { "debug", "info", "warn", "error" } -- Put a rule into action. Requires that the sink type has already been registered. -- This function is called automatically when a new sink type is added [see apply_sink_rules()] local function add_rule(sink_config) local sink_maker = log_sink_types[sink_config.to]; if sink_maker then - if sink_config.levels and not sink_config.source then - -- Create sink - local sink = sink_maker(sink_config); - - -- Set sink for all chosen levels - for level in pairs(get_levels(sink_config.levels)) do - logger.add_level_sink(level, sink); - end - elseif sink_config.source and not sink_config.levels then - logger.add_name_sink(sink_config.source, sink_maker(sink_config)); - elseif sink_config.source and sink_config.levels then - local levels = get_levels(sink_config.levels); - local sink = sink_maker(sink_config); - logger.add_name_sink(sink_config.source, - function (name, level, ...) - if levels[level] then - return sink(name, level, ...); - end - end); - else - -- All sources - -- Create sink - local sink = sink_maker(sink_config); - - -- Set sink for all levels - for _, level in pairs(logging_levels) do - logger.add_level_sink(level, sink); - end + -- Create sink + local sink = sink_maker(sink_config); + + -- Set sink for all chosen levels + for level in pairs(get_levels(sink_config.levels or logging_levels)) do + logger.add_level_sink(level, sink); end else -- No such sink type @@ -89,21 +63,27 @@ end -- the log_sink_types table. function apply_sink_rules(sink_type) if type(logging_config) == "table" then - - if sink_type == "file" then - for _, level in ipairs(logging_levels) do - if type(logging_config[level]) == "string" then + + for _, level in ipairs(logging_levels) do + if type(logging_config[level]) == "string" then + local value = logging_config[level]; + if sink_type == "file" and not value:match("^%*") then + add_rule({ + to = sink_type; + filename = value; + timestamps = true; + levels = { min = level }; + }); + elseif value == "*"..sink_type then add_rule({ - to = "file", - filename = logging_config[level], - timestamps = true, - levels = { min = level }, + to = sink_type; + levels = { min = level }; }); end end end - - for _, sink_config in pairs(logging_config) do + + for _, sink_config in ipairs(logging_config) do if (type(sink_config) == "table" and sink_config.to == sink_type) then add_rule(sink_config); elseif (type(sink_config) == "string" and sink_config:match("^%*(.+)") == sink_type) then @@ -148,7 +128,7 @@ function get_levels(criteria, set) end end end - + for _, level in ipairs(criteria) do set[level] = true; end @@ -158,27 +138,29 @@ end -- Initialize config, etc. -- function reload_logging() local old_sink_types = {}; - + for name, sink_maker in pairs(log_sink_types) do old_sink_types[name] = sink_maker; log_sink_types[name] = nil; end - + logger.reset(); + local debug_mode = config.get("*", "debug"); + default_logging = { { to = "console" , levels = { min = (debug_mode and "debug") or "info" } } }; default_file_logging = { { to = "file", levels = { min = (debug_mode and "debug") or "info" }, timestamps = true } }; default_timestamp = "%b %d %H:%M:%S"; - logging_config = config.get("*", "core", "log") or default_logging; - - + logging_config = config.get("*", "log") or default_logging; + + for name, sink_maker in pairs(old_sink_types) do log_sink_types[name] = sink_maker; end - + prosody.events.fire_event("logging-reloaded"); end @@ -195,13 +177,13 @@ end -- Column width for "source" (used by stdout and console) local sourcewidth = 20; -function log_sink_types.stdout() +function log_sink_types.stdout(config) local timestamps = config.timestamps; - + if timestamps == true then timestamps = default_timestamp; -- Default format end - + return function (name, level, message, ...) sourcewidth = math_max(#name+2, sourcewidth); local namelen = #name; @@ -218,7 +200,7 @@ end do local do_pretty_printing = true; - + local logstyles = {}; if do_pretty_printing then logstyles["info"] = getstyle("bold"); @@ -230,7 +212,7 @@ do if not do_pretty_printing then return log_sink_types.stdout(config); end - + local timestamps = config.timestamps; if timestamps == true then @@ -240,7 +222,7 @@ do return function (name, level, message, ...) sourcewidth = math_max(#name+2, sourcewidth); local namelen = #name; - + if timestamps then io_write(os_date(timestamps), " "); end @@ -266,12 +248,6 @@ function log_sink_types.file(config) end local write, flush = logfile.write, logfile.flush; - prosody.events.add_handler("logging-reloading", function () - if logfile then - logfile:close(); - end - end); - local timestamps = config.timestamps; if timestamps == nil or timestamps == true then diff --git a/core/moduleapi.lua b/core/moduleapi.lua new file mode 100644 index 00000000..65e00d41 --- /dev/null +++ b/core/moduleapi.lua @@ -0,0 +1,371 @@ +-- Prosody IM +-- Copyright (C) 2008-2012 Matthew Wild +-- Copyright (C) 2008-2012 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local config = require "core.configmanager"; +local modulemanager = require "modulemanager"; -- This is necessary to avoid require loops +local array = require "util.array"; +local set = require "util.set"; +local logger = require "util.logger"; +local pluginloader = require "util.pluginloader"; +local timer = require "util.timer"; + +local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat; +local error, setmetatable, type = error, setmetatable, type; +local ipairs, pairs, select, unpack = ipairs, pairs, select, unpack; +local tonumber, tostring = tonumber, tostring; + +local prosody = prosody; +local hosts = prosody.hosts; + +-- FIXME: This assert() is to try and catch an obscure bug (2013-04-05) +local core_post_stanza = assert(prosody.core_post_stanza, + "prosody.core_post_stanza is nil, please report this as a bug"); + +-- Registry of shared module data +local shared_data = setmetatable({}, { __mode = "v" }); + +local NULL = {}; + +local api = {}; + +-- Returns the name of the current module +function api:get_name() + return self.name; +end + +-- Returns the host that the current module is serving +function api:get_host() + return self.host; +end + +function api:get_host_type() + return (self.host == "*" and "global") or hosts[self.host].type or "local"; +end + +function api:set_global() + self.host = "*"; + -- Update the logger + local _log = logger.init("mod_"..self.name); + self.log = function (self, ...) return _log(...); end; + self._log = _log; + self.global = true; +end + +function api:add_feature(xmlns) + self:add_item("feature", xmlns); +end +function api:add_identity(category, type, name) + self:add_item("identity", {category = category, type = type, name = name}); +end +function api:add_extension(data) + self:add_item("extension", data); +end +function api:has_feature(xmlns) + for _, feature in ipairs(self:get_host_items("feature")) do + if feature == xmlns then return true; end + end + return false; +end +function api:has_identity(category, type, name) + for _, id in ipairs(self:get_host_items("identity")) do + if id.category == category and id.type == type and id.name == name then + return true; + end + end + return false; +end + +function api:fire_event(...) + return (hosts[self.host] or prosody).events.fire_event(...); +end + +function api:hook_object_event(object, event, handler, priority) + self.event_handlers:set(object, event, handler, true); + return object.add_handler(event, handler, priority); +end + +function api:unhook_object_event(object, event, handler) + return object.remove_handler(event, handler); +end + +function api:hook(event, handler, priority) + return self:hook_object_event((hosts[self.host] or prosody).events, event, handler, priority); +end + +function api:hook_global(event, handler, priority) + return self:hook_object_event(prosody.events, event, handler, priority); +end + +function api:hook_tag(xmlns, name, handler, priority) + if not handler and type(name) == "function" then + -- If only 2 options then they specified no xmlns + xmlns, name, handler, priority = nil, xmlns, name, handler; + elseif not (handler and name) then + self:log("warn", "Error: Insufficient parameters to module:hook_stanza()"); + return; + end + return self:hook("stanza/"..(xmlns and (xmlns..":") or "")..name, function (data) return handler(data.origin, data.stanza, data); end, priority); +end +api.hook_stanza = api.hook_tag; -- COMPAT w/pre-0.9 + +function api:unhook(event, handler) + return self:unhook_object_event((hosts[self.host] or prosody).events, event, handler); +end + +function api:require(lib) + local f, n = pluginloader.load_code(self.name, lib..".lib.lua", self.environment); + if not f then + f, n = pluginloader.load_code(lib, lib..".lib.lua", self.environment); + end + if not f then error("Failed to load plugin library '"..lib.."', error: "..n); end -- FIXME better error message + return f(); +end + +function api:depends(name) + if not self.dependencies then + self.dependencies = {}; + self:hook("module-reloaded", function (event) + if self.dependencies[event.module] and not self.reloading then + self:log("info", "Auto-reloading due to reload of %s:%s", event.host, event.module); + modulemanager.reload(self.host, self.name); + return; + end + end); + self:hook("module-unloaded", function (event) + if self.dependencies[event.module] then + self:log("info", "Auto-unloading due to unload of %s:%s", event.host, event.module); + modulemanager.unload(self.host, self.name); + end + end); + end + local mod = modulemanager.get_module(self.host, name) or modulemanager.get_module("*", name); + if mod and mod.module.host == "*" and self.host ~= "*" + and modulemanager.module_has_method(mod, "add_host") then + mod = nil; -- Target is a shared module, so we still want to load it on our host + end + if not mod then + local err; + mod, err = modulemanager.load(self.host, name); + if not mod then + return error(("Unable to load required module, mod_%s: %s"):format(name, ((err or "unknown error"):gsub("%-", " ")) )); + end + end + self.dependencies[name] = true; + return mod; +end + +-- Returns one or more shared tables at the specified virtual paths +-- Intentionally does not allow the table at a path to be _set_, it +-- is auto-created if it does not exist. +function api:shared(...) + if not self.shared_data then self.shared_data = {}; end + local paths = { n = select("#", ...), ... }; + local data_array = {}; + local default_path_components = { self.host, self.name }; + for i = 1, paths.n do + local path = paths[i]; + if path:sub(1,1) ~= "/" then -- Prepend default components + local n_components = select(2, path:gsub("/", "%1")); + path = (n_components<#default_path_components and "/" or "")..t_concat(default_path_components, "/", 1, #default_path_components-n_components).."/"..path; + end + local shared = shared_data[path]; + if not shared then + shared = {}; + if path:match("%-cache$") then + setmetatable(shared, { __mode = "kv" }); + end + shared_data[path] = shared; + end + t_insert(data_array, shared); + self.shared_data[path] = shared; + end + return unpack(data_array); +end + +function api:get_option(name, default_value) + local value = config.get(self.host, name); + if value == nil then + value = default_value; + end + return value; +end + +function api:get_option_string(name, default_value) + local value = self:get_option(name, default_value); + if type(value) == "table" then + if #value > 1 then + self:log("error", "Config option '%s' does not take a list, using just the first item", name); + end + value = value[1]; + end + if value == nil then + return nil; + end + return tostring(value); +end + +function api:get_option_number(name, ...) + local value = self:get_option(name, ...); + if type(value) == "table" then + if #value > 1 then + self:log("error", "Config option '%s' does not take a list, using just the first item", name); + end + value = value[1]; + end + local ret = tonumber(value); + if value ~= nil and ret == nil then + self:log("error", "Config option '%s' not understood, expecting a number", name); + end + return ret; +end + +function api:get_option_boolean(name, ...) + local value = self:get_option(name, ...); + if type(value) == "table" then + if #value > 1 then + self:log("error", "Config option '%s' does not take a list, using just the first item", name); + end + value = value[1]; + end + if value == nil then + return nil; + end + local ret = value == true or value == "true" or value == 1 or nil; + if ret == nil then + ret = (value == false or value == "false" or value == 0); + if ret then + ret = false; + else + ret = nil; + end + end + if ret == nil then + self:log("error", "Config option '%s' not understood, expecting true/false", name); + end + return ret; +end + +function api:get_option_array(name, ...) + local value = self:get_option(name, ...); + + if value == nil then + return nil; + end + + if type(value) ~= "table" then + return array{ value }; -- Assume any non-list is a single-item list + end + + return array():append(value); -- Clone +end + +function api:get_option_set(name, ...) + local value = self:get_option_array(name, ...); + + if value == nil then + return nil; + end + + return set.new(value); +end + +function api:get_option_inherited_set(name, ...) + local value = self:get_option_set(name, ...); + local global_value = self:context("*"):get_option_set(name, ...); + if not value then + return global_value; + elseif not global_value then + return value; + end + value:include(global_value); + return value; +end + +function api:context(host) + return setmetatable({host=host or "*"}, {__index=self,__newindex=self}); +end + +function api:add_item(key, value) + self.items = self.items or {}; + self.items[key] = self.items[key] or {}; + t_insert(self.items[key], value); + self:fire_event("item-added/"..key, {source = self, item = value}); +end +function api:remove_item(key, value) + local t = self.items and self.items[key] or NULL; + for i = #t,1,-1 do + if t[i] == value then + t_remove(self.items[key], i); + self:fire_event("item-removed/"..key, {source = self, item = value}); + return value; + end + end +end + +function api:get_host_items(key) + local result = modulemanager.get_items(key, self.host) or {}; + return result; +end + +function api:handle_items(type, added_cb, removed_cb, existing) + self:hook("item-added/"..type, added_cb); + self:hook("item-removed/"..type, removed_cb); + if existing ~= false then + for _, item in ipairs(self:get_host_items(type)) do + added_cb({ item = item }); + end + end +end + +function api:provides(name, item) + -- if not item then item = setmetatable({}, { __index = function(t,k) return rawget(self.environment, k); end }); end + if not item then + item = {} + for k,v in pairs(self.environment) do + if k ~= "module" then item[k] = v; end + end + end + if not item.name then + local item_name = self.name; + -- Strip a provider prefix to find the item name + -- (e.g. "auth_foo" -> "foo" for an auth provider) + if item_name:find(name.."_", 1, true) == 1 then + item_name = item_name:sub(#name+2); + end + item.name = item_name; + end + item._provided_by = self.name; + self:add_item(name.."-provider", item); +end + +function api:send(stanza) + return core_post_stanza(hosts[self.host], stanza); +end + +function api:add_timer(delay, callback) + return timer.add_task(delay, function (t) + if self.loaded == false then return; end + return callback(t); + end); +end + +local path_sep = package.config:sub(1,1); +function api:get_directory() + return self.path and (self.path:gsub("%"..path_sep.."[^"..path_sep.."]*$", "")) or nil; +end + +function api:load_resource(path, mode) + path = config.resolve_relative_path(self:get_directory(), path); + return io.open(path, mode); +end + +function api:open_store(name, type) + return storagemanager.open(self.host, name or self.name, type); +end + +return api; diff --git a/core/modulemanager.lua b/core/modulemanager.lua index 0af9a271..2ad2fc17 100644 --- a/core/modulemanager.lua +++ b/core/modulemanager.lua @@ -1,33 +1,25 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- -local plugin_dir = CFG_PLUGINDIR or "./plugins/"; - local logger = require "util.logger"; local log = logger.init("modulemanager"); local config = require "core.configmanager"; -local multitable_new = require "util.multitable".new; -local st = require "util.stanza"; local pluginloader = require "util.pluginloader"; +local set = require "util.set"; + +local new_multitable = require "util.multitable".new; local hosts = hosts; local prosody = prosody; -local prosody_events = prosody.events; - -local loadfile, pcall, xpcall = loadfile, pcall, xpcall; -local setmetatable, setfenv, getfenv = setmetatable, setfenv, getfenv; -local pairs, ipairs = pairs, ipairs; -local t_insert, t_concat = table.insert, table.concat; -local type = type; -local next = next; -local rawget = rawget; -local error = error; -local tostring, tonumber = tostring, tonumber; + +local pcall, xpcall = pcall, xpcall; +local setmetatable, rawget = setmetatable, rawget; +local ipairs, pairs, type, tostring, t_insert = ipairs, pairs, type, tostring, table.insert; local debug_traceback = debug.traceback; local unpack, select = unpack, select; @@ -37,53 +29,44 @@ pcall = function(f, ...) return xpcall(function() return f(unpack(params, 1, n)) end, function(e) return tostring(e).."\n"..debug_traceback(); end); end -local array, set = require "util.array", require "util.set"; - -local autoload_modules = {"presence", "message", "iq", "offline"}; -local component_inheritable_modules = {"tls", "dialback", "iq"}; +local autoload_modules = {"presence", "message", "iq", "offline", "c2s", "s2s"}; +local component_inheritable_modules = {"tls", "dialback", "iq", "s2s"}; -- We need this to let modules access the real global namespace local _G = _G; module "modulemanager" -api = {}; -local api = api; -- Module API container +local api = _G.require "core.moduleapi"; -- Module API container +-- [host] = { [module] = module_env } local modulemap = { ["*"] = {} }; -local modulehelpers = setmetatable({}, { __index = _G }); - -local hooks = multitable_new(); - -local NULL = {}; - -- Load modules when a host is activated function load_modules_for_host(host) - local component = config.get(host, "core", "component_module"); - - local global_modules_enabled = config.get("*", "core", "modules_enabled"); - local global_modules_disabled = config.get("*", "core", "modules_disabled"); - local host_modules_enabled = config.get(host, "core", "modules_enabled"); - local host_modules_disabled = config.get(host, "core", "modules_disabled"); - + local component = config.get(host, "component_module"); + + local global_modules_enabled = config.get("*", "modules_enabled"); + local global_modules_disabled = config.get("*", "modules_disabled"); + local host_modules_enabled = config.get(host, "modules_enabled"); + local host_modules_disabled = config.get(host, "modules_disabled"); + if host_modules_enabled == global_modules_enabled then host_modules_enabled = nil; end if host_modules_disabled == global_modules_disabled then host_modules_disabled = nil; end - - local host_modules = set.new(host_modules_enabled) - set.new(host_modules_disabled); + local global_modules = set.new(autoload_modules) + set.new(global_modules_enabled) - set.new(global_modules_disabled); if component then global_modules = set.intersection(set.new(component_inheritable_modules), global_modules); end - local modules = global_modules + host_modules; - + local modules = (global_modules + set.new(host_modules_enabled)) - set.new(host_modules_disabled); + -- COMPAT w/ pre 0.8 if modules:contains("console") then log("error", "The mod_console plugin has been renamed to mod_admin_telnet. Please update your config."); modules:remove("console"); modules:add("admin_telnet"); end - + if component then load(host, component); end @@ -91,111 +74,134 @@ function load_modules_for_host(host) load(host, module); end end -prosody_events.add_handler("host-activated", load_modules_for_host); --- +prosody.events.add_handler("host-activated", load_modules_for_host); +prosody.events.add_handler("host-deactivated", function (host) + modulemap[host] = nil; +end); + +--- Private helpers --- + +local function do_unload_module(host, name) + local mod = get_module(host, name); + if not mod then return nil, "module-not-loaded"; end + + if module_has_method(mod, "unload") then + local ok, err = call_module_method(mod, "unload"); + if (not ok) and err then + log("warn", "Non-fatal error unloading module '%s' on '%s': %s", name, host, err); + end + end + + for object, event, handler in mod.module.event_handlers:iter(nil, nil, nil) do + object.remove_handler(event, handler); + end + + if mod.module.items then -- remove items + local events = (host == "*" and prosody.events) or hosts[host].events; + for key,t in pairs(mod.module.items) do + for i = #t,1,-1 do + local value = t[i]; + t[i] = nil; + events.fire_event("item-removed/"..key, {source = mod.module, item = value}); + end + end + end + mod.module.loaded = false; + modulemap[host][name] = nil; + return true; +end -function load(host, module_name, config) +local function do_load_module(host, module_name, state) if not (host and module_name) then return nil, "insufficient-parameters"; - elseif not hosts[host] then + elseif not hosts[host] and host ~= "*"then return nil, "unknown-host"; end - + if not modulemap[host] then - modulemap[host] = {}; + modulemap[host] = hosts[host].modules; end - + if modulemap[host][module_name] then log("warn", "%s is already loaded for %s, so not loading again", module_name, host); return nil, "module-already-loaded"; elseif modulemap["*"][module_name] then + local mod = modulemap["*"][module_name]; + if module_has_method(mod, "add_host") then + local _log = logger.init(host..":"..module_name); + local host_module_api = setmetatable({ + host = host, event_handlers = new_multitable(), items = {}; + _log = _log, log = function (self, ...) return _log(...); end; + },{ + __index = modulemap["*"][module_name].module; + }); + local host_module = setmetatable({ module = host_module_api }, { __index = mod }); + host_module_api.environment = host_module; + modulemap[host][module_name] = host_module; + local ok, result, module_err = call_module_method(mod, "add_host", host_module_api); + if not ok or result == false then + modulemap[host][module_name] = nil; + return nil, ok and module_err or result; + end + return host_module; + end return nil, "global-module-already-loaded"; end - - local mod, err = pluginloader.load_code(module_name); + + + local _log = logger.init(host..":"..module_name); + local api_instance = setmetatable({ name = module_name, host = host, + _log = _log, log = function (self, ...) return _log(...); end, event_handlers = new_multitable(), + reloading = not not state, saved_state = state~=true and state or nil } + , { __index = api }); + + local pluginenv = setmetatable({ module = api_instance }, { __index = _G }); + api_instance.environment = pluginenv; + + local mod, err = pluginloader.load_code(module_name, nil, pluginenv); if not mod then log("error", "Unable to load module '%s': %s", module_name or "nil", err or "nil"); return nil, err; end - local _log = logger.init(host..":"..module_name); - local api_instance = setmetatable({ name = module_name, host = host, config = config, _log = _log, log = function (self, ...) return _log(...); end }, { __index = api }); + api_instance.path = err; - local pluginenv = setmetatable({ module = api_instance }, { __index = _G }); - api_instance.environment = pluginenv; - - setfenv(mod, pluginenv); - hosts[host].modules = modulemap[host]; modulemap[host][module_name] = pluginenv; - - local success, err = pcall(mod); - if success then + local ok, err = pcall(mod); + if ok then + -- Call module's "load" if module_has_method(pluginenv, "load") then - success, err = call_module_method(pluginenv, "load"); - if not success then + ok, err = call_module_method(pluginenv, "load"); + if not ok then log("warn", "Error loading module '%s' on '%s': %s", module_name, host, err or "nil"); end end - - -- Use modified host, if the module set one - if api_instance.host == "*" and host ~= "*" then + api_instance.reloading, api_instance.saved_state = nil, nil; + + if api_instance.host == "*" then + if not api_instance.global then -- COMPAT w/pre-0.9 + if host ~= "*" then + log("warn", "mod_%s: Setting module.host = '*' deprecated, call module:set_global() instead", module_name); + end + api_instance:set_global(); + end modulemap[host][module_name] = nil; - modulemap["*"][module_name] = pluginenv; - api_instance:set_global(); - end - else - log("error", "Error initializing module '%s' on '%s': %s", module_name, host, err or "nil"); - end - if success then - (hosts[api_instance.host] or prosody).events.fire_event("module-loaded", { module = module_name, host = host }); - return true; - else -- load failed, unloading - unload(api_instance.host, module_name); - return nil, err; - end -end - -function get_module(host, name) - return modulemap[host] and modulemap[host][name]; -end - -function is_loaded(host, name) - return modulemap[host] and modulemap[host][name] and true; -end - -function unload(host, name, ...) - local mod = get_module(host, name); - if not mod then return nil, "module-not-loaded"; end - - if module_has_method(mod, "unload") then - local ok, err = call_module_method(mod, "unload"); - if (not ok) and err then - log("warn", "Non-fatal error unloading module '%s' on '%s': %s", name, host, err); - end - end - -- unhook event handlers hooked by module:hook - for event, handlers in pairs(hooks:get(host, name) or NULL) do - for handler in pairs(handlers or NULL) do - (hosts[host] or prosody).events.remove_handler(event, handler); - end - end - hooks:remove(host, name); - if mod.module.items then -- remove items - for key,t in pairs(mod.module.items) do - for i = #t,1,-1 do - local value = t[i]; - t[i] = nil; - hosts[host].events.fire_event("item-removed/"..key, {source = self, item = value}); + modulemap[api_instance.host][module_name] = pluginenv; + if host ~= api_instance.host and module_has_method(pluginenv, "add_host") then + -- Now load the module again onto the host it was originally being loaded on + ok, err = do_load_module(host, module_name); end end end - modulemap[host][name] = nil; - (hosts[host] or prosody).events.fire_event("module-unloaded", { module = name, host = host }); - return true; + if not ok then + modulemap[api_instance.host][module_name] = nil; + log("error", "Error initializing module '%s' on '%s': %s", module_name, host, err or "nil"); + end + return ok and pluginenv, err; end -function reload(host, name, ...) +local function do_reload_module(host, name) local mod = get_module(host, name); if not mod then return nil, "module-not-loaded"; end @@ -206,14 +212,13 @@ function reload(host, name, ...) end local saved; - if module_has_method(mod, "save") then local ok, ret, err = call_module_method(mod, "save"); if ok then saved = ret; else - log("warn", "Error saving module '%s:%s' state: %s", host, module, ret); - if not config.get(host, "core", "force_module_reload") then + log("warn", "Error saving module '%s:%s' state: %s", host, name, ret); + if not config.get(host, "force_module_reload") then log("warn", "Aborting reload due to error, set force_module_reload to ignore this"); return nil, "save-state-failed"; else @@ -222,8 +227,9 @@ function reload(host, name, ...) end end - unload(host, name, ...); - local ok, err = load(host, name, ...); + mod.module.reloading = true; + do_unload_module(host, name); + local ok, err = do_load_module(host, name, saved or true); if ok then mod = get_module(host, name); if module_has_method(mod, "restore") then @@ -232,214 +238,82 @@ function reload(host, name, ...) log("warn", "Error restoring module '%s' from '%s': %s", name, host, err); end end - return true; end - return ok, err; + return ok and mod, err; end -function module_has_method(module, method) - return type(module.module[method]) == "function"; -end +--- Public API --- -function call_module_method(module, method, ...) - if module_has_method(module, method) then - local f = module.module[method]; - return pcall(f, ...); - else - return false, "no-such-method"; +-- Load a module and fire module-loaded event +function load(host, name) + local mod, err = do_load_module(host, name); + if mod then + (hosts[mod.module.host] or prosody).events.fire_event("module-loaded", { module = name, host = mod.module.host }); end + return mod, err; end ------ API functions exposed to modules ----------- --- Must all be in api.* - --- Returns the name of the current module -function api:get_name() - return self.name; -end - --- Returns the host that the current module is serving -function api:get_host() - return self.host; -end - -function api:get_host_type() - return hosts[self.host].type; -end - -function api:set_global() - self.host = "*"; - -- Update the logger - local _log = logger.init("mod_"..self.name); - self.log = function (self, ...) return _log(...); end; - self._log = _log; -end - -function api:add_feature(xmlns) - self:add_item("feature", xmlns); -end -function api:add_identity(category, type, name) - self:add_item("identity", {category = category, type = type, name = name}); -end - -function api:fire_event(...) - return (hosts[self.host] or prosody).events.fire_event(...); -end - -function api:hook(event, handler, priority) - hooks:set(self.host, self.name, event, handler, true); - (hosts[self.host] or prosody).events.add_handler(event, handler, priority); -end - -function api:hook_stanza(xmlns, name, handler, priority) - if not handler and type(name) == "function" then - -- If only 2 options then they specified no xmlns - xmlns, name, handler, priority = nil, xmlns, name, handler; - elseif not (handler and name) then - self:log("warn", "Error: Insufficient parameters to module:hook_stanza()"); - return; +-- Unload a module and fire module-unloaded +function unload(host, name) + local ok, err = do_unload_module(host, name); + if ok then + (hosts[host] or prosody).events.fire_event("module-unloaded", { module = name, host = host }); end - return api.hook(self, "stanza/"..(xmlns and (xmlns..":") or "")..name, function (data) return handler(data.origin, data.stanza, data); end, priority); + return ok, err; end -function api:require(lib) - local f, n = pluginloader.load_code(self.name, lib..".lib.lua"); - if not f then - f, n = pluginloader.load_code(lib, lib..".lib.lua"); +function reload(host, name) + local mod, err = do_reload_module(host, name); + if mod then + modulemap[host][name].module.reloading = true; + (hosts[host] or prosody).events.fire_event("module-reloaded", { module = name, host = host }); + mod.module.reloading = nil; + elseif not is_loaded(host, name) then + (hosts[host] or prosody).events.fire_event("module-unloaded", { module = name, host = host }); end - if not f then error("Failed to load plugin library '"..lib.."', error: "..n); end -- FIXME better error message - setfenv(f, self.environment); - return f(); + return mod, err; end -function api:get_option(name, default_value) - local value = config.get(self.host, self.name, name); - if value == nil then - value = config.get(self.host, "core", name); - if value == nil then - value = default_value; - end - end - return value; -end - -function api:get_option_string(name, default_value) - local value = self:get_option(name, default_value); - if type(value) == "table" then - if #value > 1 then - self:log("error", "Config option '%s' does not take a list, using just the first item", name); - end - value = value[1]; - end - if value == nil then - return nil; - end - return tostring(value); +function get_module(host, name) + return modulemap[host] and modulemap[host][name]; end -function api:get_option_number(name, ...) - local value = self:get_option(name, ...); - if type(value) == "table" then - if #value > 1 then - self:log("error", "Config option '%s' does not take a list, using just the first item", name); +function get_items(key, host) + local result = {}; + local modules = modulemap[host]; + if not key or not host or not modules then return nil; end + + for _, module in pairs(modules) do + local mod = module.module; + if mod.items and mod.items[key] then + for _, value in ipairs(mod.items[key]) do + t_insert(result, value); + end end - value = value[1]; end - local ret = tonumber(value); - if value ~= nil and ret == nil then - self:log("error", "Config option '%s' not understood, expecting a number", name); - end - return ret; -end -function api:get_option_boolean(name, ...) - local value = self:get_option(name, ...); - if type(value) == "table" then - if #value > 1 then - self:log("error", "Config option '%s' does not take a list, using just the first item", name); - end - value = value[1]; - end - if value == nil then - return nil; - end - local ret = value == true or value == "true" or value == 1 or nil; - if ret == nil then - ret = (value == false or value == "false" or value == 0); - if ret then - ret = false; - else - ret = nil; - end - end - if ret == nil then - self:log("error", "Config option '%s' not understood, expecting true/false", name); - end - return ret; + return result; end -function api:get_option_array(name, ...) - local value = self:get_option(name, ...); - - if value == nil then - return nil; - end - - if type(value) ~= "table" then - return array{ value }; -- Assume any non-list is a single-item list - end - - return array():append(value); -- Clone +function get_modules(host) + return modulemap[host]; end -function api:get_option_set(name, ...) - local value = self:get_option_array(name, ...); - - if value == nil then - return nil; - end - - return set.new(value); +function is_loaded(host, name) + return modulemap[host] and modulemap[host][name] and true; end -local t_remove = _G.table.remove; -local module_items = multitable_new(); -function api:add_item(key, value) - self.items = self.items or {}; - self.items[key] = self.items[key] or {}; - t_insert(self.items[key], value); - self:fire_event("item-added/"..key, {source = self, item = value}); -end -function api:remove_item(key, value) - local t = self.items and self.items[key] or NULL; - for i = #t,1,-1 do - if t[i] == value then - t_remove(self.items[key], i); - self:fire_event("item-removed/"..key, {source = self, item = value}); - return value; - end - end +function module_has_method(module, method) + return type(rawget(module.module, method)) == "function"; end -function api:get_host_items(key) - local result = {}; - for mod_name, module in pairs(modulemap[self.host]) do - module = module.module; - if module.items then - for _, item in ipairs(module.items[key] or NULL) do - t_insert(result, item); - end - end - end - for mod_name, module in pairs(modulemap["*"]) do - module = module.module; - if module.items then - for _, item in ipairs(module.items[key] or NULL) do - t_insert(result, item); - end - end +function call_module_method(module, method, ...) + local f = rawget(module.module, method); + if type(f) == "function" then + return pcall(f, ...); + else + return false, "no-such-method"; end - return result; end return _M; diff --git a/core/portmanager.lua b/core/portmanager.lua new file mode 100644 index 00000000..95900c08 --- /dev/null +++ b/core/portmanager.lua @@ -0,0 +1,239 @@ +local config = require "core.configmanager"; +local certmanager = require "core.certmanager"; +local server = require "net.server"; +local socket = require "socket"; + +local log = require "util.logger".init("portmanager"); +local multitable = require "util.multitable"; +local set = require "util.set"; + +local table = table; +local setmetatable, rawset, rawget = setmetatable, rawset, rawget; +local type, tonumber, tostring, ipairs, pairs = type, tonumber, tostring, ipairs, pairs; + +local prosody = prosody; +local fire_event = prosody.events.fire_event; + +module "portmanager"; + +--- Config + +local default_interfaces = { }; +local default_local_interfaces = { }; +if config.get("*", "use_ipv4") ~= false then + table.insert(default_interfaces, "*"); + table.insert(default_local_interfaces, "127.0.0.1"); +end +if socket.tcp6 and config.get("*", "use_ipv6") ~= false then + table.insert(default_interfaces, "::"); + table.insert(default_local_interfaces, "::1"); +end + +--- Private state + +-- service_name -> { service_info, ... } +local services = setmetatable({}, { __index = function (t, k) rawset(t, k, {}); return rawget(t, k); end }); + +-- service_name, interface (string), port (number) +local active_services = multitable.new(); + +--- Private helpers + +local function error_to_friendly_message(service_name, port, err) + local friendly_message = err; + if err:match(" in use") then + -- FIXME: Use service_name here + if port == 5222 or port == 5223 or port == 5269 then + friendly_message = "check that Prosody or another XMPP server is " + .."not already running and using this port"; + elseif port == 80 or port == 81 then + friendly_message = "check that a HTTP server is not already using " + .."this port"; + elseif port == 5280 then + friendly_message = "check that Prosody or a BOSH connection manager " + .."is not already running"; + else + friendly_message = "this port is in use by another application"; + end + elseif err:match("permission") then + friendly_message = "Prosody does not have sufficient privileges to use this port"; + end + return friendly_message; +end + +prosody.events.add_handler("item-added/net-provider", function (event) + local item = event.item; + register_service(item.name, item); +end); +prosody.events.add_handler("item-removed/net-provider", function (event) + local item = event.item; + unregister_service(item.name, item); +end); + +local function duplicate_ssl_config(ssl_config) + local ssl_config = type(ssl_config) == "table" and ssl_config or {}; + + local _config = {}; + for k, v in pairs(ssl_config) do + _config[k] = v; + end + return _config; +end + +--- Public API + +function activate(service_name) + local service_info = services[service_name][1]; + if not service_info then + return nil, "Unknown service: "..service_name; + end + + local listener = service_info.listener; + + local config_prefix = (service_info.config_prefix or service_name).."_"; + if config_prefix == "_" then + config_prefix = ""; + end + + local bind_interfaces = config.get("*", config_prefix.."interfaces") + or config.get("*", config_prefix.."interface") -- COMPAT w/pre-0.9 + or (service_info.private and (config.get("*", "local_interfaces") or default_local_interfaces)) + or config.get("*", "interfaces") + or config.get("*", "interface") -- COMPAT w/pre-0.9 + or listener.default_interface -- COMPAT w/pre0.9 + or default_interfaces + bind_interfaces = set.new(type(bind_interfaces)~="table" and {bind_interfaces} or bind_interfaces); + + local bind_ports = config.get("*", config_prefix.."ports") + or service_info.default_ports + or {service_info.default_port + or listener.default_port -- COMPAT w/pre-0.9 + } + bind_ports = set.new(type(bind_ports) ~= "table" and { bind_ports } or bind_ports ); + + local mode, ssl = listener.default_mode or "*a"; + local hooked_ports = {}; + + for interface in bind_interfaces do + for port in bind_ports do + local port_number = tonumber(port); + if not port_number then + log("error", "Invalid port number specified for service '%s': %s", service_info.name, tostring(port)); + elseif #active_services:search(nil, interface, port_number) > 0 then + log("error", "Multiple services configured to listen on the same port ([%s]:%d): %s, %s", interface, port, active_services:search(nil, interface, port)[1][1].service.name or "<unnamed>", service_name or "<unnamed>"); + else + local err; + -- Create SSL context for this service/port + if service_info.encryption == "ssl" then + local ssl_config = duplicate_ssl_config((config.get("*", config_prefix.."ssl") and config.get("*", config_prefix.."ssl")[interface]) + or (config.get("*", config_prefix.."ssl") and config.get("*", config_prefix.."ssl")[port]) + or config.get("*", config_prefix.."ssl") + or (config.get("*", "ssl") and config.get("*", "ssl")[interface]) + or (config.get("*", "ssl") and config.get("*", "ssl")[port]) + or config.get("*", "ssl")); + -- add default entries for, or override ssl configuration + if ssl_config and service_info.ssl_config then + for key, value in pairs(service_info.ssl_config) do + if not service_info.ssl_config_override and not ssl_config[key] then + ssl_config[key] = value; + elseif service_info.ssl_config_override then + ssl_config[key] = value; + end + end + end + + ssl, err = certmanager.create_context(service_info.name.." port "..port, "server", ssl_config); + if not ssl then + log("error", "Error binding encrypted port for %s: %s", service_info.name, error_to_friendly_message(service_name, port_number, err) or "unknown error"); + end + end + if not err then + -- Start listening on interface+port + local handler, err = server.addserver(interface, port_number, listener, mode, ssl); + if not handler then + log("error", "Failed to open server port %d on %s, %s", port_number, interface, error_to_friendly_message(service_name, port_number, err)); + else + table.insert(hooked_ports, "["..interface.."]:"..port_number); + log("debug", "Added listening service %s to [%s]:%d", service_name, interface, port_number); + active_services:add(service_name, interface, port_number, { + server = handler; + service = service_info; + }); + end + end + end + end + end + log("info", "Activated service '%s' on %s", service_name, #hooked_ports == 0 and "no ports" or table.concat(hooked_ports, ", ")); + return true; +end + +function deactivate(service_name, service_info) + for name, interface, port, n, active_service + in active_services:iter(service_name or service_info and service_info.name, nil, nil, nil) do + if service_info == nil or active_service.service == service_info then + close(interface, port); + end + end + log("info", "Deactivated service '%s'", service_name or service_info.name); +end + +function register_service(service_name, service_info) + table.insert(services[service_name], service_info); + + if not active_services:get(service_name) then + log("debug", "No active service for %s, activating...", service_name); + local ok, err = activate(service_name); + if not ok then + log("error", "Failed to activate service '%s': %s", service_name, err or "unknown error"); + end + end + + fire_event("service-added", { name = service_name, service = service_info }); + return true; +end + +function unregister_service(service_name, service_info) + log("debug", "Unregistering service: %s", service_name); + local service_info_list = services[service_name]; + for i, service in ipairs(service_info_list) do + if service == service_info then + table.remove(service_info_list, i); + end + end + deactivate(nil, service_info); + if #service_info_list > 0 then -- Other services registered with this name + activate(service_name); -- Re-activate with the next available one + end + fire_event("service-removed", { name = service_name, service = service_info }); +end + +function close(interface, port) + local service, server = get_service_at(interface, port); + if not service then + return false, "port-not-open"; + end + server:close(); + active_services:remove(service.name, interface, port); + log("debug", "Removed listening service %s from [%s]:%d", service.name, interface, port); + return true; +end + +function get_service_at(interface, port) + local data = active_services:search(nil, interface, port)[1][1]; + return data.service, data.server; +end + +function get_service(service_name) + return (services[service_name] or {})[1]; +end + +function get_active_services(...) + return active_services; +end + +function get_registered_services() + return services; +end + +return _M; diff --git a/core/rostermanager.lua b/core/rostermanager.lua index 59ba6579..5266afb5 100644 --- a/core/rostermanager.lua +++ b/core/rostermanager.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -11,16 +11,14 @@ local log = require "util.logger".init("rostermanager"); -local setmetatable = setmetatable; -local format = string.format; -local loadfile, setfenv, pcall = loadfile, setfenv, pcall; -local pairs, ipairs = pairs, ipairs; +local pairs = pairs; local tostring = tostring; local hosts = hosts; local bare_sessions = bare_sessions; local datamanager = require "util.datamanager" +local um_user_exists = require "core.usermanager".user_exists; local st = require "util.stanza"; module "rostermanager" @@ -83,15 +81,15 @@ end function load_roster(username, host) local jid = username.."@"..host; - log("debug", "load_roster: asked for: "..jid); + log("debug", "load_roster: asked for: %s", jid); local user = bare_sessions[jid]; local roster; if user then roster = user.roster; if roster then return roster; end - log("debug", "load_roster: loading for new user: "..username.."@"..host); + log("debug", "load_roster: loading for new user: %s@%s", username, host); else -- Attempt to load roster for non-loaded user - log("debug", "load_roster: loading for offline user: "..username.."@"..host); + log("debug", "load_roster: loading for offline user: %s@%s", username, host); end local data, err = datamanager.load(username, host, "roster"); roster = data or {}; @@ -99,16 +97,21 @@ function load_roster(username, host) if not roster[false] then roster[false] = { broken = err or nil }; end if roster[jid] then roster[jid] = nil; - log("warn", "roster for "..jid.." has a self-contact"); + log("warn", "roster for %s has a self-contact", jid); end if not err then - hosts[host].events.fire_event("roster-load", username, host, roster); + hosts[host].events.fire_event("roster-load", { username = username, host = host, roster = roster }); end return roster, err; end function save_roster(username, host, roster) - log("debug", "save_roster: saving roster for "..username.."@"..host); + if not um_user_exists(username, host) then + log("debug", "not saving roster for %s@%s: the user doesn't exist", username, host); + return nil; + end + + log("debug", "save_roster: saving roster for %s@%s", username, host); if not roster then roster = hosts[host] and hosts[host].sessions[username] and hosts[host].sessions[username].roster; --if not roster then @@ -238,7 +241,7 @@ function set_contact_pending_out(username, host, jid) -- subscribe roster[jid] = item; end item.ask = "subscribe"; - log("debug", "set_contact_pending_out: saving roster; set "..username.."@"..host..".roster["..jid.."].ask=subscribe"); + log("debug", "set_contact_pending_out: saving roster; set %s@%s.roster[%q].ask=subscribe", username, host, jid); return save_roster(username, host, roster); end function unsubscribe(username, host, jid) @@ -278,23 +281,21 @@ function unsubscribed(username, host, jid) local roster = load_roster(username, host); local item = roster[jid]; local pending = is_contact_pending_in(username, host, jid); - local changed = nil; - if is_contact_pending_in(username, host, jid) then + if pending then roster.pending[jid] = nil; -- TODO maybe delete roster.pending if empty? - changed = true; end + local subscribed; if item then if item.subscription == "from" then item.subscription = "none"; - changed = true; + subscribed = true; elseif item.subscription == "both" then item.subscription = "to"; - changed = true; + subscribed = true; end end - if changed then - return save_roster(username, host, roster); - end + local success = (pending or subscribed) and save_roster(username, host, roster); + return success, pending, subscribed; end function process_outbound_subscription_request(username, host, jid) diff --git a/core/s2smanager.lua b/core/s2smanager.lua index 7e6f8135..59c1831b 100644 --- a/core/s2smanager.lua +++ b/core/s2smanager.lua @@ -1,599 +1,43 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- -local hosts = hosts; -local sessions = sessions; -local core_process_stanza = function(a, b) core_process_stanza(a, b); end -local add_task = require "util.timer".add_task; -local socket = require "socket"; -local format = string.format; -local t_insert, t_sort = table.insert, table.sort; -local get_traceback = debug.traceback; -local tostring, pairs, ipairs, getmetatable, newproxy, error, tonumber, setmetatable - = tostring, pairs, ipairs, getmetatable, newproxy, error, tonumber, setmetatable; - -local idna_to_ascii = require "util.encodings".idna.to_ascii; -local connlisteners_get = require "net.connlisteners".get; -local initialize_filters = require "util.filters".initialize; -local wrapclient = require "net.server".wrapclient; -local modulemanager = require "core.modulemanager"; -local st = require "stanza"; -local stanza = st.stanza; -local nameprep = require "util.encodings".stringprep.nameprep; -local cert_verify_identity = require "util.x509".verify_identity; - -local fire_event = prosody.events.fire_event; -local uuid_gen = require "util.uuid".generate; +local hosts = prosody.hosts; +local tostring, pairs, setmetatable + = tostring, pairs, setmetatable; local logger_init = require "util.logger".init; local log = logger_init("s2smanager"); -local sha256_hash = require "util.hashes".sha256; - -local adns, dns = require "net.adns", require "net.dns"; -local config = require "core.configmanager"; -local connect_timeout = config.get("*", "core", "s2s_timeout") or 60; -local dns_timeout = config.get("*", "core", "dns_timeout") or 15; -local max_dns_depth = config.get("*", "core", "dns_max_depth") or 3; - -dns.settimeout(dns_timeout); - local prosody = _G.prosody; incoming_s2s = {}; prosody.incoming_s2s = incoming_s2s; local incoming_s2s = incoming_s2s; +local fire_event = prosody.events.fire_event; module "s2smanager" -function compare_srv_priorities(a,b) - return a.priority < b.priority or (a.priority == b.priority and a.weight > b.weight); -end - -local bouncy_stanzas = { message = true, presence = true, iq = true }; -local function bounce_sendq(session, reason) - local sendq = session.sendq; - if sendq then - session.log("info", "sending error replies for "..#sendq.." queued stanzas because of failed outgoing connection to "..tostring(session.to_host)); - local dummy = { - type = "s2sin"; - send = function(s) - (session.log or log)("error", "Replying to to an s2s error reply, please report this! Traceback: %s", get_traceback()); - end; - dummy = true; - }; - for i, data in ipairs(sendq) do - local reply = data[2]; - local xmlns = reply.attr.xmlns; - if not(xmlns) and bouncy_stanzas[reply.name] then - reply.attr.type = "error"; - reply:tag("error", {type = "cancel"}) - :tag("remote-server-not-found", {xmlns = "urn:ietf:params:xml:ns:xmpp-stanzas"}):up(); - if reason then - reply:tag("text", {xmlns = "urn:ietf:params:xml:ns:xmpp-stanzas"}):text("Connection failed: "..reason):up(); - end - core_process_stanza(dummy, reply); - end - sendq[i] = nil; - end - session.sendq = nil; - end -end - -function send_to_host(from_host, to_host, data) - if not hosts[from_host] then - log("warn", "Attempt to send stanza from %s - a host we don't serve", from_host); - return false; - end - local host = hosts[from_host].s2sout[to_host]; - if host then - -- We have a connection to this host already - if host.type == "s2sout_unauthed" and (data.name ~= "db:verify" or not host.dialback_key) then - (host.log or log)("debug", "trying to send over unauthed s2sout to "..to_host); - - -- Queue stanza until we are able to send it - if host.sendq then t_insert(host.sendq, {tostring(data), st.reply(data)}); - else host.sendq = { {tostring(data), st.reply(data)} }; end - host.log("debug", "stanza [%s] queued ", data.name); - elseif host.type == "local" or host.type == "component" then - log("error", "Trying to send a stanza to ourselves??") - log("error", "Traceback: %s", get_traceback()); - log("error", "Stanza: %s", tostring(data)); - return false; - else - (host.log or log)("debug", "going to send stanza to "..to_host.." from "..from_host); - -- FIXME - if host.from_host ~= from_host then - log("error", "WARNING! This might, possibly, be a bug, but it might not..."); - log("error", "We are going to send from %s instead of %s", tostring(host.from_host), tostring(from_host)); - end - host.sends2s(data); - host.log("debug", "stanza sent over "..host.type); - end - else - log("debug", "opening a new outgoing connection for this stanza"); - local host_session = new_outgoing(from_host, to_host); - - -- Store in buffer - host_session.sendq = { {tostring(data), st.reply(data)} }; - log("debug", "stanza [%s] queued until connection complete", tostring(data.name)); - if (not host_session.connecting) and (not host_session.conn) then - log("warn", "Connection to %s failed already, destroying session...", to_host); - if not destroy_session(host_session, "Connection failed") then - -- Already destroyed, we need to bounce our stanza - bounce_sendq(host_session, host_session.destruction_reason); - end - return false; - end - end - return true; -end - -local open_sessions = 0; - function new_incoming(conn) local session = { conn = conn, type = "s2sin_unauthed", direction = "incoming", hosts = {} }; - if true then - session.trace = newproxy(true); - getmetatable(session.trace).__gc = function () open_sessions = open_sessions - 1; end; - end - open_sessions = open_sessions + 1; - local w, log = conn.write, logger_init("s2sin"..tostring(conn):match("[a-f0-9]+$")); - session.log = log; - local filter = initialize_filters(session); - session.sends2s = function (t) - log("debug", "sending: %s", t.top_tag and t:top_tag() or t:match("^([^>]*>?)")); - if t.name then - t = filter("stanzas/out", t); - end - if t then - t = filter("bytes/out", tostring(t)); - if t then - return w(conn, t); - end - end - end + session.log = logger_init("s2sin"..tostring(session):match("[a-f0-9]+$")); incoming_s2s[session] = true; - add_task(connect_timeout, function () - if session.conn ~= conn or - session.type == "s2sin" then - return; -- Ok, we're connect[ed|ing] - end - -- Not connected, need to close session and clean up - (session.log or log)("warn", "Destroying incomplete session %s->%s due to inactivity", - session.from_host or "(unknown)", session.to_host or "(unknown)"); - session:close("connection-timeout"); - end); return session; end -function new_outgoing(from_host, to_host, connect) - local host_session = { to_host = to_host, from_host = from_host, host = from_host, - notopen = true, type = "s2sout_unauthed", direction = "outgoing", - open_stream = session_open_stream }; - - hosts[from_host].s2sout[to_host] = host_session; - - host_session.close = destroy_session; -- This gets replaced by xmppserver_listener later - - local log; - do - local conn_name = "s2sout"..tostring(host_session):match("[a-f0-9]*$"); - log = logger_init(conn_name); - host_session.log = log; - end - - initialize_filters(host_session); - - if connect ~= false then - -- Kick the connection attempting machine into life - if not attempt_connection(host_session) then - -- Intentionally not returning here, the - -- session is needed, connected or not - destroy_session(host_session); - end - end - - if not host_session.sends2s then - -- A sends2s which buffers data (until the stream is opened) - -- note that data in this buffer will be sent before the stream is authed - -- and will not be ack'd in any way, successful or otherwise - local buffer; - function host_session.sends2s(data) - if not buffer then - buffer = {}; - host_session.send_buffer = buffer; - end - log("debug", "Buffering data on unconnected s2sout to %s", to_host); - buffer[#buffer+1] = data; - log("debug", "Buffered item %d: %s", #buffer, tostring(data)); - end - end - - return host_session; -end - - -function attempt_connection(host_session, err) - local from_host, to_host = host_session.from_host, host_session.to_host; - local connect_host, connect_port = to_host and idna_to_ascii(to_host), 5269; - - if not connect_host then - return false; - end - - if not err then -- This is our first attempt - log("debug", "First attempt to connect to %s, starting with SRV lookup...", to_host); - host_session.connecting = true; - local handle; - handle = adns.lookup(function (answer) - handle = nil; - host_session.connecting = nil; - if answer then - log("debug", to_host.." has SRV records, handling..."); - local srv_hosts = {}; - host_session.srv_hosts = srv_hosts; - for _, record in ipairs(answer) do - t_insert(srv_hosts, record.srv); - end - t_sort(srv_hosts, compare_srv_priorities); - - local srv_choice = srv_hosts[1]; - host_session.srv_choice = 1; - if srv_choice then - connect_host, connect_port = srv_choice.target or to_host, srv_choice.port or connect_port; - log("debug", "Best record found, will connect to %s:%d", connect_host, connect_port); - end - else - log("debug", to_host.." has no SRV records, falling back to A"); - end - -- Try with SRV, or just the plain hostname if no SRV - local ok, err = try_connect(host_session, connect_host, connect_port); - if not ok then - if not attempt_connection(host_session, err) then - -- No more attempts will be made - destroy_session(host_session, err); - end - end - end, "_xmpp-server._tcp."..connect_host..".", "SRV"); - - return true; -- Attempt in progress - elseif host_session.srv_hosts and #host_session.srv_hosts > host_session.srv_choice then -- Not our first attempt, and we also have SRV - host_session.srv_choice = host_session.srv_choice + 1; - local srv_choice = host_session.srv_hosts[host_session.srv_choice]; - connect_host, connect_port = srv_choice.target or to_host, srv_choice.port or connect_port; - host_session.log("info", "Connection failed (%s). Attempt #%d: This time to %s:%d", tostring(err), host_session.srv_choice, connect_host, connect_port); - else - host_session.log("info", "Out of connection options, can't connect to %s", tostring(host_session.to_host)); - -- We're out of options - return false; - end - - if not (connect_host and connect_port) then - -- Likely we couldn't resolve DNS - log("warn", "Hmm, we're without a host (%s) and port (%s) to connect to for %s, giving up :(", tostring(connect_host), tostring(connect_port), tostring(to_host)); - return false; - end - - return try_connect(host_session, connect_host, connect_port); -end - -function try_connect(host_session, connect_host, connect_port) - host_session.connecting = true; - local handle; - handle = adns.lookup(function (reply, err) - handle = nil; - host_session.connecting = nil; - - -- COMPAT: This is a compromise for all you CNAME-(ab)users :) - if not (reply and reply[#reply] and reply[#reply].a) then - local count = max_dns_depth; - reply = dns.peek(connect_host, "CNAME", "IN"); - while count > 0 and reply and reply[#reply] and not reply[#reply].a and reply[#reply].cname do - log("debug", "Looking up %s (DNS depth is %d)", tostring(reply[#reply].cname), count); - reply = dns.peek(reply[#reply].cname, "A", "IN") or dns.peek(reply[#reply].cname, "CNAME", "IN"); - count = count - 1; - end - end - -- end of CNAME resolving - - if reply and reply[#reply] and reply[#reply].a then - log("debug", "DNS reply for %s gives us %s", connect_host, reply[#reply].a); - local ok, err = make_connect(host_session, reply[#reply].a, connect_port); - if not ok then - if not attempt_connection(host_session, err or "closed") then - err = err and (": "..err) or ""; - destroy_session(host_session, "Connection failed"..err); - end - end - else - log("debug", "DNS lookup failed to get a response for %s", connect_host); - if not attempt_connection(host_session, "name resolution failed") then -- Retry if we can - log("debug", "No other records to try for %s - destroying", host_session.to_host); - err = err and (": "..err) or ""; - destroy_session(host_session, "DNS resolution failed"..err); -- End of the line, we can't - end - end - end, connect_host, "A", "IN"); - - return true; -end - -function make_connect(host_session, connect_host, connect_port) - (host_session.log or log)("info", "Beginning new connection attempt to %s (%s:%d)", host_session.to_host, connect_host, connect_port); - -- Ok, we're going to try to connect - - local from_host, to_host = host_session.from_host, host_session.to_host; - - local conn, handler = socket.tcp(); - - if not conn then - log("warn", "Failed to create outgoing connection, system error: %s", handler); - return false, handler; - end - - conn:settimeout(0); - local success, err = conn:connect(connect_host, connect_port); - if not success and err ~= "timeout" then - log("warn", "s2s connect() to %s (%s:%d) failed: %s", host_session.to_host, connect_host, connect_port, err); - return false, err; - end - - local cl = connlisteners_get("xmppserver"); - conn = wrapclient(conn, connect_host, connect_port, cl, cl.default_mode or 1 ); - host_session.conn = conn; - - local filter = initialize_filters(host_session); - local w, log = conn.write, host_session.log; - host_session.sends2s = function (t) - log("debug", "sending: %s", (t.top_tag and t:top_tag()) or t:match("^[^>]*>?")); - if t.name then - t = filter("stanzas/out", t); - end - if t then - t = filter("bytes/out", tostring(t)); - if t then - return w(conn, tostring(t)); - end - end - end - - -- Register this outgoing connection so that xmppserver_listener knows about it - -- otherwise it will assume it is a new incoming connection - cl.register_outgoing(conn, host_session); - - host_session:open_stream(from_host, to_host); - - log("debug", "Connection attempt in progress..."); - add_task(connect_timeout, function () - if host_session.conn ~= conn or - host_session.type == "s2sout" or - host_session.connecting then - return; -- Ok, we're connect[ed|ing] - end - -- Not connected, need to close session and clean up - (host_session.log or log)("warn", "Destroying incomplete session %s->%s due to inactivity", - host_session.from_host or "(unknown)", host_session.to_host or "(unknown)"); - host_session:close("connection-timeout"); - end); - return true; -end - -function session_open_stream(session, from, to) - session.sends2s(st.stanza("stream:stream", { - xmlns='jabber:server', ["xmlns:db"]='jabber:server:dialback', - ["xmlns:stream"]='http://etherx.jabber.org/streams', - from=from, to=to, version='1.0', ["xml:lang"]='en'}):top_tag()); -end - -local function check_cert_status(session) - local conn = session.conn:socket() - local cert - if conn.getpeercertificate then - cert = conn:getpeercertificate() - end - - if cert then - local chain_valid, err = conn:getpeerchainvalid() - if not chain_valid then - session.cert_chain_status = "invalid"; - (session.log or log)("debug", "certificate chain validation result: %s", err); - else - session.cert_chain_status = "valid"; - - local host = session.direction == "incoming" and session.from_host or session.to_host - - -- We'll go ahead and verify the asserted identity if the - -- connecting server specified one. - if host then - if cert_verify_identity(host, "xmpp-server", cert) then - session.cert_identity_status = "valid" - else - session.cert_identity_status = "invalid" - end - end - end - end -end - -function streamopened(session, attr) - local send = session.sends2s; - - -- TODO: #29: SASL/TLS on s2s streams - session.version = tonumber(attr.version) or 0; - - -- TODO: Rename session.secure to session.encrypted - if session.secure == false then - session.secure = true; - end - - if session.direction == "incoming" then - -- Send a reply stream header - session.to_host = attr.to and nameprep(attr.to); - session.from_host = attr.from and nameprep(attr.from); - - session.streamid = uuid_gen(); - (session.log or log)("debug", "incoming s2s received <stream:stream>"); - if session.to_host then - if not hosts[session.to_host] then - -- Attempting to connect to a host we don't serve - session:close({ - condition = "host-unknown"; - text = "This host does not serve "..session.to_host - }); - return; - elseif hosts[session.to_host].disallow_s2s then - -- Attempting to connect to a host that disallows s2s - session:close({ - condition = "policy-violation"; - text = "Server-to-server communication is not allowed to this host"; - }); - return; - end - end - - if session.secure and not session.cert_chain_status then check_cert_status(session); end - - send("<?xml version='1.0'?>"); - send(stanza("stream:stream", { xmlns='jabber:server', ["xmlns:db"]='jabber:server:dialback', - ["xmlns:stream"]='http://etherx.jabber.org/streams', id=session.streamid, from=session.to_host, to=session.from_host, version=(session.version > 0 and "1.0" or nil) }):top_tag()); - if session.version >= 1.0 then - local features = st.stanza("stream:features"); - - if session.to_host then - hosts[session.to_host].events.fire_event("s2s-stream-features", { origin = session, features = features }); - else - (session.log or log)("warn", "No 'to' on stream header from %s means we can't offer any features", session.from_host or "unknown host"); - end - - log("debug", "Sending stream features: %s", tostring(features)); - send(features); - end - elseif session.direction == "outgoing" then - -- If we are just using the connection for verifying dialback keys, we won't try and auth it - if not attr.id then error("stream response did not give us a streamid!!!"); end - session.streamid = attr.id; - - if session.secure and not session.cert_chain_status then check_cert_status(session); end - - -- Send unauthed buffer - -- (stanzas which are fine to send before dialback) - -- Note that this is *not* the stanza queue (which - -- we can only send if auth succeeds) :) - local send_buffer = session.send_buffer; - if send_buffer and #send_buffer > 0 then - log("debug", "Sending s2s send_buffer now..."); - for i, data in ipairs(send_buffer) do - session.sends2s(tostring(data)); - send_buffer[i] = nil; - end - end - session.send_buffer = nil; - - -- If server is pre-1.0, don't wait for features, just do dialback - if session.version < 1.0 then - if not session.dialback_verifying then - log("debug", "Initiating dialback..."); - initiate_dialback(session); - else - mark_connected(session); - end - end - end - session.notopen = nil; -end - -function streamclosed(session) - (session.log or log)("debug", "Received </stream:stream>"); - session:close(); -end - -function initiate_dialback(session) - -- generate dialback key - session.dialback_key = generate_dialback(session.streamid, session.to_host, session.from_host); - session.sends2s(format("<db:result from='%s' to='%s'>%s</db:result>", session.from_host, session.to_host, session.dialback_key)); - session.log("info", "sent dialback key on outgoing s2s stream"); -end - -function generate_dialback(id, to, from) - return sha256_hash(id..to..from..hosts[from].dialback_secret, true); -end - -function verify_dialback(id, to, from, key) - return key == generate_dialback(id, to, from); -end - -function make_authenticated(session, host) - if not session.secure then - local local_host = session.direction == "incoming" and session.to_host or session.from_host; - if config.get(local_host, "core", "s2s_require_encryption") then - session:close({ - condition = "policy-violation", - text = "Encrypted server-to-server communication is required but was not " - ..((session.direction == "outgoing" and "offered") or "used") - }); - end - end - if session.type == "s2sout_unauthed" then - session.type = "s2sout"; - elseif session.type == "s2sin_unauthed" then - session.type = "s2sin"; - if host then - if not session.hosts[host] then session.hosts[host] = {}; end - session.hosts[host].authed = true; - end - elseif session.type == "s2sin" and host then - if not session.hosts[host] then session.hosts[host] = {}; end - session.hosts[host].authed = true; - else - return false; - end - session.log("debug", "connection %s->%s is now authenticated", session.from_host or "(unknown)", session.to_host or "(unknown)"); - - mark_connected(session); - - return true; -end - --- Stream is authorised, and ready for normal stanzas -function mark_connected(session) - local sendq, send = session.sendq, session.sends2s; - - local from, to = session.from_host, session.to_host; - - session.log("info", session.direction.." s2s connection "..from.."->"..to.." complete"); - - local send_to_host = send_to_host; - function session.send(data) return send_to_host(to, from, data); end - - local event_data = { session = session }; - if session.type == "s2sout" then - prosody.events.fire_event("s2sout-established", event_data); - hosts[session.from_host].events.fire_event("s2sout-established", event_data); - else - prosody.events.fire_event("s2sin-established", event_data); - hosts[session.to_host].events.fire_event("s2sin-established", event_data); - end - - if session.direction == "outgoing" then - if sendq then - session.log("debug", "sending "..#sendq.." queued stanzas across new outgoing connection to "..session.to_host); - for i, data in ipairs(sendq) do - send(data[1]); - sendq[i] = nil; - end - session.sendq = nil; - end - - session.srv_hosts = nil; - end +function new_outgoing(from_host, to_host) + local host_session = { to_host = to_host, from_host = from_host, host = from_host, + notopen = true, type = "s2sout_unauthed", direction = "outgoing" }; + hosts[from_host].s2sout[to_host] = host_session; + local conn_name = "s2sout"..tostring(host_session):match("[a-f0-9]*$"); + host_session.log = logger_init(conn_name); + return host_session; end local resting_session = { -- Resting, not dead @@ -611,7 +55,7 @@ local resting_session = { -- Resting, not dead function retire_session(session, reason) local log = session.log or log; for k in pairs(session) do - if k ~= "trace" and k ~= "log" and k ~= "id" then + if k ~= "log" and k ~= "id" and k ~= "conn" then session[k] = nil; end end @@ -625,28 +69,28 @@ end function destroy_session(session, reason) if session.destroyed then return; end - (session.log or log)("info", "Destroying "..tostring(session.direction).." session "..tostring(session.from_host).."->"..tostring(session.to_host)); - + (session.log or log)("debug", "Destroying "..tostring(session.direction).." session "..tostring(session.from_host).."->"..tostring(session.to_host)..(reason and (": "..reason) or "")); + if session.direction == "outgoing" then hosts[session.from_host].s2sout[session.to_host] = nil; - bounce_sendq(session, reason); + session:bounce_sendq(reason); elseif session.direction == "incoming" then incoming_s2s[session] = nil; end - + local event_data = { session = session, reason = reason }; if session.type == "s2sout" then - prosody.events.fire_event("s2sout-destroyed", event_data); + fire_event("s2sout-destroyed", event_data); if hosts[session.from_host] then hosts[session.from_host].events.fire_event("s2sout-destroyed", event_data); end elseif session.type == "s2sin" then - prosody.events.fire_event("s2sin-destroyed", event_data); + fire_event("s2sin-destroyed", event_data); if hosts[session.to_host] then hosts[session.to_host].events.fire_event("s2sin-destroyed", event_data); end end - + retire_session(session, reason); -- Clean session until it is GC'd return true; end diff --git a/core/sessionmanager.lua b/core/sessionmanager.lua index d36591bf..5f7f688e 100644 --- a/core/sessionmanager.lua +++ b/core/sessionmanager.lua @@ -1,57 +1,33 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- - - -local tonumber, tostring, setmetatable = tonumber, tostring, setmetatable; -local ipairs, pairs, print, next= ipairs, pairs, print, next; -local format = import("string", "format"); +local tostring, setmetatable = tostring, setmetatable; +local pairs, next= pairs, next; local hosts = hosts; local full_sessions = full_sessions; local bare_sessions = bare_sessions; -local modulemanager = require "core.modulemanager"; local logger = require "util.logger"; local log = logger.init("sessionmanager"); -local error = error; -local uuid_generate = require "util.uuid".generate; local rm_load_roster = require "core.rostermanager".load_roster; local config_get = require "core.configmanager".get; -local nameprep = require "util.encodings".stringprep.nameprep; local resourceprep = require "util.encodings".stringprep.resourceprep; local nodeprep = require "util.encodings".stringprep.nodeprep; +local uuid_generate = require "util.uuid".generate; local initialize_filters = require "util.filters".initialize; -local fire_event = prosody.events.fire_event; -local add_task = require "util.timer".add_task; local gettime = require "socket".gettime; -local st = require "util.stanza"; - -local c2s_timeout = config_get("*", "core", "c2s_timeout"); - -local newproxy = newproxy; -local getmetatable = getmetatable; - module "sessionmanager" -local open_sessions = 0; - function new_session(conn) local session = { conn = conn, type = "c2s_unauthed", conntime = gettime() }; - if true then - session.trace = newproxy(true); - getmetatable(session.trace).__gc = function () open_sessions = open_sessions - 1; end; - end - open_sessions = open_sessions + 1; - log("debug", "open sessions now: ".. open_sessions); - local filter = initialize_filters(session); local w = conn.write; session.send = function (t) @@ -66,17 +42,9 @@ function new_session(conn) end end session.ip = conn:ip(); - local conn_name = "c2s"..tostring(conn):match("[a-f0-9]+$"); + local conn_name = "c2s"..tostring(session):match("[a-f0-9]+$"); session.log = logger.init(conn_name); - - if c2s_timeout then - add_task(c2s_timeout, function () - if session.type == "c2s_unauthed" then - session:close("connection-timeout"); - end - end); - end - + return session; end @@ -92,34 +60,41 @@ local resting_session = { -- Resting, not dead function retire_session(session) local log = session.log or log; for k in pairs(session) do - if k ~= "trace" and k ~= "log" and k ~= "id" then + if k ~= "log" and k ~= "id" then session[k] = nil; end end - function session.send(data) log("debug", "Discarding data sent to resting session: %s", tostring(data)); end + function session.send(data) log("debug", "Discarding data sent to resting session: %s", tostring(data)); return false; end function session.data(data) log("debug", "Discarding data received from resting session: %s", tostring(data)); end return setmetatable(session, resting_session); end function destroy_session(session, err) - (session.log or log)("info", "Destroying session for %s (%s@%s)", session.full_jid or "(unknown)", session.username or "(unknown)", session.host or "(unknown)"); + (session.log or log)("debug", "Destroying session for %s (%s@%s)%s", session.full_jid or "(unknown)", session.username or "(unknown)", session.host or "(unknown)", err and (": "..err) or ""); if session.destroyed then return; end - + -- Remove session/resource from user's session list if session.full_jid then - hosts[session.host].sessions[session.username].sessions[session.resource] = nil; + local host_session = hosts[session.host]; + + -- Allow plugins to prevent session destruction + if host_session.events.fire_event("pre-resource-unbind", {session=session, error=err}) then + return; + end + + host_session.sessions[session.username].sessions[session.resource] = nil; full_sessions[session.full_jid] = nil; - - if not next(hosts[session.host].sessions[session.username].sessions) then + + if not next(host_session.sessions[session.username].sessions) then log("debug", "All resources of %s are now offline", session.username); - hosts[session.host].sessions[session.username] = nil; + host_session.sessions[session.username] = nil; bare_sessions[session.username..'@'..session.host] = nil; end - hosts[session.host].events.fire_event("resource-unbind", {session=session, error=err}); + host_session.events.fire_event("resource-unbind", {session=session, error=err}); end - + retire_session(session); end @@ -144,20 +119,16 @@ function bind_resource(session, resource) resource = resourceprep(resource); resource = resource ~= "" and resource or uuid_generate(); --FIXME: Randomly-generated resources must be unique per-user, and never conflict with existing - + if not hosts[session.host].sessions[session.username] then local sessions = { sessions = {} }; hosts[session.host].sessions[session.username] = sessions; bare_sessions[session.username..'@'..session.host] = sessions; else local sessions = hosts[session.host].sessions[session.username].sessions; - local limit = config_get(session.host, "core", "max_resources") or 10; - if #sessions >= limit then - return nil, "cancel", "resource-constraint", "Resource limit reached; only "..limit.." resources allowed"; - end if sessions[resource] then -- Resource conflict - local policy = config_get(session.host, "core", "conflict_resolve"); + local policy = config_get(session.host, "conflict_resolve"); local increment; if policy == "random" then resource = uuid_generate(); @@ -185,12 +156,12 @@ function bind_resource(session, resource) end end end - + session.resource = resource; session.full_jid = session.username .. '@' .. session.host .. '/' .. resource; hosts[session.host].sessions[session.username].sessions[resource] = session; full_sessions[session.full_jid] = session; - + local err; session.roster, err = rm_load_roster(session.username, session.host); if err then @@ -202,56 +173,13 @@ function bind_resource(session, resource) bare_sessions[session.username..'@'..session.host] = nil; hosts[session.host].sessions[session.username] = nil; end + session.log("error", "Roster loading failed: %s", err); return nil, "cancel", "internal-server-error", "Error loading roster"; end - - hosts[session.host].events.fire_event("resource-bind", {session=session}); - - return true; -end - -function streamopened(session, attr) - local send = session.send; - session.host = attr.to; - if not session.host then - session:close{ condition = "improper-addressing", - text = "A 'to' attribute is required on stream headers" }; - return; - end - session.host = nameprep(session.host); - session.version = tonumber(attr.version) or 0; - session.streamid = uuid_generate(); - (session.log or session)("debug", "Client sent opening <stream:stream> to %s", session.host); - - if not hosts[session.host] then - -- We don't serve this host... - session:close{ condition = "host-unknown", text = "This server does not serve "..tostring(session.host)}; - return; - end - - send("<?xml version='1.0'?>"); - send(format("<stream:stream xmlns='jabber:client' xmlns:stream='http://etherx.jabber.org/streams' id='%s' from='%s' version='1.0' xml:lang='en'>", session.streamid, session.host)); - - (session.log or log)("debug", "Sent reply <stream:stream> to client"); - session.notopen = nil; - - -- If session.secure is *false* (not nil) then it means we /were/ encrypting - -- since we now have a new stream header, session is secured - if session.secure == false then - session.secure = true; - end - local features = st.stanza("stream:features"); - hosts[session.host].events.fire_event("stream-features", { origin = session, features = features }); - fire_event("stream-features", session, features); - - send(features); - -end + hosts[session.host].events.fire_event("resource-bind", {session=session}); -function streamclosed(session) - session.log("debug", "Received </stream:stream>"); - session:close(); + return true; end function send_to_available_resources(user, host, stanza) diff --git a/core/stanza_router.lua b/core/stanza_router.lua index 97d328a1..c78a657a 100644 --- a/core/stanza_router.lua +++ b/core/stanza_router.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -11,13 +11,24 @@ local log = require "util.logger".init("stanzarouter") local hosts = _G.prosody.hosts; local tostring = tostring; local st = require "util.stanza"; -local send_s2s = require "core.s2smanager".send_to_host; local jid_split = require "util.jid".split; local jid_prepped_split = require "util.jid".prepped_split; local full_sessions = _G.prosody.full_sessions; local bare_sessions = _G.prosody.bare_sessions; +local core_post_stanza, core_process_stanza, core_route_stanza; + +function deprecated_warning(f) + _G[f] = function(...) + log("warn", "Using the global %s() is deprecated, use module:send() or prosody.%s(). %s", f, f, debug.traceback()); + return prosody[f](...); + end +end +deprecated_warning"core_post_stanza"; +deprecated_warning"core_process_stanza"; +deprecated_warning"core_route_stanza"; + local function handle_unhandled_stanza(host, origin, stanza) local name, xmlns, origin_type = stanza.name, stanza.attr.xmlns or "jabber:client", origin.type; if name == "iq" and xmlns == "jabber:client" then @@ -29,17 +40,18 @@ local function handle_unhandled_stanza(host, origin, stanza) return true; end end - if stanza.attr.xmlns == nil then + if stanza.attr.xmlns == nil and origin.send then log("debug", "Unhandled %s stanza: %s; xmlns=%s", origin.type, stanza.name, xmlns); -- we didn't handle it if stanza.attr.type ~= "error" and stanza.attr.type ~= "result" then origin.send(st.error_reply(stanza, "cancel", "service-unavailable")); end elseif not((name == "features" or name == "error") and xmlns == "http://etherx.jabber.org/streams") then -- FIXME remove check once we handle S2S features - log("warn", "Unhandled %s stream element: %s; xmlns=%s: %s", origin.type, stanza.name, xmlns, tostring(stanza)); -- we didn't handle it + log("warn", "Unhandled %s stream element or stanza: %s; xmlns=%s: %s", origin.type, stanza.name, xmlns, tostring(stanza)); -- we didn't handle it origin:close("unsupported-stanza-type"); end end +local iq_types = { set=true, get=true, result=true, error=true }; function core_process_stanza(origin, stanza) (origin.log or log)("debug", "Received[%s]: %s", origin.type, stanza:top_tag()) @@ -47,8 +59,8 @@ function core_process_stanza(origin, stanza) if stanza.attr.type == "error" and #stanza.tags == 0 then return; end -- TODO invalid stanza, log if stanza.name == "iq" then if not stanza.attr.id then stanza.attr.id = ""; end -- COMPAT Jabiru doesn't send the id attribute on roster requests - if (stanza.attr.type == "set" or stanza.attr.type == "get") and (#stanza.tags ~= 1) then - origin.send(st.error_reply(stanza, "modify", "bad-request")); + if not iq_types[stanza.attr.type] or ((stanza.attr.type == "set" or stanza.attr.type == "get") and (#stanza.tags ~= 1)) then + origin.send(st.error_reply(stanza, "modify", "bad-request", "Invalid IQ type or incorrect number of children")); return; end end @@ -104,17 +116,17 @@ function core_process_stanza(origin, stanza) stanza.attr.from = from; end - --[[if to and not(hosts[to]) and not(hosts[to_bare]) and (hosts[host] and hosts[host].type ~= "local") then -- not for us? - log("warn", "stanza recieved for a non-local server"); - return; -- FIXME what should we do here? - end]] -- FIXME - if (origin.type == "s2sin" or origin.type == "c2s" or origin.type == "component") and xmlns == nil then if origin.type == "s2sin" and not origin.dummy then local host_status = origin.hosts[from_host]; if not host_status or not host_status.authed then -- remote server trying to impersonate some other server? log("warn", "Received a stanza claiming to be from %s, over a stream authed for %s!", from_host, origin.from_host); - return; -- FIXME what should we do here? does this work with subdomains? + origin:close("not-authorized"); + return; + elseif not hosts[host] then + log("warn", "Remote server %s sent us a stanza for %s, closing stream", origin.from_host, host); + origin:close("host-unknown"); + return; end end core_post_stanza(origin, stanza, origin.full_jid); @@ -184,30 +196,28 @@ function core_route_stanza(origin, stanza) -- Auto-detect origin if not specified origin = origin or hosts[from_host]; if not origin then return false; end - + if hosts[host] then -- old stanza routing code removed core_post_stanza(origin, stanza); - elseif origin.type == "c2s" then - -- Remote host - if not hosts[from_host] then + else + log("debug", "Routing to remote..."); + local host_session = hosts[from_host]; + if not host_session then log("error", "No hosts[from_host] (please report): %s", tostring(stanza)); - end - if (not hosts[from_host]) or (not hosts[from_host].disallow_s2s) then + else local xmlns = stanza.attr.xmlns; - --stanza.attr.xmlns = "jabber:server"; stanza.attr.xmlns = nil; - log("debug", "sending s2s stanza: %s", tostring(stanza.top_tag and stanza:top_tag()) or stanza); - send_s2s(origin.host, host, stanza); -- TODO handle remote routing errors + local routed = host_session.events.fire_event("route/remote", { origin = origin, stanza = stanza, from_host = from_host, to_host = host }); stanza.attr.xmlns = xmlns; -- reset - else - core_route_stanza(hosts[from_host], st.error_reply(stanza, "cancel", "not-allowed", "Communication with remote servers is not allowed")); + if not routed then + log("debug", "... no, just kidding."); + if stanza.attr.type == "error" or (stanza.name == "iq" and stanza.attr.type == "result") then return; end + core_route_stanza(host_session, st.error_reply(stanza, "cancel", "not-allowed", "Communication with remote domains is not enabled")); + end end - elseif origin.type == "component" or origin.type == "local" then - -- Route via s2s for components and modules - log("debug", "Routing outgoing stanza for %s to %s", from_host, host); - send_s2s(from_host, host, stanza); - else - log("warn", "received %s stanza from unhandled connection type: %s", tostring(stanza.name), tostring(origin.type)); end end +prosody.core_process_stanza = core_process_stanza; +prosody.core_post_stanza = core_post_stanza; +prosody.core_route_stanza = core_route_stanza; diff --git a/core/storagemanager.lua b/core/storagemanager.lua index 43409960..5674ff32 100644 --- a/core/storagemanager.lua +++ b/core/storagemanager.lua @@ -1,5 +1,5 @@ -local error, type = error, type; +local error, type, pairs = error, type, pairs; local setmetatable = setmetatable; local config = require "core.configmanager"; @@ -9,55 +9,57 @@ local multitable = require "util.multitable"; local hosts = hosts; local log = require "util.logger".init("storagemanager"); -local olddm = {}; -- maintain old datamanager, for backwards compatibility -for k,v in pairs(datamanager) do olddm[k] = v; end local prosody = prosody; module("storagemanager") -local default_driver_mt = { name = "internal" }; -default_driver_mt.__index = default_driver_mt; -function default_driver_mt:open(store) - return setmetatable({ host = self.host, store = store }, default_driver_mt); -end -function default_driver_mt:get(user) return olddm.load(user, self.host, self.store); end -function default_driver_mt:set(user, data) return olddm.store(user, self.host, self.store, data); end +local olddm = {}; -- maintain old datamanager, for backwards compatibility +for k,v in pairs(datamanager) do olddm[k] = v; end +_M.olddm = olddm; + +local null_storage_method = function () return false, "no data storage active"; end +local null_storage_driver = setmetatable( + { + name = "null", + open = function (self) return self; end + }, { + __index = function (self, method) + return null_storage_method; + end + } +); local stores_available = multitable.new(); function initialize_host(host) local host_session = hosts[host]; - host_session.events.add_handler("item-added/data-driver", function (event) + host_session.events.add_handler("item-added/storage-provider", function (event) local item = event.item; stores_available:set(host, item.name, item); end); - - host_session.events.add_handler("item-removed/data-driver", function (event) + + host_session.events.add_handler("item-removed/storage-provider", function (event) local item = event.item; stores_available:set(host, item.name, nil); end); end prosody.events.add_handler("host-activated", initialize_host, 101); -local function load_driver(host, driver_name) - if not driver_name then - return; +function load_driver(host, driver_name) + if driver_name == "null" then + return null_storage_driver; end local driver = stores_available:get(host, driver_name); if driver then return driver; end - if driver_name ~= "internal" then - local ok, err = modulemanager.load(host, "storage_"..driver_name); - if not ok then - log("error", "Failed to load storage driver plugin %s: %s", driver_name, err); - end - return stores_available:get(host, driver_name); - else - return setmetatable({host = host}, default_driver_mt); + local ok, err = modulemanager.load(host, "storage_"..driver_name); + if not ok then + log("error", "Failed to load storage driver plugin %s on %s: %s", driver_name, host, err); end + return stores_available:get(host, driver_name); end -function open(host, store, typ) - local storage = config.get(host, "core", "storage"); +function get_driver(host, store) + local storage = config.get(host, "storage"); local driver_name; local option_type = type(storage); if option_type == "string" then @@ -65,38 +67,69 @@ function open(host, store, typ) elseif option_type == "table" then driver_name = storage[store]; end - + if not driver_name then + driver_name = config.get(host, "default_storage") or "internal"; + end + local driver = load_driver(host, driver_name); if not driver then - driver_name = config.get(host, "core", "default_storage"); - driver = load_driver(host, driver_name); - if not driver then - if driver_name or (type(storage) == "string" - or type(storage) == "table" and storage[store]) then - log("warn", "Falling back to default driver for %s storage on %s", store, host); - end - driver_name = "internal"; - driver = load_driver(host, driver_name); - end + log("warn", "Falling back to null driver for %s storage on %s", store, host); + driver_name = "null"; + driver = null_storage_driver; end - + return driver, driver_name; +end + +function open(host, store, typ) + local driver, driver_name = get_driver(host, store); local ret, err = driver:open(store, typ); if not ret then if err == "unsupported-store" then - log("debug", "Storage driver %s does not support store %s (%s), falling back to internal driver", - driver_name, store, typ); - ret = setmetatable({ host = host, store = store }, default_driver_mt); -- default to default driver + log("debug", "Storage driver %s does not support store %s (%s), falling back to null driver", + driver_name, store, typ or "<nil>"); + ret = null_storage_driver; err = nil; end end return ret, err; end +function purge(user, host) + local storage = config.get(host, "storage"); + if type(storage) == "table" then + -- multiple storage backends in use that we need to purge + local purged = {}; + for store, driver in pairs(storage) do + if not purged[driver] then + purged[driver] = get_driver(host, store):purge(user); + end + end + end + get_driver(host):purge(user); -- and the default driver + + olddm.purge(user, host); -- COMPAT list stores, like offline messages end up in the old datamanager + + return true; +end + function datamanager.load(username, host, datastore) return open(host, datastore):get(username); end function datamanager.store(username, host, datastore, data) return open(host, datastore):set(username, data); end +function datamanager.users(host, datastore, typ) + local driver = open(host, datastore, typ); + if not driver.users then + return function() log("warn", "storage driver %s does not support listing users", driver.name) end + end + return driver:users(); +end +function datamanager.stores(username, host, typ) + return get_driver(host):stores(username, typ); +end +function datamanager.purge(username, host) + return purge(username, host); +end return _M; diff --git a/core/usermanager.lua b/core/usermanager.lua index 2e64af8c..4ac288a4 100644 --- a/core/usermanager.lua +++ b/core/usermanager.lua @@ -11,9 +11,11 @@ local log = require "util.logger".init("usermanager"); local type = type; local ipairs = ipairs; local jid_bare = require "util.jid".bare; +local jid_prep = require "util.jid".prep; local config = require "core.configmanager"; local hosts = hosts; local sasl_new = require "util.sasl".new; +local storagemanager = require "core.storagemanager"; local prosody = _G.prosody; @@ -24,21 +26,28 @@ local default_provider = "internal_plain"; module "usermanager" function new_null_provider() - local function dummy() end; + local function dummy() return nil, "method not implemented"; end; local function dummy_get_sasl_handler() return sasl_new(nil, {}); end - return setmetatable({name = "null", get_sasl_handler = dummy_get_sasl_handler}, { __index = function() return dummy; end }); + return setmetatable({name = "null", get_sasl_handler = dummy_get_sasl_handler}, { + __index = function(self, method) return dummy; end + }); end +local provider_mt = { __index = new_null_provider() }; + function initialize_host(host) local host_session = hosts[host]; if host_session.type ~= "local" then return; end - + host_session.events.add_handler("item-added/auth-provider", function (event) local provider = event.item; - local auth_provider = config.get(host, "core", "authentication") or default_provider; - if config.get(host, "core", "anonymous_login") then auth_provider = "anonymous"; end -- COMPAT 0.7 + local auth_provider = config.get(host, "authentication") or default_provider; + if config.get(host, "anonymous_login") then + log("error", "Deprecated config option 'anonymous_login'. Use authentication = 'anonymous' instead."); + auth_provider = "anonymous"; + end -- COMPAT 0.7 if provider.name == auth_provider then - host_session.users = provider; + host_session.users = setmetatable(provider, provider_mt); end if host_session.users ~= nil and host_session.users.name ~= nil then log("debug", "host '%s' now set to use user provider '%s'", host, host_session.users.name); @@ -51,8 +60,8 @@ function initialize_host(host) end end); host_session.users = new_null_provider(); -- Start with the default usermanager provider - local auth_provider = config.get(host, "core", "authentication") or default_provider; - if config.get(host, "core", "anonymous_login") then auth_provider = "anonymous"; end -- COMPAT 0.7 + local auth_provider = config.get(host, "authentication") or default_provider; + if config.get(host, "anonymous_login") then auth_provider = "anonymous"; end -- COMPAT 0.7 if auth_provider ~= "null" then modulemanager.load(host, "auth_"..auth_provider); end @@ -79,8 +88,19 @@ function create_user(username, password, host) return hosts[host].users.create_user(username, password); end -function get_sasl_handler(host) - return hosts[host].users.get_sasl_handler(); +function delete_user(username, host) + local ok, err = hosts[host].users.delete_user(username); + if not ok then return nil, err; end + prosody.events.fire_event("user-deleted", { username = username, host = host }); + return storagemanager.purge(username, host); +end + +function users(host) + return hosts[host].users.users(); +end + +function get_sasl_handler(host, session) + return hosts[host].users.get_sasl_handler(session); end function get_provider(host) @@ -88,17 +108,20 @@ function get_provider(host) end function is_admin(jid, host) + if host and not hosts[host] then return false; end + if type(jid) ~= "string" then return false; end + local is_admin; jid = jid_bare(jid); host = host or "*"; - - local host_admins = config.get(host, "core", "admins"); - local global_admins = config.get("*", "core", "admins"); - + + local host_admins = config.get(host, "admins"); + local global_admins = config.get("*", "admins"); + if host_admins and host_admins ~= global_admins then if type(host_admins) == "table" then for _,admin in ipairs(host_admins) do - if admin == jid then + if jid_prep(admin) == jid then is_admin = true; break; end @@ -107,11 +130,11 @@ function is_admin(jid, host) log("error", "Option 'admins' for host '%s' is not a list", host); end end - + if not is_admin and global_admins then if type(global_admins) == "table" then for _,admin in ipairs(global_admins) do - if admin == jid then + if jid_prep(admin) == jid then is_admin = true; break; end @@ -120,7 +143,7 @@ function is_admin(jid, host) log("error", "Global option 'admins' is not a list"); end end - + -- Still not an admin, check with auth provider if not is_admin and host ~= "*" and hosts[host].users and hosts[host].users.is_admin then is_admin = hosts[host].users.is_admin(jid); diff --git a/fallbacks/bit.lua b/fallbacks/bit.lua index 2482c473..28dca4e6 100644 --- a/fallbacks/bit.lua +++ b/fallbacks/bit.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/fallbacks/lxp.lua b/fallbacks/lxp.lua index 6d3297d1..ac1c9a03 100644 --- a/fallbacks/lxp.lua +++ b/fallbacks/lxp.lua @@ -61,7 +61,7 @@ local function parser(data, handlers, ns_separator) while #data == 0 do data = coroutine.yield(); end return data:sub(1,1); end - + local ns = { xml = "http://www.w3.org/XML/1998/namespace" }; ns.__index = ns; local function apply_ns(name, dodefault) @@ -100,7 +100,7 @@ local function parser(data, handlers, ns_separator) ns = getmetatable(ns); return tag; end - + while true do if peek() == "<" then local elem = read_until(">"):sub(2,-2); diff --git a/net/adns.lua b/net/adns.lua index 2f7b6804..08421f77 100644 --- a/net/adns.lua +++ b/net/adns.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -64,21 +64,25 @@ function new_async_socket(sock, resolver) if resolver.socketset[conn] == resolver.best_server and resolver.best_server == #servers then log("error", "Exhausted all %d configured DNS servers, next lookup will try %s again", #servers, servers[1]); end - + resolver:servfail(conn); -- Let the magic commence end end - handler = server.wrapclient(sock, "dns", 53, listener); + handler, err = server.wrapclient(sock, "dns", 53, listener); if not handler then - log("warn", "handler is nil"); + return nil, err; end - + handler.settimeout = function () end handler.setsockname = function (_, ...) return sock:setsockname(...); end handler.setpeername = function (_, ...) peername = (...); local ret = sock:setpeername(...); _:set_send(dummy_send); return ret; end handler.connect = function (_, ...) return sock:connect(...) end --handler.send = function (_, data) _:write(data); return _.sendbuffer and _.sendbuffer(); end - handler.send = function (_, data) return sock:send(data); end + handler.send = function (_, data) + local getpeername = sock.getpeername; + log("debug", "Sending DNS query to %s", (getpeername and getpeername(sock)) or "<unconnected>"); + return sock:send(data); + end return handler; end diff --git a/net/connlisteners.lua b/net/connlisteners.lua index e13f85de..99ddc720 100644 --- a/net/connlisteners.lua +++ b/net/connlisteners.lua @@ -1,69 +1,15 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- +-- COMPAT w/pre-0.9 +local log = require "util.logger".init("net.connlisteners"); +local traceback = debug.traceback; +module "httpserver" - -local listeners_dir = (CFG_SOURCEDIR or ".").."/net/"; -local server = require "net.server"; -local log = require "util.logger".init("connlisteners"); -local tostring = tostring; - -local dofile, pcall, error = - dofile, pcall, error - -module "connlisteners" - -local listeners = {}; - -function register(name, listener) - if listeners[name] and listeners[name] ~= listener then - log("debug", "Listener %s is already registered, not registering any more", name); - return false; - end - listeners[name] = listener; - log("debug", "Registered connection listener %s", name); - return true; +function fail() + log("error", "Attempt to use legacy connlisteners API. For more info see http://prosody.im/doc/developers/network"); + log("error", "Legacy connlisteners API usage, %s", traceback("", 2)); end -function deregister(name) - listeners[name] = nil; -end - -function get(name) - local h = listeners[name]; - if not h then - local ok, ret = pcall(dofile, listeners_dir..name:gsub("[^%w%-]", "_").."_listener.lua"); - if not ok then - log("error", "Error while loading listener '%s': %s", tostring(name), tostring(ret)); - return nil, ret; - end - h = listeners[name]; - end - return h; -end - -function start(name, udata) - local h, err = get(name); - if not h then - error("No such connection module: "..name.. (err and (" ("..err..")") or ""), 0); - end - - local interface = (udata and udata.interface) or h.default_interface or "*"; - local port = (udata and udata.port) or h.default_port or error("Can't start listener "..name.." because no port was specified, and it has no default port", 0); - local mode = (udata and udata.mode) or h.default_mode or 1; - local ssl = (udata and udata.ssl) or nil; - local autossl = udata and udata.type == "ssl"; - - if autossl and not ssl then - return nil, "no ssl context"; - end - - return server.addserver(interface, port, h, mode, autossl and ssl or nil); -end +register, deregister = fail, fail; +get, start = fail, fail, epic_fail; return _M; diff --git a/net/dns.lua b/net/dns.lua index 61fb62e8..bd5c260e 100644 --- a/net/dns.lua +++ b/net/dns.lua @@ -14,6 +14,7 @@ local socket = require "socket"; local timer = require "util.timer"; +local new_ip = require "util.ip".new_ip; local _, windows = pcall(require, "util.windows"); local is_windows = (_ and windows) or os.getenv("WINDIR"); @@ -158,8 +159,6 @@ resolver.__index = resolver; resolver.timeout = default_timeout; -local SRV_tostring; - local function default_rr_tostring(rr) local rr_val = rr.type and rr[rr.type:lower()]; if type(rr_val) ~= "string" then @@ -170,8 +169,13 @@ end local special_tostrings = { LOC = resolver.LOC_tostring; - MX = function (rr) return string.format('%2i %s', rr.pref, rr.mx); end; - SRV = SRV_tostring; + MX = function (rr) + return string.format('%2i %s', rr.pref, rr.mx); + end; + SRV = function (rr) + local s = rr.srv; + return string.format('%5d %5d %5d %s', s.priority, s.weight, s.port, s.target); + end; }; local rr_metatable = {}; -- - - - - - - - - - - - - - - - - - - rr_metatable @@ -220,7 +224,7 @@ end function dns.random(...) -- - - - - - - - - - - - - - - - - - - dns.random - math.randomseed(math.floor(10000*socket.gettime())); + math.randomseed(math.floor(10000*socket.gettime()) % 0x100000000); dns.random = math.random; return dns.random(...); end @@ -355,6 +359,7 @@ function resolver:name() -- - - - - - - - - - - - - - - - - - - - - - name local remember, pointers = nil, 0; local len = self:byte(); local n = {}; + if len == 0 then return "." end -- Root label while len > 0 do if len >= 0xc0 then -- name is "compressed" pointers = pointers + 1; @@ -386,6 +391,25 @@ function resolver:A(rr) -- - - - - - - - - - - - - - - - - - - - - - - - A rr.a = string.format('%i.%i.%i.%i', b1, b2, b3, b4); end +function resolver:AAAA(rr) + local addr = {}; + for i = 1, rr.rdlength, 2 do + local b1, b2 = self:byte(2); + table.insert(addr, ("%02x%02x"):format(b1, b2)); + end + addr = table.concat(addr, ":"):gsub("%f[%x]0+(%x)","%1"); + local zeros = {}; + for item in addr:gmatch(":[0:]+:") do + table.insert(zeros, item) + end + if #zeros == 0 then + rr.aaaa = addr; + return + elseif #zeros > 1 then + table.sort(zeros, function(a, b) return #a > #b end); + end + rr.aaaa = addr:gsub(zeros[1], "::", 1):gsub("^0::", "::"):gsub("::0$", "::"); +end function resolver:CNAME(rr) -- - - - - - - - - - - - - - - - - - - - CNAME rr.cname = self:name(); @@ -475,14 +499,8 @@ function resolver:PTR(rr) rr.ptr = self:name(); end -function SRV_tostring(rr) -- - - - - - - - - - - - - - - - - - SRV_tostring - local s = rr.srv; - return string.format( '%5d %5d %5d %s', s.priority, s.weight, s.port, s.target ); -end - - function resolver:TXT(rr) -- - - - - - - - - - - - - - - - - - - - - - TXT - rr.txt = self:sub (rr.rdlength); + rr.txt = self:sub (self:byte()); end @@ -532,6 +550,7 @@ function resolver:decode(packet, force) -- - - - - - - - - - - - - - decode if not force then if not self.active[response.header.id] or not self.active[response.header.id][response.question.raw] then + self.active[response.header.id] = nil; return nil; end end @@ -579,11 +598,12 @@ function resolver:adddefaultnameservers() -- - - - - adddefaultnameservers if resolv_conf then for line in resolv_conf:lines() do line = line:gsub("#.*$", "") - :match('^%s*nameserver%s+(.*)%s*$'); + :match('^%s*nameserver%s+([%x:%.]*)%s*$'); if line then - line:gsub("%f[%d.](%d+%.%d+%.%d+%.%d+)%f[^%d.]", function (address) - self:addnameserver(address) - end); + local ip = new_ip(line); + if ip then + self:addnameserver(ip.addr); + end end end end @@ -603,15 +623,20 @@ function resolver:getsocket(servernum) -- - - - - - - - - - - - - getsocket if sock then return sock; end local err; - sock, err = socket.udp(); + local peer = self.server[servernum]; + if peer:find(":") then + sock, err = socket.udp6(); + else + sock, err = socket.udp(); + end + if sock and self.socket_wrapper then sock, err = self.socket_wrapper(sock, self); end if not sock then return nil, err; end - if self.socket_wrapper then sock = self.socket_wrapper(sock, self); end sock:settimeout(0); -- todo: attempt to use a random port, fallback to 0 sock:setsockname('*', 0); - sock:setpeername(self.server[servernum], 53); + sock:setpeername(peer, 53); self.socket[servernum] = sock; self.socketset[sock] = servernum; return sock; @@ -625,6 +650,7 @@ function resolver:voidsocket(sock) self.socket[self.socketset[sock]] = nil; self.socketset[sock] = nil; end + sock:close(); end function resolver:socket_wrapper_set(func) -- - - - - - - socket_wrapper_set @@ -689,7 +715,7 @@ function resolver:purge(soft) -- - - - - - - - - - - - - - - - - - - purge end end end - else self.cache = {}; end + else self.cache = setmetatable({}, cache_metatable); end end @@ -727,7 +753,7 @@ function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query return nil, err; end conn:send (o.packet) - + if timer and self.timeout then local num_servers = #self.server; local i = 1; @@ -779,6 +805,9 @@ function resolver:servfail(sock) end end end + if next(queries) == nil then + self.active[id] = nil; + end end if num == self.best_server then @@ -820,7 +849,7 @@ function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive -- retire the query local queries = self.active[response.header.id]; queries[response.question.raw] = nil; - + if not next(queries) then self.active[response.header.id] = nil; end if not next(self.active) then self:closeall(); end @@ -835,6 +864,7 @@ function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive set(self.wanted, q.class, q.type, q.name, nil); end end + end end end @@ -1049,6 +1079,10 @@ function dns.settimeout(...) return _resolver:settimeout(...); end +function dns.cache() + return _resolver.cache; +end + function dns.socket_wrapper_set(...) -- - - - - - - - - socket_wrapper_set return _resolver:socket_wrapper_set(...); end diff --git a/net/http.lua b/net/http.lua index 6c8e0a68..5ec3163c 100644 --- a/net/http.lua +++ b/net/http.lua @@ -1,64 +1,94 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- - local socket = require "socket" -local mime = require "mime" +local b64 = require "util.encodings".base64.encode; local url = require "socket.url" -local httpstream_new = require "util.httpstream".new; +local httpstream_new = require "net.http.parser".new; +local util_http = require "util.http"; -local server = require "net.server" +local ssl_available = pcall(require, "ssl"); -local connlisteners_get = require "net.connlisteners".get; -local listener = connlisteners_get("httpclient") or error("No httpclient listener!"); +local server = require "net.server" local t_insert, t_concat = table.insert, table.concat; -local pairs, ipairs = pairs, ipairs; -local tonumber, tostring, xpcall, select, debug_traceback, char, format = - tonumber, tostring, xpcall, select, debug.traceback, string.char, string.format; +local pairs = pairs; +local tonumber, tostring, xpcall, select, traceback = + tonumber, tostring, xpcall, select, debug.traceback; local log = require "util.logger".init("http"); module "http" -function urlencode(s) return s and (s:gsub("%W", function (c) return format("%%%02x", c:byte()); end)); end -function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end)); end +local requests = {}; -- Open requests -local function _formencodepart(s) - return s and (s:gsub("%W", function (c) - if c ~= " " then - return format("%%%02x", c:byte()); - else - return "+"; - end - end)); +local listener = { default_port = 80, default_mode = "*a" }; + +function listener.onconnect(conn) + local req = requests[conn]; + -- Send the request + local request_line = { req.method or "GET", " ", req.path, " HTTP/1.1\r\n" }; + if req.query then + t_insert(request_line, 4, "?"..req.query); + end + + conn:write(t_concat(request_line)); + local t = { [2] = ": ", [4] = "\r\n" }; + for k, v in pairs(req.headers) do + t[1], t[3] = k, v; + conn:write(t_concat(t)); + end + conn:write("\r\n"); + + if req.body then + conn:write(req.body); + end +end + +function listener.onincoming(conn, data) + local request = requests[conn]; + + if not request then + log("warn", "Received response from connection %s with no request attached!", tostring(conn)); + return; + end + + if data and request.reader then + request:reader(data); + end end -function formencode(form) - local result = {}; - for _, field in ipairs(form) do - t_insert(result, _formencodepart(field.name).."=".._formencodepart(field.value)); + +function listener.ondisconnect(conn, err) + local request = requests[conn]; + if request and request.conn then + request:reader(nil, err); end - return t_concat(result, "&"); + requests[conn] = nil; end -local function request_reader(request, data, startpos) +local function request_reader(request, data, err) if not request.parser then - local function success_cb(r) + local function error_cb(reason) if request.callback then - for k,v in pairs(r) do request[k] = v; end - request.callback(r.body, r.code, request); + request.callback(reason or "connection-closed", 0, request); request.callback = nil; end destroy_request(request); end - local function error_cb(r) + + if not data then + error_cb(err); + return; + end + + local function success_cb(r) if request.callback then - request.callback(r or "connection-closed", 0, request); + request.callback(r.body, r.code, r, request); request.callback = nil; end destroy_request(request); @@ -71,82 +101,86 @@ local function request_reader(request, data, startpos) request.parser:feed(data); end -local function handleerr(err) log("error", "Traceback[http]: %s: %s", tostring(err), debug_traceback()); end +local function handleerr(err) log("error", "Traceback[http]: %s", traceback(tostring(err), 2)); end function request(u, ex, callback) local req = url.parse(u); - + if not (req and req.host) then callback(nil, 0, req); return nil, "invalid-url"; end - + if not req.path then req.path = "/"; end - - local custom_headers, body; - local default_headers = { ["Host"] = req.host, ["User-Agent"] = "Prosody XMPP Server" } - - + + local method, headers, body; + + local host, port = req.host, req.port; + local host_header = host; + if (port == "80" and req.scheme == "http") + or (port == "443" and req.scheme == "https") then + port = nil; + elseif port then + host_header = host_header..":"..port; + end + + headers = { + ["Host"] = host_header; + ["User-Agent"] = "Prosody XMPP Server"; + }; + if req.userinfo then - default_headers["Authorization"] = "Basic "..mime.b64(req.userinfo); + headers["Authorization"] = "Basic "..b64(req.userinfo); end - + if ex then - custom_headers = ex.headers; req.onlystatus = ex.onlystatus; body = ex.body; if body then - req.method = "POST "; - default_headers["Content-Length"] = tostring(#body); - default_headers["Content-Type"] = "application/x-www-form-urlencoded"; + method = "POST"; + headers["Content-Length"] = tostring(#body); + headers["Content-Type"] = "application/x-www-form-urlencoded"; + end + if ex.method then method = ex.method; end + if ex.headers then + for k, v in pairs(ex.headers) do + headers[k] = v; + end end - if ex.method then req.method = ex.method; end end - - req.handler, req.conn = server.wrapclient(socket.tcp(), req.host, req.port or 80, listener, "*a"); - req.write = function (...) return req.handler:write(...); end - req.conn:settimeout(0); - local ok, err = req.conn:connect(req.host, req.port or 80); + + -- Attach to request object + req.method, req.headers, req.body = method, headers, body; + + local using_https = req.scheme == "https"; + if using_https and not ssl_available then + error("SSL not available, unable to contact https URL"); + end + local port_number = port and tonumber(port) or (using_https and 443 or 80); + + -- Connect the socket, and wrap it with net.server + local conn = socket.tcp(); + conn:settimeout(10); + local ok, err = conn:connect(host, port_number); if not ok and err ~= "timeout" then callback(nil, 0, req); return nil, err; end - - local request_line = { req.method or "GET", " ", req.path, " HTTP/1.1\r\n" }; - - if req.query then - t_insert(request_line, 4, "?"); - t_insert(request_line, 5, req.query); - end - - req.write(t_concat(request_line)); - local t = { [2] = ": ", [4] = "\r\n" }; - if custom_headers then - for k, v in pairs(custom_headers) do - t[1], t[3] = k, v; - req.write(t_concat(t)); - default_headers[k] = nil; - end - end - - for k, v in pairs(default_headers) do - t[1], t[3] = k, v; - req.write(t_concat(t)); - default_headers[k] = nil; - end - req.write("\r\n"); - - if body then - req.write(body); + + local sslctx = false; + if using_https then + sslctx = ex and ex.sslctx or { mode = "client", protocol = "sslv23", options = { "no_sslv2" } }; end - - req.callback = function (content, code, request) log("debug", "Calling callback, status %s", code or "---"); return select(2, xpcall(function () return callback(content, code, request) end, handleerr)); end + + req.handler, req.conn = server.wrapclient(conn, host, port_number, listener, "*a", sslctx); + req.write = function (...) return req.handler:write(...); end + + req.callback = function (content, code, request, response) log("debug", "Calling callback, status %s", code or "---"); return select(2, xpcall(function () return callback(content, code, request, response) end, handleerr)); end req.reader = request_reader; req.state = "status"; - - listener.register_request(req.handler, req); + requests[req.handler] = req; return req; end @@ -154,10 +188,13 @@ function destroy_request(request) if request.conn then request.conn = nil; request.handler:close() - listener.ondisconnect(request.handler, "closed"); end end -_M.urlencode = urlencode; +local urlencode, urldecode = util_http.urlencode, util_http.urldecode; +local formencode, formdecode = util_http.formencode, util_http.formdecode; + +_M.urlencode, _M.urldecode = urlencode, urldecode; +_M.formencode, _M.formdecode = formencode, formdecode; return _M; diff --git a/net/http/codes.lua b/net/http/codes.lua new file mode 100644 index 00000000..0cadd079 --- /dev/null +++ b/net/http/codes.lua @@ -0,0 +1,67 @@ + +local response_codes = { + -- Source: http://www.iana.org/assignments/http-status-codes + -- s/^\(\d*\)\s*\(.*\S\)\s*\[RFC.*\]\s*$/^I["\1"] = "\2"; + [100] = "Continue"; + [101] = "Switching Protocols"; + [102] = "Processing"; + + [200] = "OK"; + [201] = "Created"; + [202] = "Accepted"; + [203] = "Non-Authoritative Information"; + [204] = "No Content"; + [205] = "Reset Content"; + [206] = "Partial Content"; + [207] = "Multi-Status"; + [208] = "Already Reported"; + [226] = "IM Used"; + + [300] = "Multiple Choices"; + [301] = "Moved Permanently"; + [302] = "Found"; + [303] = "See Other"; + [304] = "Not Modified"; + [305] = "Use Proxy"; + -- The 306 status code was used in a previous version of [RFC2616], is no longer used, and the code is reserved. + [307] = "Temporary Redirect"; + + [400] = "Bad Request"; + [401] = "Unauthorized"; + [402] = "Payment Required"; + [403] = "Forbidden"; + [404] = "Not Found"; + [405] = "Method Not Allowed"; + [406] = "Not Acceptable"; + [407] = "Proxy Authentication Required"; + [408] = "Request Timeout"; + [409] = "Conflict"; + [410] = "Gone"; + [411] = "Length Required"; + [412] = "Precondition Failed"; + [413] = "Request Entity Too Large"; + [414] = "Request-URI Too Long"; + [415] = "Unsupported Media Type"; + [416] = "Requested Range Not Satisfiable"; + [417] = "Expectation Failed"; + [418] = "I'm a teapot"; + [422] = "Unprocessable Entity"; + [423] = "Locked"; + [424] = "Failed Dependency"; + -- The 425 status code is reserved for the WebDAV advanced collections expired proposal [RFC2817] + [426] = "Upgrade Required"; + + [500] = "Internal Server Error"; + [501] = "Not Implemented"; + [502] = "Bad Gateway"; + [503] = "Service Unavailable"; + [504] = "Gateway Timeout"; + [505] = "HTTP Version Not Supported"; + [506] = "Variant Also Negotiates"; -- Experimental + [507] = "Insufficient Storage"; + [508] = "Loop Detected"; + [510] = "Not Extended"; +}; + +for k,v in pairs(response_codes) do response_codes[k] = k.." "..v; end +return setmetatable(response_codes, { __index = function(t, k) return k.." Unassigned"; end }) diff --git a/net/http/parser.lua b/net/http/parser.lua new file mode 100644 index 00000000..f9e6cea0 --- /dev/null +++ b/net/http/parser.lua @@ -0,0 +1,160 @@ +local tonumber = tonumber; +local assert = assert; +local url_parse = require "socket.url".parse; +local urldecode = require "util.http".urldecode; + +local function preprocess_path(path) + path = urldecode((path:gsub("//+", "/"))); + if path:sub(1,1) ~= "/" then + path = "/"..path; + end + local level = 0; + for component in path:gmatch("([^/]+)/") do + if component == ".." then + level = level - 1; + elseif component ~= "." then + level = level + 1; + end + if level < 0 then + return nil; + end + end + return path; +end + +local httpstream = {}; + +function httpstream.new(success_cb, error_cb, parser_type, options_cb) + local client = true; + if not parser_type or parser_type == "server" then client = false; else assert(parser_type == "client", "Invalid parser type"); end + local buf = ""; + local chunked, chunk_size, chunk_start; + local state = nil; + local packet; + local len; + local have_body; + local error; + return { + feed = function(self, data) + if error then return nil, "parse has failed"; end + if not data then -- EOF + if state and client and not len then -- reading client body until EOF + packet.body = buf; + success_cb(packet); + elseif buf ~= "" then -- unexpected EOF + error = true; return error_cb(); + end + return; + end + buf = buf..data; + while #buf > 0 do + if state == nil then -- read request + local index = buf:find("\r\n\r\n", nil, true); + if not index then return; end -- not enough data + local method, path, httpversion, status_code, reason_phrase; + local first_line; + local headers = {}; + for line in buf:sub(1,index+1):gmatch("([^\r\n]+)\r\n") do -- parse request + if first_line then + local key, val = line:match("^([^%s:]+): *(.*)$"); + if not key then error = true; return error_cb("invalid-header-line"); end -- TODO handle multi-line and invalid headers + key = key:lower(); + headers[key] = headers[key] and headers[key]..","..val or val; + else + first_line = line; + if client then + httpversion, status_code, reason_phrase = line:match("^HTTP/(1%.[01]) (%d%d%d) (.*)$"); + status_code = tonumber(status_code); + if not status_code then error = true; return error_cb("invalid-status-line"); end + have_body = not + ( (options_cb and options_cb().method == "HEAD") + or (status_code == 204 or status_code == 304 or status_code == 301) + or (status_code >= 100 and status_code < 200) ); + else + method, path, httpversion = line:match("^(%w+) (%S+) HTTP/(1%.[01])$"); + if not method then error = true; return error_cb("invalid-status-line"); end + end + end + end + if not first_line then error = true; return error_cb("invalid-status-line"); end + chunked = have_body and headers["transfer-encoding"] == "chunked"; + len = tonumber(headers["content-length"]); -- TODO check for invalid len + if client then + -- FIXME handle '100 Continue' response (by skipping it) + if not have_body then len = 0; end + packet = { + code = status_code; + httpversion = httpversion; + headers = headers; + body = have_body and "" or nil; + -- COMPAT the properties below are deprecated + responseversion = httpversion; + responseheaders = headers; + }; + else + local parsed_url; + if path:byte() == 47 then -- starts with / + local _path, _query = path:match("([^?]*).?(.*)"); + if _query == "" then _query = nil; end + parsed_url = { path = _path, query = _query }; + else + parsed_url = url_parse(path); + if not(parsed_url and parsed_url.path) then error = true; return error_cb("invalid-url"); end + end + path = preprocess_path(parsed_url.path); + headers.host = parsed_url.host or headers.host; + + len = len or 0; + packet = { + method = method; + url = parsed_url; + path = path; + httpversion = httpversion; + headers = headers; + body = nil; + }; + end + buf = buf:sub(index + 4); + state = true; + end + if state then -- read body + if client then + if chunked then + if not buf:find("\r\n", nil, true) then + return; + end -- not enough data + if not chunk_size then + chunk_size, chunk_start = buf:match("^(%x+)[^\r\n]*\r\n()"); + chunk_size = chunk_size and tonumber(chunk_size, 16); + if not chunk_size then error = true; return error_cb("invalid-chunk-size"); end + end + if chunk_size == 0 and buf:find("\r\n\r\n", chunk_start-2, true) then + state, chunk_size = nil, nil; + buf = buf:gsub("^.-\r\n\r\n", ""); -- This ensure extensions and trailers are stripped + success_cb(packet); + elseif #buf - chunk_start + 2 >= chunk_size then -- we have a chunk + packet.body = packet.body..buf:sub(chunk_start, chunk_start + (chunk_size-1)); + buf = buf:sub(chunk_start + chunk_size + 2); + chunk_size, chunk_start = nil, nil; + else -- Partial chunk remaining + break; + end + elseif len and #buf >= len then + packet.body, buf = buf:sub(1, len), buf:sub(len + 1); + state = nil; success_cb(packet); + else + break; + end + elseif #buf >= len then + packet.body, buf = buf:sub(1, len), buf:sub(len + 1); + state = nil; success_cb(packet); + else + break; + end + end + end + end; + }; +end + +return httpstream; diff --git a/net/http/server.lua b/net/http/server.lua new file mode 100644 index 00000000..5961169f --- /dev/null +++ b/net/http/server.lua @@ -0,0 +1,303 @@ + +local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat; +local parser_new = require "net.http.parser".new; +local events = require "util.events".new(); +local addserver = require "net.server".addserver; +local log = require "util.logger".init("http.server"); +local os_date = os.date; +local pairs = pairs; +local s_upper = string.upper; +local setmetatable = setmetatable; +local xpcall = xpcall; +local traceback = debug.traceback; +local tostring = tostring; +local codes = require "net.http.codes"; + +local _M = {}; + +local sessions = {}; +local listener = {}; +local hosts = {}; +local default_host; + +local function is_wildcard_event(event) + return event:sub(-2, -1) == "/*"; +end +local function is_wildcard_match(wildcard_event, event) + return wildcard_event:sub(1, -2) == event:sub(1, #wildcard_event-1); +end + +local recent_wildcard_events, max_cached_wildcard_events = {}, 10000; + +local event_map = events._event_map; +setmetatable(events._handlers, { + -- Called when firing an event that doesn't exist (but may match a wildcard handler) + __index = function (handlers, curr_event) + if is_wildcard_event(curr_event) then return; end -- Wildcard events cannot be fired + -- Find all handlers that could match this event, sort them + -- and then put the array into handlers[curr_event] (and return it) + local matching_handlers_set = {}; + local handlers_array = {}; + for event, handlers_set in pairs(event_map) do + if event == curr_event or + is_wildcard_event(event) and is_wildcard_match(event, curr_event) then + for handler, priority in pairs(handlers_set) do + matching_handlers_set[handler] = { (select(2, event:gsub("/", "%1"))), is_wildcard_event(event) and 0 or 1, priority }; + table.insert(handlers_array, handler); + end + end + end + if #handlers_array > 0 then + table.sort(handlers_array, function(b, a) + local a_score, b_score = matching_handlers_set[a], matching_handlers_set[b]; + for i = 1, #a_score do + if a_score[i] ~= b_score[i] then -- If equal, compare next score value + return a_score[i] < b_score[i]; + end + end + return false; + end); + else + handlers_array = false; + end + rawset(handlers, curr_event, handlers_array); + if not event_map[curr_event] then -- Only wildcard handlers match, if any + table.insert(recent_wildcard_events, curr_event); + if #recent_wildcard_events > max_cached_wildcard_events then + rawset(handlers, table.remove(recent_wildcard_events, 1), nil); + end + end + return handlers_array; + end; + __newindex = function (handlers, curr_event, handlers_array) + if handlers_array == nil + and is_wildcard_event(curr_event) then + -- Invalidate the indexes of all matching events + for event in pairs(handlers) do + if is_wildcard_match(curr_event, event) then + handlers[event] = nil; + end + end + end + rawset(handlers, curr_event, handlers_array); + end; +}); + +local handle_request; +local _1, _2, _3; +local function _handle_request() return handle_request(_1, _2, _3); end + +local last_err; +local function _traceback_handler(err) last_err = err; log("error", "Traceback[httpserver]: %s", traceback(tostring(err), 2)); end +events.add_handler("http-error", function (error) + return "Error processing request: "..codes[error.code]..". Check your error log for more information."; +end, -1); + +function listener.onconnect(conn) + local secure = conn:ssl() and true or nil; + local pending = {}; + local waiting = false; + local function process_next() + if waiting then log("debug", "can't process_next, waiting"); return; end + waiting = true; + while sessions[conn] and #pending > 0 do + local request = t_remove(pending); + --log("debug", "process_next: %s", request.path); + --handle_request(conn, request, process_next); + _1, _2, _3 = conn, request, process_next; + if not xpcall(_handle_request, _traceback_handler) then + conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = last_err })); + conn:close(); + end + end + --log("debug", "ready for more"); + waiting = false; + end + local function success_cb(request) + --log("debug", "success_cb: %s", request.path); + if waiting then + log("error", "http connection handler is not reentrant: %s", request.path); + assert(false, "http connection handler is not reentrant"); + end + request.secure = secure; + t_insert(pending, request); + process_next(); + end + local function error_cb(err) + log("debug", "error_cb: %s", err or "<nil>"); + -- FIXME don't close immediately, wait until we process current stuff + -- FIXME if err, send off a bad-request response + sessions[conn] = nil; + conn:close(); + end + sessions[conn] = parser_new(success_cb, error_cb); +end + +function listener.ondisconnect(conn) + local open_response = conn._http_open_response; + if open_response and open_response.on_destroy then + open_response.finished = true; + open_response:on_destroy(); + end + sessions[conn] = nil; +end + +function listener.onincoming(conn, data) + sessions[conn]:feed(data); +end + +local headerfix = setmetatable({}, { + __index = function(t, k) + local v = "\r\n"..k:gsub("_", "-"):gsub("%f[%w].", s_upper)..": "; + t[k] = v; + return v; + end +}); + +function _M.hijack_response(response, listener) + error("TODO"); +end +function handle_request(conn, request, finish_cb) + --log("debug", "handler: %s", request.path); + local headers = {}; + for k,v in pairs(request.headers) do headers[k:gsub("-", "_")] = v; end + request.headers = headers; + request.conn = conn; + + local date_header = os_date('!%a, %d %b %Y %H:%M:%S GMT'); -- FIXME use + local conn_header = request.headers.connection; + conn_header = conn_header and ","..conn_header:gsub("[ \t]", ""):lower().."," or "" + local httpversion = request.httpversion + local persistent = conn_header:find(",keep-alive,", 1, true) + or (httpversion == "1.1" and not conn_header:find(",close,", 1, true)); + + local response_conn_header; + if persistent then + response_conn_header = "Keep-Alive"; + else + response_conn_header = httpversion == "1.1" and "close" or nil + end + + local response = { + request = request; + status_code = 200; + headers = { date = date_header, connection = response_conn_header }; + persistent = persistent; + conn = conn; + send = _M.send_response; + finish_cb = finish_cb; + }; + conn._http_open_response = response; + + local host = (request.headers.host or ""):match("[^:]+"); + + -- Some sanity checking + local err_code, err; + if not request.path then + err_code, err = 400, "Invalid path"; + elseif not hosts[host] then + if hosts[default_host] then + host = default_host; + elseif host then + err_code, err = 404, "Unknown host: "..host; + else + err_code, err = 400, "Missing or invalid 'Host' header"; + end + end + + if err then + response.status_code = err_code; + response:send(events.fire_event("http-error", { code = err_code, message = err })); + return; + end + + local event = request.method.." "..host..request.path:match("[^?]*"); + local payload = { request = request, response = response }; + --log("debug", "Firing event: %s", event); + local result = events.fire_event(event, payload); + if result ~= nil then + if result ~= true then + local body; + local result_type = type(result); + if result_type == "number" then + response.status_code = result; + if result >= 400 then + body = events.fire_event("http-error", { code = result }); + end + elseif result_type == "string" then + body = result; + elseif result_type == "table" then + for k, v in pairs(result) do + if k ~= "headers" then + response[k] = v; + else + for header_name, header_value in pairs(v) do + response.headers[header_name] = header_value; + end + end + end + end + response:send(body); + end + return; + end + + -- if handler not called, return 404 + response.status_code = 404; + response:send(events.fire_event("http-error", { code = 404 })); +end +function _M.send_response(response, body) + if response.finished then return; end + response.finished = true; + response.conn._http_open_response = nil; + + local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]); + local headers = response.headers; + body = body or response.body or ""; + headers.content_length = #body; + + local output = { status_line }; + for k,v in pairs(headers) do + t_insert(output, headerfix[k]..v); + end + t_insert(output, "\r\n\r\n"); + t_insert(output, body); + + response.conn:write(t_concat(output)); + if response.on_destroy then + response:on_destroy(); + response.on_destroy = nil; + end + if response.persistent then + response:finish_cb(); + else + response.conn:close(); + end +end +function _M.add_handler(event, handler, priority) + events.add_handler(event, handler, priority); +end +function _M.remove_handler(event, handler) + events.remove_handler(event, handler); +end + +function _M.listen_on(port, interface, ssl) + addserver(interface or "*", port, listener, "*a", ssl); +end +function _M.add_host(host) + hosts[host] = true; +end +function _M.remove_host(host) + hosts[host] = nil; +end +function _M.set_default_host(host) + default_host = host; +end +function _M.fire_event(event, ...) + return events.fire_event(event, ...); +end + +_M.listener = listener; +_M.codes = codes; +_M._events = events; +return _M; diff --git a/net/httpclient_listener.lua b/net/httpclient_listener.lua deleted file mode 100644 index dfa25062..00000000 --- a/net/httpclient_listener.lua +++ /dev/null @@ -1,44 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - -local log = require "util.logger".init("httpclient_listener"); - -local connlisteners_register = require "net.connlisteners".register; - -local requests = {}; -- Open requests -local buffers = {}; -- Buffers of partial lines - -local httpclient = { default_port = 80, default_mode = "*a" }; - -function httpclient.onincoming(conn, data) - local request = requests[conn]; - - if not request then - log("warn", "Received response from connection %s with no request attached!", tostring(conn)); - return; - end - - if data and request.reader then - request:reader(data); - end -end - -function httpclient.ondisconnect(conn, err) - local request = requests[conn]; - if request and err ~= "closed" then - request:reader(nil); - end - requests[conn] = nil; -end - -function httpclient.register_request(conn, req) - log("debug", "Attaching request %s to connection %s", tostring(req.id or req), tostring(conn)); - requests[conn] = req; -end - -connlisteners_register("httpclient", httpclient); diff --git a/net/httpserver.lua b/net/httpserver.lua index 4c1200ac..7d574788 100644 --- a/net/httpserver.lua +++ b/net/httpserver.lua @@ -1,226 +1,15 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - -local socket = require "socket" -local server = require "net.server" -local url_parse = require "socket.url".parse; -local httpstream_new = require "util.httpstream".new; - -local connlisteners_start = require "net.connlisteners".start; -local connlisteners_get = require "net.connlisteners".get; -local listener; - -local t_insert, t_concat = table.insert, table.concat; -local s_match, s_gmatch = string.match, string.gmatch; -local tonumber, tostring, pairs, ipairs, type = tonumber, tostring, pairs, ipairs, type; - -local urlencode = function (s) return s and (s:gsub("%W", function (c) return string.format("%%%02x", c:byte()); end)); end - -local log = require "util.logger".init("httpserver"); - -local http_servers = {}; +-- COMPAT w/pre-0.9 +local log = require "util.logger".init("net.httpserver"); +local traceback = debug.traceback; module "httpserver" -local default_handler; - -local function expectbody(reqt) - return reqt.method == "POST"; -end - -local function send_response(request, response) - -- Write status line - local resp; - if response.body or response.headers then - local body = response.body and tostring(response.body); - log("debug", "Sending response to %s", request.id); - resp = { "HTTP/1.0 "..(response.status or "200 OK").."\r\n" }; - local h = response.headers; - if h then - for k, v in pairs(h) do - t_insert(resp, k..": "..v.."\r\n"); - end - end - if body and not (h and h["Content-Length"]) then - t_insert(resp, "Content-Length: "..#body.."\r\n"); - end - t_insert(resp, "\r\n"); - - if body and request.method ~= "HEAD" then - t_insert(resp, body); - end - request.write(t_concat(resp)); - else - -- Response we have is just a string (the body) - log("debug", "Sending 200 response to %s", request.id or "<none>"); - - local resp = "HTTP/1.0 200 OK\r\n" - .. "Connection: close\r\n" - .. "Content-Type: text/html\r\n" - .. "Content-Length: "..#response.."\r\n" - .. "\r\n" - .. response; - - request.write(resp); - end - if not request.stayopen then - request:destroy(); - end -end - -local function call_callback(request, err) - if request.handled then return; end - request.handled = true; - local callback = request.callback; - if not callback and request.path then - local path = request.url.path; - local base = path:match("^/([^/?]+)"); - if not base then - base = path:match("^http://[^/?]+/([^/?]+)"); - end - - callback = (request.server and request.server.handlers[base]) or default_handler; - end - if callback then - if err then - log("debug", "Request error: "..err); - if not callback(nil, err, request) then - destroy_request(request); - end - return; - end - - local response = callback(request.method, request.body and t_concat(request.body), request); - if response then - if response == true and not request.destroyed then - -- Keep connection open, we will reply later - log("debug", "Request %s left open, on_destroy is %s", request.id, tostring(request.on_destroy)); - elseif response ~= true then - -- Assume response - send_response(request, response); - destroy_request(request); - end - else - log("debug", "Request handler provided no response, destroying request..."); - -- No response, close connection - destroy_request(request); - end - end -end - -local function request_reader(request, data, startpos) - if not request.parser then - local function success_cb(r) - for k,v in pairs(r) do request[k] = v; end - request.url = url_parse(request.path); - request.body = { request.body }; - call_callback(request); - end - local function error_cb(r) - call_callback(request, r or "connection-closed"); - destroy_request(request); - end - request.parser = httpstream_new(success_cb, error_cb); - end - request.parser:feed(data); -end - --- The default handler for requests -default_handler = function (method, body, request) - log("debug", method.." request for "..tostring(request.path) .. " on port "..request.handler:serverport()); - return { status = "404 Not Found", - headers = { ["Content-Type"] = "text/html" }, - body = "<html><head><title>Page Not Found</title></head><body>Not here :(</body></html>" }; -end - - -function new_request(handler) - return { handler = handler, conn = handler, - write = function (...) return handler:write(...); end, state = "request", - server = http_servers[handler:serverport()], - send = send_response, - destroy = destroy_request, - id = tostring{}:match("%x+$") - }; -end - -function destroy_request(request) - log("debug", "Destroying request %s", request.id); - listener = listener or connlisteners_get("httpserver"); - if not request.destroyed then - request.destroyed = true; - if request.on_destroy then - log("debug", "Request has destroy callback"); - request.on_destroy(request); - else - log("debug", "Request has no destroy callback"); - end - request.handler:close() - if request.conn then - listener.ondisconnect(request.conn, "closed"); - end - end -end - -function new(params) - local http_server = http_servers[params.port]; - if not http_server then - http_server = { handlers = {} }; - http_servers[params.port] = http_server; - -- We weren't already listening on this port, so start now - connlisteners_start("httpserver", params); - end - if params.base then - http_server.handlers[params.base] = params.handler; - end -end - -function set_default_handler(handler) - default_handler = handler; -end - -function new_from_config(ports, handle_request, default_options) - if type(handle_request) == "string" then -- COMPAT with old plugins - log("warn", "Old syntax of httpserver.new_from_config being used to register %s", handle_request); - handle_request, default_options = default_options, { base = handle_request }; - end - ports = ports or {5280}; - for _, options in ipairs(ports) do - local port = default_options.port or 5280; - local base = default_options.base; - local ssl = default_options.ssl or false; - local interface = default_options.interface; - if type(options) == "number" then - port = options; - elseif type(options) == "table" then - port = options.port or port; - base = options.path or base; - ssl = options.ssl or ssl; - interface = options.interface or interface; - elseif type(options) == "string" then - base = options; - end - - if ssl then - ssl.mode = "server"; - ssl.protocol = "sslv23"; - ssl.options = "no_sslv2"; - end - - new{ port = port, interface = interface, - base = base, handler = handle_request, - ssl = ssl, type = (ssl and "ssl") or "tcp" }; - end +function fail() + log("error", "Attempt to use legacy HTTP API. For more info see http://prosody.im/doc/developers/legacy_http"); + log("error", "Legacy HTTP API usage, %s", traceback("", 2)); end -_M.request_reader = request_reader; -_M.send_response = send_response; -_M.urlencode = urlencode; +new, new_from_config = fail, fail; +set_default_handler = fail; return _M; diff --git a/net/httpserver_listener.lua b/net/httpserver_listener.lua deleted file mode 100644 index dd14b43c..00000000 --- a/net/httpserver_listener.lua +++ /dev/null @@ -1,46 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - - -local connlisteners_register = require "net.connlisteners".register; -local new_request = require "net.httpserver".new_request; -local request_reader = require "net.httpserver".request_reader; - -local requests = {}; -- Open requests - -local httpserver = { default_port = 80, default_mode = "*a" }; - -function httpserver.onincoming(conn, data) - local request = requests[conn]; - - if not request then - request = new_request(conn); - requests[conn] = request; - - -- If using HTTPS, request is secure - if conn:ssl() then - request.secure = true; - end - end - - if data and data ~= "" then - request_reader(request, data); - end -end - -function httpserver.ondisconnect(conn, err) - local request = requests[conn]; - if request and not request.destroyed then - request.conn = nil; - request_reader(request, nil); - end - requests[conn] = nil; -end - -connlisteners_register("httpserver", httpserver); diff --git a/net/multiplex_listener.lua b/net/multiplex_listener.lua deleted file mode 100644 index b515ccce..00000000 --- a/net/multiplex_listener.lua +++ /dev/null @@ -1,50 +0,0 @@ - -local connlisteners_register = require "net.connlisteners".register; -local connlisteners_get = require "net.connlisteners".get; - -local httpserver_listener = connlisteners_get("httpserver"); -local xmppserver_listener = connlisteners_get("xmppserver"); -local xmppclient_listener = connlisteners_get("xmppclient"); -local xmppcomponent_listener = connlisteners_get("xmppcomponent"); - -local server = { default_mode = "*a" }; - -local buffer = {}; - -function server.onincoming(conn, data) - if not data then return; end - local buf = buffer[conn]; - buffer[conn] = nil; - buf = buf and buf..data or data; - if buf:match("^[a-zA-Z]") then - local listener = httpserver_listener; - conn:setlistener(listener); - local onconnect = listener.onconnect; - if onconnect then onconnect(conn) end - listener.onincoming(conn, buf); - elseif buf:match(">") then - local listener; - local xmlns = buf:match("%sxmlns%s*=%s*['\"]([^'\"]*)"); - if xmlns == "jabber:server" then - listener = xmppserver_listener; - elseif xmlns == "jabber:component:accept" then - listener = xmppcomponent_listener; - else - listener = xmppclient_listener; - end - conn:setlistener(listener); - local onconnect = listener.onconnect; - if onconnect then onconnect(conn) end - listener.onincoming(conn, buf); - elseif #buf > 1024 then - conn:close(); - else - buffer[conn] = buf; - end -end - -function server.ondisconnect(conn, err) - buffer[conn] = nil; -- warn if no buffer? -end - -connlisteners_register("multiplex", server); diff --git a/net/server.lua b/net/server.lua index 1c1a63a4..2a0b89ae 100644 --- a/net/server.lua +++ b/net/server.lua @@ -1,12 +1,12 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- -local use_luaevent = prosody and require "core.configmanager".get("*", "core", "use_libevent"); +local use_luaevent = prosody and require "core.configmanager".get("*", "use_libevent"); if use_luaevent then use_luaevent = pcall(require, "luaevent.core"); @@ -19,18 +19,7 @@ local server; if use_luaevent then server = require "net.server_event"; - -- util.timer requires "net.server", so instead of having - -- Lua look for, and load us again (causing a loop) - set this here - -- (usually it isn't set until we return, look down there...) - package.loaded["net.server"] = server; - - -- Backwards compatibility for timers, addtimer - -- called a function roughly every second - local add_task = require "util.timer".add_task; - function server.addtimer(f) - return add_task(1, function (...) f(...); return 1; end); - end - + -- Overwrite signal.signal() because we need to ask libevent to -- handle them instead local ok, signal = pcall(require, "util.signal"); @@ -47,8 +36,47 @@ if use_luaevent then end end else + use_luaevent = false; server = require "net.server_select"; - package.loaded["net.server"] = server; +end + +if prosody then + local config_get = require "core.configmanager".get; + local defaults = {}; + for k,v in pairs(server.cfg or server.getsettings()) do + defaults[k] = v; + end + local function load_config() + local settings = config_get("*", "network_settings") or {}; + if use_luaevent then + local event_settings = { + ACCEPT_DELAY = settings.event_accept_retry_interval; + ACCEPT_QUEUE = settings.tcp_backlog; + CLEAR_DELAY = settings.event_clear_interval; + CONNECT_TIMEOUT = settings.connect_timeout; + DEBUG = settings.debug; + HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout; + MAX_CONNECTIONS = settings.max_connections; + MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips; + MAX_READ_LENGTH = settings.max_receive_buffer_size; + MAX_SEND_LENGTH = settings.max_send_buffer_size; + READ_TIMEOUT = settings.read_timeout; + WRITE_TIMEOUT = settings.send_timeout; + }; + + for k,default in pairs(defaults) do + server.cfg[k] = event_settings[k] or default; + end + else + local select_settings = {}; + for k,default in pairs(defaults) do + select_settings[k] = settings[k] or default; + end + server.changesettings(select_settings); + end + end + load_config(); + prosody.events.add_handler("config-reloaded", load_config); end -- require "net.server" shall now forever return this, diff --git a/net/server_event.lua b/net/server_event.lua index 122d80fc..59217a0c 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -6,7 +6,6 @@ notes: -- when using luaevent, never register 2 or more EV_READ at one socket, same for EV_WRITE -- you cant even register a new EV_READ/EV_WRITE callback inside another one - -- never call eventcallback:close( ) from inside eventcallback -- to do some of the above, use timeout events or something what will called from outside -- dont let garbagecollect eventcallbacks, as long they are running -- when using luasec, there are 4 cases of timeout errors: wantread or wantwrite during reading or writing @@ -24,6 +23,7 @@ local cfg = { HANDSHAKE_TIMEOUT = 60, -- timeout in seconds per handshake attempt MAX_READ_LENGTH = 1024 * 1024 * 1024 * 1024, -- max bytes allowed to read from sockets MAX_SEND_LENGTH = 1024 * 1024 * 1024 * 1024, -- max bytes size of write buffer (for writing on sockets) + ACCEPT_QUEUE = 128, -- might influence the length of the pending sockets queue ACCEPT_DELAY = 10, -- seconds to wait until the next attempt of a full server to accept READ_TIMEOUT = 60 * 60 * 6, -- timeout in seconds for read data from socket WRITE_TIMEOUT = 180, -- timeout in seconds for write data on socket @@ -33,8 +33,6 @@ local cfg = { } local function use(x) return rawget(_G, x); end -local print = use "print" -local pcall = use "pcall" local ipairs = use "ipairs" local string = use "string" local select = use "select" @@ -43,6 +41,9 @@ local tostring = use "tostring" local coroutine = use "coroutine" local setmetatable = use "setmetatable" +local t_insert = table.insert +local t_concat = table.concat + local ssl = use "ssl" local socket = use "socket" or require "socket" @@ -114,26 +115,19 @@ end )( ) local interface_mt do interface_mt = {}; interface_mt.__index = interface_mt; - + local addevent = base.addevent local coroutine_wrap, coroutine_yield = coroutine.wrap,coroutine.yield - local string_len = string.len - + -- Private methods function interface_mt:_position(new_position) self.position = new_position or self.position return self.position; end - function interface_mt:_close() -- regs event to start self:_destroy() - local callback = function( ) - self:_destroy(); - self.eventclose = nil - return -1 - end - self.eventclose = addevent( base, nil, EV_TIMEOUT, callback, 0 ) - return true + function interface_mt:_close() + return self:_destroy(); end - + function interface_mt:_start_connection(plainssl) -- should be called from addclient local callback = function( event ) if EV_TIMEOUT == event then -- timeout during connection @@ -143,7 +137,7 @@ do debug( "new connection failed. id:", self.id, "error:", self.fatalerror ) else if plainssl and ssl then -- start ssl session - self:starttls(nil, true) + self:starttls(self._sslctx, true) else -- normal connection self:_start_session(true) end @@ -212,7 +206,6 @@ do self:_lock( false, false, false ) -- unlock the interface; sending, closing etc allowed self.send = self.conn.send -- caching table lookups with new client object self.receive = self.conn.receive - local onsomething if not call_onconnect then -- trigger listener self:onstatus("ssl-handshake-complete"); end @@ -221,12 +214,12 @@ do self.eventhandshake = nil return -1 end - debug( "error during ssl handshake:", err ) if err == "wantwrite" then event = EV_WRITE elseif err == "wantread" then event = EV_READ else + debug( "ssl handshake error:", err ) self.fatalerror = err end end @@ -249,10 +242,10 @@ do return true end function interface_mt:_destroy() -- close this interface + events and call last listener - debug( "closing client with id:", self.id ) + debug( "closing client with id:", self.id, self.fatalerror ) self:_lock( true, true, true ) -- first of all, lock the interface to avoid further actions local _ - _ = self.eventread and self.eventread:close( ) -- close events; this must be called outside of the event callbacks! + _ = self.eventread and self.eventread:close( ) if self.type == "client" then _ = self.eventwrite and self.eventwrite:close( ) _ = self.eventhandshake and self.eventhandshake:close( ) @@ -262,7 +255,7 @@ do _ = self.eventwritetimeout and self.eventwritetimeout:close( ) _ = self.eventreadtimeout and self.eventreadtimeout:close( ) _ = self.ondisconnect and self:ondisconnect( self.fatalerror ~= "client to close" and self.fatalerror) -- call ondisconnect listener (wont be the case if handshake failed on connect) - _ = self.conn and self.conn:close( ) -- close connection, must also be called outside of any socket registered events! + _ = self.conn and self.conn:close( ) -- close connection _ = self._server and self._server:counter(-1); self.eventread, self.eventwrite = nil, nil self.eventstarthandshake, self.eventhandshake, self.eventclose = nil, nil, nil @@ -275,12 +268,12 @@ do interfacelist( "delete", self ) return true end - + function interface_mt:_lock(nointerface, noreading, nowriting) -- lock or unlock this interface or events self.nointerface, self.noreading, self.nowriting = nointerface, noreading, nowriting return nointerface, noreading, nowriting end - + --TODO: Deprecate function interface_mt:lock_read(switch) if switch then @@ -295,7 +288,10 @@ do end function interface_mt:resume() - return self:_lock(self.nointerface, false, self.nowriting); + self:_lock(self.nointerface, false, self.nowriting); + if not self.eventread then + self.eventread = addevent( base, self.conn, EV_READ, self.readcallback, cfg.READ_TIMEOUT ); -- register callback + end end function interface_mt:counter(c) @@ -304,20 +300,20 @@ do end return self._connections end - + -- Public methods function interface_mt:write(data) if self.nowriting then return nil, "locked" end --vdebug( "try to send data to client, id/data:", self.id, data ) data = tostring( data ) - local len = string_len( data ) + local len = #data local total = len + self.writebufferlen if total > cfg.MAX_SEND_LENGTH then -- check buffer length local err = "send buffer exceeded" debug( "error:", err ) -- to much, check your app return nil, err end - self.writebuffer = self.writebuffer .. data -- new buffer + t_insert(self.writebuffer, data) -- new buffer self.writebufferlen = total if not self.eventwrite then -- register new write event --vdebug( "register new write event" ) @@ -325,62 +321,49 @@ do end return true end - function interface_mt:close(now) + function interface_mt:close() if self.nointerface then return nil, "locked"; end debug( "try to close client connection with id:", self.id ) if self.type == "client" then self.fatalerror = "client to close" - if ( not self.eventwrite ) or now then -- try to close immediately - self:_lock( true, true, true ) - self:_close() - return true - else -- wait for incomplete write request + if self.eventwrite then -- wait for incomplete write request self:_lock( true, true, false ) debug "closing delayed until writebuffer is empty" return nil, "writebuffer not empty, waiting" + else -- close now + self:_lock( true, true, true ) + self:_close() + return true end else - debug( "try to close server with id:", self.id, "args:", now ) + debug( "try to close server with id:", tostring(self.id)) self.fatalerror = "server to close" self:_lock( true ) - local count = 0 - for _, item in ipairs( interfacelist( ) ) do - if ( item.type ~= "server" ) and ( item._server == self ) then -- client/server match - if item:close( now ) then -- writebuffer was empty - count = count + 1 - end - end - end - local timeout = 0 -- dont wait for unfinished writebuffers of clients... - if not now then - timeout = cfg.WRITE_TIMEOUT -- ...or wait for it - end - self:_close( timeout ) -- add new event to remove the server interface - debug( "seconds remained until server is closed:", timeout ) - return count -- returns finished clients with empty writebuffer + self:_close( 0 ) + return true end end - + function interface_mt:socket() return self.conn end - + function interface_mt:server() return self._server or self; end - + function interface_mt:port() return self._port end - + function interface_mt:serverport() return self._serverport end - + function interface_mt:ip() return self._ip end - + function interface_mt:ssl() return self._usingssl end @@ -388,15 +371,15 @@ do function interface_mt:type() return self._type or "client" end - + function interface_mt:connections() return self._connections end - + function interface_mt:address() return self.addr end - + function interface_mt:set_sslctx(sslctx) self._sslctx = sslctx; if sslctx then @@ -412,11 +395,11 @@ do end return self._pattern; end - + function interface_mt:set_send(new_send) -- No-op, we always use the underlying connection's send end - + function interface_mt:starttls(sslctx, call_onconnect) debug( "try to start ssl at client id:", self.id ) local err @@ -445,22 +428,22 @@ do self.starttls = false; return true end - + function interface_mt:setoption(option, value) if self.conn.setoption then return self.conn:setoption(option, value); end return false, "setoption not implemented"; end - + function interface_mt:setlistener(listener) - self.onconnect, self.ondisconnect, self.onincoming, self.ontimeout, self.onstatus - = listener.onconnect, listener.ondisconnect, listener.onincoming, listener.ontimeout, listener.onstatus; + self.onconnect, self.ondisconnect, self.onincoming, self.ontimeout, self.onreadtimeout, self.onstatus + = listener.onconnect, listener.ondisconnect, listener.onincoming, + listener.ontimeout, listener.onreadtimeout, listener.onstatus; end - + -- Stub handlers function interface_mt:onconnect() - return self:onincoming(nil); end function interface_mt:onincoming() end @@ -468,6 +451,12 @@ do end function interface_mt:ontimeout() end + function interface_mt:onreadtimeout() + self.fatalerror = "timeout during receiving" + debug( "connection failed:", self.fatalerror ) + self:_close() + self.eventread = nil + end function interface_mt:ondrain() end function interface_mt:onstatus() @@ -479,18 +468,15 @@ end local handleclient; do local string_sub = string.sub -- caching table lookups - local string_len = string.len local addevent = base.addevent - local coroutine_wrap = coroutine.wrap local socket_gettime = socket.gettime - local coroutine_yield = coroutine.yield - function handleclient( client, ip, port, server, pattern, listener, _, sslctx ) -- creates an client interface + function handleclient( client, ip, port, server, pattern, listener, sslctx ) -- creates an client interface --vdebug("creating client interfacce...") local interface = { type = "client"; conn = client; currenttime = socket_gettime( ); -- safe the origin - writebuffer = ""; -- writebuffer + writebuffer = {}; -- writebuffer writebufferlen = 0; -- length of writebuffer send = client.send; -- caching table lookups receive = client.receive; @@ -498,6 +484,8 @@ do ondisconnect = listener.ondisconnect; -- will be called when client disconnects onincoming = listener.onincoming; -- will be called when client sends data ontimeout = listener.ontimeout; -- called when fatal socket timeout occurs + onreadtimeout = listener.onreadtimeout; -- called when socket inactivity timeout occurs + ondrain = listener.ondrain; -- called when writebuffer is empty onstatus = listener.onstatus; -- called for status changes (e.g. of SSL/TLS) eventread = false, eventwrite = false, eventclose = false, eventhandshake = false, eventstarthandshake = false; -- event handler @@ -511,7 +499,7 @@ do noreading = false, nowriting = false; -- locks of the read/writecallback startsslcallback = false; -- starting handshake callback position = false; -- position of client in interfacelist - + -- Properties _ip = ip, _port = port, _server = server, _pattern = pattern, _serverport = (server and server:port() or nil), @@ -544,10 +532,11 @@ do interface.eventwritetimeout = false end end - local succ, err, byte = interface.conn:send( interface.writebuffer, 1, interface.writebufferlen ) + interface.writebuffer = { t_concat(interface.writebuffer) } + local succ, err, byte = interface.conn:send( interface.writebuffer[1], 1, interface.writebufferlen ) --vdebug( "write data:", interface.writebuffer, "error:", err, "part:", byte ) if succ then -- writing succesful - interface.writebuffer = "" + interface.writebuffer[1] = nil interface.writebufferlen = 0 interface:ondrain(); if interface.fatalerror then @@ -563,7 +552,7 @@ do return -1 elseif byte and (err == "timeout" or err == "wantwrite") then -- want write again --vdebug( "writebuffer is not empty:", err ) - interface.writebuffer = string_sub( interface.writebuffer, byte + 1, interface.writebufferlen ) -- new buffer + interface.writebuffer[1] = string_sub( interface.writebuffer[1], byte + 1, interface.writebufferlen ) -- new buffer interface.writebufferlen = interface.writebufferlen - byte if "wantread" == err then -- happens only with luasec local callback = function( ) @@ -586,7 +575,7 @@ do end end end - + interface.readcallback = function( event ) -- called on read events --vdebug( "new client read event, id/ip/port:", tostring(interface.id), tostring(ip), tostring(port) ) if interface.noreading or interface.fatalerror then -- leave this event @@ -594,57 +583,56 @@ do interface.eventread = nil return -1 end - if EV_TIMEOUT == event then -- took too long to get some data from client -> disconnect - interface.fatalerror = "timeout during receiving" - debug( "connection failed:", interface.fatalerror ) + if EV_TIMEOUT == event and interface:onreadtimeout() ~= true then + return -1 -- took too long to get some data from client -> disconnect + end + if interface._usingssl then -- handle luasec + if interface.eventwritetimeout then -- ok, in the past writecallback was regged + local ret = interface.writecallback( ) -- call it + --vdebug( "tried to write in readcallback, result:", tostring(ret) ) + end + if interface.eventreadtimeout then + interface.eventreadtimeout:close( ) + interface.eventreadtimeout = nil + end + end + local buffer, err, part = interface.conn:receive( interface._pattern ) -- receive buffer with "pattern" + --vdebug( "read data:", tostring(buffer), "error:", tostring(err), "part:", tostring(part) ) + buffer = buffer or part + if buffer and #buffer > cfg.MAX_READ_LENGTH then -- check buffer length + interface.fatalerror = "receive buffer exceeded" + debug( "fatal error:", interface.fatalerror ) interface:_close() interface.eventread = nil return -1 - else -- can read - if interface._usingssl then -- handle luasec - if interface.eventwritetimeout then -- ok, in the past writecallback was regged - local ret = interface.writecallback( ) -- call it - --vdebug( "tried to write in readcallback, result:", tostring(ret) ) - end - if interface.eventreadtimeout then - interface.eventreadtimeout:close( ) - interface.eventreadtimeout = nil + end + if err and ( err ~= "timeout" and err ~= "wantread" ) then + if "wantwrite" == err then -- need to read on write event + if not interface.eventwrite then -- register new write event if needed + interface.eventwrite = addevent( base, interface.conn, EV_WRITE, interface.writecallback, cfg.WRITE_TIMEOUT ) end - end - local buffer, err, part = interface.conn:receive( interface._pattern ) -- receive buffer with "pattern" - --vdebug( "read data:", tostring(buffer), "error:", tostring(err), "part:", tostring(part) ) - buffer = buffer or part or "" - local len = string_len( buffer ) - if len > cfg.MAX_READ_LENGTH then -- check buffer length - interface.fatalerror = "receive buffer exceeded" - debug( "fatal error:", interface.fatalerror ) + interface.eventreadtimeout = addevent( base, nil, EV_TIMEOUT, + function( ) + interface:_close() + end, cfg.READ_TIMEOUT + ) + debug( "wantwrite during read attempt, reg it in writecallback but dont know what really happens next..." ) + -- to be honest i dont know what happens next, if it is allowed to first read, the write etc... + else -- connection was closed or fatal error + interface.fatalerror = err + debug( "connection failed in read event:", interface.fatalerror ) interface:_close() interface.eventread = nil return -1 end + else interface.onincoming( interface, buffer, err ) -- send new data to listener - if err and ( err ~= "timeout" and err ~= "wantread" ) then - if "wantwrite" == err then -- need to read on write event - if not interface.eventwrite then -- register new write event if needed - interface.eventwrite = addevent( base, interface.conn, EV_WRITE, interface.writecallback, cfg.WRITE_TIMEOUT ) - end - interface.eventreadtimeout = addevent( base, nil, EV_TIMEOUT, - function( ) - interface:_close() - end, cfg.READ_TIMEOUT - ) - debug( "wantwrite during read attempt, reg it in writecallback but dont know what really happens next..." ) - -- to be honest i dont know what happens next, if it is allowed to first read, the write etc... - else -- connection was closed or fatal error - interface.fatalerror = err - debug( "connection failed in read event:", interface.fatalerror ) - interface:_close() - interface.eventread = nil - return -1 - end - end - return EV_READ, cfg.READ_TIMEOUT end + if interface.noreading then + interface.eventread = nil; + return -1; + end + return EV_READ, cfg.READ_TIMEOUT end client:settimeout( 0 ) -- set non blocking @@ -660,7 +648,7 @@ do debug "creating server interface..." local interface = { _connections = 0; - + conn = server; onconnect = listener.onconnect; -- will be called when new client connected eventread = false; -- read event handler @@ -668,7 +656,7 @@ do readcallback = false; -- read event callback fatalerror = false; -- error message nointerface = true; -- lock/unlock parameter - + _ip = addr, _port = port, _pattern = pattern, _sslctx = sslctx; } @@ -699,7 +687,7 @@ do end local client_ip, client_port = client:getpeername( ) interface._connections = interface._connections + 1 -- increase connection count - local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, nil, sslctx ) + local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx ) --vdebug( "client id:", clientinterface, "startssl:", startssl ) if ssl and sslctx then clientinterface:starttls(sslctx, true) @@ -707,12 +695,12 @@ do clientinterface:_start_session( true ) end debug( "accepted incoming client connection from:", client_ip or "<unknown IP>", client_port or "<unknown port>", "to", port or "<unknown port>"); - + client, err = server:accept() -- try to accept again end return EV_READ end - + server:settimeout( 0 ) setmetatable(interface, interface_mt) interfacelist( "add", interface ) @@ -726,7 +714,7 @@ local addserver = ( function( ) --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslcfg or "nil", startssl or "nil") local server, err = socket.bind( addr, port, cfg.ACCEPT_QUEUE ) -- create server socket if not server then - debug( "creating server socket failed because:", err ) + debug( "creating server socket on "..addr.." port "..port.." failed:", err ) return nil, err end local sslctx @@ -749,13 +737,13 @@ end )( ) local addclient, wrapclient do - function wrapclient( client, ip, port, listeners, pattern, sslctx, startssl ) + function wrapclient( client, ip, port, listeners, pattern, sslctx ) local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx ) - interface:_start_session() + interface:_start_connection(sslctx) return interface, client --function handleclient( client, ip, port, server, pattern, listener, _, sslctx ) -- creates an client interface end - + function addclient( addr, serverport, listener, pattern, localaddr, localport, sslcfg, startssl ) local client, err = socket.tcp() -- creating new socket if not client then @@ -785,9 +773,6 @@ do local res, err = client:connect( addr, serverport ) -- connect if res or ( err == "timeout" ) then local ip, port = client:getsockname( ) - local server = function( ) - return nil, "this is a dummy server interface" - end local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx, startssl ) interface:_start_connection( startssl ) debug( "new connection id:", interface.id ) @@ -828,14 +813,14 @@ local function setquitting(yes) end end -function get_backend() +local function get_backend() return base:method(); end -- We need to hold onto the events to stop them -- being garbage-collected local signal_events = {}; -- [signal_num] -> event object -function hook_signal(signal_num, handler) +local function hook_signal(signal_num, handler) local function _handler(event) local ret = handler(); if ret ~= false then -- Continue handling this signal? @@ -849,14 +834,14 @@ end local function link(sender, receiver, buffersize) local sender_locked; - + function receiver:ondrain() if sender_locked then sender:resume(); sender_locked = nil; end end - + function sender:onincoming(data) receiver:write(data); if receiver.writebufferlen >= buffersize then diff --git a/net/server_select.lua b/net/server_select.lua index cfd7f3cd..ca55d2d5 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -1,7 +1,7 @@ --- +-- -- server.lua by blastbeat of the luadch project -- Re-used here under the MIT/X Consortium License --- +-- -- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain -- @@ -10,16 +10,10 @@ local use = function( what ) return _G[ what ] end -local clean = function( tbl ) - for i, k in pairs( tbl ) do - tbl[ i ] = nil - end -end local log, table_concat = require ("util.logger").init("socket"), table.concat; local out_put = function (...) return log("debug", table_concat{...}); end local out_error = function (...) return log("warn", table_concat{...}); end -local mem_free = collectgarbage ----------------------------------// DECLARATION //-- @@ -34,7 +28,6 @@ local pairs = use "pairs" local ipairs = use "ipairs" local tonumber = use "tonumber" local tostring = use "tostring" -local collectgarbage = use "collectgarbage" --// lua libs //-- @@ -49,8 +42,6 @@ local os_difftime = os.difftime local math_min = math.min local math_huge = math.huge local table_concat = table.concat -local table_remove = table.remove -local string_len = string.len local string_sub = string.sub local coroutine_wrap = coroutine.wrap local coroutine_yield = coroutine.yield @@ -67,7 +58,6 @@ local ssl_wrap = ( luasec and luasec.wrap ) local socket_bind = luasocket.bind local socket_sleep = luasocket.sleep local socket_select = luasocket.select -local ssl_newcontext = ( luasec and luasec.newcontext ) --// functions //-- @@ -75,17 +65,16 @@ local id local loop local stats local idfalse -local addtimer local closeall local addsocket local addserver +local addtimer local getserver local wrapserver local getsettings local closesocket local removesocket local removeserver -local changetimeout local wrapconnection local changesettings @@ -112,6 +101,7 @@ local _readtraffic local _selecttimeout local _sleeptime +local _tcpbacklog local _starttime local _currenttime @@ -123,11 +113,10 @@ local _checkinterval local _sendtimeout local _readtimeout -local _cleanqueue - local _timer -local _maxclientsperserver +local _maxselectlen +local _maxfd local _maxsslhandshake @@ -151,29 +140,34 @@ _readtraffic = 0 _selecttimeout = 1 -- timeout of socket.select _sleeptime = 0 -- time to wait at the end of every loop +_tcpbacklog = 128 -- some kind of hint to the OS _maxsendlen = 51000 * 1024 -- max len of send buffer _maxreadlen = 25000 * 1024 -- max len of read buffer -_checkinterval = 1200000 -- interval in secs to check idle clients +_checkinterval = 30 -- interval in secs to check idle clients _sendtimeout = 60000 -- allowed send idle time in secs _readtimeout = 6 * 60 * 60 -- allowed read idle time in secs -_cleanqueue = false -- clean bufferqueue after using - -_maxclientsperserver = 1000 +local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows +_maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows +_maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows _maxsslhandshake = 30 -- max handshake round-trips ----------------------------------// PRIVATE //-- -wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxconnections ) -- this function wraps a server +wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd - maxconnections = maxconnections or _maxclientsperserver + if socket:getfd() >= _maxfd then + out_error("server.lua: Disallowed FD number: "..socket:getfd()) + socket:close() + return nil, "fd-too-large" + end local connections = 0 - local dispatch, disconnect = listeners.onconnect or listeners.onincoming, listeners.ondisconnect + local dispatch, disconnect = listeners.onconnect, listeners.ondisconnect local accept = socket.accept @@ -191,23 +185,43 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco end handler.remove = function( ) connections = connections - 1 - end - handler.close = function( ) - for _, handler in pairs( _socketlist ) do - if handler.serverport == serverport then - handler.disconnect( handler, "server closed" ) - handler:close( true ) - end + if handler then + handler.resume( ) end + end + handler.close = function() socket:close( ) _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _server[ip..":"..serverport] = nil; _socketlist[ socket ] = nil handler = nil socket = nil --mem_free( ) out_put "server.lua: closed server handler and removed sockets from list" end + handler.pause = function( hard ) + if not handler.paused then + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + if hard then + _socketlist[ socket ] = nil + socket:close( ) + socket = nil; + end + handler.paused = true; + end + end + handler.resume = function( ) + if handler.paused then + if not socket then + socket = socket_bind( ip, serverport, _tcpbacklog ); + socket:settimeout( 0 ) + end + _readlistlen = addsocket(_readlist, socket, _readlistlen) + _socketlist[ socket ] = handler + handler.paused = false; + end + end handler.ip = function( ) return ip end @@ -218,21 +232,24 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco return socket end handler.readbuffer = function( ) - if connections > maxconnections then + if _readlistlen >= _maxselectlen or _sendlistlen >= _maxselectlen then + handler.pause( ) out_put( "server.lua: refused new client connection: server full" ) return false end local client, err = accept( socket ) -- try to accept if client then local ip, clientport = client:getpeername( ) - client:settimeout( 0 ) local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket if err then -- error while wrapping ssl socket return false end connections = connections + 1 out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport)) - return dispatch( handler ) + if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes + return dispatch( handler ); + end + return; elseif err then -- maybe timeout or something else out_put( "server.lua: error with new client connection: ", tostring(err) ) return false @@ -243,6 +260,14 @@ end wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object + if socket:getfd() >= _maxfd then + out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent + socket:close( ) -- Should we send some kind of error here? + if server then + server.pause( ) + end + return nil, nil, "fd-too-large" + end socket:settimeout( 0 ) --// local import of socket methods //-- @@ -317,22 +342,25 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end return false, "setoption not implemented"; end - handler.close = function( self, forced ) + handler.force_close = function ( self, err ) + if bufferqueuelen ~= 0 then + out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport)) + bufferqueuelen = 0; + end + return self:close(err); + end + handler.close = function( self, err ) if not handler then return true; end _readlistlen = removesocket( _readlist, socket, _readlistlen ) _readtimes[ handler ] = nil if bufferqueuelen ~= 0 then - if not ( forced or fatalerror ) then - handler.sendbuffer( ) - if bufferqueuelen ~= 0 then -- try again... - if handler then - handler.write = nil -- ... but no further writing allowed - end - toclose = true - return false + handler.sendbuffer() -- Try now to send any outstanding data + if bufferqueuelen ~= 0 then -- Still not empty, so we'll try again later + if handler then + handler.write = nil -- ... but no further writing allowed end - else - send( socket, table_concat( bufferqueue, "", 1, bufferqueuelen ), 1, bufferlen ) -- forced send + toclose = true + return false end end if socket then @@ -347,7 +375,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport if handler then _writetimes[ handler ] = nil _closelist[ handler ] = nil + local _handler = handler; handler = nil + if disconnect then + disconnect(_handler, err or false); + disconnect = nil + end end if server then server.remove( ) @@ -365,7 +398,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport return clientport end local write = function( self, data ) - bufferlen = bufferlen + string_len( data ) + bufferlen = bufferlen + #data if bufferlen > maxsendlen then _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle handler.write = idfalse -- dont write anymore @@ -447,10 +480,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern" if not err or (err == "wantread" or err == "timeout") then -- received something local buffer = buffer or part or "" - local len = string_len( buffer ) + local len = #buffer if len > maxreadlen then - disconnect( handler, "receive buffer exceeded" ) - handler:close( true ) + handler:close( "receive buffer exceeded" ) return false end local count = len * STAT_UNIT @@ -462,24 +494,24 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport else -- connections was closed or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) ) fatalerror = true - disconnect( handler, err ) - _ = handler and handler:close( ) + _ = handler and handler:force_close( err ) return false end end local _sendbuffer = function( ) -- this function sends data local succ, err, byte, buffer, count; - local count; if socket then buffer = table_concat( bufferqueue, "", 1, bufferqueuelen ) succ, err, byte = send( socket, buffer, 1, bufferlen ) count = ( succ or byte or 0 ) * STAT_UNIT sendtraffic = sendtraffic + count _sendtraffic = _sendtraffic + count - _ = _cleanqueue and clean( bufferqueue ) + for i = bufferqueuelen,1,-1 do + bufferqueue[ i ] = nil + end --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) ) else - succ, err, count = false, "closed", 0; + succ, err, count = false, "unexpected close", 0; end if succ then -- sending succesful bufferqueuelen = 0 @@ -490,7 +522,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport drain(handler) end _ = needtls and handler:starttls(nil) - _ = toclose and handler:close( ) + _ = toclose and handler:force_close( ) return true elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer @@ -502,8 +534,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport else -- connection was closed during sending or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) ) fatalerror = true - disconnect( handler, err ) - _ = handler and handler:close( ) + _ = handler and handler:force_close( err ) return false end end @@ -511,10 +542,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport -- Set the sslctx local handshake; function handler.set_sslctx(self, new_sslctx) - ssl = true sslctx = new_sslctx; - local wrote - local read + local read, wrote handshake = coroutine_wrap( function( client ) -- create handshake coroutine local err for i = 1, _maxsslhandshake do @@ -527,108 +556,93 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.readbuffer = _readbuffer -- when handshake is done, replace the handshake function with regular functions handler.sendbuffer = _sendbuffer _ = status and status( handler, "ssl-handshake-complete" ) + if self.autostart_ssl and listeners.onconnect then + listeners.onconnect(self); + end _readlistlen = addsocket(_readlist, client, _readlistlen) return true else - out_put( "server.lua: error during ssl handshake: ", tostring(err) ) - if err == "wantwrite" and not wrote then + if err == "wantwrite" then _sendlistlen = addsocket(_sendlist, client, _sendlistlen) wrote = true - elseif err == "wantread" and not read then + elseif err == "wantread" then _readlistlen = addsocket(_readlist, client, _readlistlen) read = true else break; end - --coroutine_yield( handler, nil, err ) -- handshake not finished - coroutine_yield( ) + err = nil; + coroutine_yield( ) -- handshake not finished end end - disconnect( handler, "ssl handshake failed" ) - _ = handler and handler:close( true ) -- forced disconnect - return false -- handshake failed + out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") ) + _ = handler and handler:force_close("ssl handshake failed") + return false, err -- handshake failed end ) end if luasec then - if sslctx then -- ssl? - handler:set_sslctx(sslctx); - out_put("server.lua: ", "starting ssl handshake") - local err - socket, err = ssl_wrap( socket, sslctx ) -- wrap socket - if err then - out_put( "server.lua: ssl error: ", tostring(err) ) - --mem_free( ) - return nil, nil, err -- fatal error + handler.starttls = function( self, _sslctx) + if _sslctx then + handler:set_sslctx(_sslctx); end - socket:settimeout( 0 ) - handler.readbuffer = handshake - handler.sendbuffer = handshake - handshake( socket ) -- do handshake + if bufferqueuelen > 0 then + out_put "server.lua: we need to do tls, but delaying until send buffer empty" + needtls = true + return + end + out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) + local oldsocket, err = socket + socket, err = ssl_wrap( socket, sslctx ) -- wrap socket if not socket then - return nil, nil, "ssl handshake failed"; + out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) + return nil, err -- fatal error end - else - local sslctx; - handler.starttls = function( self, _sslctx) - if _sslctx then - sslctx = _sslctx; - handler:set_sslctx(sslctx); - end - if bufferqueuelen > 0 then - out_put "server.lua: we need to do tls, but delaying until send buffer empty" - needtls = true - return - end - out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) - local oldsocket, err = socket - socket, err = ssl_wrap( socket, sslctx ) -- wrap socket - --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) ) - if err then - out_put( "server.lua: error while starting tls on client: ", tostring(err) ) - return nil, err -- fatal error - end - socket:settimeout( 0 ) - - -- add the new socket to our system - - send = socket.send - receive = socket.receive - shutdown = id - - _socketlist[ socket ] = handler - _readlistlen = addsocket(_readlist, socket, _readlistlen) + socket:settimeout( 0 ) - -- remove traces of the old socket + -- add the new socket to our system + send = socket.send + receive = socket.receive + shutdown = id + _socketlist[ socket ] = handler + _readlistlen = addsocket(_readlist, socket, _readlistlen) - _readlistlen = removesocket( _readlist, oldsocket, _readlistlen ) - _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen ) - _socketlist[ oldsocket ] = nil + -- remove traces of the old socket + _readlistlen = removesocket( _readlist, oldsocket, _readlistlen ) + _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen ) + _socketlist[ oldsocket ] = nil - handler.starttls = nil - needtls = nil + handler.starttls = nil + needtls = nil - -- Secure now - ssl = true + -- Secure now (if handshake fails connection will close) + ssl = true - handler.readbuffer = handshake - handler.sendbuffer = handshake - handshake( socket ) -- do handshake - end - handler.readbuffer = _readbuffer - handler.sendbuffer = _sendbuffer + handler.readbuffer = handshake + handler.sendbuffer = handshake + return handshake( socket ) -- do handshake end - else - handler.readbuffer = _readbuffer - handler.sendbuffer = _sendbuffer end + + handler.readbuffer = _readbuffer + handler.sendbuffer = _sendbuffer send = socket.send receive = socket.receive shutdown = ( ssl and id ) or socket.shutdown _socketlist[ socket ] = handler _readlistlen = addsocket(_readlist, socket, _readlistlen) + + if sslctx and luasec then + out_put "server.lua: auto-starting ssl negotiation..." + handler.autostart_ssl = true; + local ok, err = handler:starttls(sslctx); + if ok == false then + return nil, nil, err + end + end + return handler, socket end @@ -681,7 +695,7 @@ local function link(sender, receiver, buffersize) sender_locked = nil; end end - + local _readbuffer = sender.readbuffer; function sender.readbuffer() _readbuffer(); @@ -701,45 +715,45 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function end if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then err = "invalid port" - elseif _server[ port ] then - err = "listeners on port '" .. port .. "' already exist" + elseif _server[ addr..":"..port ] then + err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist" elseif sslctx and not luasec then err = "luasec not found" end if err then - out_error( "server.lua, port ", port, ": ", err ) + out_error( "server.lua, [", addr, "]:", port, ": ", err ) return nil, err end addr = addr or "*" - local server, err = socket_bind( addr, port ) + local server, err = socket_bind( addr, port, _tcpbacklog ) if err then - out_error( "server.lua, port ", port, ": ", err ) + out_error( "server.lua, [", addr, "]:", port, ": ", err ) return nil, err end - local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, _maxclientsperserver ) -- wrap new server socket + local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket if not handler then server:close( ) return nil, err end server:settimeout( 0 ) _readlistlen = addsocket(_readlist, server, _readlistlen) - _server[ port ] = handler + _server[ addr..":"..port ] = handler _socketlist[ server ] = handler - out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '", addr, ":", port, "'" ) + out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" ) return handler end -getserver = function ( port ) - return _server[ port ]; +getserver = function ( addr, port ) + return _server[ addr..":"..port ]; end -removeserver = function( port ) - local handler = _server[ port ] +removeserver = function( addr, port ) + local handler = _server[ addr..":"..port ] if not handler then - return nil, "no server found on port '" .. tostring( port ) .. "'" + return nil, "no server found on '[" .. addr .. "]:" .. tostring( port ) .. "'" end handler:close( ) - _server[ port ] = nil + _server[ addr..":"..port ] = nil return true end @@ -760,23 +774,36 @@ closeall = function( ) end getsettings = function( ) - return _selecttimeout, _sleeptime, _maxsendlen, _maxreadlen, _checkinterval, _sendtimeout, _readtimeout, _cleanqueue, _maxclientsperserver, _maxsslhandshake + return { + select_timeout = _selecttimeout; + select_sleep_time = _sleeptime; + tcp_backlog = _tcpbacklog; + max_send_buffer_size = _maxsendlen; + max_receive_buffer_size = _maxreadlen; + select_idle_check_interval = _checkinterval; + send_timeout = _sendtimeout; + read_timeout = _readtimeout; + max_connections = _maxselectlen; + max_ssl_handshake_roundtrips = _maxsslhandshake; + highest_allowed_fd = _maxfd; + } end changesettings = function( new ) if type( new ) ~= "table" then return nil, "invalid settings table" end - _selecttimeout = tonumber( new.timeout ) or _selecttimeout - _sleeptime = tonumber( new.sleeptime ) or _sleeptime - _maxsendlen = tonumber( new.maxsendlen ) or _maxsendlen - _maxreadlen = tonumber( new.maxreadlen ) or _maxreadlen - _checkinterval = tonumber( new.checkinterval ) or _checkinterval - _sendtimeout = tonumber( new.sendtimeout ) or _sendtimeout - _readtimeout = tonumber( new.readtimeout ) or _readtimeout - _cleanqueue = new.cleanqueue - _maxclientsperserver = new._maxclientsperserver or _maxclientsperserver - _maxsslhandshake = new._maxsslhandshake or _maxsslhandshake + _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout + _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime + _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen + _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen + _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval + _tcpbacklog = tonumber( new.tcp_backlog ) or _tcpbacklog + _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout + _readtimeout = tonumber( new.read_timeout ) or _readtimeout + _maxselectlen = new.max_connections or _maxselectlen + _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake + _maxfd = new.highest_allowed_fd or _maxfd return true end @@ -795,7 +822,7 @@ end local quitting; -setquitting = function (quit) +local function setquitting(quit) quitting = not not quit; end @@ -825,10 +852,32 @@ loop = function(once) -- this is the main loop of the program end for handler, err in pairs( _closelist ) do handler.disconnect( )( handler, err ) - handler:close( true ) -- forced disconnect + handler:force_close() -- forced disconnect + _closelist[ handler ] = nil; end - clean( _closelist ) _currenttime = luasocket_gettime( ) + + -- Check for socket timeouts + local difftime = os_difftime( _currenttime - _starttime ) + if difftime > _checkinterval then + _starttime = _currenttime + for handler, timestamp in pairs( _writetimes ) do + if os_difftime( _currenttime - timestamp ) > _sendtimeout then + handler.disconnect( )( handler, "send timeout" ) + handler:force_close() -- forced disconnect + end + end + for handler, timestamp in pairs( _readtimes ) do + if os_difftime( _currenttime - timestamp ) > _readtimeout then + if not(handler.onreadtimeout) or handler:onreadtimeout() ~= true then + handler.disconnect( )( handler, "read timeout" ) + handler:close( ) -- forced disconnect? + end + end + end + end + + -- Fire timers if _currenttime - _timer >= math_min(next_timer_time, 1) then next_timer_time = math_huge; for i = 1, _timerlistlen do @@ -839,14 +888,15 @@ loop = function(once) -- this is the main loop of the program else next_timer_time = next_timer_time - (_currenttime - _timer); end - socket_sleep( _sleeptime ) -- wait some time - --collectgarbage( ) + + -- wait some time (0 by default) + socket_sleep( _sleeptime ) until quitting; if once and quitting == "once" then quitting = nil; return; end return "quitting" end -step = function () +local function step() return loop(true); end @@ -857,18 +907,22 @@ end --// EXPERIMENTAL //-- local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx ) - local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) + local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) + if not handler then return nil, err end _socketlist[ socket ] = handler - _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) - if listeners.onconnect then - -- When socket is writeable, call onconnect - local _sendbuffer = handler.sendbuffer; - handler.sendbuffer = function () - handler.sendbuffer = _sendbuffer; - listeners.onconnect(handler); - -- If there was data with the incoming packet, handle it now. - if #handler:bufferqueue() > 0 then - return _sendbuffer(); + if not sslctx then + _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) + if listeners.onconnect then + -- When socket is writeable, call onconnect + local _sendbuffer = handler.sendbuffer; + handler.sendbuffer = function () + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ); + handler.sendbuffer = _sendbuffer; + listeners.onconnect(handler); + -- If there was data with the incoming packet, handle it now. + if #handler:bufferqueue() > 0 then + return _sendbuffer(); + end end end end @@ -883,9 +937,9 @@ local addclient = function( address, port, listeners, pattern, sslctx ) client:settimeout( 0 ) _, err = client:connect( address, port ) if err then -- try again - local handler = wrapclient( client, address, port, listeners ) + return wrapclient( client, address, port, listeners, pattern, sslctx ) else - wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx ) + return wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx ) end end @@ -900,28 +954,6 @@ use "setmetatable" ( _writetimes, { __mode = "k" } ) _timer = luasocket_gettime( ) _starttime = luasocket_gettime( ) -addtimer( function( ) - local difftime = os_difftime( _currenttime - _starttime ) - if difftime > _checkinterval then - _starttime = _currenttime - for handler, timestamp in pairs( _writetimes ) do - if os_difftime( _currenttime - timestamp ) > _sendtimeout then - --_writetimes[ handler ] = nil - handler.disconnect( )( handler, "send timeout" ) - handler:close( true ) -- forced disconnect - end - end - for handler, timestamp in pairs( _readtimes ) do - if os_difftime( _currenttime - timestamp ) > _readtimeout then - --_readtimes[ handler ] = nil - handler.disconnect( )( handler, "read timeout" ) - handler:close( ) -- forced disconnect? - end - end - end - end -) - local function setlogger(new_logger) local old_logger = log; if new_logger then @@ -933,15 +965,16 @@ end ----------------------------------// PUBLIC INTERFACE //-- return { + _addtimer = addtimer, addclient = addclient, wrapclient = wrapclient, - + loop = loop, link = link, + step = step, stats = stats, closeall = closeall, - addtimer = addtimer, addserver = addserver, getserver = getserver, setlogger = setlogger, diff --git a/net/xmppclient_listener.lua b/net/xmppclient_listener.lua deleted file mode 100644 index 4cc90cbf..00000000 --- a/net/xmppclient_listener.lua +++ /dev/null @@ -1,179 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - - -local logger = require "logger"; -local log = logger.init("xmppclient_listener"); -local new_xmpp_stream = require "util.xmppstream".new; - -local connlisteners_register = require "net.connlisteners".register; - -local sessionmanager = require "core.sessionmanager"; -local sm_new_session, sm_destroy_session = sessionmanager.new_session, sessionmanager.destroy_session; -local sm_streamopened = sessionmanager.streamopened; -local sm_streamclosed = sessionmanager.streamclosed; -local st = require "util.stanza"; -local xpcall = xpcall; -local tostring = tostring; -local type = type; -local traceback = debug.traceback; - -local config = require "core.configmanager"; -local opt_keepalives = config.get("*", "core", "tcp_keepalives"); - -local stream_callbacks = { default_ns = "jabber:client", - streamopened = sm_streamopened, streamclosed = sm_streamclosed, handlestanza = core_process_stanza }; - -local xmlns_xmpp_streams = "urn:ietf:params:xml:ns:xmpp-streams"; - -function stream_callbacks.error(session, error, data) - if error == "no-stream" then - session.log("debug", "Invalid opening stream header"); - session:close("invalid-namespace"); - elseif error == "parse-error" then - (session.log or log)("debug", "Client XML parse error: %s", tostring(data)); - session:close("not-well-formed"); - elseif error == "stream-error" then - local condition, text = "undefined-condition"; - for child in data:children() do - if child.attr.xmlns == xmlns_xmpp_streams then - if child.name ~= "text" then - condition = child.name; - else - text = child:get_text(); - end - if condition ~= "undefined-condition" and text then - break; - end - end - end - text = condition .. (text and (" ("..text..")") or ""); - session.log("info", "Session closed by remote with error: %s", text); - session:close(nil, text); - end -end - -local function handleerr(err) log("error", "Traceback[c2s]: %s: %s", tostring(err), traceback()); end -function stream_callbacks.handlestanza(session, stanza) - stanza = session.filter("stanzas/in", stanza); - if stanza then - return xpcall(function () return core_process_stanza(session, stanza) end, handleerr); - end -end - -local sessions = {}; -local xmppclient = { default_port = 5222, default_mode = "*a" }; - --- These are session methods -- - -local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; -local default_stream_attr = { ["xmlns:stream"] = "http://etherx.jabber.org/streams", xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; -local function session_close(session, reason) - local log = session.log or log; - if session.conn then - if session.notopen then - session.send("<?xml version='1.0'?>"); - session.send(st.stanza("stream:stream", default_stream_attr):top_tag()); - end - if reason then - if type(reason) == "string" then -- assume stream error - log("info", "Disconnecting client, <stream:error> is: %s", reason); - session.send(st.stanza("stream:error"):tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' })); - elseif type(reason) == "table" then - if reason.condition then - local stanza = st.stanza("stream:error"):tag(reason.condition, stream_xmlns_attr):up(); - if reason.text then - stanza:tag("text", stream_xmlns_attr):text(reason.text):up(); - end - if reason.extra then - stanza:add_child(reason.extra); - end - log("info", "Disconnecting client, <stream:error> is: %s", tostring(stanza)); - session.send(stanza); - elseif reason.name then -- a stanza - log("info", "Disconnecting client, <stream:error> is: %s", tostring(reason)); - session.send(reason); - end - end - end - session.send("</stream:stream>"); - session.conn:close(); - xmppclient.ondisconnect(session.conn, (reason and (reason.text or reason.condition)) or reason or "session closed"); - end -end - - --- End of session methods -- - -function xmppclient.onconnect(conn) - local session = sm_new_session(conn); - sessions[conn] = session; - - session.log("info", "Client connected"); - - -- Client is using legacy SSL (otherwise mod_tls sets this flag) - if conn:ssl() then - session.secure = true; - end - - if opt_keepalives ~= nil then - conn:setoption("keepalive", opt_keepalives); - end - - session.close = session_close; - - local stream = new_xmpp_stream(session, stream_callbacks); - session.stream = stream; - - session.notopen = true; - - function session.reset_stream() - session.notopen = true; - session.stream:reset(); - end - - local filter = session.filter; - function session.data(data) - data = filter("bytes/in", data); - if data then - local ok, err = stream:feed(data); - if ok then return; end - log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " "):gsub("[%z\1-\31]", "_")); - session:close("not-well-formed"); - end - end - - local handlestanza = stream_callbacks.handlestanza; - function session.dispatch_stanza(session, stanza) - return handlestanza(session, stanza); - end -end - -function xmppclient.onincoming(conn, data) - local session = sessions[conn]; - if session then - session.data(data); - end -end - -function xmppclient.ondisconnect(conn, err) - local session = sessions[conn]; - if session then - (session.log or log)("info", "Client disconnected: %s", err); - sm_destroy_session(session, err); - sessions[conn] = nil; - session = nil; - end -end - -function xmppclient.associate_session(conn, session) - sessions[conn] = session; -end - -connlisteners_register("xmppclient", xmppclient); diff --git a/net/xmppcomponent_listener.lua b/net/xmppcomponent_listener.lua deleted file mode 100644 index 90293559..00000000 --- a/net/xmppcomponent_listener.lua +++ /dev/null @@ -1,220 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - -local hosts = _G.hosts; - -local t_concat = table.concat; -local tostring = tostring; -local type = type; -local pairs = pairs; - -local lxp = require "lxp"; -local logger = require "util.logger"; -local config = require "core.configmanager"; -local connlisteners = require "net.connlisteners"; -local uuid_gen = require "util.uuid".generate; -local jid_split = require "util.jid".split; -local sha1 = require "util.hashes".sha1; -local st = require "util.stanza"; -local new_xmpp_stream = require "util.xmppstream".new; - -local sessions = {}; - -local log = logger.init("componentlistener"); - -local component_listener = { default_port = 5347; default_mode = "*a"; default_interface = config.get("*", "core", "component_interface") or "127.0.0.1" }; - -local xmlns_component = 'jabber:component:accept'; - ---- Callbacks/data for xmppstream to handle streams for us --- - -local stream_callbacks = { default_ns = xmlns_component }; - -local xmlns_xmpp_streams = "urn:ietf:params:xml:ns:xmpp-streams"; - -function stream_callbacks.error(session, error, data, data2) - if session.destroyed then return; end - log("warn", "Error processing component stream: "..tostring(error)); - if error == "no-stream" then - session:close("invalid-namespace"); - elseif error == "parse-error" then - session.log("warn", "External component %s XML parse error: %s", tostring(session.host), tostring(data)); - session:close("not-well-formed"); - elseif error == "stream-error" then - local condition, text = "undefined-condition"; - for child in data:children() do - if child.attr.xmlns == xmlns_xmpp_streams then - if child.name ~= "text" then - condition = child.name; - else - text = child:get_text(); - end - if condition ~= "undefined-condition" and text then - break; - end - end - end - text = condition .. (text and (" ("..text..")") or ""); - session.log("info", "Session closed by remote with error: %s", text); - session:close(nil, text); - end -end - -function stream_callbacks.streamopened(session, attr) - if config.get(attr.to, "core", "component_module") ~= "component" then - -- Trying to act as a component domain which - -- hasn't been configured - session:close{ condition = "host-unknown", text = tostring(attr.to).." does not match any configured external components" }; - return; - end - - -- Note that we don't create the internal component - -- until after the external component auths successfully - - session.host = attr.to; - session.streamid = uuid_gen(); - session.notopen = nil; - - session.send(st.stanza("stream:stream", { xmlns=xmlns_component, - ["xmlns:stream"]='http://etherx.jabber.org/streams', id=session.streamid, from=session.host }):top_tag()); - -end - -function stream_callbacks.streamclosed(session) - session.log("debug", "Received </stream:stream>"); - session:close(); -end - -local core_process_stanza = core_process_stanza; - -function stream_callbacks.handlestanza(session, stanza) - -- Namespaces are icky. - if not stanza.attr.xmlns and stanza.name == "handshake" then - stanza.attr.xmlns = xmlns_component; - end - if not stanza.attr.xmlns or stanza.attr.xmlns == "jabber:client" then - local from = stanza.attr.from; - if from then - if session.component_validate_from then - local _, domain = jid_split(stanza.attr.from); - if domain ~= session.host then - -- Return error - session.log("warn", "Component sent stanza with missing or invalid 'from' address"); - session:close{ - condition = "invalid-from"; - text = "Component tried to send from address <"..tostring(from) - .."> which is not in domain <"..tostring(session.host)..">"; - }; - return; - end - end - else - stanza.attr.from = session.host; - end - if not stanza.attr.to then - session.log("warn", "Rejecting stanza with no 'to' address"); - session.send(st.error_reply(stanza, "modify", "bad-request", "Components MUST specify a 'to' address on stanzas")); - return; - end - end - return core_process_stanza(session, stanza); -end - ---- Closing a component connection -local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; -local default_stream_attr = { ["xmlns:stream"] = "http://etherx.jabber.org/streams", xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; -local function session_close(session, reason) - if session.destroyed then return; end - local log = session.log or log; - if session.conn then - if session.notopen then - session.send("<?xml version='1.0'?>"); - session.send(st.stanza("stream:stream", default_stream_attr):top_tag()); - end - if reason then - if type(reason) == "string" then -- assume stream error - log("info", "Disconnecting component, <stream:error> is: %s", reason); - session.send(st.stanza("stream:error"):tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' })); - elseif type(reason) == "table" then - if reason.condition then - local stanza = st.stanza("stream:error"):tag(reason.condition, stream_xmlns_attr):up(); - if reason.text then - stanza:tag("text", stream_xmlns_attr):text(reason.text):up(); - end - if reason.extra then - stanza:add_child(reason.extra); - end - log("info", "Disconnecting component, <stream:error> is: %s", tostring(stanza)); - session.send(stanza); - elseif reason.name then -- a stanza - log("info", "Disconnecting component, <stream:error> is: %s", tostring(reason)); - session.send(reason); - end - end - end - session.send("</stream:stream>"); - session.conn:close(); - component_listener.ondisconnect(session.conn, "stream error"); - end -end - ---- Component connlistener -function component_listener.onconnect(conn) - local _send = conn.write; - local session = { type = "component", conn = conn, send = function (data) return _send(conn, tostring(data)); end }; - - -- Logging functions -- - local conn_name = "jcp"..tostring(conn):match("[a-f0-9]+$"); - session.log = logger.init(conn_name); - session.close = session_close; - - session.log("info", "Incoming Jabber component connection"); - - local stream = new_xmpp_stream(session, stream_callbacks); - session.stream = stream; - - session.notopen = true; - - function session.reset_stream() - session.notopen = true; - session.stream:reset(); - end - - function session.data(conn, data) - local ok, err = stream:feed(data); - if ok then return; end - log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " "):gsub("[%z\1-\31]", "_")); - session:close("not-well-formed"); - end - - session.dispatch_stanza = stream_callbacks.handlestanza; - - sessions[conn] = session; -end -function component_listener.onincoming(conn, data) - local session = sessions[conn]; - session.data(conn, data); -end -function component_listener.ondisconnect(conn, err) - local session = sessions[conn]; - if session then - (session.log or log)("info", "component disconnected: %s (%s)", tostring(session.host), tostring(err)); - if session.on_destroy then session:on_destroy(err); end - sessions[conn] = nil; - for k in pairs(session) do - if k ~= "log" and k ~= "close" then - session[k] = nil; - end - end - session.destroyed = true; - session = nil; - end -end - -connlisteners.register('xmppcomponent', component_listener); diff --git a/net/xmppserver_listener.lua b/net/xmppserver_listener.lua deleted file mode 100644 index 3af0b962..00000000 --- a/net/xmppserver_listener.lua +++ /dev/null @@ -1,209 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - -local tostring = tostring; -local type = type; -local xpcall = xpcall; -local s_format = string.format; -local traceback = debug.traceback; - -local logger = require "logger"; -local log = logger.init("xmppserver_listener"); -local st = require "util.stanza"; -local connlisteners_register = require "net.connlisteners".register; -local new_xmpp_stream = require "util.xmppstream".new; -local s2s_new_incoming = require "core.s2smanager".new_incoming; -local s2s_streamopened = require "core.s2smanager".streamopened; -local s2s_streamclosed = require "core.s2smanager".streamclosed; -local s2s_destroy_session = require "core.s2smanager".destroy_session; -local s2s_attempt_connect = require "core.s2smanager".attempt_connection; -local stream_callbacks = { default_ns = "jabber:server", - streamopened = s2s_streamopened, streamclosed = s2s_streamclosed, handlestanza = core_process_stanza }; - -local xmlns_xmpp_streams = "urn:ietf:params:xml:ns:xmpp-streams"; - -function stream_callbacks.error(session, error, data) - if error == "no-stream" then - session:close("invalid-namespace"); - elseif error == "parse-error" then - session.log("debug", "Server-to-server XML parse error: %s", tostring(error)); - session:close("not-well-formed"); - elseif error == "stream-error" then - local condition, text = "undefined-condition"; - for child in data:children() do - if child.attr.xmlns == xmlns_xmpp_streams then - if child.name ~= "text" then - condition = child.name; - else - text = child:get_text(); - end - if condition ~= "undefined-condition" and text then - break; - end - end - end - text = condition .. (text and (" ("..text..")") or ""); - session.log("info", "Session closed by remote with error: %s", text); - session:close(nil, text); - end -end - -local function handleerr(err) log("error", "Traceback[s2s]: %s: %s", tostring(err), traceback()); end -function stream_callbacks.handlestanza(session, stanza) - if stanza.attr.xmlns == "jabber:client" then --COMPAT: Prosody pre-0.6.2 may send jabber:client - stanza.attr.xmlns = nil; - end - stanza = session.filter("stanzas/in", stanza); - if stanza then - return xpcall(function () return core_process_stanza(session, stanza) end, handleerr); - end -end - -local sessions = {}; -local xmppserver = { default_port = 5269, default_mode = "*a" }; - --- These are session methods -- - -local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; -local default_stream_attr = { ["xmlns:stream"] = "http://etherx.jabber.org/streams", xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; -local function session_close(session, reason, remote_reason) - local log = session.log or log; - if session.conn then - if session.notopen then - session.sends2s("<?xml version='1.0'?>"); - session.sends2s(st.stanza("stream:stream", default_stream_attr):top_tag()); - end - if reason then - if type(reason) == "string" then -- assume stream error - log("info", "Disconnecting %s[%s], <stream:error> is: %s", session.host or "(unknown host)", session.type, reason); - session.sends2s(st.stanza("stream:error"):tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' })); - elseif type(reason) == "table" then - if reason.condition then - local stanza = st.stanza("stream:error"):tag(reason.condition, stream_xmlns_attr):up(); - if reason.text then - stanza:tag("text", stream_xmlns_attr):text(reason.text):up(); - end - if reason.extra then - stanza:add_child(reason.extra); - end - log("info", "Disconnecting %s[%s], <stream:error> is: %s", session.host or "(unknown host)", session.type, tostring(stanza)); - session.sends2s(stanza); - elseif reason.name then -- a stanza - log("info", "Disconnecting %s->%s[%s], <stream:error> is: %s", session.from_host or "(unknown host)", session.to_host or "(unknown host)", session.type, tostring(reason)); - session.sends2s(reason); - end - end - end - session.sends2s("</stream:stream>"); - if session.notopen or not session.conn:close() then - session.conn:close(true); -- Force FIXME: timer? - end - session.conn:close(); - xmppserver.ondisconnect(session.conn, remote_reason or (reason and (reason.text or reason.condition)) or reason or "stream closed"); - end -end - - --- End of session methods -- - -local function initialize_session(session) - local stream = new_xmpp_stream(session, stream_callbacks); - session.stream = stream; - - session.notopen = true; - - function session.reset_stream() - session.notopen = true; - session.stream:reset(); - end - - local filter = session.filter; - function session.data(data) - data = filter("bytes/in", data); - if data then - local ok, err = stream:feed(data); - if ok then return; end - (session.log or log)("warn", "Received invalid XML: %s", data); - (session.log or log)("warn", "Problem was: %s", err); - session:close("not-well-formed"); - end - end - - session.close = session_close; - local handlestanza = stream_callbacks.handlestanza; - function session.dispatch_stanza(session, stanza) - return handlestanza(session, stanza); - end -end - -function xmppserver.onconnect(conn) - if not sessions[conn] then -- May be an existing outgoing session - local session = s2s_new_incoming(conn); - sessions[conn] = session; - - -- Logging functions -- - local conn_name = "s2sin"..tostring(conn):match("[a-f0-9]+$"); - session.log = logger.init(conn_name); - - session.log("info", "Incoming s2s connection"); - - initialize_session(session); - end -end - -function xmppserver.onincoming(conn, data) - local session = sessions[conn]; - if session then - session.data(data); - end -end - -function xmppserver.onstatus(conn, status) - if status == "ssl-handshake-complete" then - local session = sessions[conn]; - if session and session.direction == "outgoing" then - local to_host, from_host = session.to_host, session.from_host; - session.log("debug", "Sending stream header..."); - session.sends2s(s_format([[<stream:stream xmlns='jabber:server' xmlns:db='jabber:server:dialback' xmlns:stream='http://etherx.jabber.org/streams' from='%s' to='%s' version='1.0'>]], from_host, to_host)); - end - end -end - -function xmppserver.ondisconnect(conn, err) - local session = sessions[conn]; - if session then - if err and err ~= "closed" and session.srv_hosts then - (session.log or log)("debug", "s2s connection attempt failed: %s", err); - if s2s_attempt_connect(session, err) then - (session.log or log)("debug", "...so we're going to try another target"); - return; -- Session lives for now - end - end - (session.log or log)("info", "s2s disconnected: %s->%s (%s)", tostring(session.from_host), tostring(session.to_host), tostring(err or "closed")); - s2s_destroy_session(session, err); - sessions[conn] = nil; - session = nil; - end -end - -function xmppserver.register_outgoing(conn, session) - session.direction = "outgoing"; - sessions[conn] = session; - - initialize_session(session); -end - -connlisteners_register("xmppserver", xmppserver); - - --- We need to perform some initialisation when a connection is created --- We also need to perform that same initialisation at other points (SASL, TLS, ...) - --- ...and we need to handle data --- ...and record all sessions associated with connections diff --git a/plugins/adhoc/adhoc.lib.lua b/plugins/adhoc/adhoc.lib.lua index 0cb4efe1..b544ddc8 100644 --- a/plugins/adhoc/adhoc.lib.lua +++ b/plugins/adhoc/adhoc.lib.lua @@ -12,7 +12,7 @@ local states = {} local _M = {}; -function _cmdtag(desc, status, sessionid, action) +local function _cmdtag(desc, status, sessionid, action) local cmd = st.stanza("command", { xmlns = xmlns_cmd, node = desc.node, status = status }); if sessionid then cmd.attr.sessionid = sessionid; end if action then cmd.attr.action = action; end @@ -34,7 +34,7 @@ function _M.handle_cmd(command, origin, stanza) local data, state = command:handler(dataIn, states[sessionid]); states[sessionid] = state; - local stanza = st.reply(stanza); + local cmdtag; if data.status == "completed" then states[sessionid] = nil; cmdtag = command:cmdtag("completed", sessionid); @@ -43,11 +43,12 @@ function _M.handle_cmd(command, origin, stanza) cmdtag = command:cmdtag("canceled", sessionid); elseif data.status == "error" then states[sessionid] = nil; - stanza = st.error_reply(stanza, data.error.type, data.error.condition, data.error.message); - origin.send(stanza); + local reply = st.error_reply(stanza, data.error.type, data.error.condition, data.error.message); + origin.send(reply); return true; else cmdtag = command:cmdtag("executing", sessionid); + data.actions = data.actions or { "complete" }; end for name, content in pairs(data) do @@ -57,14 +58,14 @@ function _M.handle_cmd(command, origin, stanza) cmdtag:tag("note", {type="warn"}):text(content):up(); elseif name == "error" then cmdtag:tag("note", {type="error"}):text(content.message):up(); - elseif name =="actions" then - local actions = st.stanza("actions"); + elseif name == "actions" then + local actions = st.stanza("actions", { execute = content.default }); for _, action in ipairs(content) do if (action == "prev") or (action == "next") or (action == "complete") then actions:tag(action):up(); else - module:log("error", 'Command "'..command.name.. - '" at node "'..command.node..'" provided an invalid action "'..action..'"'); + module:log("error", "Command %q at node %q provided an invalid action %q", + command.name, command.node, action); end end cmdtag:add_child(actions); @@ -76,8 +77,9 @@ function _M.handle_cmd(command, origin, stanza) cmdtag:add_child(content); end end - stanza:add_child(cmdtag); - origin.send(stanza); + local reply = st.reply(stanza); + reply:add_child(cmdtag); + origin.send(reply); return true; end diff --git a/plugins/adhoc/mod_adhoc.lua b/plugins/adhoc/mod_adhoc.lua index 20c0f2be..f3e7f520 100644 --- a/plugins/adhoc/mod_adhoc.lua +++ b/plugins/adhoc/mod_adhoc.lua @@ -1,81 +1,86 @@ -- Copyright (C) 2009 Thilo Cestonaro --- Copyright (C) 2009-2010 Florian Zeitz +-- Copyright (C) 2009-2011 Florian Zeitz -- -- This file is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- local st = require "util.stanza"; +local keys = require "util.iterators".keys; +local array_collect = require "util.array".collect; local is_admin = require "core.usermanager".is_admin; +local jid_split = require "util.jid".split; local adhoc_handle_cmd = module:require "adhoc".handle_cmd; local xmlns_cmd = "http://jabber.org/protocol/commands"; -local xmlns_disco = "http://jabber.org/protocol/disco"; local commands = {}; module:add_feature(xmlns_cmd); -module:hook("iq/host/"..xmlns_disco.."#info:query", function (event) - local origin, stanza = event.origin, event.stanza; - local node = stanza.tags[1].attr.node; - if stanza.attr.type == "get" and node then - if commands[node] then - local privileged = is_admin(stanza.attr.from, stanza.attr.to); - if (commands[node].permission == "admin" and privileged) - or (commands[node].permission == "user") then - reply = st.reply(stanza); - reply:tag("query", { xmlns = xmlns_disco.."#info", - node = node }); - reply:tag("identity", { name = commands[node].name, - category = "automation", type = "command-node" }):up(); - reply:tag("feature", { var = xmlns_cmd }):up(); - reply:tag("feature", { var = "jabber:x:data" }):up(); - else - reply = st.error_reply(stanza, "auth", "forbidden", "This item is not available to you"); - end - origin.send(reply); - return true; - elseif node == xmlns_cmd then - reply = st.reply(stanza); - reply:tag("query", { xmlns = xmlns_disco.."#info", - node = node }); - reply:tag("identity", { name = "Ad-Hoc Commands", - category = "automation", type = "command-list" }):up(); - origin.send(reply); - return true; - +module:hook("host-disco-info-node", function (event) + local stanza, origin, reply, node = event.stanza, event.origin, event.reply, event.node; + if commands[node] then + local from = stanza.attr.from; + local privileged = is_admin(from, stanza.attr.to); + local global_admin = is_admin(from); + local username, hostname = jid_split(from); + local command = commands[node]; + if (command.permission == "admin" and privileged) + or (command.permission == "global_admin" and global_admin) + or (command.permission == "local_user" and hostname == module.host) + or (command.permission == "user") then + reply:tag("identity", { name = command.name, + category = "automation", type = "command-node" }):up(); + reply:tag("feature", { var = xmlns_cmd }):up(); + reply:tag("feature", { var = "jabber:x:data" }):up(); + event.exists = true; + else + return origin.send(st.error_reply(stanza, "auth", "forbidden", "This item is not available to you")); end + elseif node == xmlns_cmd then + reply:tag("identity", { name = "Ad-Hoc Commands", + category = "automation", type = "command-list" }):up(); + event.exists = true; end end); -module:hook("iq/host/"..xmlns_disco.."#items:query", function (event) - local origin, stanza = event.origin, event.stanza; - if stanza.attr.type == "get" and stanza.tags[1].attr.node - and stanza.tags[1].attr.node == xmlns_cmd then - local privileged = is_admin(stanza.attr.from, stanza.attr.to); - reply = st.reply(stanza); - reply:tag("query", { xmlns = xmlns_disco.."#items", - node = xmlns_cmd }); - for node, command in pairs(commands) do - if (command.permission == "admin" and privileged) - or (command.permission == "user") then - reply:tag("item", { name = command.name, - node = node, jid = module:get_host() }); - reply:up(); - end +module:hook("host-disco-items-node", function (event) + local stanza, origin, reply, node = event.stanza, event.origin, event.reply, event.node; + if node ~= xmlns_cmd then + return; + end + + local from = stanza.attr.from; + local admin = is_admin(from, stanza.attr.to); + local global_admin = is_admin(from); + local username, hostname = jid_split(from); + local nodes = array_collect(keys(commands)):sort(); + for _, node in ipairs(nodes) do + local command = commands[node]; + if (command.permission == "admin" and admin) + or (command.permission == "global_admin" and global_admin) + or (command.permission == "local_user" and hostname == module.host) + or (command.permission == "user") then + reply:tag("item", { name = command.name, + node = node, jid = module:get_host() }); + reply:up(); end - origin.send(reply); - return true; end -end, 500); + event.exists = true; +end); module:hook("iq/host/"..xmlns_cmd..":command", function (event) local origin, stanza = event.origin, event.stanza; if stanza.attr.type == "set" then local node = stanza.tags[1].attr.node - if commands[node] then - local privileged = is_admin(stanza.attr.from, stanza.attr.to); - if commands[node].permission == "admin" - and not privileged then + local command = commands[node]; + if command then + local from = stanza.attr.from; + local admin = is_admin(from, stanza.attr.to); + local global_admin = is_admin(from); + local username, hostname = jid_split(from); + if (command.permission == "admin" and not admin) + or (command.permission == "global_admin" and not global_admin) + or (command.permission == "local_user" and hostname ~= module.host) then origin.send(st.error_reply(stanza, "auth", "forbidden", "You don't have permission to execute this command"):up() :add_child(commands[node]:cmdtag("canceled") :tag("note", {type="error"}):text("You don't have permission to execute this command"))); @@ -87,19 +92,14 @@ module:hook("iq/host/"..xmlns_cmd..":command", function (event) end end, 500); -local function handle_item_added(item) +local function adhoc_added(event) + local item = event.item; commands[item.node] = item; end -module:hook("item-added/adhoc", function (event) - return handle_item_added(event.item); -end, 500); - -module:hook("item-removed/adhoc", function (event) +local function adhoc_removed(event) commands[event.item.node] = nil; -end, 500); - --- Pick up any items that are already added -for _, item in ipairs(module:get_host_items("adhoc")) do - handle_item_added(item); end + +module:handle_items("adhoc", adhoc_added, adhoc_removed); +module:handle_items("adhoc-provider", adhoc_added, adhoc_removed); diff --git a/plugins/mod_admin_adhoc.lua b/plugins/mod_admin_adhoc.lua index 984ae5ea..d5aaa0c4 100644 --- a/plugins/mod_admin_adhoc.lua +++ b/plugins/mod_admin_adhoc.lua @@ -1,4 +1,4 @@ --- Copyright (C) 2009-2010 Florian Zeitz +-- Copyright (C) 2009-2011 Florian Zeitz -- -- This file is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. @@ -10,179 +10,153 @@ local prosody = _G.prosody; local hosts = prosody.hosts; local t_concat = table.concat; -require "util.iterators"; +local module_host = module:get_host(); + +local keys = require "util.iterators".keys; local usermanager_user_exists = require "core.usermanager".user_exists; local usermanager_create_user = require "core.usermanager".create_user; +local usermanager_delete_user = require "core.usermanager".delete_user; local usermanager_get_password = require "core.usermanager".get_password; local usermanager_set_password = require "core.usermanager".set_password; -local is_admin = require "core.usermanager".is_admin; +local hostmanager_activate = require "core.hostmanager".activate; +local hostmanager_deactivate = require "core.hostmanager".deactivate; local rm_load_roster = require "core.rostermanager".load_roster; -local st, jid, uuid = require "util.stanza", require "util.jid", require "util.uuid"; +local st, jid = require "util.stanza", require "util.jid"; local timer_add_task = require "util.timer".add_task; local dataforms_new = require "util.dataforms".new; local array = require "util.array"; local modulemanager = require "modulemanager"; +local core_post_stanza = prosody.core_post_stanza; +local adhoc_simple = require "util.adhoc".new_simple_form; +local adhoc_initial = require "util.adhoc".new_initial_data_form; +module:depends("adhoc"); local adhoc_new = module:require "adhoc".new; -function add_user_command_handler(self, data, state) - local add_user_layout = dataforms_new{ - title = "Adding a User"; - instructions = "Fill out this form to add a user."; +local function generate_error_message(errors) + local errmsg = {}; + for name, err in pairs(errors) do + errmsg[#errmsg + 1] = name .. ": " .. err; + end + return { status = "completed", error = { message = t_concat(errmsg, "\n") } }; +end - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for the account to be added" }; - { name = "password", type = "text-private", label = "The password for this account" }; - { name = "password-verify", type = "text-private", label = "Retype password" }; - }; +-- Adding a new user +local add_user_layout = dataforms_new{ + title = "Adding a User"; + instructions = "Fill out this form to add a user."; - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - local fields = add_user_layout:data(data.form); - if not fields.accountjid then - return { status = "completed", error = { message = "You need to specify a JID." } }; - end - local username, host, resource = jid.split(fields.accountjid); - if data.to ~= host then - return { status = "completed", error = { message = "Trying to add a user on " .. host .. " but command was sent to " .. data.to}}; - end - if (fields["password"] == fields["password-verify"]) and username and host then - if usermanager_user_exists(username, host) then - return { status = "completed", error = { message = "Account already exists" } }; + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for the account to be added" }; + { name = "password", type = "text-private", label = "The password for this account" }; + { name = "password-verify", type = "text-private", label = "Retype password" }; +}; + +local add_user_command_handler = adhoc_simple(add_user_layout, function(fields, err) + if err then + return generate_error_message(err); + end + local username, host, resource = jid.split(fields.accountjid); + if module_host ~= host then + return { status = "completed", error = { message = "Trying to add a user on " .. host .. " but command was sent to " .. module_host}}; + end + if (fields["password"] == fields["password-verify"]) and username and host then + if usermanager_user_exists(username, host) then + return { status = "completed", error = { message = "Account already exists" } }; + else + if usermanager_create_user(username, fields.password, host) then + module:log("info", "Created new account %s@%s", username, host); + return { status = "completed", info = "Account successfully created" }; else - if usermanager_create_user(username, fields.password, host) then - module:log("info", "Created new account " .. username.."@"..host); - return { status = "completed", info = "Account successfully created" }; - else - return { status = "completed", error = { message = "Failed to write data to disk" } }; - end + return { status = "completed", error = { message = "Failed to write data to disk" } }; end - else - module:log("debug", (fields.accountjid or "<nil>") .. " " .. (fields.password or "<nil>") .. " " - .. (fields["password-verify"] or "<nil>")); - return { status = "completed", error = { message = "Invalid data.\nPassword mismatch, or empty username" } }; end else - return { status = "executing", form = add_user_layout }, "executing"; + module:log("debug", "Invalid data, password mismatch or empty username while creating account for %s", fields.accountjid or "<nil>"); + return { status = "completed", error = { message = "Invalid data.\nPassword mismatch, or empty username" } }; end -end +end); -function change_user_password_command_handler(self, data, state) - local change_user_password_layout = dataforms_new{ - title = "Changing a User Password"; - instructions = "Fill out this form to change a user's password."; +-- Changing a user's password +local change_user_password_layout = dataforms_new{ + title = "Changing a User Password"; + instructions = "Fill out this form to change a user's password."; - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for this account" }; - { name = "password", type = "text-private", required = true, label = "The password for this account" }; - }; + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for this account" }; + { name = "password", type = "text-private", required = true, label = "The password for this account" }; +}; - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - local fields = change_user_password_layout:data(data.form); - if not fields.accountjid or fields.accountjid == "" or not fields.password then - return { status = "completed", error = { message = "Please specify username and password" } }; - end - local username, host, resource = jid.split(fields.accountjid); - if data.to ~= host then - return { status = "completed", error = { message = "Trying to change the password of a user on " .. host .. " but command was sent to " .. data.to}}; - end - if usermanager_user_exists(username, host) and usermanager_set_password(username, fields.password, host) then - return { status = "completed", info = "Password successfully changed" }; - else - return { status = "completed", error = { message = "User does not exist" } }; - end +local change_user_password_command_handler = adhoc_simple(change_user_password_layout, function(fields, err) + if err then + return generate_error_message(err); + end + local username, host, resource = jid.split(fields.accountjid); + if module_host ~= host then + return { status = "completed", error = { message = "Trying to change the password of a user on " .. host .. " but command was sent to " .. module_host}}; + end + if usermanager_user_exists(username, host) and usermanager_set_password(username, fields.password, host) then + return { status = "completed", info = "Password successfully changed" }; + else + return { status = "completed", error = { message = "User does not exist" } }; + end +end); + +-- Reloading the config +local function config_reload_handler(self, data, state) + local ok, err = prosody.reload_config(); + if ok then + return { status = "completed", info = "Configuration reloaded (modules may need to be reloaded for this to have an effect)" }; else - return { status = "executing", form = change_user_password_layout }, "executing"; + return { status = "completed", error = { message = "Failed to reload config: " .. tostring(err) } }; end end -function delete_user_command_handler(self, data, state) - local delete_user_layout = dataforms_new{ - title = "Deleting a User"; - instructions = "Fill out this form to delete a user."; +-- Deleting a user's account +local delete_user_layout = dataforms_new{ + title = "Deleting a User"; + instructions = "Fill out this form to delete a user."; - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "accountjids", type = "jid-multi", label = "The Jabber ID(s) to delete" }; - }; + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "accountjids", type = "jid-multi", label = "The Jabber ID(s) to delete" }; +}; - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - local fields = delete_user_layout:data(data.form); - local failed = {}; - local succeeded = {}; - for _, aJID in ipairs(fields.accountjids) do - local username, host, resource = jid.split(aJID); - if (host == data.to) and usermanager_user_exists(username, host) and disconnect_user(aJID) and usermanager_create_user(username, nil, host) then - module:log("debug", "User " .. aJID .. " has been deleted"); - succeeded[#succeeded+1] = aJID; - else - module:log("debug", "Tried to delete non-existant user "..aJID); - failed[#failed+1] = aJID; - end +local delete_user_command_handler = adhoc_simple(delete_user_layout, function(fields, err) + if err then + return generate_error_message(err); + end + local failed = {}; + local succeeded = {}; + for _, aJID in ipairs(fields.accountjids) do + local username, host, resource = jid.split(aJID); + if (host == module_host) and usermanager_user_exists(username, host) and usermanager_delete_user(username, host) then + module:log("debug", "User %s has been deleted", aJID); + succeeded[#succeeded+1] = aJID; + else + module:log("debug", "Tried to delete non-existant user %s", aJID); + failed[#failed+1] = aJID; end - return {status = "completed", info = (#succeeded ~= 0 and - "The following accounts were successfully deleted:\n"..t_concat(succeeded, "\n").."\n" or "").. - (#failed ~= 0 and - "The following accounts could not be deleted:\n"..t_concat(failed, "\n") or "") }; - else - return { status = "executing", form = delete_user_layout }, "executing"; end -end - -function disconnect_user(match_jid) + return {status = "completed", info = (#succeeded ~= 0 and + "The following accounts were successfully deleted:\n"..t_concat(succeeded, "\n").."\n" or "").. + (#failed ~= 0 and + "The following accounts could not be deleted:\n"..t_concat(failed, "\n") or "") }; +end); + +-- Ending a user's session +local function disconnect_user(match_jid) local node, hostname, givenResource = jid.split(match_jid); local host = hosts[hostname]; local sessions = host.sessions[node] and host.sessions[node].sessions; for resource, session in pairs(sessions or {}) do if not givenResource or (resource == givenResource) then - module:log("debug", "Disconnecting "..node.."@"..hostname.."/"..resource); + module:log("debug", "Disconnecting %s@%s/%s", node, hostname, resource); session:close(); end end return true; end -function end_user_session_handler(self, data, state) - local end_user_session_layout = dataforms_new{ - title = "Ending a User Session"; - instructions = "Fill out this form to end a user's session."; - - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "accountjids", type = "jid-multi", label = "The Jabber ID(s) for which to end sessions" }; - }; - - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - - local fields = end_user_session_layout:data(data.form); - local failed = {}; - local succeeded = {}; - for _, aJID in ipairs(fields.accountjids) do - local username, host, resource = jid.split(aJID); - if (host == data.to) and usermanager_user_exists(username, host) and disconnect_user(aJID) then - succeeded[#succeeded+1] = aJID; - else - failed[#failed+1] = aJID; - end - end - return {status = "completed", info = (#succeeded ~= 0 and - "The following accounts were successfully disconnected:\n"..t_concat(succeeded, "\n").."\n" or "").. - (#failed ~= 0 and - "The following accounts could not be disconnected:\n"..t_concat(failed, "\n") or "") }; - else - return { status = "executing", form = end_user_session_layout }, "executing"; - end -end - local end_user_session_layout = dataforms_new{ title = "Ending a User Session"; instructions = "Fill out this form to end a user's session."; @@ -191,298 +165,374 @@ local end_user_session_layout = dataforms_new{ { name = "accountjids", type = "jid-multi", label = "The Jabber ID(s) for which to end sessions" }; }; +local end_user_session_handler = adhoc_simple(end_user_session_layout, function(fields, err) + if err then + return generate_error_message(err); + end + local failed = {}; + local succeeded = {}; + for _, aJID in ipairs(fields.accountjids) do + local username, host, resource = jid.split(aJID); + if (host == module_host) and usermanager_user_exists(username, host) and disconnect_user(aJID) then + succeeded[#succeeded+1] = aJID; + else + failed[#failed+1] = aJID; + end + end + return {status = "completed", info = (#succeeded ~= 0 and + "The following accounts were successfully disconnected:\n"..t_concat(succeeded, "\n").."\n" or "").. + (#failed ~= 0 and + "The following accounts could not be disconnected:\n"..t_concat(failed, "\n") or "") }; +end); -function get_user_password_handler(self, data, state) - local get_user_password_layout = dataforms_new{ - title = "Getting User's Password"; - instructions = "Fill out this form to get a user's password."; +-- Getting a user's password +local get_user_password_layout = dataforms_new{ + title = "Getting User's Password"; + instructions = "Fill out this form to get a user's password."; - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for which to retrieve the password" }; - }; + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for which to retrieve the password" }; +}; - local get_user_password_result_layout = dataforms_new{ - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "accountjid", type = "jid-single", label = "JID" }; - { name = "password", type = "text-single", label = "Password" }; - }; +local get_user_password_result_layout = dataforms_new{ + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "accountjid", type = "jid-single", label = "JID" }; + { name = "password", type = "text-single", label = "Password" }; +}; - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - local fields = get_user_password_layout:data(data.form); - if not fields.accountjid then - return { status = "completed", error = { message = "Please specify a JID." } }; - end - local user, host, resource = jid.split(fields.accountjid); - local accountjid = ""; - local password = ""; - if host ~= data.to then - return { status = "completed", error = { message = "Tried to get password for a user on " .. host .. " but command was sent to " .. data.to } }; - elseif usermanager_user_exists(user, host) then - accountjid = fields.accountjid; - password = usermanager_get_password(user, host); - else - return { status = "completed", error = { message = "User does not exist" } }; - end - return { status = "completed", result = { layout = get_user_password_result_layout, values = {accountjid = accountjid, password = password} } }; +local get_user_password_handler = adhoc_simple(get_user_password_layout, function(fields, err) + if err then + return generate_error_message(err); + end + local user, host, resource = jid.split(fields.accountjid); + local accountjid = ""; + local password = ""; + if host ~= module_host then + return { status = "completed", error = { message = "Tried to get password for a user on " .. host .. " but command was sent to " .. module_host } }; + elseif usermanager_user_exists(user, host) then + accountjid = fields.accountjid; + password = usermanager_get_password(user, host); else - return { status = "executing", form = get_user_password_layout }, "executing"; + return { status = "completed", error = { message = "User does not exist" } }; end -end + return { status = "completed", result = { layout = get_user_password_result_layout, values = {accountjid = accountjid, password = password} } }; +end); -function get_user_roster_handler(self, data, state) - local get_user_roster_layout = dataforms_new{ - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for which to retrieve the roster" }; - }; - - local get_user_roster_result_layout = dataforms_new{ - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "accountjid", type = "jid-single", label = "This is the roster for" }; - { name = "roster", type = "text-multi", label = "Roster XML" }; - }; - - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end +-- Getting a user's roster +local get_user_roster_layout = dataforms_new{ + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for which to retrieve the roster" }; +}; - local fields = get_user_roster_layout:data(data.form); +local get_user_roster_result_layout = dataforms_new{ + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "accountjid", type = "jid-single", label = "This is the roster for" }; + { name = "roster", type = "text-multi", label = "Roster XML" }; +}; - if not fields.accountjid then - return { status = "completed", error = { message = "Please specify a JID" } }; - end +local get_user_roster_handler = adhoc_simple(get_user_roster_layout, function(fields, err) + if err then + return generate_error_message(err); + end - local user, host, resource = jid.split(fields.accountjid); - if host ~= data.to then - return { status = "completed", error = { message = "Tried to get roster for a user on " .. host .. " but command was sent to " .. data.to } }; - elseif not usermanager_user_exists(user, host) then - return { status = "completed", error = { message = "User does not exist" } }; - end - local roster = rm_load_roster(user, host); - - local query = st.stanza("query", { xmlns = "jabber:iq:roster" }); - for jid in pairs(roster) do - if jid ~= "pending" and jid then - query:tag("item", { - jid = jid, - subscription = roster[jid].subscription, - ask = roster[jid].ask, - name = roster[jid].name, - }); - for group in pairs(roster[jid].groups) do - query:tag("group"):text(group):up(); - end - query:up(); + local user, host, resource = jid.split(fields.accountjid); + if host ~= module_host then + return { status = "completed", error = { message = "Tried to get roster for a user on " .. host .. " but command was sent to " .. module_host } }; + elseif not usermanager_user_exists(user, host) then + return { status = "completed", error = { message = "User does not exist" } }; + end + local roster = rm_load_roster(user, host); + + local query = st.stanza("query", { xmlns = "jabber:iq:roster" }); + for jid in pairs(roster) do + if jid ~= "pending" and jid then + query:tag("item", { + jid = jid, + subscription = roster[jid].subscription, + ask = roster[jid].ask, + name = roster[jid].name, + }); + for group in pairs(roster[jid].groups) do + query:tag("group"):text(group):up(); end + query:up(); end - - local query_text = query:__tostring(); -- TODO: Use upcoming pretty_print() function - query_text = query_text:gsub("><", ">\n<"); - - local result = get_user_roster_result_layout:form({ accountjid = user.."@"..host, roster = query_text }, "result"); - result:add_child(query); - return { status = "completed", other = result }; - else - return { status = "executing", form = get_user_roster_layout }, "executing"; end -end -function get_user_stats_handler(self, data, state) - local get_user_stats_layout = dataforms_new{ - title = "Get User Statistics"; - instructions = "Fill out this form to gather user statistics."; + local query_text = tostring(query):gsub("><", ">\n<"); - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for statistics" }; - }; + local result = get_user_roster_result_layout:form({ accountjid = user.."@"..host, roster = query_text }, "result"); + result:add_child(query); + return { status = "completed", other = result }; +end); - local get_user_stats_result_layout = dataforms_new{ - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "ipaddresses", type = "text-multi", label = "IP Addresses" }; - { name = "rostersize", type = "text-single", label = "Roster size" }; - { name = "onlineresources", type = "text-multi", label = "Online Resources" }; - }; +-- Getting user statistics +local get_user_stats_layout = dataforms_new{ + title = "Get User Statistics"; + instructions = "Fill out this form to gather user statistics."; - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for statistics" }; +}; - local fields = get_user_stats_layout:data(data.form); +local get_user_stats_result_layout = dataforms_new{ + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "ipaddresses", type = "text-multi", label = "IP Addresses" }; + { name = "rostersize", type = "text-single", label = "Roster size" }; + { name = "onlineresources", type = "text-multi", label = "Online Resources" }; +}; - if not fields.accountjid then - return { status = "completed", error = { message = "Please specify a JID." } }; - end +local get_user_stats_handler = adhoc_simple(get_user_stats_layout, function(fields, err) + if err then + return generate_error_message(err); + end - local user, host, resource = jid.split(fields.accountjid); - if host ~= data.to then - return { status = "completed", error = { message = "Tried to get stats for a user on " .. host .. " but command was sent to " .. data.to } }; - elseif not usermanager_user_exists(user, host) then - return { status = "completed", error = { message = "User does not exist" } }; - end - local roster = rm_load_roster(user, host); - local rostersize = 0; - local IPs = ""; - local resources = ""; - for jid in pairs(roster) do - if jid ~= "pending" and jid then - rostersize = rostersize + 1; - end - end - for resource, session in pairs((hosts[host].sessions[user] and hosts[host].sessions[user].sessions) or {}) do - resources = resources .. "\n" .. resource; - IPs = IPs .. "\n" .. session.ip; + local user, host, resource = jid.split(fields.accountjid); + if host ~= module_host then + return { status = "completed", error = { message = "Tried to get stats for a user on " .. host .. " but command was sent to " .. module_host } }; + elseif not usermanager_user_exists(user, host) then + return { status = "completed", error = { message = "User does not exist" } }; + end + local roster = rm_load_roster(user, host); + local rostersize = 0; + local IPs = ""; + local resources = ""; + for jid in pairs(roster) do + if jid ~= "pending" and jid then + rostersize = rostersize + 1; end - return { status = "completed", result = {layout = get_user_stats_result_layout, values = {ipaddresses = IPs, rostersize = tostring(rostersize), - onlineresources = resources}} }; - else - return { status = "executing", form = get_user_stats_layout }, "executing"; end -end - -function get_online_users_command_handler(self, data, state) - local get_online_users_layout = dataforms_new{ - title = "Getting List of Online Users"; - instructions = "How many users should be returned at most?"; + for resource, session in pairs((hosts[host].sessions[user] and hosts[host].sessions[user].sessions) or {}) do + resources = resources .. "\n" .. resource; + IPs = IPs .. "\n" .. session.ip; + end + return { status = "completed", result = {layout = get_user_stats_result_layout, values = {ipaddresses = IPs, rostersize = tostring(rostersize), + onlineresources = resources}} }; +end); - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "max_items", type = "list-single", label = "Maximum number of users", - value = { "25", "50", "75", "100", "150", "200", "all" } }; - { name = "details", type = "boolean", label = "Show details" }; - }; +-- Getting a list of online users +local get_online_users_layout = dataforms_new{ + title = "Getting List of Online Users"; + instructions = "How many users should be returned at most?"; - local get_online_users_result_layout = dataforms_new{ - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "onlineuserjids", type = "text-multi", label = "The list of all online users" }; - }; + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "max_items", type = "list-single", label = "Maximum number of users", + value = { "25", "50", "75", "100", "150", "200", "all" } }; + { name = "details", type = "boolean", label = "Show details" }; +}; - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end +local get_online_users_result_layout = dataforms_new{ + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "onlineuserjids", type = "text-multi", label = "The list of all online users" }; +}; - local fields = get_online_users_layout:data(data.form); +local get_online_users_command_handler = adhoc_simple(get_online_users_layout, function(fields, err) + if err then + return generate_error_message(err); + end - local max_items = nil - if fields.max_items ~= "all" then - max_items = tonumber(fields.max_items); - end - local count = 0; - local users = {}; - for username, user in pairs(hosts[data.to].sessions or {}) do - if (max_items ~= nil) and (count >= max_items) then - break; - end - users[#users+1] = username.."@"..data.to; - count = count + 1; - if fields.details then - for resource, session in pairs(user.sessions or {}) do - local status, priority = "unavailable", tostring(session.priority or "-"); - if session.presence then - status = session.presence:child_with_name("show"); - if status then - status = status:get_text() or "[invalid!]"; - else - status = "available"; - end + local max_items = nil + if fields.max_items ~= "all" then + max_items = tonumber(fields.max_items); + end + local count = 0; + local users = {}; + for username, user in pairs(hosts[module_host].sessions or {}) do + if (max_items ~= nil) and (count >= max_items) then + break; + end + users[#users+1] = username.."@"..module_host; + count = count + 1; + if fields.details then + for resource, session in pairs(user.sessions or {}) do + local status, priority = "unavailable", tostring(session.priority or "-"); + if session.presence then + status = session.presence:child_with_name("show"); + if status then + status = status:get_text() or "[invalid!]"; + else + status = "available"; end - users[#users+1] = " - "..resource..": "..status.."("..priority..")"; end + users[#users+1] = " - "..resource..": "..status.."("..priority..")"; end end - return { status = "completed", result = {layout = get_online_users_result_layout, values = {onlineuserjids=t_concat(users, "\n")}} }; - else - return { status = "executing", form = get_online_users_layout }, "executing"; end + return { status = "completed", result = {layout = get_online_users_result_layout, values = {onlineuserjids=t_concat(users, "\n")}} }; +end); + +-- Getting a list of loaded modules +local list_modules_result = dataforms_new { + title = "List of loaded modules"; + + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#list" }; + { name = "modules", type = "text-multi", label = "The following modules are loaded:" }; +}; + +local function list_modules_handler(self, data, state) + local modules = array.collect(keys(hosts[module_host].modules)):sort():concat("\n"); + return { status = "completed", result = { layout = list_modules_result; values = { modules = modules } } }; end -function list_modules_handler(self, data, state) - local result = dataforms_new { - title = "List of loaded modules"; +-- Loading a module +local load_module_layout = dataforms_new { + title = "Load module"; + instructions = "Specify the module to be loaded"; - { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#list" }; - { name = "modules", type = "text-multi", label = "The following modules are loaded:" }; - }; + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#load" }; + { name = "module", type = "text-single", required = true, label = "Module to be loaded:"}; +}; - local modules = array.collect(keys(hosts[data.to].modules)):sort():concat("\n"); +local load_module_handler = adhoc_simple(load_module_layout, function(fields, err) + if err then + return generate_error_message(err); + end + if modulemanager.is_loaded(module_host, fields.module) then + return { status = "completed", info = "Module already loaded" }; + end + local ok, err = modulemanager.load(module_host, fields.module); + if ok then + return { status = "completed", info = 'Module "'..fields.module..'" successfully loaded on host "'..module_host..'".' }; + else + return { status = "completed", error = { message = 'Failed to load module "'..fields.module..'" on host "'..module_host.. + '". Error was: "'..tostring(err or "<unspecified>")..'"' } }; + end +end); - return { status = "completed", result = { layout = result; values = { modules = modules } } }; -end +-- Globally loading a module +local globally_load_module_layout = dataforms_new { + title = "Globally load module"; + instructions = "Specify the module to be loaded on all hosts"; -function load_module_handler(self, data, state) - local layout = dataforms_new { - title = "Load module"; - instructions = "Specify the module to be loaded"; + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#global-load" }; + { name = "module", type = "text-single", required = true, label = "Module to globally load:"}; +}; - { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#load" }; - { name = "module", type = "text-single", required = true, label = "Module to be loaded:"}; - }; - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - local fields = layout:data(data.form); - if (not fields.module) or (fields.module == "") then - return { status = "completed", error = { - message = "Please specify a module." - } }; - end - if modulemanager.is_loaded(data.to, fields.module) then - return { status = "completed", info = "Module already loaded" }; +local globally_load_module_handler = adhoc_simple(globally_load_module_layout, function(fields, err) + local ok_list, err_list = {}, {}; + + if err then + return generate_error_message(err); + end + + local ok, err = modulemanager.load(module_host, fields.module); + if ok then + ok_list[#ok_list + 1] = module_host; + else + err_list[#err_list + 1] = module_host .. " (Error: " .. tostring(err) .. ")"; + end + + -- Is this a global module? + if modulemanager.is_loaded("*", fields.module) and not modulemanager.is_loaded(module_host, fields.module) then + return { status = "completed", info = 'Global module '..fields.module..' loaded.' }; + end + + -- This is either a shared or "normal" module, load it on all other hosts + for host_name, host in pairs(hosts) do + if host_name ~= module_host and host.type == "local" then + local ok, err = modulemanager.load(host_name, fields.module); + if ok then + ok_list[#ok_list + 1] = host_name; + else + err_list[#err_list + 1] = host_name .. " (Error: " .. tostring(err) .. ")"; + end end - local ok, err = modulemanager.load(data.to, fields.module); + end + + local info = (#ok_list > 0 and ("The module "..fields.module.." was successfully loaded onto the hosts:\n"..t_concat(ok_list, "\n")) or "") + .. ((#ok_list > 0 and #err_list > 0) and "\n" or "") .. + (#err_list > 0 and ("Failed to load the module "..fields.module.." onto the hosts:\n"..t_concat(err_list, "\n")) or ""); + return { status = "completed", info = info }; +end); + +-- Reloading modules +local reload_modules_layout = dataforms_new { + title = "Reload modules"; + instructions = "Select the modules to be reloaded"; + + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#reload" }; + { name = "modules", type = "list-multi", required = true, label = "Modules to be reloaded:"}; +}; + +local reload_modules_handler = adhoc_initial(reload_modules_layout, function() + return { modules = array.collect(keys(hosts[module_host].modules)):sort() }; +end, function(fields, err) + if err then + return generate_error_message(err); + end + local ok_list, err_list = {}, {}; + for _, module in ipairs(fields.modules) do + local ok, err = modulemanager.reload(module_host, module); if ok then - return { status = "completed", info = 'Module "'..fields.module..'" successfully loaded on host "'..data.to..'".' }; + ok_list[#ok_list + 1] = module; else - return { status = "completed", error = { message = 'Failed to load module "'..fields.module..'" on host "'..data.to.. - '". Error was: "'..tostring(err or "<unspecified>")..'"' } }; + err_list[#err_list + 1] = module .. "(Error: " .. tostring(err) .. ")"; end - else - local modules = array.collect(keys(hosts[data.to].modules)):sort(); - return { status = "executing", form = layout }, "executing"; end -end + local info = (#ok_list > 0 and ("The following modules were successfully reloaded on host "..module_host..":\n"..t_concat(ok_list, "\n")) or "") + .. ((#ok_list > 0 and #err_list > 0) and "\n" or "") .. + (#err_list > 0 and ("Failed to reload the following modules on host "..module_host..":\n"..t_concat(err_list, "\n")) or ""); + return { status = "completed", info = info }; +end); + +-- Globally reloading a module +local globally_reload_module_layout = dataforms_new { + title = "Globally reload module"; + instructions = "Specify the module to reload on all hosts"; + + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#global-reload" }; + { name = "module", type = "list-single", required = true, label = "Module to globally reload:"}; +}; -function reload_modules_handler(self, data, state) - local layout = dataforms_new { - title = "Reload modules"; - instructions = "Select the modules to be reloaded"; +local globally_reload_module_handler = adhoc_initial(globally_reload_module_layout, function() + local loaded_modules = array(keys(modulemanager.get_modules("*"))); + for _, host in pairs(hosts) do + loaded_modules:append(array(keys(host.modules))); + end + loaded_modules = array(set.new(loaded_modules):items()):sort(); + return { module = loaded_modules }; +end, function(fields, err) + local is_global = false; - { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#reload" }; - { name = "modules", type = "list-multi", required = true, label = "Modules to be reloaded:"}; - }; - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - local fields = layout:data(data.form); - if #fields.modules == 0 then - return { status = "completed", error = { - message = "Please specify a module. (This means your client misbehaved, as this field is required)" - } }; + if err then + return generate_error_message(err); + end + + if modulemanager.is_loaded("*", fields.module) then + local ok, err = modulemanager.reload("*", fields.module); + if not ok then + return { status = "completed", info = 'Global module '..fields.module..' failed to reload: '..err }; end - local ok_list, err_list = {}, {}; - for _, module in ipairs(fields.modules) do - local ok, err = modulemanager.reload(data.to, module); + is_global = true; + end + + local ok_list, err_list = {}, {}; + for host_name, host in pairs(hosts) do + if modulemanager.is_loaded(host_name, fields.module) then + local ok, err = modulemanager.reload(host_name, fields.module); if ok then - ok_list[#ok_list + 1] = module; + ok_list[#ok_list + 1] = host_name; else - err_list[#err_list + 1] = module .. "(Error: " .. tostring(err) .. ")"; + err_list[#err_list + 1] = host_name .. " (Error: " .. tostring(err) .. ")"; end end - local info = (#ok_list > 0 and ("The following modules were successfully reloaded on host "..data.to..":\n"..t_concat(ok_list, "\n")) or "").. - (#err_list > 0 and ("Failed to reload the following modules on host "..data.to..":\n"..t_concat(err_list, "\n")) or ""); - return { status = "completed", info = info }; - else - local modules = array.collect(keys(hosts[data.to].modules)):sort(); - return { status = "executing", form = { layout = layout; values = { modules = modules } } }, "executing"; end -end -function send_to_online(message, server) + if #ok_list == 0 and #err_list == 0 then + if is_global then + return { status = "completed", info = 'Successfully reloaded global module '..fields.module }; + else + return { status = "completed", info = 'Module '..fields.module..' not loaded on any host.' }; + end + end + + local info = (#ok_list > 0 and ("The module "..fields.module.." was successfully reloaded on the hosts:\n"..t_concat(ok_list, "\n")) or "") + .. ((#ok_list > 0 and #err_list > 0) and "\n" or "") .. + (#err_list > 0 and ("Failed to reload the module "..fields.module.." on the hosts:\n"..t_concat(err_list, "\n")) or ""); + return { status = "completed", info = info }; +end); + +local function send_to_online(message, server) if server then sessions = { [server] = hosts[server] }; else @@ -502,108 +552,208 @@ function send_to_online(message, server) return c; end -function shut_down_service_handler(self, data, state) - local shut_down_service_layout = dataforms_new{ - title = "Shutting Down the Service"; - instructions = "Fill out this form to shut down the service."; - - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "delay", type = "list-single", label = "Time delay before shutting down", - value = { {label = "30 seconds", value = "30"}, - {label = "60 seconds", value = "60"}, - {label = "90 seconds", value = "90"}, - {label = "2 minutes", value = "120"}, - {label = "3 minutes", value = "180"}, - {label = "4 minutes", value = "240"}, - {label = "5 minutes", value = "300"}, - }; +-- Shutting down the service +local shut_down_service_layout = dataforms_new{ + title = "Shutting Down the Service"; + instructions = "Fill out this form to shut down the service."; + + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "delay", type = "list-single", label = "Time delay before shutting down", + value = { {label = "30 seconds", value = "30"}, + {label = "60 seconds", value = "60"}, + {label = "90 seconds", value = "90"}, + {label = "2 minutes", value = "120"}, + {label = "3 minutes", value = "180"}, + {label = "4 minutes", value = "240"}, + {label = "5 minutes", value = "300"}, }; - { name = "announcement", type = "text-multi", label = "Announcement" }; }; + { name = "announcement", type = "text-multi", label = "Announcement" }; +}; - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end +local shut_down_service_handler = adhoc_simple(shut_down_service_layout, function(fields, err) + if err then + return generate_error_message(err); + end - local fields = shut_down_service_layout:data(data.form); + if fields.announcement and #fields.announcement > 0 then + local message = st.message({type = "headline"}, fields.announcement):up() + :tag("subject"):text("Server is shutting down"); + send_to_online(message); + end - if fields.announcement and #fields.announcement > 0 then - local message = st.message({type = "headline"}, fields.announcement):up() - :tag("subject"):text("Server is shutting down"); - send_to_online(message); - end + timer_add_task(tonumber(fields.delay or "5"), function(time) prosody.shutdown("Shutdown by adhoc command") end); - timer_add_task(tonumber(fields.delay or "5"), prosody.shutdown); + return { status = "completed", info = "Server is about to shut down" }; +end); - return { status = "completed", info = "Server is about to shut down" }; - else - return { status = "executing", form = shut_down_service_layout }, "executing"; - end +-- Unloading modules +local unload_modules_layout = dataforms_new { + title = "Unload modules"; + instructions = "Select the modules to be unloaded"; - return true; -end - -function unload_modules_handler(self, data, state) - local layout = dataforms_new { - title = "Unload modules"; - instructions = "Select the modules to be unloaded"; + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#unload" }; + { name = "modules", type = "list-multi", required = true, label = "Modules to be unloaded:"}; +}; - { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#unload" }; - { name = "modules", type = "list-multi", required = true, label = "Modules to be unloaded:"}; - }; - if state then - if data.action == "cancel" then - return { status = "canceled" }; +local unload_modules_handler = adhoc_initial(unload_modules_layout, function() + return { modules = array.collect(keys(hosts[module_host].modules)):sort() }; +end, function(fields, err) + if err then + return generate_error_message(err); + end + local ok_list, err_list = {}, {}; + for _, module in ipairs(fields.modules) do + local ok, err = modulemanager.unload(module_host, module); + if ok then + ok_list[#ok_list + 1] = module; + else + err_list[#err_list + 1] = module .. "(Error: " .. tostring(err) .. ")"; end - local fields = layout:data(data.form); - if #fields.modules == 0 then - return { status = "completed", error = { - message = "Please specify a module. (This means your client misbehaved, as this field is required)" - } }; + end + local info = (#ok_list > 0 and ("The following modules were successfully unloaded on host "..module_host..":\n"..t_concat(ok_list, "\n")) or "") + .. ((#ok_list > 0 and #err_list > 0) and "\n" or "") .. + (#err_list > 0 and ("Failed to unload the following modules on host "..module_host..":\n"..t_concat(err_list, "\n")) or ""); + return { status = "completed", info = info }; +end); + +-- Globally unloading a module +local globally_unload_module_layout = dataforms_new { + title = "Globally unload module"; + instructions = "Specify a module to unload on all hosts"; + + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#global-unload" }; + { name = "module", type = "list-single", required = true, label = "Module to globally unload:"}; +}; + +local globally_unload_module_handler = adhoc_initial(globally_unload_module_layout, function() + local loaded_modules = array(keys(modulemanager.get_modules("*"))); + for _, host in pairs(hosts) do + loaded_modules:append(array(keys(host.modules))); + end + loaded_modules = array(set.new(loaded_modules):items()):sort(); + return { module = loaded_modules }; +end, function(fields, err) + local is_global = false; + if err then + return generate_error_message(err); + end + + if modulemanager.is_loaded("*", fields.module) then + local ok, err = modulemanager.unload("*", fields.module); + if not ok then + return { status = "completed", info = 'Global module '..fields.module..' failed to unload: '..err }; end - local ok_list, err_list = {}, {}; - for _, module in ipairs(fields.modules) do - local ok, err = modulemanager.unload(data.to, module); + is_global = true; + end + + local ok_list, err_list = {}, {}; + for host_name, host in pairs(hosts) do + if modulemanager.is_loaded(host_name, fields.module) then + local ok, err = modulemanager.unload(host_name, fields.module); if ok then - ok_list[#ok_list + 1] = module; + ok_list[#ok_list + 1] = host_name; else - err_list[#err_list + 1] = module .. "(Error: " .. tostring(err) .. ")"; + err_list[#err_list + 1] = host_name .. " (Error: " .. tostring(err) .. ")"; end end - local info = (#ok_list > 0 and ("The following modules were successfully unloaded on host "..data.to..":\n"..t_concat(ok_list, "\n")) or "").. - (#err_list > 0 and ("Failed to unload the following modules on host "..data.to..":\n"..t_concat(err_list, "\n")) or ""); - return { status = "completed", info = info }; + end + + if #ok_list == 0 and #err_list == 0 then + if is_global then + return { status = "completed", info = 'Successfully unloaded global module '..fields.module }; + else + return { status = "completed", info = 'Module '..fields.module..' not loaded on any host.' }; + end + end + + local info = (#ok_list > 0 and ("The module "..fields.module.." was successfully unloaded on the hosts:\n"..t_concat(ok_list, "\n")) or "") + .. ((#ok_list > 0 and #err_list > 0) and "\n" or "") .. + (#err_list > 0 and ("Failed to unload the module "..fields.module.." on the hosts:\n"..t_concat(err_list, "\n")) or ""); + return { status = "completed", info = info }; +end); + +-- Activating a host +local activate_host_layout = dataforms_new { + title = "Activate host"; + instructions = ""; + + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/hosts#activate" }; + { name = "host", type = "text-single", required = true, label = "Host:"}; +}; + +local activate_host_handler = adhoc_simple(activate_host_layout, function(fields, err) + if err then + return generate_error_message(err); + end + local ok, err = hostmanager_activate(fields.host); + + if ok then + return { status = "completed", info = fields.host .. " activated" }; else - local modules = array.collect(keys(hosts[data.to].modules)):sort(); - return { status = "executing", form = { layout = layout; values = { modules = modules } } }, "executing"; + return { status = "canceled", error = err } end -end +end); + +-- Deactivating a host +local deactivate_host_layout = dataforms_new { + title = "Deactivate host"; + instructions = ""; + + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/hosts#activate" }; + { name = "host", type = "text-single", required = true, label = "Host:"}; +}; + +local deactivate_host_handler = adhoc_simple(deactivate_host_layout, function(fields, err) + if err then + return generate_error_message(err); + end + local ok, err = hostmanager_deactivate(fields.host); + + if ok then + return { status = "completed", info = fields.host .. " deactivated" }; + else + return { status = "canceled", error = err } + end +end); + local add_user_desc = adhoc_new("Add User", "http://jabber.org/protocol/admin#add-user", add_user_command_handler, "admin"); local change_user_password_desc = adhoc_new("Change User Password", "http://jabber.org/protocol/admin#change-user-password", change_user_password_command_handler, "admin"); +local config_reload_desc = adhoc_new("Reload configuration", "http://prosody.im/protocol/config#reload", config_reload_handler, "global_admin"); local delete_user_desc = adhoc_new("Delete User", "http://jabber.org/protocol/admin#delete-user", delete_user_command_handler, "admin"); local end_user_session_desc = adhoc_new("End User Session", "http://jabber.org/protocol/admin#end-user-session", end_user_session_handler, "admin"); local get_user_password_desc = adhoc_new("Get User Password", "http://jabber.org/protocol/admin#get-user-password", get_user_password_handler, "admin"); local get_user_roster_desc = adhoc_new("Get User Roster","http://jabber.org/protocol/admin#get-user-roster", get_user_roster_handler, "admin"); local get_user_stats_desc = adhoc_new("Get User Statistics","http://jabber.org/protocol/admin#user-stats", get_user_stats_handler, "admin"); -local get_online_users_desc = adhoc_new("Get List of Online Users", "http://jabber.org/protocol/admin#get-online-users", get_online_users_command_handler, "admin"); +local get_online_users_desc = adhoc_new("Get List of Online Users", "http://jabber.org/protocol/admin#get-online-users-list", get_online_users_command_handler, "admin"); local list_modules_desc = adhoc_new("List loaded modules", "http://prosody.im/protocol/modules#list", list_modules_handler, "admin"); local load_module_desc = adhoc_new("Load module", "http://prosody.im/protocol/modules#load", load_module_handler, "admin"); +local globally_load_module_desc = adhoc_new("Globally load module", "http://prosody.im/protocol/modules#global-load", globally_load_module_handler, "global_admin"); local reload_modules_desc = adhoc_new("Reload modules", "http://prosody.im/protocol/modules#reload", reload_modules_handler, "admin"); -local shut_down_service_desc = adhoc_new("Shut Down Service", "http://jabber.org/protocol/admin#shutdown", shut_down_service_handler, "admin"); +local globally_reload_module_desc = adhoc_new("Globally reload module", "http://prosody.im/protocol/modules#global-reload", globally_reload_module_handler, "global_admin"); +local shut_down_service_desc = adhoc_new("Shut Down Service", "http://jabber.org/protocol/admin#shutdown", shut_down_service_handler, "global_admin"); local unload_modules_desc = adhoc_new("Unload modules", "http://prosody.im/protocol/modules#unload", unload_modules_handler, "admin"); - -module:add_item("adhoc", add_user_desc); -module:add_item("adhoc", change_user_password_desc); -module:add_item("adhoc", delete_user_desc); -module:add_item("adhoc", end_user_session_desc); -module:add_item("adhoc", get_user_password_desc); -module:add_item("adhoc", get_user_roster_desc); -module:add_item("adhoc", get_user_stats_desc); -module:add_item("adhoc", get_online_users_desc); -module:add_item("adhoc", list_modules_desc); -module:add_item("adhoc", load_module_desc); -module:add_item("adhoc", reload_modules_desc); -module:add_item("adhoc", shut_down_service_desc); -module:add_item("adhoc", unload_modules_desc); +local globally_unload_module_desc = adhoc_new("Globally unload module", "http://prosody.im/protocol/modules#global-unload", globally_unload_module_handler, "global_admin"); +local activate_host_desc = adhoc_new("Activate host", "http://prosody.im/protocol/hosts#activate", activate_host_handler, "global_admin"); +local deactivate_host_desc = adhoc_new("Deactivate host", "http://prosody.im/protocol/hosts#deactivate", deactivate_host_handler, "global_admin"); + +module:provides("adhoc", add_user_desc); +module:provides("adhoc", change_user_password_desc); +module:provides("adhoc", config_reload_desc); +module:provides("adhoc", delete_user_desc); +module:provides("adhoc", end_user_session_desc); +module:provides("adhoc", get_user_password_desc); +module:provides("adhoc", get_user_roster_desc); +module:provides("adhoc", get_user_stats_desc); +module:provides("adhoc", get_online_users_desc); +module:provides("adhoc", list_modules_desc); +module:provides("adhoc", load_module_desc); +module:provides("adhoc", globally_load_module_desc); +module:provides("adhoc", reload_modules_desc); +module:provides("adhoc", globally_reload_module_desc); +module:provides("adhoc", shut_down_service_desc); +module:provides("adhoc", unload_modules_desc); +module:provides("adhoc", globally_unload_module_desc); +module:provides("adhoc", activate_host_desc); +module:provides("adhoc", deactivate_host_desc); diff --git a/plugins/mod_admin_telnet.lua b/plugins/mod_admin_telnet.lua index da40f57e..e13d27c2 100644 --- a/plugins/mod_admin_telnet.lua +++ b/plugins/mod_admin_telnet.lua @@ -1,38 +1,45 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- -module.host = "*"; +module:set_global(); + +local hostmanager = require "core.hostmanager"; +local modulemanager = require "core.modulemanager"; +local s2smanager = require "core.s2smanager"; +local portmanager = require "core.portmanager"; local _G = _G; local prosody = _G.prosody; local hosts = prosody.hosts; -local connlisteners_register = require "net.connlisteners".register; -local console_listener = { default_port = 5582; default_mode = "*l"; default_interface = "127.0.0.1" }; +local console_listener = { default_port = 5582; default_mode = "*a"; interface = "127.0.0.1" }; -require "util.iterators"; -local jid_bare = require "util.jid".bare; +local iterators = require "util.iterators"; +local keys, values = iterators.keys, iterators.values; +local jid = require "util.jid"; +local jid_bare, jid_split = jid.bare, jid.split; local set, array = require "util.set", require "util.array"; local cert_verify_identity = require "util.x509".verify_identity; +local envload = require "util.envload".envload; +local envloadfile = require "util.envload".envloadfile; -local commands = {}; -local def_env = {}; +local commands = module:shared("commands") +local def_env = module:shared("env"); local default_env_mt = { __index = def_env }; - -prosody.console = { commands = commands, env = def_env }; +local core_post_stanza = prosody.core_post_stanza; local function redirect_output(_G, session) local env = setmetatable({ print = session.print }, { __index = function (t, k) return rawget(_G, k); end }); env.dofile = function(name) - local f, err = loadfile(name); + local f, err = envloadfile(name, env); if not f then return f, err; end - return setfenv(f, env)(); + return f(); end; return env; end @@ -53,17 +60,75 @@ function console:new_session(conn) disconnect = function () conn:close(); end; }; session.env = setmetatable({}, default_env_mt); - + -- Load up environment with helper objects for name, t in pairs(def_env) do if type(t) == "table" then session.env[name] = setmetatable({ session = session }, { __index = t }); end end - + return session; end +function console:process_line(session, line) + local useglobalenv; + + if line:match("^>") then + line = line:gsub("^>", ""); + useglobalenv = true; + elseif line == "\004" then + commands["bye"](session, line); + return; + else + local command = line:match("^%w+") or line:match("%p"); + if commands[command] then + commands[command](session, line); + return; + end + end + + session.env._ = line; + + local chunkname = "=console"; + local env = (useglobalenv and redirect_output(_G, session)) or session.env or nil + local chunk, err = envload("return "..line, chunkname, env); + if not chunk then + chunk, err = envload(line, chunkname, env); + if not chunk then + err = err:gsub("^%[string .-%]:%d+: ", ""); + err = err:gsub("^:%d+: ", ""); + err = err:gsub("'<eof>'", "the end of the line"); + session.print("Sorry, I couldn't understand that... "..err); + return; + end + end + + local ranok, taskok, message = pcall(chunk); + + if not (ranok or message or useglobalenv) and commands[line:lower()] then + commands[line:lower()](session, line); + return; + end + + if not ranok then + session.print("Fatal error while running command, it did not complete"); + session.print("Error: "..taskok); + return; + end + + if not message then + session.print("Result: "..tostring(taskok)); + return; + elseif (not taskok) and message then + session.print("Command completed with a problem"); + session.print("Message: "..tostring(message)); + return; + end + + session.print("OK: "..tostring(message)); +end + local sessions = {}; function console_listener.onconnect(conn) @@ -77,68 +142,17 @@ end function console_listener.onincoming(conn, data) local session = sessions[conn]; - -- Handle data - (function(session, data) - local useglobalenv; - - if data:match("^>") then - data = data:gsub("^>", ""); - useglobalenv = true; - elseif data == "\004" then - commands["bye"](session, data); - return; - else - local command = data:lower(); - command = data:match("^%w+") or data:match("%p"); - if commands[command] then - commands[command](session, data); - return; - end - end + local partial = session.partial_data; + if partial then + data = partial..data; + end - session.env._ = data; - - local chunkname = "=console"; - local chunk, err = loadstring("return "..data, chunkname); - if not chunk then - chunk, err = loadstring(data, chunkname); - if not chunk then - err = err:gsub("^%[string .-%]:%d+: ", ""); - err = err:gsub("^:%d+: ", ""); - err = err:gsub("'<eof>'", "the end of the line"); - session.print("Sorry, I couldn't understand that... "..err); - return; - end - end - - setfenv(chunk, (useglobalenv and redirect_output(_G, session)) or session.env or nil); - - local ranok, taskok, message = pcall(chunk); - - if not (ranok or message or useglobalenv) and commands[data:lower()] then - commands[data:lower()](session, data); - return; - end - - if not ranok then - session.print("Fatal error while running command, it did not complete"); - session.print("Error: "..taskok); - return; - end - - if not message then - session.print("Result: "..tostring(taskok)); - return; - elseif (not taskok) and message then - session.print("Command completed with a problem"); - session.print("Message: "..tostring(message)); - return; - end - - session.print("OK: "..tostring(message)); - end)(session, data); - - session.send(string.char(0)); + for line in data:gmatch("[^\n]*[\n\004]") do + if session.closed then return end + console:process_line(session, line); + session.send(string.char(0)); + end + session.partial_data = data:match("[^\n]+$"); end function console_listener.ondisconnect(conn, err) @@ -149,13 +163,12 @@ function console_listener.ondisconnect(conn, err) end end -connlisteners_register('console', console_listener); - -- Console commands -- -- These are simple commands, not valid standalone in Lua function commands.bye(session) session.print("See you! :)"); + session.closed = true; session.disconnect(); end commands.quit, commands.exit = commands.bye, commands.bye; @@ -190,7 +203,10 @@ function commands.help(session, data) print [[s2s - Commands to manage sessions between this server and others]] print [[module - Commands to load/reload/unload modules/plugins]] print [[host - Commands to activate, deactivate and list virtual hosts]] + print [[user - Commands to create and delete users, and change their passwords]] print [[server - Uptime, version, shutting down, etc.]] + print [[port - Commands to manage ports the server is listening on]] + print [[dns - Commands to manage and inspect the internal DNS resolver]] print [[config - Reloading the configuration, etc.]] print [[console - Help regarding the console itself]] elseif section == "c2s" then @@ -201,6 +217,7 @@ function commands.help(session, data) elseif section == "s2s" then print [[s2s:show(domain) - Show all s2s connections for the given domain (or all if no domain given)]] print [[s2s:close(from, to) - Close a connection from one domain to another]] + print [[s2s:closeall(host) - Close all the incoming/outgoing s2s sessions to specified host]] elseif section == "module" then print [[module:load(module, host) - Load the specified module on the specified host (or all hosts if none given)]] print [[module:reload(module, host) - The same, but unloads and loads the module (saving state if the module supports it)]] @@ -210,10 +227,25 @@ function commands.help(session, data) print [[host:activate(hostname) - Activates the specified host]] print [[host:deactivate(hostname) - Disconnects all clients on this host and deactivates]] print [[host:list() - List the currently-activated hosts]] + elseif section == "user" then + print [[user:create(jid, password) - Create the specified user account]] + print [[user:password(jid, password) - Set the password for the specified user account]] + print [[user:delete(jid) - Permanently remove the specified user account]] + print [[user:list(hostname, pattern) - List users on the specified host, optionally filtering with a pattern]] elseif section == "server" then print [[server:version() - Show the server's version number]] print [[server:uptime() - Show how long the server has been running]] + print [[server:memory() - Show details about the server's memory usage]] print [[server:shutdown(reason) - Shut down the server, with an optional reason to be broadcast to all connections]] + elseif section == "port" then + print [[port:list() - Lists all network ports prosody currently listens on]] + print [[port:close(port, interface) - Close a port]] + elseif section == "dns" then + print [[dns:lookup(name, type, class) - Do a DNS lookup]] + print [[dns:addnameserver(nameserver) - Add a nameserver to the list]] + print [[dns:setnameserver(nameserver) - Replace the list of name servers with the supplied one]] + print [[dns:purge() - Clear the DNS cache]] + print [[dns:cache() - Show cached records]] elseif section == "config" then print [[config:reload() - Reload the server configuration. Modules may need to be reloaded for changes to take effect.]] elseif section == "console" then @@ -268,6 +300,26 @@ function def_env.server:shutdown(reason) return true, "Shutdown initiated"; end +local function human(kb) + local unit = "K"; + if kb > 1024 then + kb, unit = kb/1024, "M"; + end + return ("%0.2f%sB"):format(kb, unit); +end + +function def_env.server:memory() + if not pposix.meminfo then + return true, "Lua is using "..collectgarbage("count"); + end + local mem, lua_mem = pposix.meminfo(), collectgarbage("count"); + local print = self.session.print; + print("Process: "..human((mem.allocated+mem.allocated_mmap)/1024)); + print(" Used: "..human(mem.used/1024).." ("..human(lua_mem).." by Lua)"); + print(" Free: "..human(mem.unused/1024).." ("..human(mem.returnable/1024).." returnable)"); + return true, "OK"; +end + def_env.module = {}; local function get_hosts_set(hosts, module) @@ -281,39 +333,49 @@ local function get_hosts_set(hosts, module) return set.new { hosts }; elseif hosts == nil then local mm = require "modulemanager"; - return set.new(array.collect(keys(prosody.hosts))) - / function (host) return prosody.hosts[host].type == "local" or module and mm.is_loaded(host, module); end; + local hosts_set = set.new(array.collect(keys(prosody.hosts))) + / function (host) return (prosody.hosts[host].type == "local" or module and mm.is_loaded(host, module)) and host or nil; end; + if module and mm.get_module("*", module) then + hosts_set:add("*"); + end + return hosts_set; end end function def_env.module:load(name, hosts, config) local mm = require "modulemanager"; - + hosts = get_hosts_set(hosts); - + -- Load the module for each host - local ok, err, count = true, nil, 0; + local ok, err, count, mod = true, nil, 0, nil; for host in hosts do if (not mm.is_loaded(host, name)) then - ok, err = mm.load(host, name, config); - if not ok then + mod, err = mm.load(host, name, config); + if not mod then ok = false; + if err == "global-module-already-loaded" then + if count > 0 then + ok, err, count = true, nil, 1; + end + break; + end self.session.print(err or "Unknown error loading module"); else count = count + 1; - self.session.print("Loaded for "..host); + self.session.print("Loaded for "..mod.module.host); end end end - - return ok, (ok and "Module loaded onto "..count.." host"..(count ~= 1 and "s" or "")) or ("Last error: "..tostring(err)); + + return ok, (ok and "Module loaded onto "..count.." host"..(count ~= 1 and "s" or "")) or ("Last error: "..tostring(err)); end function def_env.module:unload(name, hosts) local mm = require "modulemanager"; hosts = get_hosts_set(hosts, name); - + -- Unload the module for each host local ok, err, count = true, nil, 0; for host in hosts do @@ -334,11 +396,15 @@ end function def_env.module:reload(name, hosts) local mm = require "modulemanager"; - hosts = get_hosts_set(hosts, name); - + hosts = array.collect(get_hosts_set(hosts, name)):sort(function (a, b) + if a == "*" then return true + elseif b == "*" then return false + else return a < b; end + end); + -- Reload the module for each host local ok, err, count = true, nil, 0; - for host in hosts do + for _, host in ipairs(hosts) do if mm.is_loaded(host, name) then ok, err = mm.reload(host, name); if not ok then @@ -359,6 +425,7 @@ end function def_env.module:list(hosts) if hosts == nil then hosts = array.collect(keys(prosody.hosts)); + table.insert(hosts, 1, "*"); end if type(hosts) == "string" then hosts = { hosts }; @@ -366,11 +433,11 @@ function def_env.module:list(hosts) if type(hosts) ~= "table" then return false, "Please supply a host or a list of hosts you would like to see"; end - + local print = self.session.print; for _, host in ipairs(hosts) do - print(host..":"); - local modules = array.collect(keys(prosody.hosts[host] and prosody.hosts[host].modules or {})):sort(); + print((host == "*" and "Global" or host)..":"); + local modules = array.collect(keys(modulemanager.get_modules(host) or {})):sort(); if #modules == 0 then if prosody.hosts[host] then print(" No modules loaded"); @@ -416,6 +483,25 @@ end function def_env.hosts:add(name) end +local function session_flags(session, line) + line = line or {}; + if session.cert_identity_status == "valid" then + line[#line+1] = "(secure)"; + elseif session.secure then + line[#line+1] = "(encrypted)"; + end + if session.compressed then + line[#line+1] = "(compressed)"; + end + if session.smacks then + line[#line+1] = "(sm)"; + end + if session.ip and session.ip:match(":") then + line[#line+1] = "(IPv6)"; + end + return table.concat(line, " "); +end + def_env.c2s = {}; local function show_c2s(callback) @@ -429,6 +515,16 @@ local function show_c2s(callback) end end +function def_env.c2s:count(match_jid) + local count = 0; + show_c2s(function (jid, session) + if (not match_jid) or jid:match(match_jid) then + count = count + 1; + end + end); + return true, "Total: "..count.." clients"; +end + function def_env.c2s:show(match_jid) local print, count = self.session.print, 0; local curr_host; @@ -441,15 +537,10 @@ function def_env.c2s:show(match_jid) count = count + 1; local status, priority = "unavailable", tostring(session.priority or "-"); if session.presence then - status = session.presence:child_with_name("show"); - if status then - status = status:get_text() or "[invalid!]"; - else - status = "available"; - end + status = session.presence:get_child_text("show") or "available"; end - print(" "..jid.." - "..status.."("..priority..")"); - end + print(session_flags(session, { " "..jid.." - "..status.."("..priority..")" })); + end end); return true, "Total: "..count.." clients"; end @@ -460,7 +551,7 @@ function def_env.c2s:show_insecure(match_jid) if ((not match_jid) or jid:match(match_jid)) and not session.secure then count = count + 1; print(jid); - end + end end); return true, "Total: "..count.." insecure client connections"; end @@ -471,13 +562,13 @@ function def_env.c2s:show_secure(match_jid) if ((not match_jid) or jid:match(match_jid)) and session.secure then count = count + 1; print(jid); - end + end end); return true, "Total: "..count.." secure client connections"; end function def_env.c2s:close(match_jid) - local print, count = self.session.print, 0; + local count = 0; show_c2s(function (jid, session) if jid == match_jid or jid_bare(jid) == match_jid then count = count + 1; @@ -487,78 +578,80 @@ function def_env.c2s:close(match_jid) return true, "Total: "..count.." sessions closed"; end + def_env.s2s = {}; function def_env.s2s:show(match_jid) - local _print = self.session.print; local print = self.session.print; - + local count_in, count_out = 0,0; - - for host, host_session in pairs(hosts) do - print = function (...) _print(host); _print(...); print = _print; end - for remotehost, session in pairs(host_session.s2sout) do - if (not match_jid) or remotehost:match(match_jid) or host:match(match_jid) then - count_out = count_out + 1; - print(" "..host.." -> "..remotehost..(session.cert_identity_status == "valid" and " (secure)" or "")..(session.secure and " (encrypted)" or "")..(session.compressed and " (compressed)" or "")); - if session.sendq then - print(" There are "..#session.sendq.." queued outgoing stanzas for this connection"); - end - if session.type == "s2sout_unauthed" then - if session.connecting then - print(" Connection not yet established"); - if not session.srv_hosts then - if not session.conn then - print(" We do not yet have a DNS answer for this host's SRV records"); - else - print(" This host has no SRV records, using A record instead"); - end - elseif session.srv_choice then - print(" We are on SRV record "..session.srv_choice.." of "..#session.srv_hosts); - local srv_choice = session.srv_hosts[session.srv_choice]; - print(" Using "..(srv_choice.target or ".")..":"..(srv_choice.port or 5269)); + local s2s_list = { }; + + local s2s_sessions = module:shared"/*/s2s/sessions"; + for _, session in pairs(s2s_sessions) do + local remotehost, localhost, direction; + if session.direction == "outgoing" then + direction = "->"; + count_out = count_out + 1; + remotehost, localhost = session.to_host or "?", session.from_host or "?"; + else + direction = "<-"; + count_in = count_in + 1; + remotehost, localhost = session.from_host or "?", session.to_host or "?"; + end + local sess_lines = { l = localhost, r = remotehost, + session_flags(session, { "", direction, remotehost or "?", + "["..session.type..tostring(session):match("[a-f0-9]*$").."]" })}; + + if (not match_jid) or remotehost:match(match_jid) or localhost:match(match_jid) then + table.insert(s2s_list, sess_lines); + local print = function (s) table.insert(sess_lines, " "..s); end + if session.sendq then + print("There are "..#session.sendq.." queued outgoing stanzas for this connection"); + end + if session.type == "s2sout_unauthed" then + if session.connecting then + print("Connection not yet established"); + if not session.srv_hosts then + if not session.conn then + print("We do not yet have a DNS answer for this host's SRV records"); + else + print("This host has no SRV records, using A record instead"); end - elseif session.notopen then - print(" The <stream> has not yet been opened"); - elseif not session.dialback_key then - print(" Dialback has not been initiated yet"); - elseif session.dialback_key then - print(" Dialback has been requested, but no result received"); + elseif session.srv_choice then + print("We are on SRV record "..session.srv_choice.." of "..#session.srv_hosts); + local srv_choice = session.srv_hosts[session.srv_choice]; + print("Using "..(srv_choice.target or ".")..":"..(srv_choice.port or 5269)); end + elseif session.notopen then + print("The <stream> has not yet been opened"); + elseif not session.dialback_key then + print("Dialback has not been initiated yet"); + elseif session.dialback_key then + print("Dialback has been requested, but no result received"); end end - end - local subhost_filter = function (h) - return (match_jid and h:match(match_jid)); - end - for session in pairs(incoming_s2s) do - if session.to_host == host and ((not match_jid) or host:match(match_jid) - or (session.from_host and session.from_host:match(match_jid)) - -- Pft! is what I say to list comprehensions - or (session.hosts and #array.collect(keys(session.hosts)):filter(subhost_filter)>0)) then - count_in = count_in + 1; - print(" "..host.." <- "..(session.from_host or "(unknown)")..(session.cert_identity_status == "valid" and " (secure)" or "")..(session.secure and " (encrypted)" or "")..(session.compressed and " (compressed)" or "")); - if session.type == "s2sin_unauthed" then - print(" Connection not yet authenticated"); - end + if session.type == "s2sin_unauthed" then + print("Connection not yet authenticated"); + elseif session.type == "s2sin" then for name in pairs(session.hosts) do if name ~= session.from_host then - print(" also hosts "..tostring(name)); + print("also hosts "..tostring(name)); end end end end - - print = _print; end - - for session in pairs(incoming_s2s) do - if not session.to_host and ((not match_jid) or session.from_host and session.from_host:match(match_jid)) then - count_in = count_in + 1; - print("Other incoming s2s connections"); - print(" (unknown) <- "..(session.from_host or "(unknown)")); - end + + -- Sort by local host, then remote host + table.sort(s2s_list, function(a,b) + if a.l == b.l then return a.r < b.r; end + return a.l < b.l; + end); + local lasthost; + for _, sess_lines in ipairs(s2s_list) do + if sess_lines.l ~= lasthost then print(sess_lines.l); lasthost=sess_lines.l end + for _, line in ipairs(sess_lines) do print(line); end end - return true, "Total: "..count_out.." outgoing, "..count_in.." incoming connections"; end @@ -573,31 +666,41 @@ local function print_subject(print, subject) end end +-- As much as it pains me to use the 0-based depths that OpenSSL does, +-- I think there's going to be more confusion among operators if we +-- break from that. +local function print_errors(print, errors) + for depth, t in pairs(errors) do + print( + (" %d: %s"):format( + depth-1, + table.concat(t, "\n| ") + ) + ); + end +end + function def_env.s2s:showcert(domain) local ser = require "util.serialization".serialize; local print = self.session.print; - local domain_sessions = set.new(array.collect(keys(incoming_s2s))) - /function(session) return session.from_host == domain; end; - for local_host in values(prosody.hosts) do - local s2sout = local_host.s2sout; - if s2sout and s2sout[domain] then - domain_sessions:add(s2sout[domain]); - end - end + local s2s_sessions = module:shared"/*/s2s/sessions"; + local domain_sessions = set.new(array.collect(values(s2s_sessions))) + /function(session) return (session.to_host == domain or session.from_host == domain) and session or nil; end; local cert_set = {}; for session in domain_sessions do local conn = session.conn; conn = conn and conn:socket(); - if not conn.getpeercertificate then + if not conn.getpeerchain then if conn.dohandshake then error("This version of LuaSec does not support certificate viewing"); end else local cert = conn:getpeercertificate(); if cert then + local certs = conn:getpeerchain(); local digest = cert:digest("sha1"); if not cert_set[digest] then - local chain_valid, chain_err = conn:getpeerchainvalid(); + local chain_valid, chain_errors = conn:getpeerverification(); cert_set[digest] = { { from = session.from_host, @@ -605,8 +708,8 @@ function def_env.s2s:showcert(domain) direction = session.direction }; chain_valid = chain_valid; - chain_err = chain_err; - cert = cert; + chain_errors = chain_errors; + certs = certs; }; else table.insert(cert_set[digest], { @@ -620,22 +723,22 @@ function def_env.s2s:showcert(domain) end local domain_certs = array.collect(values(cert_set)); -- Phew. We now have a array of unique certificates presented by domain. - local print = self.session.print; local n_certs = #domain_certs; - + if n_certs == 0 then return "No certificates found for "..domain; end - + local function _capitalize_and_colon(byte) return string.upper(byte)..":"; end local function pretty_fingerprint(hash) return hash:gsub("..", _capitalize_and_colon):sub(1, -2); end - + for cert_info in values(domain_certs) do - local cert = cert_info.cert; + local certs = cert_info.certs; + local cert = certs[1]; print("---") print("Fingerprint (SHA1): "..pretty_fingerprint(cert:digest("sha1"))); print(""); @@ -649,9 +752,15 @@ function def_env.s2s:showcert(domain) end end print(""); - local chain_valid, err = cert_info.chain_valid, cert_info.chain_err; + local chain_valid, errors = cert_info.chain_valid, cert_info.chain_errors; local valid_identity = cert_verify_identity(domain, "xmpp-server", cert); - print("Trusted certificate: "..(chain_valid and "Yes" or ("No ("..err..")"))); + if chain_valid then + print("Trusted certificate: Yes"); + else + print("Trusted certificate: No"); + print_errors(print, errors); + end + print(""); print("Issuer: "); print_subject(print, cert:issuer()); print(""); @@ -667,46 +776,42 @@ end function def_env.s2s:close(from, to) local print, count = self.session.print, 0; - - if not (from and to) then + local s2s_sessions = module:shared"/*/s2s/sessions"; + + local match_id; + if from and not to then + match_id, from = from; + elseif not to then return false, "Syntax: s2s:close('from', 'to') - Closes all s2s sessions from 'from' to 'to'"; elseif from == to then return false, "Both from and to are the same... you can't do that :)"; end - - if hosts[from] and not hosts[to] then - -- Is an outgoing connection - local session = hosts[from].s2sout[to]; - if not session then - print("No outgoing connection from "..from.." to "..to) - else + + for _, session in pairs(s2s_sessions) do + local id = session.type..tostring(session):match("[a-f0-9]+$"); + if (match_id and match_id == id) + or (session.from_host == from and session.to_host == to) then + print(("Closing connection from %s to %s [%s]"):format(session.from_host, session.to_host, id)); (session.close or s2smanager.destroy_session)(session); - count = count + 1; - print("Closed outgoing session from "..from.." to "..to); + count = count + 1 ; end - elseif hosts[to] and not hosts[from] then - -- Is an incoming connection - for session in pairs(incoming_s2s) do - if session.to_host == to and session.from_host == from then - (session.close or s2smanager.destroy_session)(session); - count = count + 1; end - end - - if count == 0 then - print("No incoming connections from "..from.." to "..to); - else - print("Closed "..count.." incoming session"..((count == 1 and "") or "s").." from "..from.." to "..to); - end - elseif hosts[to] and hosts[from] then - return false, "Both of the hostnames you specified are local, there are no s2s sessions to close"; - else - return false, "Neither of the hostnames you specified are being used on this server"; - end - return true, "Closed "..count.." s2s session"..((count == 1 and "") or "s"); end +function def_env.s2s:closeall(host) + local count = 0; + local s2s_sessions = module:shared"/*/s2s/sessions"; + for _,session in pairs(s2s_sessions) do + if not host or session.from_host == host or session.to_host == host then + session:close(); + count = count + 1; + end + end + if count == 0 then return false, "No sessions to close."; + else return true, "Closed "..count.." s2s session"..((count == 1 and "") or "s"); end +end + def_env.host = {}; def_env.hosts = def_env.host; function def_env.host:activate(hostname, config) @@ -726,34 +831,232 @@ function def_env.host:list() return true, i.." hosts"; end +def_env.port = {}; + +function def_env.port:list() + local print = self.session.print; + local services = portmanager.get_active_services().data; + local ordered_services, n_ports = {}, 0; + for service, interfaces in pairs(services) do + table.insert(ordered_services, service); + end + table.sort(ordered_services); + for _, service in ipairs(ordered_services) do + local ports_list = {}; + for interface, ports in pairs(services[service]) do + for port in pairs(ports) do + table.insert(ports_list, "["..interface.."]:"..port); + end + end + n_ports = n_ports + #ports_list; + print(service..": "..table.concat(ports_list, ", ")); + end + return true, #ordered_services.." services listening on "..n_ports.." ports"; +end + +function def_env.port:close(close_port, close_interface) + close_port = assert(tonumber(close_port), "Invalid port number"); + local n_closed = 0; + local services = portmanager.get_active_services().data; + for service, interfaces in pairs(services) do + for interface, ports in pairs(interfaces) do + if not close_interface or close_interface == interface then + if ports[close_port] then + self.session.print("Closing ["..interface.."]:"..close_port.."..."); + local ok, err = portmanager.close(interface, close_port) + if not ok then + self.session.print("Failed to close "..interface.." "..close_port..": "..err); + else + n_closed = n_closed + 1; + end + end + end + end + end + return true, "Closed "..n_closed.." ports"; +end + +def_env.muc = {}; + +local console_room_mt = { + __index = function (self, k) return self.room[k]; end; + __tostring = function (self) + return "MUC room <"..self.room.jid..">"; + end; +}; + +local function check_muc(jid) + local room_name, host = jid_split(jid); + if not hosts[host] then + return nil, "No such host: "..host; + elseif not hosts[host].modules.muc then + return nil, "Host '"..host.."' is not a MUC service"; + end + return room_name, host; +end + +function def_env.muc:create(room_jid) + local room, host = check_muc(room_jid); + if not room then return nil, host end + if hosts[host].modules.muc.rooms[room_jid] then return nil, "Room exists already" end + return hosts[host].modules.muc.create_room(room_jid); +end + +function def_env.muc:room(room_jid) + local room_name, host = check_muc(room_jid); + local room_obj = hosts[host].modules.muc.rooms[room_jid]; + if not room_obj then + return nil, "No such room: "..room_jid; + end + return setmetatable({ room = room_obj }, console_room_mt); +end + +local um = require"core.usermanager"; + +def_env.user = {}; +function def_env.user:create(jid, password) + local username, host = jid_split(jid); + if not hosts[host] then + return nil, "No such host: "..host; + elseif um.user_exists(username, host) then + return nil, "User exists"; + end + local ok, err = um.create_user(username, password, host); + if ok then + return true, "User created"; + else + return nil, "Could not create user: "..err; + end +end + +function def_env.user:delete(jid) + local username, host = jid_split(jid); + if not hosts[host] then + return nil, "No such host: "..host; + elseif not um.user_exists(username, host) then + return nil, "No such user"; + end + local ok, err = um.delete_user(username, host); + if ok then + return true, "User deleted"; + else + return nil, "Could not delete user: "..err; + end +end + +function def_env.user:password(jid, password) + local username, host = jid_split(jid); + if not hosts[host] then + return nil, "No such host: "..host; + elseif not um.user_exists(username, host) then + return nil, "No such user"; + end + local ok, err = um.set_password(username, password, host); + if ok then + return true, "User password changed"; + else + return nil, "Could not change password for user: "..err; + end +end + +function def_env.user:list(host, pat) + if not host then + return nil, "No host given"; + elseif not hosts[host] then + return nil, "No such host"; + end + local print = self.session.print; + local total, matches = 0, 0; + for user in um.users(host) do + if not pat or user:match(pat) then + print(user.."@"..host); + matches = matches + 1; + end + total = total + 1; + end + return true, "Showing "..(pat and (matches.." of ") or "all " )..total.." users"; +end + +def_env.xmpp = {}; + +local st = require "util.stanza"; +function def_env.xmpp:ping(localhost, remotehost) + if hosts[localhost] then + core_post_stanza(hosts[localhost], + st.iq{ from=localhost, to=remotehost, type="get", id="ping" } + :tag("ping", {xmlns="urn:xmpp:ping"})); + return true, "Sent ping"; + else + return nil, "No such host"; + end +end + +def_env.dns = {}; +local adns = require"net.adns"; +local dns = require"net.dns"; + +function def_env.dns:lookup(name, typ, class) + local ret = "Query sent"; + local print = self.session.print; + local function handler(...) + ret = "Got response"; + print(...); + end + adns.lookup(handler, name, typ, class); + return true, ret; +end + +function def_env.dns:addnameserver(...) + dns.addnameserver(...) + return true +end + +function def_env.dns:setnameserver(...) + dns.setnameserver(...) + return true +end + +function def_env.dns:purge() + dns.purge() + return true +end + +function def_env.dns:cache() + return true, "Cache:\n"..tostring(dns.cache()) +end + ------------- function printbanner(session) - local option = config.get("*", "core", "console_banner"); -if option == nil or option == "full" or option == "graphic" then -session.print [[ - ____ \ / _ - | _ \ _ __ ___ ___ _-_ __| |_ _ + local option = module:get_option("console_banner"); + if option == nil or option == "full" or option == "graphic" then + session.print [[ + ____ \ / _ + | _ \ _ __ ___ ___ _-_ __| |_ _ | |_) | '__/ _ \/ __|/ _ \ / _` | | | | | __/| | | (_) \__ \ |_| | (_| | |_| | |_| |_| \___/|___/\___/ \__,_|\__, | - A study in simplicity |___/ + A study in simplicity |___/ ]] -end -if option == nil or option == "short" or option == "full" then -session.print("Welcome to the Prosody administration console. For a list of commands, type: help"); -session.print("You may find more help on using this console in our online documentation at "); -session.print("http://prosody.im/doc/console\n"); -end -if option and option ~= "short" and option ~= "full" and option ~= "graphic" then - if type(option) == "string" then - session.print(option) - elseif type(option) == "function" then - setfenv(option, redirect_output(_G, session)); - pcall(option, session); end -end + if option == nil or option == "short" or option == "full" then + session.print("Welcome to the Prosody administration console. For a list of commands, type: help"); + session.print("You may find more help on using this console in our online documentation at "); + session.print("http://prosody.im/doc/console\n"); + end + if option and option ~= "short" and option ~= "full" and option ~= "graphic" then + if type(option) == "string" then + session.print(option) + elseif type(option) == "function" then + module:log("warn", "Using functions as value for the console_banner option is no longer supported"); + end + end end -prosody.net_activate_ports("console", "console", {5582}, "tcp"); +module:provides("net", { + name = "console"; + listener = console_listener; + default_port = 5582; + private = true; +}); diff --git a/plugins/mod_announce.lua b/plugins/mod_announce.lua index 77555bec..9327556c 100644 --- a/plugins/mod_announce.lua +++ b/plugins/mod_announce.lua @@ -1,13 +1,14 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- local st, jid = require "util.stanza", require "util.jid"; +local hosts = prosody.hosts; local is_admin = require "core.usermanager".is_admin; function send_to_online(message, host) @@ -25,7 +26,7 @@ function send_to_online(message, host) for username in pairs(host_session.sessions) do c = c + 1; message.attr.to = username.."@"..hostname; - core_post_stanza(host_session, message); + module:send(message); end end end @@ -38,22 +39,22 @@ end function handle_announcement(event) local origin, stanza = event.origin, event.stanza; local node, host, resource = jid.split(stanza.attr.to); - + if resource ~= "announce/online" then return; -- Not an announcement end - + if not is_admin(stanza.attr.from) then -- Not an admin? Not allowed! module:log("warn", "Non-admin '%s' tried to send server announcement", stanza.attr.from); return; end - + module:log("info", "Sending server announcement to all online users"); local message = st.clone(stanza); message.attr.type = "headline"; message.attr.from = host; - + local c = send_to_online(message, host); module:log("info", "Announcement sent to %d online users", c); return true; @@ -82,13 +83,13 @@ function announce_handler(self, data, state) module:log("info", "Sending server announcement to all online users"); local message = st.message({type = "headline"}, fields.announcement):up() :tag("subject"):text(fields.subject or "Announcement"); - + local count = send_to_online(message, data.to); - + module:log("info", "Announcement sent to %d online users", count); return { status = "completed", info = ("Announcement sent to %d online users"):format(count) }; else - return { status = "executing", form = announce_layout }, "executing"; + return { status = "executing", actions = {"next", "complete", default = "complete"}, form = announce_layout }, "executing"; end return true; @@ -96,5 +97,5 @@ end local adhoc_new = module:require "adhoc".new; local announce_desc = adhoc_new("Send Announcement to Online Users", "http://jabber.org/protocol/admin#announce", announce_handler, "admin"); -module:add_item("adhoc", announce_desc); +module:provides("adhoc", announce_desc); diff --git a/plugins/mod_auth_anonymous.lua b/plugins/mod_auth_anonymous.lua index 9d0896e5..c877d532 100644 --- a/plugins/mod_auth_anonymous.lua +++ b/plugins/mod_auth_anonymous.lua @@ -6,63 +6,66 @@ -- COPYING file in the source package for more information. -- -local log = require "util.logger".init("auth_anonymous"); local new_sasl = require "util.sasl".new; local datamanager = require "util.datamanager"; +local hosts = prosody.hosts; -function new_default_provider(host) - local provider = { name = "anonymous" }; +-- define auth provider +local provider = {}; - function provider.test_password(username, password) - return nil, "Password based auth not supported."; - end +function provider.test_password(username, password) + return nil, "Password based auth not supported."; +end - function provider.get_password(username) - return nil, "Password not available."; - end +function provider.get_password(username) + return nil, "Password not available."; +end - function provider.set_password(username, password) - return nil, "Password based auth not supported."; - end +function provider.set_password(username, password) + return nil, "Password based auth not supported."; +end - function provider.user_exists(username) - return nil, "Only anonymous users are supported."; -- FIXME check if anonymous user is connected? - end +function provider.user_exists(username) + return nil, "Only anonymous users are supported."; -- FIXME check if anonymous user is connected? +end - function provider.create_user(username, password) - return nil, "Account creation/modification not supported."; - end +function provider.create_user(username, password) + return nil, "Account creation/modification not supported."; +end - function provider.get_sasl_handler() - local realm = module:get_option("sasl_realm") or module.host; - local anonymous_authentication_profile = { - anonymous = function(sasl, username, realm) - return true; -- for normal usage you should always return true here - end - }; - return new_sasl(realm, anonymous_authentication_profile); - end +function provider.get_sasl_handler() + local anonymous_authentication_profile = { + anonymous = function(sasl, username, realm) + return true; -- for normal usage you should always return true here + end + }; + return new_sasl(module.host, anonymous_authentication_profile); +end - return provider; +function provider.users() + return next, hosts[host].sessions, nil; end +-- datamanager callback to disable writes local function dm_callback(username, host, datastore, data) if host == module.host then return false; end return username, host, datastore, data; end -local host = hosts[module.host]; -local _saved_disallow_s2s = host.disallow_s2s; + +if not module:get_option_boolean("allow_anonymous_s2s", false) then + module:hook("route/remote", function (event) + return false; -- Block outgoing s2s from anonymous users + end, 300); +end + function module.load() - _saved_disallow_s2s = host.disallow_s2s; - host.disallow_s2s = module:get_option("disallow_s2s") ~= false; datamanager.add_callback(dm_callback); end function module.unload() - host.disallow_s2s = _saved_disallow_s2s; datamanager.remove_callback(dm_callback); end -module:add_item("auth-provider", new_default_provider(module.host)); +module:provides("auth", provider); diff --git a/plugins/mod_auth_cyrus.lua b/plugins/mod_auth_cyrus.lua index ed3d5408..7668f8c4 100644 --- a/plugins/mod_auth_cyrus.lua +++ b/plugins/mod_auth_cyrus.lua @@ -14,6 +14,7 @@ local cyrus_service_realm = module:get_option("cyrus_service_realm"); local cyrus_service_name = module:get_option("cyrus_service_name"); local cyrus_application_name = module:get_option("cyrus_application_name"); local require_provisioning = module:get_option("cyrus_require_provisioning") or false; +local host_fqdn = module:get_option("cyrus_server_fqdn"); prosody.unlock_globals(); --FIXME: Figure out why this is needed and -- why cyrussasl isn't caught by the sandbox @@ -23,50 +24,61 @@ local new_sasl = function(realm) return cyrus_new( cyrus_service_realm or realm, cyrus_service_name or "xmpp", - cyrus_application_name or "prosody" + cyrus_application_name or "prosody", + host_fqdn ); end -function new_default_provider(host) - local provider = { name = "cyrus" }; - log("debug", "initializing default authentication provider for host '%s'", host); - - function provider.test_password(username, password) - return nil, "Legacy auth not supported with Cyrus SASL."; - end - - function provider.get_password(username) - return nil, "Passwords unavailable for Cyrus SASL."; +do -- diagnostic + local list; + for mechanism in pairs(new_sasl(module.host):mechanisms()) do + list = (not(list) and mechanism) or (list..", "..mechanism); end - - function provider.set_password(username, password) - return nil, "Passwords unavailable for Cyrus SASL."; + if not list then + module:log("error", "No Cyrus SASL mechanisms available"); + else + module:log("debug", "Available Cyrus SASL mechanisms: %s", list); end +end - function provider.user_exists(username) - if require_provisioning then - return usermanager_user_exists(username, module.host); - end - return true; - end +local host = module.host; + +-- define auth provider +local provider = {}; +log("debug", "initializing default authentication provider for host '%s'", host); + +function provider.test_password(username, password) + return nil, "Legacy auth not supported with Cyrus SASL."; +end - function provider.create_user(username, password) - return nil, "Account creation/modification not available with Cyrus SASL."; +function provider.get_password(username) + return nil, "Passwords unavailable for Cyrus SASL."; +end + +function provider.set_password(username, password) + return nil, "Passwords unavailable for Cyrus SASL."; +end + +function provider.user_exists(username) + if require_provisioning then + return usermanager_user_exists(username, host); end + return true; +end + +function provider.create_user(username, password) + return nil, "Account creation/modification not available with Cyrus SASL."; +end - function provider.get_sasl_handler() - local realm = module:get_option("sasl_realm") or module.host; - local handler = new_sasl(realm); - if require_provisioning then - function handler.require_provisioning(username) - return usermanager_user_exists(username, module.host); - end +function provider.get_sasl_handler() + local handler = new_sasl(host); + if require_provisioning then + function handler.require_provisioning(username) + return usermanager_user_exists(username, host); end - return handler; end - - return provider; + return handler; end -module:add_item("auth-provider", new_default_provider(module.host)); +module:provides("auth", provider); diff --git a/plugins/mod_auth_internal_hashed.lua b/plugins/mod_auth_internal_hashed.lua index ec8da9ab..fb87bb9f 100644 --- a/plugins/mod_auth_internal_hashed.lua +++ b/plugins/mod_auth_internal_hashed.lua @@ -7,24 +7,15 @@ -- COPYING file in the source package for more information. -- -local datamanager = require "util.datamanager"; -local log = require "util.logger".init("auth_internal_hashed"); -local type = type; -local error = error; -local ipairs = ipairs; -local hashes = require "util.hashes"; -local jid_bare = require "util.jid".bare; local getAuthenticationDatabaseSHA1 = require "util.sasl.scram".getAuthenticationDatabaseSHA1; -local config = require "core.configmanager"; local usermanager = require "core.usermanager"; local generate_uuid = require "util.uuid".generate; local new_sasl = require "util.sasl".new; -local nodeprep = require "util.encodings".stringprep.nodeprep; -local hosts = hosts; --- COMPAT w/old trunk: remove these two lines before 0.8 release -local hmac_sha1 = require "util.hmac".sha1; -local sha1 = require "util.hashes".sha1; +local log = module._log; +local host = module.host; + +local accounts = module:open_store("accounts"); local to_hex; do @@ -47,135 +38,113 @@ do end -local prosody = _G.prosody; - -- Default; can be set per-user local iteration_count = 4096; -function new_hashpass_provider(host) - local provider = { name = "internal_hashed" }; - log("debug", "initializing hashpass authentication provider for host '%s'", host); +-- define auth provider +local provider = {}; - function provider.test_password(username, password) - local credentials = datamanager.load(username, host, "accounts") or {}; - - if credentials.password ~= nil and string.len(credentials.password) ~= 0 then - if credentials.password ~= password then - return nil, "Auth failed. Provided password is incorrect."; - end +function provider.test_password(username, password) + log("debug", "test password for user '%s'", username); + local credentials = accounts:get(username) or {}; - if provider.set_password(username, credentials.password) == nil then - return nil, "Auth failed. Could not set hashed password from plaintext."; - else - return true; - end + if credentials.password ~= nil and string.len(credentials.password) ~= 0 then + if credentials.password ~= password then + return nil, "Auth failed. Provided password is incorrect."; end - if credentials.iteration_count == nil or credentials.salt == nil or string.len(credentials.salt) == 0 then - return nil, "Auth failed. Stored salt and iteration count information is not complete."; - end - - -- convert hexpass to stored_key and server_key - -- COMPAT w/old trunk: remove before 0.8 release - if credentials.hashpass then - local salted_password = from_hex(credentials.hashpass); - credentials.stored_key = sha1(hmac_sha1(salted_password, "Client Key"), true); - credentials.server_key = to_hex(hmac_sha1(salted_password, "Server Key")); - credentials.hashpass = nil - datamanager.store(username, host, "accounts", credentials); - end - - local valid, stored_key, server_key = getAuthenticationDatabaseSHA1(password, credentials.salt, credentials.iteration_count); - - local stored_key_hex = to_hex(stored_key); - local server_key_hex = to_hex(server_key); - - if valid and stored_key_hex == credentials.stored_key and server_key_hex == credentials.server_key then - return true; + if provider.set_password(username, credentials.password) == nil then + return nil, "Auth failed. Could not set hashed password from plaintext."; else - return nil, "Auth failed. Invalid username, password, or password hash information."; + return true; end end - function provider.set_password(username, password) - local account = datamanager.load(username, host, "accounts"); - if account then - account.salt = account.salt or generate_uuid(); - account.iteration_count = account.iteration_count or iteration_count; - local valid, stored_key, server_key = getAuthenticationDatabaseSHA1(password, account.salt, account.iteration_count); - local stored_key_hex = to_hex(stored_key); - local server_key_hex = to_hex(server_key); - - account.stored_key = stored_key_hex - account.server_key = server_key_hex - - account.password = nil; - return datamanager.store(username, host, "accounts", account); - end - return nil, "Account not available."; + if credentials.iteration_count == nil or credentials.salt == nil or string.len(credentials.salt) == 0 then + return nil, "Auth failed. Stored salt and iteration count information is not complete."; end - function provider.user_exists(username) - local account = datamanager.load(username, host, "accounts"); - if not account then - log("debug", "account not found for username '%s' at host '%s'", username, module.host); - return nil, "Auth failed. Invalid username"; - end + local valid, stored_key, server_key = getAuthenticationDatabaseSHA1(password, credentials.salt, credentials.iteration_count); + + local stored_key_hex = to_hex(stored_key); + local server_key_hex = to_hex(server_key); + + if valid and stored_key_hex == credentials.stored_key and server_key_hex == credentials.server_key then return true; + else + return nil, "Auth failed. Invalid username, password, or password hash information."; end +end - function provider.create_user(username, password) - if password == nil then - return datamanager.store(username, host, "accounts", {}); - end - local salt = generate_uuid(); - local valid, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count); +function provider.set_password(username, password) + log("debug", "set_password for username '%s'", username); + local account = accounts:get(username); + if account then + account.salt = account.salt or generate_uuid(); + account.iteration_count = account.iteration_count or iteration_count; + local valid, stored_key, server_key = getAuthenticationDatabaseSHA1(password, account.salt, account.iteration_count); local stored_key_hex = to_hex(stored_key); local server_key_hex = to_hex(server_key); - return datamanager.store(username, host, "accounts", {stored_key = stored_key_hex, server_key = server_key_hex, salt = salt, iteration_count = iteration_count}); + + account.stored_key = stored_key_hex + account.server_key = server_key_hex + + account.password = nil; + return accounts:set(username, account); + end + return nil, "Account not available."; +end + +function provider.user_exists(username) + local account = accounts:get(username); + if not account then + log("debug", "account not found for username '%s'", username); + return nil, "Auth failed. Invalid username"; + end + return true; +end + +function provider.users() + return accounts:users(); +end + +function provider.create_user(username, password) + if password == nil then + return accounts:set(username, {}); end + local salt = generate_uuid(); + local valid, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count); + local stored_key_hex = to_hex(stored_key); + local server_key_hex = to_hex(server_key); + return accounts:set(username, {stored_key = stored_key_hex, server_key = server_key_hex, salt = salt, iteration_count = iteration_count}); +end + +function provider.delete_user(username) + return accounts:set(username, nil); +end - function provider.get_sasl_handler() - local realm = module:get_option("sasl_realm") or module.host; - local testpass_authentication_profile = { - plain_test = function(sasl, username, password, realm) - local prepped_username = nodeprep(username); - if not prepped_username then - log("debug", "NODEprep failed on username: %s", username); - return "", nil; - end - return usermanager.test_password(prepped_username, realm, password), true; - end, - scram_sha_1 = function(sasl, username, realm) - local credentials = datamanager.load(username, host, "accounts"); +function provider.get_sasl_handler() + local testpass_authentication_profile = { + plain_test = function(sasl, username, password, realm) + return usermanager.test_password(username, realm, password), true; + end, + scram_sha_1 = function(sasl, username, realm) + local credentials = accounts:get(username); + if not credentials then return; end + if credentials.password then + usermanager.set_password(username, credentials.password, host); + credentials = accounts:get(username); if not credentials then return; end - if credentials.password then - usermanager.set_password(username, credentials.password, host); - credentials = datamanager.load(username, host, "accounts"); - if not credentials then return; end - end - - -- convert hexpass to stored_key and server_key - -- COMPAT w/old trunk: remove before 0.8 release - if credentials.hashpass then - local salted_password = from_hex(credentials.hashpass); - credentials.stored_key = sha1(hmac_sha1(salted_password, "Client Key"), true); - credentials.server_key = to_hex(hmac_sha1(salted_password, "Server Key")); - credentials.hashpass = nil - datamanager.store(username, host, "accounts", credentials); - end - - local stored_key, server_key, iteration_count, salt = credentials.stored_key, credentials.server_key, credentials.iteration_count, credentials.salt; - stored_key = stored_key and from_hex(stored_key); - server_key = server_key and from_hex(server_key); - return stored_key, server_key, iteration_count, salt, true; end - }; - return new_sasl(realm, testpass_authentication_profile); - end - - return provider; + + local stored_key, server_key, iteration_count, salt = credentials.stored_key, credentials.server_key, credentials.iteration_count, credentials.salt; + stored_key = stored_key and from_hex(stored_key); + server_key = server_key and from_hex(server_key); + return stored_key, server_key, iteration_count, salt, true; + end + }; + return new_sasl(host, testpass_authentication_profile); end -module:add_item("auth-provider", new_hashpass_provider(module.host)); +module:provides("auth", provider); diff --git a/plugins/mod_auth_internal_plain.lua b/plugins/mod_auth_internal_plain.lua index 3721781b..db528432 100644 --- a/plugins/mod_auth_internal_plain.lua +++ b/plugins/mod_auth_internal_plain.lua @@ -6,84 +6,76 @@ -- COPYING file in the source package for more information. -- -local datamanager = require "util.datamanager"; -local log = require "util.logger".init("auth_internal_plain"); -local type = type; -local error = error; -local ipairs = ipairs; -local hashes = require "util.hashes"; -local jid_bare = require "util.jid".bare; -local config = require "core.configmanager"; local usermanager = require "core.usermanager"; local new_sasl = require "util.sasl".new; -local nodeprep = require "util.encodings".stringprep.nodeprep; -local hosts = hosts; -local prosody = _G.prosody; +local log = module._log; +local host = module.host; -function new_default_provider(host) - local provider = { name = "internal_plain" }; - log("debug", "initializing default authentication provider for host '%s'", host); +local accounts = module:open_store("accounts"); - function provider.test_password(username, password) - log("debug", "test password '%s' for user %s at host %s", password, username, module.host); - local credentials = datamanager.load(username, host, "accounts") or {}; - - if password == credentials.password then - return true; - else - return nil, "Auth failed. Invalid username or password."; - end - end +-- define auth provider +local provider = {}; - function provider.get_password(username) - log("debug", "get_password for username '%s' at host '%s'", username, module.host); - return (datamanager.load(username, host, "accounts") or {}).password; - end - - function provider.set_password(username, password) - local account = datamanager.load(username, host, "accounts"); - if account then - account.password = password; - return datamanager.store(username, host, "accounts", account); - end - return nil, "Account not available."; - end +function provider.test_password(username, password) + log("debug", "test password for user '%s'", username); + local credentials = accounts:get(username) or {}; - function provider.user_exists(username) - local account = datamanager.load(username, host, "accounts"); - if not account then - log("debug", "account not found for username '%s' at host '%s'", username, module.host); - return nil, "Auth failed. Invalid username"; - end + if password == credentials.password then return true; + else + return nil, "Auth failed. Invalid username or password."; end +end + +function provider.get_password(username) + log("debug", "get_password for username '%s'", username); + return (accounts:get(username) or {}).password; +end - function provider.create_user(username, password) - return datamanager.store(username, host, "accounts", {password = password}); +function provider.set_password(username, password) + log("debug", "set_password for username '%s'", username); + local account = accounts:get(username); + if account then + account.password = password; + return accounts:set(username, account); end + return nil, "Account not available."; +end - function provider.get_sasl_handler() - local realm = module:get_option("sasl_realm") or module.host; - local getpass_authentication_profile = { - plain = function(sasl, username, realm) - local prepped_username = nodeprep(username); - if not prepped_username then - log("debug", "NODEprep failed on username: %s", username); - return "", nil; - end - local password = usermanager.get_password(prepped_username, realm); - if not password then - return "", nil; - end - return password, true; - end - }; - return new_sasl(realm, getpass_authentication_profile); +function provider.user_exists(username) + local account = accounts:get(username); + if not account then + log("debug", "account not found for username '%s'", username); + return nil, "Auth failed. Invalid username"; end - - return provider; + return true; +end + +function provider.users() + return accounts:users(); +end + +function provider.create_user(username, password) + return accounts:set(username, {password = password}); +end + +function provider.delete_user(username) + return accounts:set(username, nil); +end + +function provider.get_sasl_handler() + local getpass_authentication_profile = { + plain = function(sasl, username, realm) + local password = usermanager.get_password(username, realm); + if not password then + return "", nil; + end + return password, true; + end + }; + return new_sasl(host, getpass_authentication_profile); end -module:add_item("auth-provider", new_default_provider(module.host)); +module:provides("auth", provider); diff --git a/plugins/mod_bosh.lua b/plugins/mod_bosh.lua index 58254169..ca67db73 100644 --- a/plugins/mod_bosh.lua +++ b/plugins/mod_bosh.lua @@ -1,26 +1,27 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- -module.host = "*" -- Global module +module:set_global(); -- Global module local hosts = _G.hosts; -local lxp = require "lxp"; local new_xmpp_stream = require "util.xmppstream".new; -local httpserver = require "net.httpserver"; local sm = require "core.sessionmanager"; local sm_destroy_session = sm.destroy_session; local new_uuid = require "util.uuid".generate; local fire_event = prosody.events.fire_event; -local core_process_stanza = core_process_stanza; +local core_process_stanza = prosody.core_process_stanza; local st = require "util.stanza"; local logger = require "util.logger"; local log = logger.init("mod_bosh"); -local timer = require "util.timer"; +local initialize_filters = require "util.filters".initialize; +local math_min = math.min; +local xpcall, tostring, type = xpcall, tostring, type; +local traceback = debug.traceback; local xmlns_streams = "http://etherx.jabber.org/streams"; local xmlns_xmpp_streams = "urn:ietf:params:xml:ns:xmpp-streams"; @@ -29,36 +30,23 @@ local xmlns_bosh = "http://jabber.org/protocol/httpbind"; -- (hard-coded into a local stream_callbacks = { stream_ns = xmlns_bosh, stream_tag = "body", default_ns = "jabber:client" }; -local BOSH_DEFAULT_HOLD = tonumber(module:get_option("bosh_default_hold")) or 1; -local BOSH_DEFAULT_INACTIVITY = tonumber(module:get_option("bosh_max_inactivity")) or 60; -local BOSH_DEFAULT_POLLING = tonumber(module:get_option("bosh_max_polling")) or 5; -local BOSH_DEFAULT_REQUESTS = tonumber(module:get_option("bosh_max_requests")) or 2; +local BOSH_DEFAULT_HOLD = module:get_option_number("bosh_default_hold", 1); +local BOSH_DEFAULT_INACTIVITY = module:get_option_number("bosh_max_inactivity", 60); +local BOSH_DEFAULT_POLLING = module:get_option_number("bosh_max_polling", 5); +local BOSH_DEFAULT_REQUESTS = module:get_option_number("bosh_max_requests", 2); +local bosh_max_wait = module:get_option_number("bosh_max_wait", 120); local consider_bosh_secure = module:get_option_boolean("consider_bosh_secure"); +local cross_domain = module:get_option("cross_domain_bosh", false); -local default_headers = { ["Content-Type"] = "text/xml; charset=utf-8" }; - -local cross_domain = module:get_option("cross_domain_bosh"); -if cross_domain then - default_headers["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS"; - default_headers["Access-Control-Allow-Headers"] = "Content-Type"; - default_headers["Access-Control-Max-Age"] = "7200"; - - if cross_domain == true then - default_headers["Access-Control-Allow-Origin"] = "*"; - elseif type(cross_domain) == "table" then - cross_domain = table.concat(cross_domain, ", "); - end - if type(cross_domain) == "string" then - default_headers["Access-Control-Allow-Origin"] = cross_domain; - end -end +if cross_domain == true then cross_domain = "*"; end +if type(cross_domain) == "table" then cross_domain = table.concat(cross_domain, ", "); end local trusted_proxies = module:get_option_set("trusted_proxies", {"127.0.0.1"})._items; local function get_ip_from_request(request) - local ip = request.handler:ip(); - local forwarded_for = request.headers["x-forwarded-for"]; + local ip = request.conn:ip(); + local forwarded_for = request.headers.x_forwarded_for; if forwarded_for then forwarded_for = forwarded_for..", "..ip; for forwarded_ip in forwarded_for:gmatch("[^%s,]+") do @@ -73,61 +61,91 @@ end local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat; local os_time = os.time; -local sessions = {}; -local inactive_sessions = {}; -- Sessions which have no open requests +-- All sessions, and sessions that have no requests open +local sessions, inactive_sessions = module:shared("sessions", "inactive_sessions"); -- Used to respond to idle sessions (those with waiting requests) -local waiting_requests = {}; +local waiting_requests = module:shared("waiting_requests"); function on_destroy_request(request) + log("debug", "Request destroyed: %s", tostring(request)); waiting_requests[request] = nil; - local session = sessions[request.sid]; + local session = sessions[request.context.sid]; if session then local requests = session.requests; - for i,r in ipairs(requests) do + for i, r in ipairs(requests) do if r == request then t_remove(requests, i); break; end end - + -- If this session now has no requests open, mark it as inactive - if #requests == 0 and session.bosh_max_inactive and not inactive_sessions[session] then - inactive_sessions[session] = os_time(); - (session.log or log)("debug", "BOSH session marked as inactive at %d", inactive_sessions[session]); + local max_inactive = session.bosh_max_inactive; + if max_inactive and #requests == 0 then + inactive_sessions[session] = os_time() + max_inactive; + (session.log or log)("debug", "BOSH session marked as inactive (for %ds)", max_inactive); end end end -function handle_request(method, body, request) - if (not body) or request.method ~= "POST" then - if request.method == "OPTIONS" then - local headers = {}; - for k,v in pairs(default_headers) do headers[k] = v; end - headers["Content-Type"] = nil; - return { headers = headers, body = "" }; - else - return "<html><body>You really don't look like a BOSH client to me... what do you want?</body></html>"; - end +local function set_cross_domain_headers(response) + local headers = response.headers; + headers.access_control_allow_methods = "GET, POST, OPTIONS"; + headers.access_control_allow_headers = "Content-Type"; + headers.access_control_max_age = "7200"; + headers.access_control_allow_origin = cross_domain; + return response; +end + +function handle_OPTIONS(event) + if cross_domain and event.request.headers.origin then + set_cross_domain_headers(event.response); end - if not method then - log("debug", "Request %s suffered error %s", tostring(request.id), body); - return; + return ""; +end + +function handle_POST(event) + log("debug", "Handling new request %s: %s\n----------", tostring(event.request), tostring(event.request.body)); + + local request, response = event.request, event.response; + response.on_destroy = on_destroy_request; + local body = request.body; + + local context = { request = request, response = response, notopen = true }; + local stream = new_xmpp_stream(context, stream_callbacks); + response.context = context; + + local headers = response.headers; + headers.content_type = "text/xml; charset=utf-8"; + + if cross_domain and event.request.headers.origin then + set_cross_domain_headers(response); end - --log("debug", "Handling new request %s: %s\n----------", request.id, tostring(body)); - request.notopen = true; - request.log = log; - request.on_destroy = on_destroy_request; - - local stream = new_xmpp_stream(request, stream_callbacks); + -- stream:feed() calls the stream_callbacks, so all stanzas in -- the body are processed in this next line before it returns. - stream:feed(body); - - local session = sessions[request.sid]; + -- In particular, the streamopened() stream callback is where + -- much of the session logic happens, because it's where we first + -- get to see the 'sid' of this request. + if not stream:feed(body) then + module:log("warn", "Error parsing BOSH payload") + return 400; + end + + -- Stanzas (if any) in the request have now been processed, and + -- we take care of the high-level BOSH logic here, including + -- giving a response or putting the request "on hold". + local session = sessions[context.sid]; if session then + -- Session was marked as inactive, since we have + -- a request open now, unmark it + if inactive_sessions[session] and #session.requests > 0 then + inactive_sessions[session] = nil; + end + local r = session.requests; - log("debug", "Session %s has %d out of %d requests open", request.sid, #r, session.bosh_hold); - log("debug", "and there are %d things in the send_buffer", #session.send_buffer); + log("debug", "Session %s has %d out of %d requests open", context.sid, #r, session.bosh_hold); + log("debug", "and there are %d things in the send_buffer:", #session.send_buffer); if #r > session.bosh_hold then -- We are holding too many requests, send what's in the buffer, log("debug", "We are holding too many requests, so..."); @@ -146,23 +164,25 @@ function handle_request(method, body, request) session.send_buffer = {}; session.send(resp); end - - if not request.destroyed then + + if not response.finished then -- We're keeping this request open, to respond later log("debug", "Have nothing to say, so leaving request unanswered for now"); if session.bosh_wait then - request.reply_before = os_time() + session.bosh_wait; - waiting_requests[request] = true; - end - if inactive_sessions[session] then - -- Session was marked as inactive, since we have - -- a request open now, unmark it - inactive_sessions[session] = nil; + waiting_requests[response] = os_time() + session.bosh_wait; end end - - return true; -- Inform httpserver we shall reply later + + if session.bosh_terminate then + session.log("debug", "Closing session with %d requests open", #session.requests); + session:close(); + return nil; + else + return true; -- Inform http server we shall reply later + end end + module:log("warn", "Unable to associate request with a session (incomplete request?)"); + return 400; end @@ -172,10 +192,10 @@ local stream_xmlns_attr = { xmlns = "urn:ietf:params:xml:ns:xmpp-streams" }; local function bosh_close_stream(session, reason) (session.log or log)("info", "BOSH client disconnected"); - + local close_reply = st.stanza("body", { xmlns = xmlns_bosh, type = "terminate", - ["xmlns:streams"] = xmlns_streams }); - + ["xmlns:stream"] = xmlns_streams }); + if reason then close_reply.attr.condition = "remote-stream-error"; @@ -199,115 +219,105 @@ local function bosh_close_stream(session, reason) log("info", "Disconnecting client, <stream:error> is: %s", tostring(close_reply)); end - local session_close_response = { headers = default_headers, body = tostring(close_reply) }; - - --FIXME: Quite sure we shouldn't reply to all requests with the error + local response_body = tostring(close_reply); for _, held_request in ipairs(session.requests) do - held_request:send(session_close_response); - held_request:destroy(); + held_request:send(response_body); end - sessions[session.sid] = nil; + sessions[session.sid] = nil; + inactive_sessions[session] = nil; sm_destroy_session(session); end -function stream_callbacks.streamopened(request, attr) - log("debug", "BOSH body open (sid: %s)", attr.sid); - local sid = attr.sid +-- Handle the <body> tag in the request payload. +function stream_callbacks.streamopened(context, attr) + local request, response = context.request, context.response; + local sid = attr.sid; + log("debug", "BOSH body open (sid: %s)", sid or "<none>"); if not sid then -- New session request - request.notopen = nil; -- Signals that we accept this opening tag - + context.notopen = nil; -- Signals that we accept this opening tag + -- TODO: Sanity checks here (rid, to, known host, etc.) if not hosts[attr.to] then -- Unknown host log("debug", "BOSH client tried to connect to unknown host: %s", tostring(attr.to)); local close_reply = st.stanza("body", { xmlns = xmlns_bosh, type = "terminate", - ["xmlns:streams"] = xmlns_streams, condition = "host-unknown" }); - request:send(tostring(close_reply)); + ["xmlns:stream"] = xmlns_streams, condition = "host-unknown" }); + response:send(tostring(close_reply)); return; end - + -- New session sid = new_uuid(); local session = { - type = "c2s_unauthed", conn = {}, sid = sid, rid = tonumber(attr.rid), host = attr.to, - bosh_version = attr.ver, bosh_wait = attr.wait, streamid = sid, + type = "c2s_unauthed", conn = {}, sid = sid, rid = tonumber(attr.rid)-1, host = attr.to, + bosh_version = attr.ver, bosh_wait = math_min(attr.wait, bosh_max_wait), streamid = sid, bosh_hold = BOSH_DEFAULT_HOLD, bosh_max_inactive = BOSH_DEFAULT_INACTIVITY, requests = { }, send_buffer = {}, reset_stream = bosh_reset_stream, - close = bosh_close_stream, dispatch_stanza = core_process_stanza, + close = bosh_close_stream, dispatch_stanza = core_process_stanza, notopen = true, log = logger.init("bosh"..sid), secure = consider_bosh_secure or request.secure, ip = get_ip_from_request(request); }; sessions[sid] = session; - + + local filter = initialize_filters(session); + session.log("debug", "BOSH session created for request from %s", session.ip); log("info", "New BOSH session, assigned it sid '%s'", sid); - local r, send_buffer = session.requests, session.send_buffer; - local response = { headers = default_headers } + + -- Send creation response + local creating_session = true; + + local r = session.requests; function session.send(s) -- We need to ensure that outgoing stanzas have the jabber:client xmlns if s.attr and not s.attr.xmlns then s = st.clone(s); s.attr.xmlns = "jabber:client"; end + s = filter("stanzas/out", s); --log("debug", "Sending BOSH data: %s", tostring(s)); + t_insert(session.send_buffer, tostring(s)); + local oldest_request = r[1]; - if oldest_request then + if oldest_request and not session.bosh_processing then log("debug", "We have an open request, so sending on that"); - response.body = t_concat{"<body xmlns='http://jabber.org/protocol/httpbind' sid='", sid, "' xmlns:stream = 'http://etherx.jabber.org/streams'>", tostring(s), "</body>" }; - oldest_request:send(response); - --log("debug", "Sent"); - if oldest_request.stayopen then - if #r>1 then - -- Move front request to back - t_insert(r, oldest_request); - t_remove(r, 1); - end - else - log("debug", "Destroying the request now..."); - oldest_request:destroy(); + local body_attr = { xmlns = "http://jabber.org/protocol/httpbind", + ["xmlns:stream"] = "http://etherx.jabber.org/streams"; + type = session.bosh_terminate and "terminate" or nil; + sid = sid; + }; + if creating_session then + creating_session = nil; + body_attr.inactivity = tostring(BOSH_DEFAULT_INACTIVITY); + body_attr.polling = tostring(BOSH_DEFAULT_POLLING); + body_attr.requests = tostring(BOSH_DEFAULT_REQUESTS); + body_attr.wait = tostring(session.bosh_wait); + body_attr.hold = tostring(session.bosh_hold); + body_attr.authid = sid; + body_attr.secure = "true"; + body_attr.ver = '1.6'; + body_attr.from = session.host; + body_attr["xmlns:xmpp"] = "urn:xmpp:xbosh"; + body_attr["xmpp:version"] = "1.0"; end - elseif s ~= "" then - log("debug", "Saved to send buffer because there are %d open requests", #r); - -- Hmm, no requests are open :( - t_insert(session.send_buffer, tostring(s)); - log("debug", "There are now %d things in the send_buffer", #session.send_buffer); + oldest_request:send(st.stanza("body", body_attr):top_tag()..t_concat(session.send_buffer).."</body>"); + session.send_buffer = {}; end + return true; end - - -- Send creation response - - local features = st.stanza("stream:features"); - hosts[session.host].events.fire_event("stream-features", { origin = session, features = features }); - fire_event("stream-features", session, features); - --xmpp:version='1.0' xmlns:xmpp='urn:xmpp:xbosh' - local response = st.stanza("body", { xmlns = xmlns_bosh, - wait = attr.wait, - inactivity = tostring(BOSH_DEFAULT_INACTIVITY), - polling = tostring(BOSH_DEFAULT_POLLING), - requests = tostring(BOSH_DEFAULT_REQUESTS), - hold = tostring(session.bosh_hold), - sid = sid, authid = sid, - ver = '1.6', from = session.host, - secure = 'true', ["xmpp:version"] = "1.0", - ["xmlns:xmpp"] = "urn:xmpp:xbosh", - ["xmlns:stream"] = "http://etherx.jabber.org/streams" - }):add_child(features); - request:send{ headers = default_headers, body = tostring(response) }; - request.sid = sid; - return; end - + local session = sessions[sid]; if not session then -- Unknown sid log("info", "Client tried to use sid '%s' which we don't know about", sid); - request:send{ headers = default_headers, body = tostring(st.stanza("body", { xmlns = xmlns_bosh, type = "terminate", condition = "item-not-found" })) }; - request.notopen = nil; + response:send(tostring(st.stanza("body", { xmlns = xmlns_bosh, type = "terminate", condition = "item-not-found" }))); + context.notopen = nil; return; end - + if session.rid then local rid = tonumber(attr.rid); local diff = rid - session.rid; @@ -315,56 +325,71 @@ function stream_callbacks.streamopened(request, attr) session.log("warn", "rid too large (means a request was lost). Last rid: %d New rid: %s", session.rid, attr.rid); elseif diff <= 0 then -- Repeated, ignore - session.log("debug", "rid repeated (on request %s), ignoring: %s (diff %d)", request.id, session.rid, diff); - request.notopen = nil; - request.ignore = true; - request.sid = sid; - t_insert(session.requests, request); + session.log("debug", "rid repeated, ignoring: %s (diff %d)", session.rid, diff); + context.notopen = nil; + context.ignore = true; + context.sid = sid; + t_insert(session.requests, response); return; end session.rid = rid; end - + if attr.type == "terminate" then - -- Client wants to end this session - session:close(); - request.notopen = nil; - return; + -- Client wants to end this session, which we'll do + -- after processing any stanzas in this request + session.bosh_terminate = true; end - + + context.notopen = nil; -- Signals that we accept this opening tag + t_insert(session.requests, response); + context.sid = sid; + session.bosh_processing = true; -- Used to suppress replies until processing of this request is done + if session.notopen then local features = st.stanza("stream:features"); hosts[session.host].events.fire_event("stream-features", { origin = session, features = features }); - fire_event("stream-features", session, features); session.send(features); session.notopen = nil; end - - request.notopen = nil; -- Signals that we accept this opening tag - t_insert(session.requests, request); - request.sid = sid; end -function stream_callbacks.handlestanza(request, stanza) - if request.ignore then return; end +local function handleerr(err) log("error", "Traceback[bosh]: %s", traceback(tostring(err), 2)); end +function stream_callbacks.handlestanza(context, stanza) + if context.ignore then return; end log("debug", "BOSH stanza received: %s\n", stanza:top_tag()); - local session = sessions[request.sid]; + local session = sessions[context.sid]; if session then if stanza.attr.xmlns == xmlns_bosh then stanza.attr.xmlns = nil; end - core_process_stanza(session, stanza); + stanza = session.filter("stanzas/in", stanza); + if stanza then + return xpcall(function () return core_process_stanza(session, stanza) end, handleerr); + end + end +end + +function stream_callbacks.streamclosed(context) + local session = sessions[context.sid]; + if session then + session.bosh_processing = false; + if #session.send_buffer > 0 then + session.send(""); + end end end -function stream_callbacks.error(request, error) +function stream_callbacks.error(context, error) log("debug", "Error parsing BOSH request payload; %s", error); - if not request.sid then - request:send({ headers = default_headers, status = "400 Bad Request" }); + if not context.sid then + local response = context.response; + response.status_code = 400; + response:send(); return; end - - local session = sessions[request.sid]; + + local session = sessions[context.sid]; if error == "stream-error" then -- Remote stream error, we close normally session:close(); else @@ -372,35 +397,31 @@ function stream_callbacks.error(request, error) end end -local dead_sessions = {}; +local dead_sessions = module:shared("dead_sessions"); function on_timer() -- log("debug", "Checking for requests soon to timeout..."); -- Identify requests timing out within the next few seconds local now = os_time() + 3; - for request in pairs(waiting_requests) do - if request.reply_before <= now then - log("debug", "%s was soon to timeout, sending empty response", request.id); + for request, reply_before in pairs(waiting_requests) do + if reply_before <= now then + log("debug", "%s was soon to timeout (at %d, now %d), sending empty response", tostring(request), reply_before, now); -- Send empty response to let the -- client know we're still here if request.conn then - sessions[request.sid].send(""); + sessions[request.context.sid].send(""); end end end - + now = now - 3; local n_dead_sessions = 0; - for session, inactive_since in pairs(inactive_sessions) do - if session.bosh_max_inactive then - if now - inactive_since > session.bosh_max_inactive then - (session.log or log)("debug", "BOSH client inactive too long, destroying session at %d", now); - sessions[session.sid] = nil; - inactive_sessions[session] = nil; - n_dead_sessions = n_dead_sessions + 1; - dead_sessions[n_dead_sessions] = session; - end - else + for session, close_after in pairs(inactive_sessions) do + if close_after < now then + (session.log or log)("debug", "BOSH client inactive too long, destroying session at %d", now); + sessions[session.sid] = nil; inactive_sessions[session] = nil; + n_dead_sessions = n_dead_sessions + 1; + dead_sessions[n_dead_sessions] = session; end end @@ -411,15 +432,30 @@ function on_timer() end return 1; end +module:add_timer(1, on_timer); -local function setup() - local ports = module:get_option("bosh_ports") or { 5280 }; - httpserver.new_from_config(ports, handle_request, { base = "http-bind" }); - timer.add_task(1, on_timer); -end -if prosody.start_time then -- already started - setup(); -else - prosody.events.add_handler("server-started", setup); +local GET_response = { + headers = { + content_type = "text/html"; + }; + body = [[<html><body> + <p>It works! Now point your BOSH client to this URL to connect to Prosody.</p> + <p>For more information see <a href="http://prosody.im/doc/setting_up_bosh">Prosody: Setting up BOSH</a>.</p> + </body></html>]]; +}; + +function module.add_host(module) + module:depends("http"); + module:provides("http", { + default_path = "/http-bind"; + route = { + ["GET"] = GET_response; + ["GET /"] = GET_response; + ["OPTIONS"] = handle_OPTIONS; + ["OPTIONS /"] = handle_OPTIONS; + ["POST"] = handle_POST; + ["POST /"] = handle_POST; + }; + }); end diff --git a/plugins/mod_c2s.lua b/plugins/mod_c2s.lua new file mode 100644 index 00000000..1fb8dcf5 --- /dev/null +++ b/plugins/mod_c2s.lua @@ -0,0 +1,332 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +module:set_global(); + +local add_task = require "util.timer".add_task; +local new_xmpp_stream = require "util.xmppstream".new; +local nameprep = require "util.encodings".stringprep.nameprep; +local sessionmanager = require "core.sessionmanager"; +local st = require "util.stanza"; +local sm_new_session, sm_destroy_session = sessionmanager.new_session, sessionmanager.destroy_session; +local uuid_generate = require "util.uuid".generate; +local runner = require "util.async".runner; + +local xpcall, tostring, type = xpcall, tostring, type; +local t_insert, t_remove = table.insert, table.remove; + +local xmlns_xmpp_streams = "urn:ietf:params:xml:ns:xmpp-streams"; + +local log = module._log; + +local c2s_timeout = module:get_option_number("c2s_timeout"); +local stream_close_timeout = module:get_option_number("c2s_close_timeout", 5); +local opt_keepalives = module:get_option_boolean("c2s_tcp_keepalives", module:get_option_boolean("tcp_keepalives", true)); + +local sessions = module:shared("sessions"); +local core_process_stanza = prosody.core_process_stanza; +local hosts = prosody.hosts; + +local stream_callbacks = { default_ns = "jabber:client" }; +local listener = {}; +local runner_callbacks = {}; + +--- Stream events handlers +local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; +local default_stream_attr = { ["xmlns:stream"] = "http://etherx.jabber.org/streams", xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; + +function stream_callbacks.streamopened(session, attr) + local send = session.send; + session.host = nameprep(attr.to); + if not session.host then + session:close{ condition = "improper-addressing", + text = "A valid 'to' attribute is required on stream headers" }; + return; + end + session.version = tonumber(attr.version) or 0; + session.streamid = uuid_generate(); + (session.log or session)("debug", "Client sent opening <stream:stream> to %s", session.host); + + if not hosts[session.host] or not hosts[session.host].modules.c2s then + -- We don't serve this host... + session:close{ condition = "host-unknown", text = "This server does not serve "..tostring(session.host)}; + return; + end + + send("<?xml version='1.0'?>"..st.stanza("stream:stream", { + xmlns = 'jabber:client', ["xmlns:stream"] = 'http://etherx.jabber.org/streams'; + id = session.streamid, from = session.host, version = '1.0', ["xml:lang"] = 'en' }):top_tag()); + + (session.log or log)("debug", "Sent reply <stream:stream> to client"); + session.notopen = nil; + + -- If session.secure is *false* (not nil) then it means we /were/ encrypting + -- since we now have a new stream header, session is secured + if session.secure == false then + session.secure = true; + + local sock = session.conn:socket(); + if sock.info then + local info = sock:info(); + (session.log or log)("info", "Stream encrypted (%s with %s)", info.protocol, info.cipher); + session.compressed = info.compression; + else + (session.log or log)("info", "Stream encrypted"); + session.compressed = sock.compression and sock:compression(); --COMPAT mw/luasec-hg + end + end + + local features = st.stanza("stream:features"); + hosts[session.host].events.fire_event("stream-features", { origin = session, features = features }); + send(features); +end + +function stream_callbacks.streamclosed(session) + session.log("debug", "Received </stream:stream>"); + session:close(false); +end + +function stream_callbacks.error(session, error, data) + if error == "no-stream" then + session.log("debug", "Invalid opening stream header"); + session:close("invalid-namespace"); + elseif error == "parse-error" then + (session.log or log)("debug", "Client XML parse error: %s", tostring(data)); + session:close("not-well-formed"); + elseif error == "stream-error" then + local condition, text = "undefined-condition"; + for child in data:children() do + if child.attr.xmlns == xmlns_xmpp_streams then + if child.name ~= "text" then + condition = child.name; + else + text = child:get_text(); + end + if condition ~= "undefined-condition" and text then + break; + end + end + end + text = condition .. (text and (" ("..text..")") or ""); + session.log("info", "Session closed by remote with error: %s", text); + session:close(nil, text); + end +end + +function stream_callbacks.handlestanza(session, stanza) + stanza = session.filter("stanzas/in", stanza); + session.thread:run(stanza); +end + +--- Session methods +local function session_close(session, reason) + local log = session.log or log; + if session.conn then + if session.notopen then + session.send("<?xml version='1.0'?>"); + session.send(st.stanza("stream:stream", default_stream_attr):top_tag()); + end + if reason then -- nil == no err, initiated by us, false == initiated by client + local stream_error = st.stanza("stream:error"); + if type(reason) == "string" then -- assume stream error + stream_error:tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' }); + elseif type(reason) == "table" then + if reason.condition then + stream_error:tag(reason.condition, stream_xmlns_attr):up(); + if reason.text then + stream_error:tag("text", stream_xmlns_attr):text(reason.text):up(); + end + if reason.extra then + stream_error:add_child(reason.extra); + end + elseif reason.name then -- a stanza + stream_error = reason; + end + end + stream_error = tostring(stream_error); + log("debug", "Disconnecting client, <stream:error> is: %s", stream_error); + session.send(stream_error); + end + + session.send("</stream:stream>"); + function session.send() return false; end + + local reason = (reason and (reason.name or reason.text or reason.condition)) or reason; + session.log("debug", "c2s stream for %s closed: %s", session.full_jid or ("<"..session.ip..">"), reason or "session closed"); + + -- Authenticated incoming stream may still be sending us stanzas, so wait for </stream:stream> from remote + local conn = session.conn; + if reason == nil and not session.notopen and session.type == "c2s" then + -- Grace time to process data from authenticated cleanly-closed stream + add_task(stream_close_timeout, function () + if not session.destroyed then + session.log("warn", "Failed to receive a stream close response, closing connection anyway..."); + sm_destroy_session(session, reason); + conn:close(); + end + end); + else + sm_destroy_session(session, reason); + conn:close(); + end + end +end + +module:hook_global("user-deleted", function(event) + local username, host = event.username, event.host; + local user = hosts[host].sessions[username]; + if user and user.sessions then + for jid, session in pairs(user.sessions) do + session:close{ condition = "not-authorized", text = "Account deleted" }; + end + end +end, 200); + +function runner_callbacks:ready() + self.data.conn:resume(); +end + +function runner_callbacks:waiting() + self.data.conn:pause(); +end + +function runner_callbacks:error(err) + (self.data.log or log)("error", "Traceback[c2s]: %s", err); +end + +--- Port listener +function listener.onconnect(conn) + local session = sm_new_session(conn); + sessions[conn] = session; + + session.log("info", "Client connected"); + + -- Client is using legacy SSL (otherwise mod_tls sets this flag) + if conn:ssl() then + session.secure = true; + + -- Check if TLS compression is used + local sock = conn:socket(); + if sock.info then + session.compressed = sock:info"compression"; + elseif sock.compression then + session.compressed = sock:compression(); --COMPAT mw/luasec-hg + end + end + + if opt_keepalives then + conn:setoption("keepalive", opt_keepalives); + end + + session.close = session_close; + + local stream = new_xmpp_stream(session, stream_callbacks); + session.stream = stream; + session.notopen = true; + + function session.reset_stream() + session.notopen = true; + session.stream:reset(); + end + + session.thread = runner(function (stanza) + core_process_stanza(session, stanza); + end, runner_callbacks, session); + + local filter = session.filter; + function session.data(data) + -- Parse the data, which will store stanzas in session.pending_stanzas + if data then + data = filter("bytes/in", data); + if data then + local ok, err = stream:feed(data); + if not ok then + log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " "):gsub("[%z\1-\31]", "_")); + session:close("not-well-formed"); + end + end + end + end + + if c2s_timeout then + add_task(c2s_timeout, function () + if session.type == "c2s_unauthed" then + session:close("connection-timeout"); + end + end); + end + + session.dispatch_stanza = stream_callbacks.handlestanza; +end + +function listener.onincoming(conn, data) + local session = sessions[conn]; + if session then + session.data(data); + end +end + +function listener.ondisconnect(conn, err) + local session = sessions[conn]; + if session then + (session.log or log)("info", "Client disconnected: %s", err or "connection closed"); + sm_destroy_session(session, err); + sessions[conn] = nil; + end +end + +function listener.onreadtimeout(conn) + local session = sessions[conn]; + if session then + return (hosts[session.host] or prosody).events.fire_event("c2s-read-timeout", { session = session }); + end +end + +local function keepalive(event) + return event.session.send(' '); +end + +function listener.associate_session(conn, session) + sessions[conn] = session; +end + +function module.add_host(module) + module:hook("c2s-read-timeout", keepalive, -1); +end + +module:hook("c2s-read-timeout", keepalive, -1); + +module:hook("server-stopping", function(event) + local reason = event.reason; + for _, session in pairs(sessions) do + session:close{ condition = "system-shutdown", text = reason }; + end +end, 1000); + + + +module:provides("net", { + name = "c2s"; + listener = listener; + default_port = 5222; + encryption = "starttls"; + multiplex = { + pattern = "^<.*:stream.*%sxmlns%s*=%s*(['\"])jabber:client%1.*>"; + }; +}); + +module:provides("net", { + name = "legacy_ssl"; + listener = listener; + encryption = "ssl"; + multiplex = { + pattern = "^<.*:stream.*%sxmlns%s*=%s*(['\"])jabber:client%1.*>"; + }; +}); + + diff --git a/plugins/mod_component.lua b/plugins/mod_component.lua index fda271dd..3eaacb8e 100644 --- a/plugins/mod_component.lua +++ b/plugins/mod_component.lua @@ -1,100 +1,324 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- -if module:get_host_type() ~= "component" then - error("Don't load mod_component manually, it should be for a component, please see http://prosody.im/doc/components", 0); -end - -local hosts = _G.hosts; +module:set_global(); local t_concat = table.concat; +local xpcall, tostring, type = xpcall, tostring, type; +local traceback = debug.traceback; +local logger = require "util.logger"; local sha1 = require "util.hashes".sha1; local st = require "util.stanza"; +local jid_split = require "util.jid".split; +local new_xmpp_stream = require "util.xmppstream".new; +local uuid_gen = require "util.uuid".generate; + +local core_process_stanza = prosody.core_process_stanza; +local hosts = prosody.hosts; + local log = module._log; -local main_session, send; +local sessions = module:shared("sessions"); + +function module.add_host(module) + if module:get_host_type() ~= "component" then + error("Don't load mod_component manually, it should be for a component, please see http://prosody.im/doc/components", 0); + end + + local env = module.environment; + env.connected = false; + + local send; -local function on_destroy(session, err) - if main_session == session then - main_session = nil; + local function on_destroy(session, err) + env.connected = false; send = nil; session.on_destroy = nil; end + + -- Handle authentication attempts by component + local function handle_component_auth(event) + local session, stanza = event.origin, event.stanza; + + if session.type ~= "component_unauthed" then return; end + + if (not session.host) or #stanza.tags > 0 then + (session.log or log)("warn", "Invalid component handshake for host: %s", session.host); + session:close("not-authorized"); + return true; + end + + local secret = module:get_option("component_secret"); + if not secret then + (session.log or log)("warn", "Component attempted to identify as %s, but component_secret is not set", session.host); + session:close("not-authorized"); + return true; + end + + local supplied_token = t_concat(stanza); + local calculated_token = sha1(session.streamid..secret, true); + if supplied_token:lower() ~= calculated_token:lower() then + module:log("info", "Component authentication failed for %s", session.host); + session:close{ condition = "not-authorized", text = "Given token does not match calculated token" }; + return true; + end + + if env.connected then + module:log("error", "Second component attempted to connect, denying connection"); + session:close{ condition = "conflict", text = "Component already connected" }; + return true; + end + + env.connected = true; + send = session.send; + session.on_destroy = on_destroy; + session.component_validate_from = module:get_option_boolean("validate_from_addresses", true); + session.type = "component"; + module:log("info", "External component successfully authenticated"); + session.send(st.stanza("handshake")); + + return true; + end + module:hook("stanza/jabber:component:accept:handshake", handle_component_auth); + + -- Handle stanzas addressed to this component + local function handle_stanza(event) + local stanza = event.stanza; + if send then + stanza.attr.xmlns = nil; + send(stanza); + else + if stanza.name == "iq" and stanza.attr.type == "get" and stanza.attr.to == module.host then + local query = stanza.tags[1]; + local node = query.attr.node; + if query.name == "query" and query.attr.xmlns == "http://jabber.org/protocol/disco#info" and (not node or node == "") then + local name = module:get_option_string("name"); + if name then + event.origin.send(st.reply(stanza):tag("query", { xmlns = "http://jabber.org/protocol/disco#info" }) + :tag("identity", { category = "component", type = "generic", name = module:get_option_string("name", "Prosody") })) + return true; + end + end + end + module:log("warn", "Component not connected, bouncing error for: %s", stanza:top_tag()); + if stanza.attr.type ~= "error" and stanza.attr.type ~= "result" then + event.origin.send(st.error_reply(stanza, "wait", "service-unavailable", "Component unavailable")); + end + end + return true; + end + + module:hook("iq/bare", handle_stanza, -1); + module:hook("message/bare", handle_stanza, -1); + module:hook("presence/bare", handle_stanza, -1); + module:hook("iq/full", handle_stanza, -1); + module:hook("message/full", handle_stanza, -1); + module:hook("presence/full", handle_stanza, -1); + module:hook("iq/host", handle_stanza, -1); + module:hook("message/host", handle_stanza, -1); + module:hook("presence/host", handle_stanza, -1); +end + +--- Network and stream part --- + +local xmlns_component = 'jabber:component:accept'; + +local listener = {}; + +--- Callbacks/data for xmppstream to handle streams for us --- + +local stream_callbacks = { default_ns = xmlns_component }; + +local xmlns_xmpp_streams = "urn:ietf:params:xml:ns:xmpp-streams"; + +function stream_callbacks.error(session, error, data, data2) + if session.destroyed then return; end + module:log("warn", "Error processing component stream: %s", tostring(error)); + if error == "no-stream" then + session:close("invalid-namespace"); + elseif error == "parse-error" then + session.log("warn", "External component %s XML parse error: %s", tostring(session.host), tostring(data)); + session:close("not-well-formed"); + elseif error == "stream-error" then + local condition, text = "undefined-condition"; + for child in data:children() do + if child.attr.xmlns == xmlns_xmpp_streams then + if child.name ~= "text" then + condition = child.name; + else + text = child:get_text(); + end + if condition ~= "undefined-condition" and text then + break; + end + end + end + text = condition .. (text and (" ("..text..")") or ""); + session.log("info", "Session closed by remote with error: %s", text); + session:close(nil, text); + end +end + +function stream_callbacks.streamopened(session, attr) + if not hosts[attr.to] or not hosts[attr.to].modules.component then + session:close{ condition = "host-unknown", text = tostring(attr.to).." does not match any configured external components" }; + return; + end + session.host = attr.to; + session.streamid = uuid_gen(); + session.notopen = nil; + -- Return stream header + session.send("<?xml version='1.0'?>"); + session.send(st.stanza("stream:stream", { xmlns=xmlns_component, + ["xmlns:stream"]='http://etherx.jabber.org/streams', id=session.streamid, from=session.host }):top_tag()); +end + +function stream_callbacks.streamclosed(session) + session.log("debug", "Received </stream:stream>"); + session:close(); end -local function handle_stanza(event) - local stanza = event.stanza; - if send then - stanza.attr.xmlns = nil; - send(stanza); - else - log("warn", "Stanza being handled by default component; bouncing error for: %s", stanza:top_tag()); - if stanza.attr.type ~= "error" and stanza.attr.type ~= "result" then - event.origin.send(st.error_reply(stanza, "wait", "service-unavailable", "Component unavailable")); +local function handleerr(err) log("error", "Traceback[component]: %s", traceback(tostring(err), 2)); end +function stream_callbacks.handlestanza(session, stanza) + -- Namespaces are icky. + if not stanza.attr.xmlns and stanza.name == "handshake" then + stanza.attr.xmlns = xmlns_component; + end + if not stanza.attr.xmlns or stanza.attr.xmlns == "jabber:client" then + local from = stanza.attr.from; + if from then + if session.component_validate_from then + local _, domain = jid_split(stanza.attr.from); + if domain ~= session.host then + -- Return error + session.log("warn", "Component sent stanza with missing or invalid 'from' address"); + session:close{ + condition = "invalid-from"; + text = "Component tried to send from address <"..tostring(from) + .."> which is not in domain <"..tostring(session.host)..">"; + }; + return; + end + end + else + stanza.attr.from = session.host; -- COMPAT: Strictly we shouldn't allow this + end + if not stanza.attr.to then + session.log("warn", "Rejecting stanza with no 'to' address"); + session.send(st.error_reply(stanza, "modify", "bad-request", "Components MUST specify a 'to' address on stanzas")); + return; end end - return true; + + if stanza then + return xpcall(function () return core_process_stanza(session, stanza) end, handleerr); + end end -module:hook("iq/bare", handle_stanza, -1); -module:hook("message/bare", handle_stanza, -1); -module:hook("presence/bare", handle_stanza, -1); -module:hook("iq/full", handle_stanza, -1); -module:hook("message/full", handle_stanza, -1); -module:hook("presence/full", handle_stanza, -1); -module:hook("iq/host", handle_stanza, -1); -module:hook("message/host", handle_stanza, -1); -module:hook("presence/host", handle_stanza, -1); - ---- Handle authentication attempts by components -function handle_component_auth(event) - local session, stanza = event.origin, event.stanza; - - if session.type ~= "component" then return; end - if main_session == session then return; end - - if (not session.host) or #stanza.tags > 0 then - (session.log or log)("warn", "Invalid component handshake for host: %s", session.host); - session:close("not-authorized"); - return true; +--- Closing a component connection +local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; +local default_stream_attr = { ["xmlns:stream"] = "http://etherx.jabber.org/streams", xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; +local function session_close(session, reason) + if session.destroyed then return; end + if session.conn then + if session.notopen then + session.send("<?xml version='1.0'?>"); + session.send(st.stanza("stream:stream", default_stream_attr):top_tag()); + end + if reason then + if type(reason) == "string" then -- assume stream error + module:log("info", "Disconnecting component, <stream:error> is: %s", reason); + session.send(st.stanza("stream:error"):tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' })); + elseif type(reason) == "table" then + if reason.condition then + local stanza = st.stanza("stream:error"):tag(reason.condition, stream_xmlns_attr):up(); + if reason.text then + stanza:tag("text", stream_xmlns_attr):text(reason.text):up(); + end + if reason.extra then + stanza:add_child(reason.extra); + end + module:log("info", "Disconnecting component, <stream:error> is: %s", tostring(stanza)); + session.send(stanza); + elseif reason.name then -- a stanza + module:log("info", "Disconnecting component, <stream:error> is: %s", tostring(reason)); + session.send(reason); + end + end + end + session.send("</stream:stream>"); + session.conn:close(); + listener.ondisconnect(session.conn, "stream error"); end - - local secret = module:get_option("component_secret"); - if not secret then - (session.log or log)("warn", "Component attempted to identify as %s, but component_secret is not set", session.host); - session:close("not-authorized"); - return true; +end + +--- Component connlistener + +function listener.onconnect(conn) + local _send = conn.write; + local session = { type = "component_unauthed", conn = conn, send = function (data) return _send(conn, tostring(data)); end }; + + -- Logging functions -- + local conn_name = "jcp"..tostring(session):match("[a-f0-9]+$"); + session.log = logger.init(conn_name); + session.close = session_close; + + session.log("info", "Incoming Jabber component connection"); + + local stream = new_xmpp_stream(session, stream_callbacks); + session.stream = stream; + + session.notopen = true; + + function session.reset_stream() + session.notopen = true; + session.stream:reset(); end - - local supplied_token = t_concat(stanza); - local calculated_token = sha1(session.streamid..secret, true); - if supplied_token:lower() ~= calculated_token:lower() then - log("info", "Component authentication failed for %s", session.host); - session:close{ condition = "not-authorized", text = "Given token does not match calculated token" }; - return true; + + function session.data(conn, data) + local ok, err = stream:feed(data); + if ok then return; end + module:log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " "):gsub("[%z\1-\31]", "_")); + session:close("not-well-formed"); end - - -- If component not already created for this host, create one now - if not main_session then - send = session.send; - main_session = session; - session.on_destroy = on_destroy; - session.component_validate_from = module:get_option_boolean("validate_from_addresses") ~= false; - log("info", "Component successfully authenticated: %s", session.host); - session.send(st.stanza("handshake")); - else -- TODO: Implement stanza distribution - log("error", "Multiple components bound to the same address, first one wins: %s", session.host); - session:close{ condition = "conflict", text = "Component already connected" }; + + session.dispatch_stanza = stream_callbacks.handlestanza; + + sessions[conn] = session; +end +function listener.onincoming(conn, data) + local session = sessions[conn]; + session.data(conn, data); +end +function listener.ondisconnect(conn, err) + local session = sessions[conn]; + if session then + (session.log or log)("info", "component disconnected: %s (%s)", tostring(session.host), tostring(err)); + if session.on_destroy then session:on_destroy(err); end + sessions[conn] = nil; + for k in pairs(session) do + if k ~= "log" and k ~= "close" then + session[k] = nil; + end + end + session.destroyed = true; + session = nil; end - - return true; end -module:hook("stanza/jabber:component:accept:handshake", handle_component_auth); +module:provides("net", { + name = "component"; + private = true; + listener = listener; + default_port = 5347; + multiplex = { + pattern = "^<.*:stream.*%sxmlns%s*=%s*(['\"])jabber:component:accept%1.*>"; + }; +}); diff --git a/plugins/mod_compression.lua b/plugins/mod_compression.lua index 82403016..f44e8a6d 100644 --- a/plugins/mod_compression.lua +++ b/plugins/mod_compression.lua @@ -1,6 +1,6 @@ -- Prosody IM --- Copyright (C) 2009 Tobias Markmann --- +-- Copyright (C) 2009-2012 Tobias Markmann +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -16,12 +16,8 @@ local xmlns_stream = "http://etherx.jabber.org/streams"; local compression_stream_feature = st.stanza("compression", {xmlns=xmlns_compression_feature}):tag("method"):text("zlib"):up(); local add_filter = require "util.filters".add_filter; -local compression_level = module:get_option("compression_level"); --- if not defined assume admin wants best compression -if compression_level == nil then compression_level = 9 end; - +local compression_level = module:get_option_number("compression_level", 7); -compression_level = tonumber(compression_level); if not compression_level or compression_level < 1 or compression_level > 9 then module:log("warn", "Invalid compression level in config: %s", tostring(compression_level)); module:log("warn", "Module loading aborted. Compression won't be available."); @@ -30,7 +26,7 @@ end module:hook("stream-features", function(event) local origin, features = event.origin, event.features; - if not origin.compressed then + if not origin.compressed and (origin.type == "c2s" or origin.type == "s2sin" or origin.type == "s2sout") then -- FIXME only advertise compression support when TLS layer has no compression enabled features:add_child(compression_stream_feature); end @@ -39,7 +35,7 @@ end); module:hook("s2s-stream-features", function(event) local origin, features = event.origin, event.features; -- FIXME only advertise compression support when TLS layer has no compression enabled - if not origin.compressed then + if not origin.compressed and (origin.type == "c2s" or origin.type == "s2sin" or origin.type == "s2sout") then features:add_child(compression_stream_feature); end end); @@ -47,7 +43,7 @@ end); -- Hook to activate compression if remote server supports it. module:hook_stanza(xmlns_stream, "features", function (session, stanza) - if not session.compressed then + if not session.compressed and (session.type == "c2s" or session.type == "s2sin" or session.type == "s2sout") then -- does remote server support compression? local comp_st = stanza:child_with_name("compression"); if comp_st then @@ -107,7 +103,7 @@ local function setup_compression(session, deflate_stream) return; end return compressed; - end); + end); end -- setup decompression for a stream @@ -129,26 +125,23 @@ end module:hook("stanza/http://jabber.org/protocol/compress:compressed", function(event) local session = event.origin; - + if session.type == "s2sout_unauthed" or session.type == "s2sout" then session.log("debug", "Activating compression...") -- create deflate and inflate streams local deflate_stream = get_deflate_stream(session); if not deflate_stream then return true; end - + local inflate_stream = get_inflate_stream(session); if not inflate_stream then return true; end - + -- setup compression for session.w setup_compression(session, deflate_stream); - + -- setup decompression for session.data setup_decompression(session, inflate_stream); session:reset_stream(); - local default_stream_attr = {xmlns = "jabber:server", ["xmlns:stream"] = "http://etherx.jabber.org/streams", - ["xmlns:db"] = 'jabber:server:dialback', version = "1.0", to = session.to_host, from = session.from_host}; - session.sends2s("<?xml version='1.0'?>"); - session.sends2s(st.stanza("stream:stream", default_stream_attr):top_tag()); + session:open_stream(session.from_host, session.to_host); session.compressed = true; return true; end @@ -165,29 +158,29 @@ module:hook("stanza/http://jabber.org/protocol/compress:compress", function(even session.log("debug", "Client tried to establish another compression layer."); return true; end - + -- checking if the compression method is supported local method = stanza:child_with_name("method"); method = method and (method[1] or ""); if method == "zlib" then session.log("debug", "zlib compression enabled."); - + -- create deflate and inflate streams local deflate_stream = get_deflate_stream(session); if not deflate_stream then return true; end - + local inflate_stream = get_inflate_stream(session); if not inflate_stream then return true; end - + (session.sends2s or session.send)(st.stanza("compressed", {xmlns=xmlns_compression_protocol})); session:reset_stream(); - + -- setup compression for session.w setup_compression(session, deflate_stream); - + -- setup decompression for session.data setup_decompression(session, inflate_stream); - + session.compressed = true; elseif method then session.log("debug", "%s compression selected, but we don't support it.", tostring(method)); diff --git a/plugins/mod_dialback.lua b/plugins/mod_dialback.lua index a8923e27..8d2bbd8f 100644 --- a/plugins/mod_dialback.lua +++ b/plugins/mod_dialback.lua @@ -1,39 +1,54 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- - local hosts = _G.hosts; -local send_s2s = require "core.s2smanager".send_to_host; -local s2s_make_authenticated = require "core.s2smanager".make_authenticated; -local s2s_initiate_dialback = require "core.s2smanager".initiate_dialback; -local s2s_verify_dialback = require "core.s2smanager".verify_dialback; -local s2s_destroy_session = require "core.s2smanager".destroy_session; local log = module._log; local st = require "util.stanza"; +local sha256_hash = require "util.hashes".sha256; +local nameprep = require "util.encodings".stringprep.nameprep; local xmlns_stream = "http://etherx.jabber.org/streams"; -local xmlns_dialback = "jabber:server:dialback"; local dialback_requests = setmetatable({}, { __mode = 'v' }); +function generate_dialback(id, to, from) + return sha256_hash(id..to..from..hosts[from].dialback_secret, true); +end + +function initiate_dialback(session) + -- generate dialback key + session.dialback_key = generate_dialback(session.streamid, session.to_host, session.from_host); + session.sends2s(st.stanza("db:result", { from = session.from_host, to = session.to_host }):text(session.dialback_key)); + session.log("debug", "sent dialback key on outgoing s2s stream"); +end + +function verify_dialback(id, to, from, key) + return key == generate_dialback(id, to, from); +end + module:hook("stanza/jabber:server:dialback:verify", function(event) local origin, stanza = event.origin, event.stanza; - + if origin.type == "s2sin_unauthed" or origin.type == "s2sin" then -- We are being asked to verify the key, to ensure it was generated by us origin.log("debug", "verifying that dialback key is ours..."); local attr = stanza.attr; + if attr.type then + module:log("warn", "Ignoring incoming session from %s claiming a dialback key for %s is %s", + origin.from_host or "(unknown)", attr.from or "(unknown)", attr.type); + return true; + end -- COMPAT: Grr, ejabberd breaks this one too?? it is black and white in XEP-220 example 34 --if attr.from ~= origin.to_host then error("invalid-from"); end local type; - if s2s_verify_dialback(attr.id, attr.from, attr.to, stanza[1]) then + if verify_dialback(attr.id, attr.from, attr.to, stanza[1]) then type = "valid" else type = "invalid" @@ -47,62 +62,68 @@ end); module:hook("stanza/jabber:server:dialback:result", function(event) local origin, stanza = event.origin, event.stanza; - + if origin.type == "s2sin_unauthed" or origin.type == "s2sin" then -- he wants to be identified through dialback -- We need to check the key with the Authoritative server local attr = stanza.attr; - origin.hosts[attr.from] = { dialback_key = stanza[1] }; - - if not hosts[attr.to] then + local to, from = nameprep(attr.to), nameprep(attr.from); + + if not hosts[to] then -- Not a host that we serve - origin.log("info", "%s tried to connect to %s, which we don't serve", attr.from, attr.to); + origin.log("warn", "%s tried to connect to %s, which we don't serve", from, to); origin:close("host-unknown"); return true; + elseif not from then + origin:close("improper-addressing"); end - - dialback_requests[attr.from] = origin; - + + origin.hosts[from] = { dialback_key = stanza[1] }; + + dialback_requests[from.."/"..origin.streamid] = origin; + + -- COMPAT: ejabberd, gmail and perhaps others do not always set 'to' and 'from' + -- on streams. We fill in the session's to/from here instead. if not origin.from_host then - -- Just used for friendlier logging - origin.from_host = attr.from; + origin.from_host = from; end if not origin.to_host then - -- Just used for friendlier logging - origin.to_host = attr.to; + origin.to_host = to; end - - origin.log("debug", "asking %s if key %s belongs to them", attr.from, stanza[1]); - send_s2s(attr.to, attr.from, - st.stanza("db:verify", { from = attr.to, to = attr.from, id = origin.streamid }):text(stanza[1])); + + origin.log("debug", "asking %s if key %s belongs to them", from, stanza[1]); + module:fire_event("route/remote", { + from_host = to, to_host = from; + stanza = st.stanza("db:verify", { from = to, to = from, id = origin.streamid }):text(stanza[1]); + }); return true; end end); module:hook("stanza/jabber:server:dialback:verify", function(event) local origin, stanza = event.origin, event.stanza; - + if origin.type == "s2sout_unauthed" or origin.type == "s2sout" then local attr = stanza.attr; - local dialback_verifying = dialback_requests[attr.from]; - if dialback_verifying then + local dialback_verifying = dialback_requests[attr.from.."/"..(attr.id or "")]; + if dialback_verifying and attr.from == origin.to_host then local valid; if attr.type == "valid" then - s2s_make_authenticated(dialback_verifying, attr.from); + module:fire_event("s2s-authenticated", { session = dialback_verifying, host = attr.from }); valid = "valid"; else -- Warn the original connection that is was not verified successfully - log("warn", "authoritative server for "..(attr.from or "(unknown)").." denied the key"); + log("warn", "authoritative server for %s denied the key", attr.from or "(unknown)"); valid = "invalid"; end - if not dialback_verifying.sends2s then + if dialback_verifying.destroyed then log("warn", "Incoming s2s session %s was closed in the meantime, so we can't notify it of the db result", tostring(dialback_verifying):match("%w+$")); else dialback_verifying.sends2s( st.stanza("db:result", { from = attr.to, to = attr.from, id = attr.id, type = valid }) :text(dialback_verifying.hosts[attr.from].dialback_key)); end - dialback_requests[attr.from] = nil; + dialback_requests[attr.from.."/"..(attr.id or "")] = nil; end return true; end @@ -110,10 +131,10 @@ end); module:hook("stanza/jabber:server:dialback:result", function(event) local origin, stanza = event.origin, event.stanza; - + if origin.type == "s2sout_unauthed" or origin.type == "s2sout" then -- Remote server is telling us whether we passed dialback - + local attr = stanza.attr; if not hosts[attr.to] then origin:close("host-unknown"); @@ -124,9 +145,9 @@ module:hook("stanza/jabber:server:dialback:result", function(event) return true; end if stanza.attr.type == "valid" then - s2s_make_authenticated(origin, attr.from); + module:fire_event("s2s-authenticated", { session = origin, host = attr.from }); else - s2s_destroy_session(origin) + origin:close("not-authorized", "dialback authentication failed"); end return true; end @@ -135,19 +156,26 @@ end); module:hook_stanza("urn:ietf:params:xml:ns:xmpp-sasl", "failure", function (origin, stanza) if origin.external_auth == "failed" then module:log("debug", "SASL EXTERNAL failed, falling back to dialback"); - s2s_initiate_dialback(origin); + initiate_dialback(origin); return true; end end, 100); module:hook_stanza(xmlns_stream, "features", function (origin, stanza) if not origin.external_auth or origin.external_auth == "failed" then - s2s_initiate_dialback(origin); + module:log("debug", "Initiating dialback..."); + initiate_dialback(origin); return true; end end, 100); +module:hook("s2sout-authenticate-legacy", function (event) + module:log("debug", "Initiating dialback..."); + initiate_dialback(event.origin); + return true; +end, 100); + -- Offer dialback to incoming hosts module:hook("s2s-stream-features", function (data) - data.features:tag("dialback", { xmlns='urn:xmpp:features:dialback' }):tag("optional"):up():up(); + data.features:tag("dialback", { xmlns='urn:xmpp:features:dialback' }):up(); end); diff --git a/plugins/mod_disco.lua b/plugins/mod_disco.lua index 907ca753..61749580 100644 --- a/plugins/mod_disco.lua +++ b/plugins/mod_disco.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -32,7 +32,9 @@ do -- validate disco_items end end -module:add_identity("server", "im", "Prosody"); -- FIXME should be in the non-existing mod_router +if module:get_host_type() == "local" then + module:add_identity("server", "im", module:get_option_string("name", "Prosody")); -- FIXME should be in the non-existing mod_router +end module:add_feature("http://jabber.org/protocol/disco#info"); module:add_feature("http://jabber.org/protocol/disco#items"); @@ -54,6 +56,12 @@ local function build_server_disco_info() done[feature] = true; end end + for _,extension in ipairs(module:get_host_items("extension")) do + if not done[extension] then + query:add_child(extension); + done[extension] = true; + end + end _cached_server_disco_info = query; _cached_server_caps_hash = calculate_hash(query); _cached_server_caps_feature = st.stanza("c", { @@ -81,15 +89,28 @@ end module:hook("item-added/identity", clear_disco_cache); module:hook("item-added/feature", clear_disco_cache); +module:hook("item-added/extension", clear_disco_cache); module:hook("item-removed/identity", clear_disco_cache); module:hook("item-removed/feature", clear_disco_cache); +module:hook("item-removed/extension", clear_disco_cache); -- Handle disco requests to the server module:hook("iq/host/http://jabber.org/protocol/disco#info:query", function(event) local origin, stanza = event.origin, event.stanza; if stanza.attr.type ~= "get" then return; end local node = stanza.tags[1].attr.node; - if node and node ~= "" and node ~= "http://prosody.im#"..get_server_caps_hash() then return; end -- TODO fire event? + if node and node ~= "" and node ~= "http://prosody.im#"..get_server_caps_hash() then + local reply = st.reply(stanza):tag('query', {xmlns='http://jabber.org/protocol/disco#info', node=node}); + local event = { origin = origin, stanza = stanza, reply = reply, node = node, exists = false}; + local ret = module:fire_event("host-disco-info-node", event); + if ret ~= nil then return ret; end + if event.exists then + origin.send(reply); + else + origin.send(st.error_reply(stanza, "cancel", "item-not-found", "Node does not exist")); + end + return true; + end local reply_query = get_server_disco_info(); reply_query.node = node; local reply = st.reply(stanza):add_child(reply_query); @@ -100,11 +121,23 @@ module:hook("iq/host/http://jabber.org/protocol/disco#items:query", function(eve local origin, stanza = event.origin, event.stanza; if stanza.attr.type ~= "get" then return; end local node = stanza.tags[1].attr.node; - if node and node ~= "" then return; end -- TODO fire event? - + if node and node ~= "" then + local reply = st.reply(stanza):tag('query', {xmlns='http://jabber.org/protocol/disco#items', node=node}); + local event = { origin = origin, stanza = stanza, reply = reply, node = node, exists = false}; + local ret = module:fire_event("host-disco-items-node", event); + if ret ~= nil then return ret; end + if event.exists then + origin.send(reply); + else + origin.send(st.error_reply(stanza, "cancel", "item-not-found", "Node does not exist")); + end + return true; + end local reply = st.reply(stanza):query("http://jabber.org/protocol/disco#items"); - for jid in pairs(get_children(module.host)) do - reply:tag("item", {jid = jid}):up(); + local ret = module:fire_event("host-disco-items", { origin = origin, stanza = stanza, reply = reply }); + if ret ~= nil then return ret; end + for jid, name in pairs(get_children(module.host)) do + reply:tag("item", {jid = jid, name = name~=true and name or nil}):up(); end for _, item in ipairs(disco_items) do reply:tag("item", {jid=item[1], name=item[2]}):up(); @@ -125,12 +158,24 @@ module:hook("iq/bare/http://jabber.org/protocol/disco#info:query", function(even local origin, stanza = event.origin, event.stanza; if stanza.attr.type ~= "get" then return; end local node = stanza.tags[1].attr.node; - if node and node ~= "" then return; end -- TODO fire event? local username = jid_split(stanza.attr.to) or origin.username; if not stanza.attr.to or is_contact_subscribed(username, module.host, jid_bare(stanza.attr.from)) then + if node and node ~= "" then + local reply = st.reply(stanza):tag('query', {xmlns='http://jabber.org/protocol/disco#info', node=node}); + if not reply.attr.from then reply.attr.from = origin.username.."@"..origin.host; end -- COMPAT To satisfy Psi when querying own account + local event = { origin = origin, stanza = stanza, reply = reply, node = node, exists = false}; + local ret = module:fire_event("account-disco-info-node", event); + if ret ~= nil then return ret; end + if event.exists then + origin.send(reply); + else + origin.send(st.error_reply(stanza, "cancel", "item-not-found", "Node does not exist")); + end + return true; + end local reply = st.reply(stanza):tag('query', {xmlns='http://jabber.org/protocol/disco#info'}); if not reply.attr.from then reply.attr.from = origin.username.."@"..origin.host; end -- COMPAT To satisfy Psi when querying own account - module:fire_event("account-disco-info", { origin = origin, stanza = reply }); + module:fire_event("account-disco-info", { origin = origin, reply = reply }); origin.send(reply); return true; end @@ -139,12 +184,24 @@ module:hook("iq/bare/http://jabber.org/protocol/disco#items:query", function(eve local origin, stanza = event.origin, event.stanza; if stanza.attr.type ~= "get" then return; end local node = stanza.tags[1].attr.node; - if node and node ~= "" then return; end -- TODO fire event? local username = jid_split(stanza.attr.to) or origin.username; if not stanza.attr.to or is_contact_subscribed(username, module.host, jid_bare(stanza.attr.from)) then + if node and node ~= "" then + local reply = st.reply(stanza):tag('query', {xmlns='http://jabber.org/protocol/disco#items', node=node}); + if not reply.attr.from then reply.attr.from = origin.username.."@"..origin.host; end -- COMPAT To satisfy Psi when querying own account + local event = { origin = origin, stanza = stanza, reply = reply, node = node, exists = false}; + local ret = module:fire_event("account-disco-items-node", event); + if ret ~= nil then return ret; end + if event.exists then + origin.send(reply); + else + origin.send(st.error_reply(stanza, "cancel", "item-not-found", "Node does not exist")); + end + return true; + end local reply = st.reply(stanza):tag('query', {xmlns='http://jabber.org/protocol/disco#items'}); if not reply.attr.from then reply.attr.from = origin.username.."@"..origin.host; end -- COMPAT To satisfy Psi when querying own account - module:fire_event("account-disco-items", { origin = origin, stanza = reply }); + module:fire_event("account-disco-items", { origin = origin, stanza = stanza, reply = reply }); origin.send(reply); return true; end diff --git a/plugins/mod_groups.lua b/plugins/mod_groups.lua index 7a876f1d..be1a5508 100644 --- a/plugins/mod_groups.lua +++ b/plugins/mod_groups.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -13,15 +13,17 @@ local members; local groups_file; local jid, datamanager = require "util.jid", require "util.datamanager"; -local jid_bare, jid_prep = jid.bare, jid.prep; +local jid_prep = jid.prep; local module_host = module:get_host(); -function inject_roster_contacts(username, host, roster) +function inject_roster_contacts(event) + local username, host= event.username, event.host; --module:log("debug", "Injecting group members to roster"); local bare_jid = username.."@"..host; if not members[bare_jid] and not members[false] then return; end -- Not a member of any groups - + + local roster = event.roster; local function import_jids_to_roster(group_name) for jid in pairs(groups[group_name]) do -- Add them to roster @@ -48,7 +50,7 @@ function inject_roster_contacts(username, host, roster) import_jids_to_roster(group_name); end end - + -- Import public groups if members[false] then for _, group_name in ipairs(members[false]) do @@ -56,7 +58,7 @@ function inject_roster_contacts(username, host, roster) import_jids_to_roster(group_name); end end - + if roster[false] then roster[false].version = true; end @@ -80,12 +82,12 @@ function remove_virtual_contacts(username, host, datastore, data) end function module.load() - groups_file = config.get(module:get_host(), "core", "groups_file"); + groups_file = module:get_option_string("groups_file"); if not groups_file then return; end - + module:hook("roster-load", inject_roster_contacts); datamanager.add_callback(remove_virtual_contacts); - + groups = { default = {} }; members = { }; local curr_group = "default"; @@ -121,3 +123,8 @@ end function module.unload() datamanager.remove_callback(remove_virtual_contacts); end + +-- Public for other modules to access +function group_contains(group_name, jid) + return groups[group_name][jid]; +end diff --git a/plugins/mod_http.lua b/plugins/mod_http.lua new file mode 100644 index 00000000..95933da5 --- /dev/null +++ b/plugins/mod_http.lua @@ -0,0 +1,146 @@ +-- Prosody IM +-- Copyright (C) 2008-2012 Matthew Wild +-- Copyright (C) 2008-2012 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +module:set_global(); +module:depends("http_errors"); + +local portmanager = require "core.portmanager"; +local moduleapi = require "core.moduleapi"; +local url_parse = require "socket.url".parse; +local url_build = require "socket.url".build; + +local server = require "net.http.server"; + +server.set_default_host(module:get_option_string("http_default_host")); + +local function normalize_path(path) + if path:sub(-1,-1) == "/" then path = path:sub(1, -2); end + if path:sub(1,1) ~= "/" then path = "/"..path; end + return path; +end + +local function get_http_event(host, app_path, key) + local method, path = key:match("^(%S+)%s+(.+)$"); + if not method then -- No path specified, default to "" (base path) + method, path = key, ""; + end + if method:sub(1,1) == "/" then + return nil; + end + if app_path == "/" and path:sub(1,1) == "/" then + app_path = ""; + end + return method:upper().." "..host..app_path..path; +end + +local function get_base_path(host_module, app_name, default_app_path) + return (normalize_path(host_module:get_option("http_paths", {})[app_name] -- Host + or module:get_option("http_paths", {})[app_name] -- Global + or default_app_path)) -- Default + :gsub("%$(%w+)", { host = module.host }); +end + +local ports_by_scheme = { http = 80, https = 443, }; + +-- Helper to deduce a module's external URL +function moduleapi.http_url(module, app_name, default_path) + app_name = app_name or (module.name:gsub("^http_", "")); + local external_url = url_parse(module:get_option_string("http_external_url")) or {}; + local services = portmanager.get_active_services(); + local http_services = services:get("https") or services:get("http") or {}; + for interface, ports in pairs(http_services) do + for port, services in pairs(ports) do + local url = { + scheme = (external_url.scheme or services[1].service.name); + host = (external_url.host or module:get_option_string("http_host", module.host)); + port = tonumber(external_url.port) or port or 80; + path = normalize_path(external_url.path or "/").. + (get_base_path(module, app_name, default_path or "/"..app_name):sub(2)); + } + if ports_by_scheme[url.scheme] == url.port then url.port = nil end + return url_build(url); + end + end +end + +function module.add_host(module) + local host = module:get_option_string("http_host", module.host); + local apps = {}; + module.environment.apps = apps; + local function http_app_added(event) + local app_name = event.item.name; + local default_app_path = event.item.default_path or "/"..app_name; + local app_path = get_base_path(module, app_name, default_app_path); + if not app_name then + -- TODO: Link to docs + module:log("error", "HTTP app has no 'name', add one or use module:provides('http', app)"); + return; + end + apps[app_name] = apps[app_name] or {}; + local app_handlers = apps[app_name]; + for key, handler in pairs(event.item.route or {}) do + local event_name = get_http_event(host, app_path, key); + if event_name then + if type(handler) ~= "function" then + local data = handler; + handler = function () return data; end + elseif event_name:sub(-2, -1) == "/*" then + local base_path_len = #event_name:match("/.+$"); + local _handler = handler; + handler = function (event) + local path = event.request.path:sub(base_path_len); + return _handler(event, path); + end; + end + if not app_handlers[event_name] then + app_handlers[event_name] = handler; + module:hook_object_event(server, event_name, handler); + else + module:log("warn", "App %s added handler twice for '%s', ignoring", app_name, event_name); + end + else + module:log("error", "Invalid route in %s, %q. See http://prosody.im/doc/developers/http#routes", app_name, key); + end + end + end + + local function http_app_removed(event) + local app_handlers = apps[event.item.name]; + apps[event.item.name] = nil; + for event, handler in pairs(app_handlers) do + module:unhook_object_event(server, event, handler); + end + end + + module:handle_items("http-provider", http_app_added, http_app_removed); + + server.add_host(host); + function module.unload() + server.remove_host(host); + end +end + +module:provides("net", { + name = "http"; + listener = server.listener; + default_port = 5280; + multiplex = { + pattern = "^[A-Z]"; + }; +}); + +module:provides("net", { + name = "https"; + listener = server.listener; + default_port = 5281; + encryption = "ssl"; + ssl_config = { verify = "none" }; + multiplex = { + pattern = "^[A-Z]"; + }; +}); diff --git a/plugins/mod_http_errors.lua b/plugins/mod_http_errors.lua new file mode 100644 index 00000000..0c37e104 --- /dev/null +++ b/plugins/mod_http_errors.lua @@ -0,0 +1,75 @@ +module:set_global(); + +local server = require "net.http.server"; +local codes = require "net.http.codes"; + +local show_private = module:get_option_boolean("http_errors_detailed", false); +local always_serve = module:get_option_boolean("http_errors_always_show", true); +local default_message = { module:get_option_string("http_errors_default_message", "That's all I know.") }; +local default_messages = { + [400] = { "What kind of request do you call that??" }; + [403] = { "You're not allowed to do that." }; + [404] = { "Whatever you were looking for is not here. %"; + "Where did you put it?", "It's behind you.", "Keep looking." }; + [500] = { "% Check your error log for more info."; + "Gremlins.", "It broke.", "Don't look at me." }; +}; + +local messages = setmetatable(module:get_option("http_errors_messages", {}), { __index = default_messages }); + +local html = [[ +<!DOCTYPE html> +<html> +<head> + <meta charset="utf-8"> + <style> + body{ + margin-top:14%; + text-align:center; + background-color:#F8F8F8; + font-family:sans-serif; + } + h1{ + font-size:xx-large; + } + p{ + font-size:x-large; + } + p+p { font-size: large; font-family: courier } + </style> +</head> +<body> + <h1>$title</h1> + <p>$message</p> + <p>$extra</p> +</body> +</html>]]; +html = html:gsub("%s%s+", ""); + +local entities = { + ["<"] = "<", [">"] = ">", ["&"] = "&", + ["'"] = "'", ["\""] = """, ["\n"] = "<br/>", +}; + +local function tohtml(plain) + return (plain:gsub("[<>&'\"\n]", entities)); + +end + +local function get_page(code, extra) + local message = messages[code]; + if always_serve or message then + message = message or default_message; + return (html:gsub("$(%a+)", { + title = rawget(codes, code) or ("Code "..tostring(code)); + message = message[1]:gsub("%%", function () + return message[math.random(2, math.max(#message,2))]; + end); + extra = tohtml(extra or ""); + })); + end +end + +module:hook_object_event(server, "http-error", function (event) + return get_page(event.code, (show_private and event.private_message) or event.message); +end); diff --git a/plugins/mod_http_files.lua b/plugins/mod_http_files.lua new file mode 100644 index 00000000..dd04853b --- /dev/null +++ b/plugins/mod_http_files.lua @@ -0,0 +1,153 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +module:depends("http"); +local server = require"net.http.server"; +local lfs = require "lfs"; + +local os_date = os.date; +local open = io.open; +local stat = lfs.attributes; +local build_path = require"socket.url".build_path; + +local base_path = module:get_option_string("http_files_dir", module:get_option_string("http_path")); +local dir_indices = module:get_option("http_index_files", { "index.html", "index.htm" }); +local directory_index = module:get_option_boolean("http_dir_listing"); + +local mime_map = module:shared("/*/http_files/mime").types; +if not mime_map then + mime_map = { + html = "text/html", htm = "text/html", + xml = "application/xml", + txt = "text/plain", + css = "text/css", + js = "application/javascript", + png = "image/png", + gif = "image/gif", + jpeg = "image/jpeg", jpg = "image/jpeg", + svg = "image/svg+xml", + }; + module:shared("/*/http_files/mime").types = mime_map; + + local mime_types, err = open(module:get_option_string("mime_types_file", "/etc/mime.types"),"r"); + if mime_types then + local mime_data = mime_types:read("*a"); + mime_types:close(); + setmetatable(mime_map, { + __index = function(t, ext) + local typ = mime_data:match("\n(%S+)[^\n]*%s"..(ext:lower()).."%s") or "application/octet-stream"; + t[ext] = typ; + return typ; + end + }); + end +end + +local cache = setmetatable({}, { __mode = "kv" }); -- Let the garbage collector have it if it wants to. + +function serve(opts) + if type(opts) ~= "table" then -- assume path string + opts = { path = opts }; + end + local base_path = opts.path; + local dir_indices = opts.index_files or dir_indices; + local directory_index = opts.directory_index; + local function serve_file(event, path) + local request, response = event.request, event.response; + local orig_path = request.path; + local full_path = base_path .. (path and "/"..path or ""); + local attr = stat(full_path); + if not attr then + return 404; + end + + local request_headers, response_headers = request.headers, response.headers; + + local last_modified = os_date('!%a, %d %b %Y %H:%M:%S GMT', attr.modification); + response_headers.last_modified = last_modified; + + local etag = ("%02x-%x-%x-%x"):format(attr.dev or 0, attr.ino or 0, attr.size or 0, attr.modification or 0); + response_headers.etag = etag; + + local if_none_match = request_headers.if_none_match + local if_modified_since = request_headers.if_modified_since; + if etag == if_none_match + or (not if_none_match and last_modified == if_modified_since) then + return 304; + end + + local data = cache[orig_path]; + if data and data.etag == etag then + response_headers.content_type = data.content_type; + data = data.data; + elseif attr.mode == "directory" and path then + if full_path:sub(-1) ~= "/" then + local path = { is_absolute = true, is_directory = true }; + for dir in orig_path:gmatch("[^/]+") do path[#path+1]=dir; end + response_headers.location = build_path(path); + return 301; + end + for i=1,#dir_indices do + if stat(full_path..dir_indices[i], "mode") == "file" then + return serve_file(event, path..dir_indices[i]); + end + end + + if directory_index then + data = server._events.fire_event("directory-index", { path = request.path, full_path = full_path }); + end + if not data then + return 403; + end + cache[orig_path] = { data = data, content_type = mime_map.html; etag = etag; }; + response_headers.content_type = mime_map.html; + + else + local f, err = open(full_path, "rb"); + if f then + data, err = f:read("*a"); + f:close(); + end + if not data then + module:log("debug", "Could not open or read %s. Error was %s", full_path, err); + return 403; + end + local ext = full_path:match("%.([^./]+)$"); + local content_type = ext and mime_map[ext]; + cache[orig_path] = { data = data; content_type = content_type; etag = etag }; + response_headers.content_type = content_type; + end + + return response:send(data); + end + + return serve_file; +end + +function wrap_route(routes) + for route,handler in pairs(routes) do + if type(handler) ~= "function" then + routes[route] = serve(handler); + end + end + return routes; +end + +if base_path then + module:provides("http", { + route = { + ["GET /*"] = serve { + path = base_path; + directory_index = directory_index; + } + }; + }); +else + module:log("debug", "http_files_dir not set, assuming use by some other module"); +end + diff --git a/plugins/mod_httpserver.lua b/plugins/mod_httpserver.lua deleted file mode 100644 index 654aff06..00000000 --- a/plugins/mod_httpserver.lua +++ /dev/null @@ -1,97 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - -local httpserver = require "net.httpserver"; -local lfs = require "lfs"; - -local open = io.open; -local t_concat = table.concat; -local stat = lfs.attributes; - -local http_base = config.get("*", "core", "http_path") or "www_files"; - -local response_400 = { status = "400 Bad Request", body = "<h1>Bad Request</h1>Sorry, we didn't understand your request :(" }; -local response_403 = { status = "403 Forbidden", body = "<h1>Forbidden</h1>You don't have permission to view the contents of this directory :(" }; -local response_404 = { status = "404 Not Found", body = "<h1>Page Not Found</h1>Sorry, we couldn't find what you were looking for :(" }; - --- TODO: Should we read this from /etc/mime.types if it exists? (startup time...?) -local mime_map = { - html = "text/html"; - htm = "text/html"; - xml = "text/xml"; - xsl = "text/xml"; - txt = "text/plain; charset=utf-8"; - js = "text/javascript"; - css = "text/css"; -}; - -local function preprocess_path(path) - if path:sub(1,1) ~= "/" then - path = "/"..path; - end - local level = 0; - for component in path:gmatch("([^/]+)/") do - if component == ".." then - level = level - 1; - elseif component ~= "." then - level = level + 1; - end - if level < 0 then - return nil; - end - end - return path; -end - -function serve_file(path) - local full_path = http_base..path; - if stat(full_path, "mode") == "directory" then - if stat(full_path.."/index.html", "mode") == "file" then - return serve_file(path.."/index.html"); - end - return response_403; - end - local f, err = open(full_path, "rb"); - if not f then return response_404; end - local data = f:read("*a"); - f:close(); - if not data then - return response_403; - end - local ext = path:match("%.([^.]*)$"); - local mime = mime_map[ext]; -- Content-Type should be nil when not known - return { - headers = { ["Content-Type"] = mime; }; - body = data; - }; -end - -local function handle_file_request(method, body, request) - local path = preprocess_path(request.url.path); - if not path then return response_400; end - path = path:gsub("^/[^/]+", ""); -- Strip /files/ - return serve_file(path); -end - -local function handle_default_request(method, body, request) - local path = preprocess_path(request.url.path); - if not path then return response_400; end - return serve_file(path); -end - -local function setup() - local ports = config.get(module.host, "core", "http_ports") or { 5280 }; - httpserver.set_default_handler(handle_default_request); - httpserver.new_from_config(ports, handle_file_request, { base = "files" }); -end -if prosody.start_time then -- already started - setup(); -else - prosody.events.add_handler("server-started", setup); -end diff --git a/plugins/mod_iq.lua b/plugins/mod_iq.lua index 484a1f8f..c6d62e85 100644 --- a/plugins/mod_iq.lua +++ b/plugins/mod_iq.lua @@ -1,17 +1,15 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- local st = require "util.stanza"; -local jid_split = require "util.jid".split; -local full_sessions = full_sessions; -local bare_sessions = bare_sessions; +local full_sessions = prosody.full_sessions; if module:get_host_type() == "local" then module:hook("iq/full", function(data) @@ -19,10 +17,7 @@ if module:get_host_type() == "local" then local origin, stanza = data.origin, data.stanza; local session = full_sessions[stanza.attr.to]; - if session then - -- TODO fire post processing event - session.send(stanza); - else -- resource not online + if not (session and session.send(stanza)) then if stanza.attr.type == "get" or stanza.attr.type == "set" then origin.send(st.error_reply(stanza, "cancel", "service-unavailable")); end @@ -33,15 +28,16 @@ end module:hook("iq/bare", function(data) -- IQ to bare JID recieved - local origin, stanza = data.origin, data.stanza; + local stanza = data.stanza; local type = stanza.attr.type; -- TODO fire post processing events if type == "get" or type == "set" then local child = stanza.tags[1]; - local ret = module:fire_event("iq/bare/"..child.attr.xmlns..":"..child.name, data); + local xmlns = child.attr.xmlns or "jabber:client"; + local ret = module:fire_event("iq/bare/"..xmlns..":"..child.name, data); if ret ~= nil then return ret; end - return module:fire_event("iq-"..type.."/bare/"..child.attr.xmlns..":"..child.name, data); + return module:fire_event("iq-"..type.."/bare/"..xmlns..":"..child.name, data); else return module:fire_event("iq-"..type.."/bare/"..stanza.attr.id, data); end @@ -49,14 +45,15 @@ end); module:hook("iq/self", function(data) -- IQ to self JID recieved - local origin, stanza = data.origin, data.stanza; + local stanza = data.stanza; local type = stanza.attr.type; if type == "get" or type == "set" then local child = stanza.tags[1]; - local ret = module:fire_event("iq/self/"..child.attr.xmlns..":"..child.name, data); + local xmlns = child.attr.xmlns or "jabber:client"; + local ret = module:fire_event("iq/self/"..xmlns..":"..child.name, data); if ret ~= nil then return ret; end - return module:fire_event("iq-"..type.."/self/"..child.attr.xmlns..":"..child.name, data); + return module:fire_event("iq-"..type.."/self/"..xmlns..":"..child.name, data); else return module:fire_event("iq-"..type.."/self/"..stanza.attr.id, data); end @@ -64,14 +61,15 @@ end); module:hook("iq/host", function(data) -- IQ to a local host recieved - local origin, stanza = data.origin, data.stanza; + local stanza = data.stanza; local type = stanza.attr.type; if type == "get" or type == "set" then local child = stanza.tags[1]; - local ret = module:fire_event("iq/host/"..child.attr.xmlns..":"..child.name, data); + local xmlns = child.attr.xmlns or "jabber:client"; + local ret = module:fire_event("iq/host/"..xmlns..":"..child.name, data); if ret ~= nil then return ret; end - return module:fire_event("iq-"..type.."/host/"..child.attr.xmlns..":"..child.name, data); + return module:fire_event("iq-"..type.."/host/"..xmlns..":"..child.name, data); else return module:fire_event("iq-"..type.."/host/"..stanza.attr.id, data); end diff --git a/plugins/mod_lastactivity.lua b/plugins/mod_lastactivity.lua index 11053709..fabf07b4 100644 --- a/plugins/mod_lastactivity.lua +++ b/plugins/mod_lastactivity.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/plugins/mod_legacyauth.lua b/plugins/mod_legacyauth.lua index 47a8c0ab..cb5ce0d3 100644 --- a/plugins/mod_legacyauth.lua +++ b/plugins/mod_legacyauth.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -11,7 +11,9 @@ local st = require "util.stanza"; local t_concat = table.concat; -local secure_auth_only = module:get_option("c2s_require_encryption") or module:get_option("require_encryption"); +local secure_auth_only = module:get_option("c2s_require_encryption") + or module:get_option("require_encryption") + or not(module:get_option("allow_unencrypted_plain_auth")); local sessionmanager = require "core.sessionmanager"; local usermanager = require "core.usermanager"; @@ -33,7 +35,7 @@ module:hook("stanza/iq/jabber:iq:auth:query", function(event) local session, stanza = event.origin, event.stanza; if session.type ~= "c2s_unauthed" then - session.send(st.error_reply(stanza, "cancel", "service-unavailable", "Legacy authentication is only allowed for unauthenticated client connections.")); + (session.sends2s or session.send)(st.error_reply(stanza, "cancel", "service-unavailable", "Legacy authentication is only allowed for unauthenticated client connections.")); return true; end @@ -41,7 +43,7 @@ module:hook("stanza/iq/jabber:iq:auth:query", function(event) session.send(st.error_reply(stanza, "modify", "not-acceptable", "Encryption (SSL or TLS) is required to connect to this server")); return true; end - + local username = stanza.tags[1]:child_with_name("username"); local password = stanza.tags[1]:child_with_name("password"); local resource = stanza.tags[1]:child_with_name("resource"); @@ -55,7 +57,10 @@ module:hook("stanza/iq/jabber:iq:auth:query", function(event) username, password, resource = t_concat(username), t_concat(password), t_concat(resource); username = nodeprep(username); resource = resourceprep(resource) - local reply = st.reply(stanza); + if not (username and resource) then + session.send(st.error_reply(stanza, "modify", "bad-request")); + return true; + end if usermanager.test_password(username, session.host, password) then -- Authentication successful! local success, err = sessionmanager.make_authenticated(session, username); diff --git a/plugins/mod_message.lua b/plugins/mod_message.lua index df317532..fc337db0 100644 --- a/plugins/mod_message.lua +++ b/plugins/mod_message.lua @@ -1,24 +1,23 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- -local full_sessions = full_sessions; -local bare_sessions = bare_sessions; +local full_sessions = prosody.full_sessions; +local bare_sessions = prosody.bare_sessions; local st = require "util.stanza"; local jid_bare = require "util.jid".bare; local jid_split = require "util.jid".split; local user_exists = require "core.usermanager".user_exists; -local t_insert = table.insert; local function process_to_bare(bare, origin, stanza) local user = bare_sessions[bare]; - + local t = stanza.attr.type; if t == "error" then -- discard @@ -36,10 +35,13 @@ local function process_to_bare(bare, origin, stanza) if user then -- some resources are connected local recipients = user.top_resources; if recipients then + local sent; for i=1,#recipients do - recipients[i].send(stanza); + sent = recipients[i].send(stanza) or sent; + end + if sent then + return true; end - return true; end end -- no resources are online @@ -64,11 +66,9 @@ end module:hook("message/full", function(data) -- message to full JID recieved local origin, stanza = data.origin, data.stanza; - + local session = full_sessions[stanza.attr.to]; - if session then - -- TODO fire post processing event - session.send(stanza); + if session and session.send(stanza) then return true; else -- resource not online return process_to_bare(jid_bare(stanza.attr.to), origin, stanza); diff --git a/plugins/mod_motd.lua b/plugins/mod_motd.lua index f323e606..1e2ee395 100644 --- a/plugins/mod_motd.lua +++ b/plugins/mod_motd.lua @@ -2,24 +2,29 @@ -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain -- Copyright (C) 2010 Jeff Mitchell --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- local host = module:get_host(); -local motd_text = module:get_option("motd_text") or "MOTD: (blank)"; -local motd_jid = module:get_option("motd_jid") or host; +local motd_text = module:get_option_string("motd_text"); +local motd_jid = module:get_option_string("motd_jid", host); + +if not motd_text then return; end local st = require "util.stanza"; -module:hook("resource-bind", - function (event) - local session = event.session; - local motd_stanza = - st.message({ to = session.username..'@'..session.host, from = motd_jid }) - :tag("body"):text(motd_text); - core_route_stanza(hosts[host], motd_stanza); - module:log("debug", "MOTD send to user %s@%s", session.username, session.host); +motd_text = motd_text:gsub("^%s*(.-)%s*$", "%1"):gsub("\n%s+", "\n"); -- Strip indentation from the config -end); +module:hook("presence/bare", function (event) + local session, stanza = event.origin, event.stanza; + if session.username and not session.presence + and not stanza.attr.type and not stanza.attr.to then + local motd_stanza = + st.message({ to = session.full_jid, from = motd_jid }) + :tag("body"):text(motd_text); + module:send(motd_stanza); + module:log("debug", "MOTD send to user %s", session.full_jid); + end +end, 1); diff --git a/plugins/mod_net_multiplex.lua b/plugins/mod_net_multiplex.lua new file mode 100644 index 00000000..d666b907 --- /dev/null +++ b/plugins/mod_net_multiplex.lua @@ -0,0 +1,70 @@ +module:set_global(); + +local max_buffer_len = module:get_option_number("multiplex_buffer_size", 1024); + +local portmanager = require "core.portmanager"; + +local available_services = {}; + +local function add_service(service) + local multiplex_pattern = service.multiplex and service.multiplex.pattern; + if multiplex_pattern then + module:log("debug", "Adding multiplex service %q with pattern %q", service.name, multiplex_pattern); + available_services[service] = multiplex_pattern; + else + module:log("debug", "Service %q is not multiplex-capable", service.name); + end +end +module:hook("service-added", function (event) add_service(event.service); end); +module:hook("service-removed", function (event) available_services[event.service] = nil; end); + +for service_name, services in pairs(portmanager.get_registered_services()) do + for i, service in ipairs(services) do + add_service(service); + end +end + +local buffers = {}; + +local listener = { default_mode = "*a" }; + +function listener.onconnect() +end + +function listener.onincoming(conn, data) + if not data then return; end + local buf = buffers[conn]; + buffers[conn] = nil; + buf = buf and buf..data or data; + for service, multiplex_pattern in pairs(available_services) do + if buf:match(multiplex_pattern) then + module:log("debug", "Routing incoming connection to %s", service.name); + local listener = service.listener; + conn:setlistener(listener); + local onconnect = listener.onconnect; + if onconnect then onconnect(conn) end + return listener.onincoming(conn, buf); + end + end + if #buf > max_buffer_len then -- Give up + conn:close(); + else + buffers[conn] = buf; + end +end + +function listener.ondisconnect(conn, err) + buffers[conn] = nil; -- warn if no buffer? +end + +module:provides("net", { + name = "multiplex"; + config_prefix = ""; + listener = listener; +}); + +module:provides("net", { + name = "multiplex_ssl"; + config_prefix = "ssl"; + listener = listener; +}); diff --git a/plugins/mod_offline.lua b/plugins/mod_offline.lua index 1ac62f94..c168711b 100644 --- a/plugins/mod_offline.lua +++ b/plugins/mod_offline.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2009 Matthew Wild -- Copyright (C) 2008-2009 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -24,11 +24,11 @@ module:hook("message/offline/handle", function(event) else node, host = origin.username, origin.host; end - + stanza.attr.stamp, stanza.attr.stamp_legacy = datetime.datetime(), datetime.legacy(); local result = datamanager.list_append(node, host, "offline", st.preserialize(stanza)); stanza.attr.stamp, stanza.attr.stamp_legacy = nil, nil; - + return result; end); diff --git a/plugins/mod_pep.lua b/plugins/mod_pep.lua index 9ff6cac2..752cd28c 100644 --- a/plugins/mod_pep.lua +++ b/plugins/mod_pep.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -10,13 +10,12 @@ local jid_bare = require "util.jid".bare; local jid_split = require "util.jid".split; local st = require "util.stanza"; -local hosts = hosts; -local user_exists = require "core.usermanager".user_exists; local is_contact_subscribed = require "core.rostermanager".is_contact_subscribed; -local pairs, ipairs = pairs, ipairs; +local pairs = pairs; local next = next; local type = type; local calculate_hash = require "util.caps".calculate_hash; +local core_post_stanza = prosody.core_post_stanza; local NULL = {}; local data = {}; @@ -32,7 +31,7 @@ module.restore = function(state) hash_map = state.hash_map or {}; end -module:add_identity("pubsub", "pep", "Prosody"); +module:add_identity("pubsub", "pep", module:get_option_string("name", "Prosody")); module:add_feature("http://jabber.org/protocol/pubsub#publish"); local function subscription_presence(user_bare, recipient) @@ -63,7 +62,7 @@ local function publish(session, node, id, item) end else if not user_data then user_data = {}; data[bare] = user_data; end - user_data[node] = {id or "1", item}; + user_data[node] = {id, item}; end -- broadcast @@ -124,7 +123,7 @@ module:hook("presence/bare", function(event) local recipient = stanza.attr.from; local current = recipients[user] and recipients[user][recipient]; local hash = get_caps_hash_from_presence(stanza, current); - if current == hash then return; end + if current == hash or (current and current == hash_map[hash]) then return; end if not hash then if recipients[user] then recipients[user][recipient] = nil; end else @@ -136,8 +135,9 @@ module:hook("presence/bare", function(event) recipients[user][recipient] = hash; local from_bare = origin.type == "c2s" and origin.username.."@"..origin.host; if self or origin.type ~= "c2s" or (recipients[from_bare] and recipients[from_bare][origin.full_jid]) ~= hash then + -- COMPAT from ~= stanza.attr.to because OneTeam and Asterisk 1.8 can't deal with missing from attribute origin.send( - st.stanza("iq", {from=stanza.attr.to, to=stanza.attr.from, id="disco", type="get"}) + st.stanza("iq", {from=user, to=stanza.attr.from, id="disco", type="get"}) :query("http://jabber.org/protocol/disco#info") ); end @@ -169,7 +169,8 @@ module:hook("iq/bare/http://jabber.org/protocol/pubsub:pubsub", function(event) local node = payload.attr.node; payload = payload.tags[1]; if payload and payload.name == "item" then -- <item> - local id = payload.attr.id; + local id = payload.attr.id or "1"; + payload.attr.id = id; session.send(st.reply(stanza)); publish(session, node, id, st.clone(payload)); return true; @@ -262,19 +263,19 @@ module:hook("iq-result/bare/disco", function(event) end); module:hook("account-disco-info", function(event) - local stanza = event.stanza; - stanza:tag('identity', {category='pubsub', type='pep'}):up(); - stanza:tag('feature', {var='http://jabber.org/protocol/pubsub#publish'}):up(); + local reply = event.reply; + reply:tag('identity', {category='pubsub', type='pep'}):up(); + reply:tag('feature', {var='http://jabber.org/protocol/pubsub#publish'}):up(); end); module:hook("account-disco-items", function(event) - local stanza = event.stanza; - local bare = stanza.attr.to; + local reply = event.reply; + local bare = reply.attr.to; local user_data = data[bare]; if user_data then for node, _ in pairs(user_data) do - stanza:tag('item', {jid=bare, node=node}):up(); -- TODO we need to handle queries to these nodes + reply:tag('item', {jid=bare, node=node}):up(); -- TODO we need to handle queries to these nodes end end end); diff --git a/plugins/mod_ping.lua b/plugins/mod_ping.lua index c0ba6189..eddb92d2 100644 --- a/plugins/mod_ping.lua +++ b/plugins/mod_ping.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -22,8 +22,10 @@ module:hook("iq/host/urn:xmpp:ping:ping", ping_handler); -- Ad-hoc command +local datetime = require "util.datetime".datetime; + function ping_command_handler (self, data, state) - local now = os.date("%Y-%m-%dT%X"); + local now = datetime(); return { info = "Pong\n"..now, status = "completed" }; end diff --git a/plugins/mod_posix.lua b/plugins/mod_posix.lua index d229c1b8..7a6ccd94 100644 --- a/plugins/mod_posix.lua +++ b/plugins/mod_posix.lua @@ -1,16 +1,18 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- -local want_pposix_version = "0.3.5"; +local want_pposix_version = "0.3.6"; local pposix = assert(require "util.pposix"); -if pposix._VERSION ~= want_pposix_version then module:log("warn", "Unknown version (%s) of binary pposix module, expected %s", tostring(pposix._VERSION), want_pposix_version); end +if pposix._VERSION ~= want_pposix_version then + module:log("warn", "Unknown version (%s) of binary pposix module, expected %s. Perhaps you need to recompile?", tostring(pposix._VERSION), want_pposix_version); +end local signal = select(2, pcall(require, "util.signal")); if type(signal) == "string" then @@ -22,7 +24,7 @@ local stat = lfs.attributes; local prosody = _G.prosody; -module.host = "*"; -- we're a global module +module:set_global(); -- we're a global module local umask = module:get_option("umask") or "027"; pposix.umask(umask); @@ -34,19 +36,19 @@ module:hook("server-started", function () if gid then local success, msg = pposix.setgid(gid); if success then - module:log("debug", "Changed group to "..gid.." successfully."); + module:log("debug", "Changed group to %s successfully.", gid); else - module:log("error", "Failed to change group to "..gid..". Error: "..msg); - prosody.shutdown("Failed to change group to "..gid); + module:log("error", "Failed to change group to %s. Error: %s", gid, msg); + prosody.shutdown("Failed to change group to %s", gid); end end if uid then local success, msg = pposix.setuid(uid); if success then - module:log("debug", "Changed user to "..uid.." successfully."); + module:log("debug", "Changed user to %s successfully.", uid); else - module:log("error", "Failed to change user to "..uid..". Error: "..msg); - prosody.shutdown("Failed to change user to "..uid); + module:log("error", "Failed to change user to %s. Error: %s", uid, msg); + prosody.shutdown("Failed to change user to %s", uid); end end end); @@ -112,15 +114,15 @@ end local syslog_opened; function syslog_sink_maker(config) if not syslog_opened then - pposix.syslog_open("prosody"); + pposix.syslog_open("prosody", module:get_option_string("syslog_facility")); syslog_opened = true; end local syslog, format = pposix.syslog_log, string.format; return function (name, level, message, ...) if ... then - syslog(level, format(message, ...)); + syslog(level, name, format(message, ...)); else - syslog(level, message); + syslog(level, name, message); end end; end @@ -136,8 +138,17 @@ if daemonize == nil then end end +local function remove_log_sinks() + local lm = require "core.loggingmanager"; + lm.register_sink_type("console", nil); + lm.register_sink_type("stdout", nil); + lm.reload_logging(); +end + if daemonize then local function daemonize_server() + module:log("info", "Prosody is about to detach from the console, disabling further console output"); + remove_log_sinks(); local ok, ret = pposix.daemonize(); if not ok then module:log("error", "Failed to daemonize: %s", ret); @@ -172,7 +183,7 @@ if signal.signal then prosody.reload_config(); prosody.reopen_logfiles(); end); - + signal.signal("SIGINT", function () module:log("info", "Received SIGINT"); prosody.unlock_globals(); diff --git a/plugins/mod_presence.lua b/plugins/mod_presence.lua index 61239c9a..2899bd7e 100644 --- a/plugins/mod_presence.lua +++ b/plugins/mod_presence.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -9,15 +9,19 @@ local log = module._log; local require = require; -local pairs, ipairs = pairs, ipairs; +local pairs = pairs; local t_concat, t_insert = table.concat, table.insert; local s_find = string.find; local tonumber = tonumber; +local core_post_stanza = prosody.core_post_stanza; local st = require "util.stanza"; local jid_split = require "util.jid".split; local jid_bare = require "util.jid".bare; -local hosts = hosts; +local datetime = require "util.datetime"; +local hosts = prosody.hosts; +local bare_sessions = prosody.bare_sessions; +local full_sessions = prosody.full_sessions; local NULL = {}; local rostermanager = require "core.rostermanager"; @@ -115,8 +119,8 @@ function handle_normal_presence(origin, stanza) end if priority >= 0 then - local event = { origin = origin } - module:fire_event('message/offline/broadcast', event); + local event = { origin = origin } + module:fire_event('message/offline/broadcast', event); end end if stanza.attr.type == "unavailable" then @@ -134,6 +138,7 @@ function handle_normal_presence(origin, stanza) end else origin.presence = stanza; + stanza:tag("delay", { xmlns = "urn:xmpp:delay", from = host, stamp = datetime.datetime() }):up(); if origin.priority ~= priority then origin.priority = priority; recalc_resource_map(user); @@ -160,7 +165,7 @@ function send_presence_of_available_resources(user, host, jid, recipient_session end end end - log("debug", "broadcasted presence of "..count.." resources from "..user.."@"..host.." to "..jid); + log("debug", "broadcasted presence of %d resources from %s@%s to %s", count, user, host, jid); return count; end @@ -169,7 +174,7 @@ function handle_outbound_presence_subscriptions_and_probes(origin, stanza, from_ if to_bare == from_bare then return; end -- No self contacts local st_from, st_to = stanza.attr.from, stanza.attr.to; stanza.attr.from, stanza.attr.to = from_bare, to_bare; - log("debug", "outbound presence "..stanza.attr.type.." from "..from_bare.." for "..to_bare); + log("debug", "outbound presence %s from %s for %s", stanza.attr.type, from_bare, to_bare); if stanza.attr.type == "probe" then stanza.attr.from, stanza.attr.to = st_from, st_to; return; @@ -197,12 +202,21 @@ function handle_outbound_presence_subscriptions_and_probes(origin, stanza, from_ core_post_stanza(origin, stanza); send_presence_of_available_resources(node, host, to_bare, origin); elseif stanza.attr.type == "unsubscribed" then - -- 1. route stanza - -- 2. roster push (subscription = none or to) - if rostermanager.unsubscribed(node, host, to_bare) then - rostermanager.roster_push(node, host, to_bare); + -- 1. send unavailable + -- 2. route stanza + -- 3. roster push (subscription = from or both) + local success, pending_in, subscribed = rostermanager.unsubscribed(node, host, to_bare); + if success then + if subscribed then + rostermanager.roster_push(node, host, to_bare); + end + core_post_stanza(origin, stanza); + if subscribed then + send_presence_of_available_resources(node, host, to_bare, origin, st.presence({ type = "unavailable" })); + end end - core_post_stanza(origin, stanza); + else + origin.send(st.error_reply(stanza, "modify", "bad-request", "Invalid presence type")); end stanza.attr.from, stanza.attr.to = st_from, st_to; return true; @@ -212,8 +226,8 @@ function handle_inbound_presence_subscriptions_and_probes(origin, stanza, from_b local node, host = jid_split(to_bare); local st_from, st_to = stanza.attr.from, stanza.attr.to; stanza.attr.from, stanza.attr.to = from_bare, to_bare; - log("debug", "inbound presence "..stanza.attr.type.." from "..from_bare.." for "..to_bare); - + log("debug", "inbound presence %s from %s for %s", stanza.attr.type, from_bare, to_bare); + if stanza.attr.type == "probe" then local result, err = rostermanager.is_contact_subscribed(node, host, from_bare); if result then @@ -253,7 +267,9 @@ function handle_inbound_presence_subscriptions_and_probes(origin, stanza, from_b sessionmanager.send_to_interested_resources(node, host, stanza); rostermanager.roster_push(node, host, from_bare); end - end -- discard any other type + else + origin.send(st.error_reply(stanza, "modify", "bad-request", "Invalid presence type")); + end stanza.attr.from, stanza.attr.to = st_from, st_to; return true; end @@ -296,7 +312,7 @@ module:hook("presence/bare", function(data) if t ~= nil and t ~= "unavailable" and t ~= "error" then -- check for subscriptions and probes sent to bare JID return handle_inbound_presence_subscriptions_and_probes(origin, stanza, jid_bare(stanza.attr.from), jid_bare(stanza.attr.to)); end - + local user = bare_sessions[to]; if user then for _, session in pairs(user.sessions) do @@ -307,6 +323,8 @@ module:hook("presence/bare", function(data) end -- no resources not online, discard elseif not t or t == "unavailable" then handle_normal_presence(origin, stanza); + else + origin.send(st.error_reply(stanza, "modify", "bad-request", "Invalid presence type")); end return true; end); @@ -328,8 +346,8 @@ module:hook("presence/full", function(data) end); module:hook("presence/host", function(data) -- inbound presence to the host - local origin, stanza = data.origin, data.stanza; - + local stanza = data.stanza; + local from_bare = jid_bare(stanza.attr.from); local t = stanza.attr.type; if t == "probe" then @@ -346,13 +364,15 @@ module:hook("resource-unbind", function(event) -- Send unavailable presence if session.presence then local pres = st.presence{ type = "unavailable" }; - if not(err) or err == "closed" then err = "connection closed"; end - pres:tag("status"):text("Disconnected: "..err):up(); + if err then + pres:tag("status"):text("Disconnected: "..err):up(); + end session:dispatch_stanza(pres); elseif session.directed then local pres = st.presence{ type = "unavailable", from = session.full_jid }; - if not(err) or err == "closed" then err = "connection closed"; end - pres:tag("status"):text("Disconnected: "..err):up(); + if err then + pres:tag("status"):text("Disconnected: "..err):up(); + end for jid in pairs(session.directed) do pres.attr.to = jid; core_post_stanza(session, pres, true); diff --git a/plugins/mod_privacy.lua b/plugins/mod_privacy.lua index d5842e26..aaa8e383 100644 --- a/plugins/mod_privacy.lua +++ b/plugins/mod_privacy.lua @@ -2,23 +2,23 @@ -- Copyright (C) 2009-2010 Matthew Wild -- Copyright (C) 2009-2010 Waqas Hussain -- Copyright (C) 2009 Thilo Cestonaro --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- module:add_feature("jabber:iq:privacy"); -local prosody = prosody; local st = require "util.stanza"; -local datamanager = require "util.datamanager"; -local bare_sessions, full_sessions = bare_sessions, full_sessions; +local bare_sessions, full_sessions = prosody.bare_sessions, prosody.full_sessions; local util_Jid = require "util.jid"; local jid_bare = util_Jid.bare; local jid_split, jid_join = util_Jid.split, util_Jid.join; local load_roster = require "core.rostermanager".load_roster; local to_number = tonumber; +local privacy_storage = module:open_store(); + function isListUsed(origin, name, privacy_lists) local user = bare_sessions[origin.username.."@"..origin.host]; if user then @@ -45,28 +45,6 @@ function isAnotherSessionUsingDefaultList(origin) end end -function sendUnavailable(origin, to, from) ---[[ example unavailable presence stanza -<presence from="node@host/resource" type="unavailable" to="node@host" > - <status>Logged out</status> -</presence> -]]-- - local presence = st.presence({from=from, type="unavailable"}); - presence:tag("status"):text("Logged out"); - - local node, host = jid_bare(to); - local bare = node .. "@" .. host; - - local user = bare_sessions[bare]; - if user then - for resource, session in pairs(user.sessions) do - presence.attr.to = session.full_jid; - module:log("debug", "send unavailable to: %s; from: %s", tostring(presence.attr.to), tostring(presence.attr.from)); - origin.send(presence); - end - end -end - function declineList(privacy_lists, origin, stanza, which) if which == "default" then if isAnotherSessionUsingDefaultList(origin) then @@ -123,9 +101,9 @@ function deleteList(privacy_lists, origin, stanza, name) return {"modify", "bad-request", "Not existing list specifed to be deleted."}; end -function createOrReplaceList (privacy_lists, origin, stanza, name, entries, roster) +function createOrReplaceList (privacy_lists, origin, stanza, name, entries) local bare_jid = origin.username.."@"..origin.host; - + if privacy_lists.lists == nil then privacy_lists.lists = {}; end @@ -141,14 +119,14 @@ function createOrReplaceList (privacy_lists, origin, stanza, name, entries, rost if to_number(item.attr.order) == nil or to_number(item.attr.order) < 0 or orderCheck[item.attr.order] ~= nil then return {"modify", "bad-request", "Order attribute not valid."}; end - + if item.attr.type ~= nil and item.attr.type ~= "jid" and item.attr.type ~= "subscription" and item.attr.type ~= "group" then return {"modify", "bad-request", "Type attribute not valid."}; end - + local tmp = {}; orderCheck[item.attr.order] = true; - + tmp["type"] = item.attr.type; tmp["value"] = item.attr.value; tmp["action"] = item.attr.action; @@ -157,13 +135,13 @@ function createOrReplaceList (privacy_lists, origin, stanza, name, entries, rost tmp["presence-out"] = false; tmp["message"] = false; tmp["iq"] = false; - + if #item.tags > 0 then for _,tag in ipairs(item.tags) do tmp[tag.name] = true; end end - + if tmp.type == "subscription" then if tmp.value ~= "both" and tmp.value ~= "to" and @@ -172,13 +150,13 @@ function createOrReplaceList (privacy_lists, origin, stanza, name, entries, rost return {"cancel", "bad-request", "Subscription value must be both, to, from or none."}; end end - + if tmp.action ~= "deny" and tmp.action ~= "allow" then return {"cancel", "bad-request", "Action must be either deny or allow."}; end list.items[#list.items + 1] = tmp; end - + table.sort(list, function(a, b) return a.order < b.order; end); origin.send(st.reply(stanza)); @@ -229,18 +207,18 @@ function getList(privacy_lists, origin, stanza, name) return {"cancel", "item-not-found", "Unknown list specified."}; end end - + origin.send(reply); return true; end module:hook("iq/bare/jabber:iq:privacy:query", function(data) local origin, stanza = data.origin, data.stanza; - + if stanza.attr.to == nil then -- only service requests to own bare JID local query = stanza.tags[1]; -- the query element local valid = false; - local privacy_lists = datamanager.load(origin.username, origin.host, "privacy") or { lists = {} }; + local privacy_lists = privacy_storage:get(origin.username) or { lists = {} }; if privacy_lists.lists[1] then -- Code to migrate from old privacy lists format, remove in 0.8 module:log("info", "Upgrading format of stored privacy lists for %s@%s", origin.username, origin.host); @@ -295,7 +273,7 @@ module:hook("iq/bare/jabber:iq:privacy:query", function(data) end origin.send(st.error_reply(stanza, valid[1], valid[2], valid[3])); else - datamanager.store(origin.username, origin.host, "privacy", privacy_lists); + privacy_storage:set(origin.username, privacy_lists); end return true; end @@ -303,16 +281,16 @@ end); function checkIfNeedToBeBlocked(e, session) local origin, stanza = e.origin, e.stanza; - local privacy_lists = datamanager.load(session.username, session.host, "privacy") or {}; + local privacy_lists = privacy_storage:get(session.username) or {}; local bare_jid = session.username.."@"..session.host; local to = stanza.attr.to or bare_jid; local from = stanza.attr.from; - + local is_to_user = bare_jid == jid_bare(to); local is_from_user = bare_jid == jid_bare(from); - + --module:log("debug", "stanza: %s, to: %s, from: %s", tostring(stanza.name), tostring(to), tostring(from)); - + if privacy_lists.lists == nil or not (session.activePrivacyList or privacy_lists.default) then @@ -322,8 +300,7 @@ function checkIfNeedToBeBlocked(e, session) --module:log("debug", "Not blocking communications between user's resources"); return; -- from one of a user's resource to another => HANDS OFF! end - - local item; + local listname = session.activePrivacyList; if listname == nil then listname = privacy_lists.default; -- no active list selected, use default list @@ -390,6 +367,10 @@ function checkIfNeedToBeBlocked(e, session) end if apply then if block then + -- drop and not bounce groupchat messages, otherwise users will get kicked + if stanza.attr.type == "groupchat" then + return true; + end module:log("debug", "stanza blocked: %s, to: %s, from: %s", tostring(stanza.name), tostring(to), tostring(from)); if stanza.name == "message" then origin.send(st.error_reply(stanza, "cancel", "service-unavailable")); @@ -414,7 +395,6 @@ function preCheckIncoming(e) end if resource == nil then local prio = 0; - local session_; if bare_sessions[node.."@"..host] ~= nil then for resource, session_ in pairs(bare_sessions[node.."@"..host].sessions) do if session_.priority ~= nil and session_.priority > prio then diff --git a/plugins/mod_private.lua b/plugins/mod_private.lua index f1ebe786..446a80b2 100644 --- a/plugins/mod_private.lua +++ b/plugins/mod_private.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -9,8 +9,7 @@ local st = require "util.stanza" -local jid_split = require "util.jid".split; -local datamanager = require "util.datamanager" +local private_storage = module:open_store(); module:add_feature("jabber:iq:private"); @@ -21,7 +20,7 @@ module:hook("iq/self/jabber:iq:private:query", function(event) if #query.tags == 1 then local tag = query.tags[1]; local key = tag.name..":"..tag.attr.xmlns; - local data, err = datamanager.load(origin.username, origin.host, "private"); + local data, err = private_storage:get(origin.username); if err then origin.send(st.error_reply(stanza, "wait", "internal-server-error")); return true; @@ -40,7 +39,7 @@ module:hook("iq/self/jabber:iq:private:query", function(event) data[key] = st.preserialize(tag); end -- TODO delete datastore if empty - if datamanager.store(origin.username, origin.host, "private", data) then + if private_storage:set(origin.username, data) then origin.send(st.reply(stanza)); else origin.send(st.error_reply(stanza, "wait", "internal-server-error")); diff --git a/plugins/mod_proxy65.lua b/plugins/mod_proxy65.lua index 5b490730..2ed9faac 100644 --- a/plugins/mod_proxy65.lua +++ b/plugins/mod_proxy65.lua @@ -1,109 +1,80 @@ +-- Prosody IM +-- Copyright (C) 2008-2011 Matthew Wild +-- Copyright (C) 2008-2011 Waqas Hussain -- Copyright (C) 2009 Thilo Cestonaro --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- ---[[ -* to restart the proxy in the console: e.g. -module:unload("proxy65"); -> server.removeserver(<proxy65_port>); -module:load("proxy65", <proxy65_jid>); -]]-- +module:set_global(); -local module = module; -local tostring = tostring; -local jid_split, jid_join, jid_compare = require "util.jid".split, require "util.jid".join, require "util.jid".compare; +local jid_compare, jid_prep = require "util.jid".compare, require "util.jid".prep; local st = require "util.stanza"; -local connlisteners = require "net.connlisteners"; local sha1 = require "util.hashes".sha1; +local b64 = require "util.encodings".base64.encode; local server = require "net.server"; +local portmanager = require "core.portmanager"; -local host, name = module:get_host(), "SOCKS5 Bytestreams Service"; -local sessions, transfers, replies_cache = {}, {}, {}; - -local proxy_port = module:get_option("proxy65_port") or 5000; -local proxy_interface = module:get_option("proxy65_interface") or "*"; -local proxy_address = module:get_option("proxy65_address") or (proxy_interface ~= "*" and proxy_interface) or host; -local proxy_acl = module:get_option("proxy65_acl"); +local sessions, transfers = module:shared("sessions", "transfers"); local max_buffer_size = 4096; -local connlistener = { default_port = proxy_port, default_interface = proxy_interface, default_mode = "*a" }; +local listener = {}; -function connlistener.onincoming(conn, data) +function listener.onincoming(conn, data) local session = sessions[conn] or {}; - - if session.setup == nil and data ~= nil and data:byte(1) == 0x05 and #data > 2 then - local nmethods = data:byte(2); - local methods = data:sub(3); - local supported = false; - for i=1, nmethods, 1 do - if(methods:byte(i) == 0x00) then -- 0x00 == method: NO AUTH - supported = true; - break; - end - end - if(supported) then - module:log("debug", "new session found ... ") - session.setup = true; - sessions[conn] = session; - conn:write(string.char(5, 0)); - end + + local transfer = transfers[session.sha]; + if transfer and transfer.activated then -- copy data between initiator and target + local initiator, target = transfer.initiator, transfer.target; + (conn == initiator and target or initiator):write(data); return; - end - if session.setup then - if session.sha ~= nil and transfers[session.sha] ~= nil then - local sha = session.sha; - if transfers[sha].activated == true and transfers[sha].target ~= nil then - if transfers[sha].initiator == conn then - transfers[sha].target:write(data); - else - transfers[sha].initiator:write(data); - end + end -- FIXME server.link should be doing this? + + if not session.greeting_done then + local nmethods = data:byte(2) or 0; + if data:byte(1) == 0x05 and nmethods > 0 and #data == 2 + nmethods then -- check if we have all the data + if data:find("%z") then -- 0x00 = 'No authentication' is supported + session.greeting_done = true; + sessions[conn] = session; + conn:write("\5\0"); -- send (SOCKS version 5, No authentication) + module:log("debug", "SOCKS5 greeting complete"); return; end - end - if data ~= nil and #data == 0x2F and -- 40 == length of SHA1 HASH, and 7 other bytes => 47 => 0x2F - data:byte(1) == 0x05 and -- SOCKS5 has 5 in first byte - data:byte(2) == 0x01 and -- CMD must be 1 - data:byte(3) == 0x00 and -- RSV must be 0 - data:byte(4) == 0x03 and -- ATYP must be 3 - data:byte(5) == 40 and -- SHA1 HASH length must be 40 (0x28) - data:byte(-2) == 0x00 and -- PORT must be 0, size 2 byte - data:byte(-1) == 0x00 - then - local sha = data:sub(6, 45); -- second param is not count! it's the ending index (included!) - if transfers[sha] == nil then + end -- else error, unexpected input + conn:write("\5\255"); -- send (SOCKS version 5, no acceptable method) + conn:close(); + module:log("debug", "Invalid SOCKS5 greeting recieved: '%s'", b64(data)); + else -- connection request + --local head = string.char( 0x05, 0x01, 0x00, 0x03, 40 ); -- ( VER=5=SOCKS5, CMD=1=CONNECT, RSV=0=RESERVED, ATYP=3=DOMAIMNAME, SHA-1 size ) + if #data == 47 and data:sub(1,5) == "\5\1\0\3\40" and data:sub(-2) == "\0\0" then + local sha = data:sub(6, 45); + conn:pause(); + conn:write("\5\0\0\3\40" .. sha .. "\0\0"); -- VER, REP, RSV, ATYP, BND.ADDR (sha), BND.PORT (2 Byte) + if not transfers[sha] then transfers[sha] = {}; - transfers[sha].activated = false; transfers[sha].target = conn; session.sha = sha; - module:log("debug", "target connected ... "); - elseif transfers[sha].target ~= nil then + module:log("debug", "SOCKS5 target connected for session %s", sha); + else -- transfers[sha].target ~= nil transfers[sha].initiator = conn; session.sha = sha; - module:log("debug", "initiator connected ... "); + module:log("debug", "SOCKS5 initiator connected for session %s", sha); server.link(conn, transfers[sha].target, max_buffer_size); server.link(transfers[sha].target, conn, max_buffer_size); end - conn:write(string.char(5, 0, 0, 3, #sha) .. sha .. string.char(0, 0)); -- VER, REP, RSV, ATYP, BND.ADDR (sha), BND.PORT (2 Byte) - conn:lock_read(true) - else - module:log("warn", "Neither data transfer nor initial connect of a participator of a transfer.") - conn:close(); - end - else - if data ~= nil then - module:log("warn", "unknown connection with no authentication data -> closing it"); + else -- error, unexpected input + conn:write("\5\1\0\3\0\0\0"); -- VER, REP, RSV, ATYP, BND.ADDR (sha), BND.PORT (2 Byte) conn:close(); + module:log("debug", "Invalid SOCKS5 negotiation recieved: '%s'", b64(data)); end end end -function connlistener.ondisconnect(conn, err) +function listener.ondisconnect(conn, err) local session = sessions[conn]; if session then - if session.sha and transfers[session.sha] then + if transfers[session.sha] then local initiator, target = transfers[session.sha].initiator, transfers[session.sha].target; if initiator == conn and target ~= nil then target:close(); @@ -117,131 +88,105 @@ function connlistener.ondisconnect(conn, err) end end -module:add_identity("proxy", "bytestreams", name); -module:add_feature("http://jabber.org/protocol/bytestreams"); - -module:hook("iq-get/host/http://jabber.org/protocol/disco#info:query", function(event) - local origin, stanza = event.origin, event.stanza; - local reply = replies_cache.disco_info; - if reply == nil then - reply = st.iq({type='result', from=host}):query("http://jabber.org/protocol/disco#info") - :tag("identity", {category='proxy', type='bytestreams', name=name}):up() - :tag("feature", {var="http://jabber.org/protocol/bytestreams"}); - replies_cache.disco_info = reply; - end +function module.add_host(module) + local host, name = module:get_host(), module:get_option_string("name", "SOCKS5 Bytestreams Service"); - reply.attr.id = stanza.attr.id; - reply.attr.to = stanza.attr.from; - origin.send(reply); - return true; -end, -1); - -module:hook("iq-get/host/http://jabber.org/protocol/disco#items:query", function(event) - local origin, stanza = event.origin, event.stanza; - local reply = replies_cache.disco_items; - if reply == nil then - reply = st.iq({type='result', from=host}):query("http://jabber.org/protocol/disco#items"); - replies_cache.disco_items = reply; - end - - reply.attr.id = stanza.attr.id; - reply.attr.to = stanza.attr.from; - origin.send(reply); - return true; -end, -1); - -module:hook("iq-get/host/http://jabber.org/protocol/bytestreams:query", function(event) - local origin, stanza = event.origin, event.stanza; - local reply = replies_cache.stream_host; - local err_reply = replies_cache.stream_host_err; - local sid = stanza.tags[1].attr.sid; - local allow = false; - local jid = stanza.attr.from; - - if proxy_acl and #proxy_acl > 0 then - for _, acl in ipairs(proxy_acl) do - if jid_compare(jid, acl) then allow = true; end - end - else - allow = true; + local proxy_address = module:get_option("proxy65_address", host); + local proxy_port = next(portmanager.get_active_services():search("proxy65", nil)[1] or {}); + local proxy_acl = module:get_option("proxy65_acl"); + + -- COMPAT w/pre-0.9 where proxy65_port was specified in the components section of the config + local legacy_config = module:get_option_number("proxy65_port"); + if legacy_config then + module:log("warn", "proxy65_port is deprecated, please put proxy65_ports = { %d } into the global section instead", legacy_config); end - if allow == true then - if reply == nil then - reply = st.iq({type="result", from=host}) - :query("http://jabber.org/protocol/bytestreams") - :tag("streamhost", {jid=host, host=proxy_address, port=proxy_port}); - replies_cache.stream_host = reply; + + module:add_identity("proxy", "bytestreams", name); + module:add_feature("http://jabber.org/protocol/bytestreams"); + + module:hook("iq-get/host/http://jabber.org/protocol/disco#info:query", function(event) + local origin, stanza = event.origin, event.stanza; + if not stanza.tags[1].attr.node then + origin.send(st.reply(stanza):query("http://jabber.org/protocol/disco#info") + :tag("identity", {category='proxy', type='bytestreams', name=name}):up() + :tag("feature", {var="http://jabber.org/protocol/bytestreams"}) ); + return true; end - else - module:log("warn", "Denying use of proxy for %s", tostring(jid)); - if err_reply == nil then - err_reply = st.iq({type="error", from=host}) - :query("http://jabber.org/protocol/bytestreams") - :tag("error", {code='403', type='auth'}) - :tag("forbidden", {xmlns='urn:ietf:params:xml:ns:xmpp-stanzas'}); - replies_cache.stream_host_err = err_reply; + end, -1); + + module:hook("iq-get/host/http://jabber.org/protocol/disco#items:query", function(event) + local origin, stanza = event.origin, event.stanza; + if not stanza.tags[1].attr.node then + origin.send(st.reply(stanza):query("http://jabber.org/protocol/disco#items")); + return true; end - reply = err_reply; - end - reply.attr.id = stanza.attr.id; - reply.attr.to = stanza.attr.from; - reply.tags[1].attr.sid = sid; - origin.send(reply); - return true; -end); - -module.unload = function() - connlisteners.deregister(module.host .. ':proxy65'); -end + end, -1); -local function set_activation(stanza) - local to, reply; - local from = stanza.attr.from; - local query = stanza.tags[1]; - local sid = query.attr.sid; - if query.tags[1] and query.tags[1].name == "activate" then - to = query.tags[1][1]; - end - if from ~= nil and to ~= nil and sid ~= nil then - reply = st.iq({type="result", from=host, to=from}); - reply.attr.id = stanza.attr.id; - end - return reply, from, to, sid; -end + module:hook("iq-get/host/http://jabber.org/protocol/bytestreams:query", function(event) + local origin, stanza = event.origin, event.stanza; -module:hook("iq-set/host/http://jabber.org/protocol/bytestreams:query", function(event) - local origin, stanza = event.origin, event.stanza; - - module:log("debug", "Received activation request from %s", stanza.attr.from); - local reply, from, to, sid = set_activation(stanza); - if reply ~= nil and from ~= nil and to ~= nil and sid ~= nil then - local sha = sha1(sid .. from .. to, true); - if transfers[sha] == nil then - module:log("error", "transfers[sha]: nil"); - elseif(transfers[sha] ~= nil and transfers[sha].initiator ~= nil and transfers[sha].target ~= nil) then - origin.send(reply); - transfers[sha].activated = true; - transfers[sha].target:lock_read(false); - transfers[sha].initiator:lock_read(false); - else - module:log("debug", "Both parties were not yet connected"); - local message = "Neither party is connected to the proxy"; - if transfers[sha].initiator then - message = "The recipient is not connected to the proxy"; - elseif transfers[sha].target then - message = "The sender (you) is not connected to the proxy"; + -- check ACL + while proxy_acl and #proxy_acl > 0 do -- using 'while' instead of 'if' so we can break out of it + local jid = stanza.attr.from; + local allow; + for _, acl in ipairs(proxy_acl) do + if jid_compare(jid, acl) then allow = true; break; end end - origin.send(st.error_reply(stanza, "cancel", "not-allowed", message)); + if allow then break; end + module:log("warn", "Denying use of proxy for %s", tostring(stanza.attr.from)); + origin.send(st.error_reply(stanza, "auth", "forbidden")); + return true; end - return true; - else - module:log("error", "activation failed: sid: %s, initiator: %s, target: %s", tostring(sid), tostring(from), tostring(to)); - end -end); -if not connlisteners.register(module.host .. ':proxy65', connlistener) then - module:log("error", "mod_proxy65: Could not establish a connection listener. Check your configuration please."); - module:log("error", "Possibly two proxy65 components are configured to share the same port."); + local sid = stanza.tags[1].attr.sid; + origin.send(st.reply(stanza):tag("query", {xmlns="http://jabber.org/protocol/bytestreams", sid=sid}) + :tag("streamhost", {jid=host, host=proxy_address, port=proxy_port})); + return true; + end); + + module:hook("iq-set/host/http://jabber.org/protocol/bytestreams:query", function(event) + local origin, stanza = event.origin, event.stanza; + + local query = stanza.tags[1]; + local sid = query.attr.sid; + local from = stanza.attr.from; + local to = query:get_child_text("activate"); + local prepped_to = jid_prep(to); + + local info = "sid: "..tostring(sid)..", initiator: "..tostring(from)..", target: "..tostring(prepped_to or to); + if prepped_to and sid then + local sha = sha1(sid .. from .. prepped_to, true); + if not transfers[sha] then + module:log("debug", "Activation request has unknown session id; activation failed (%s)", info); + origin.send(st.error_reply(stanza, "modify", "item-not-found")); + elseif not transfers[sha].initiator then + module:log("debug", "The sender was not connected to the proxy; activation failed (%s)", info); + origin.send(st.error_reply(stanza, "cancel", "not-allowed", "The sender (you) is not connected to the proxy")); + --elseif not transfers[sha].target then -- can't happen, as target is set when a transfer object is created + -- module:log("debug", "The recipient was not connected to the proxy; activation failed (%s)", info); + -- origin.send(st.error_reply(stanza, "cancel", "not-allowed", "The recipient is not connected to the proxy")); + else -- if transfers[sha].initiator ~= nil and transfers[sha].target ~= nil then + module:log("debug", "Transfer activated (%s)", info); + transfers[sha].activated = true; + transfers[sha].target:resume(); + transfers[sha].initiator:resume(); + origin.send(st.reply(stanza)); + end + elseif to and sid then + module:log("debug", "Malformed activation jid; activation failed (%s)", info); + origin.send(st.error_reply(stanza, "modify", "jid-malformed")); + else + module:log("debug", "Bad request; activation failed (%s)", info); + origin.send(st.error_reply(stanza, "modify", "bad-request")); + end + return true; + end); end -connlisteners.start(module.host .. ':proxy65'); +module:provides("net", { + default_port = 5000; + listener = listener; + multiplex = { + pattern = "^\5"; + }; +}); diff --git a/plugins/mod_pubsub.lua b/plugins/mod_pubsub.lua deleted file mode 100644 index 90c19815..00000000 --- a/plugins/mod_pubsub.lua +++ /dev/null @@ -1,369 +0,0 @@ -local pubsub = require "util.pubsub"; -local st = require "util.stanza"; -local jid_bare = require "util.jid".bare; -local uuid_generate = require "util.uuid".generate; - -require "core.modulemanager".load(module.host, "iq"); - -local xmlns_pubsub = "http://jabber.org/protocol/pubsub"; -local xmlns_pubsub_errors = "http://jabber.org/protocol/pubsub#errors"; -local xmlns_pubsub_event = "http://jabber.org/protocol/pubsub#event"; - -local autocreate_on_publish = module:get_option_boolean("autocreate_on_publish", false); -local autocreate_on_subscribe = module:get_option_boolean("autocreate_on_subscribe", false); - -local service; - -local handlers = {}; - -function handle_pubsub_iq(event) - local origin, stanza = event.origin, event.stanza; - local pubsub = stanza.tags[1]; - local action = pubsub.tags[1]; - local handler = handlers[stanza.attr.type.."_"..action.name]; - if handler then - handler(origin, stanza, action); - return true; - end -end - -local pubsub_errors = { - ["conflict"] = { "cancel", "conflict" }; - ["invalid-jid"] = { "modify", "bad-request", nil, "invalid-jid" }; - ["item-not-found"] = { "cancel", "item-not-found" }; - ["not-subscribed"] = { "modify", "unexpected-request", nil, "not-subscribed" }; - ["forbidden"] = { "cancel", "forbidden" }; -}; -function pubsub_error_reply(stanza, error) - local e = pubsub_errors[error]; - local reply = st.error_reply(stanza, unpack(e, 1, 3)); - if e[4] then - reply:tag(e[4], { xmlns = xmlns_pubsub_errors }):up(); - end - return reply; -end - -function handlers.get_items(origin, stanza, items) - local node = items.attr.node; - local item = items:get_child("item"); - local id = item and item.attr.id; - - local ok, results = service:get_items(node, stanza.attr.from, id); - if not ok then - return origin.send(pubsub_error_reply(stanza, results)); - end - - local data = st.stanza("items", { node = node }); - for _, entry in pairs(results) do - data:add_child(entry); - end - if data then - reply = st.reply(stanza) - :tag("pubsub", { xmlns = xmlns_pubsub }) - :add_child(data); - else - reply = pubsub_error_reply(stanza, "item-not-found"); - end - return origin.send(reply); -end - -function handlers.get_subscriptions(origin, stanza, subscriptions) - local node = subscriptions.attr.node; - local ok, ret = service:get_subscriptions(node, stanza.attr.from, stanza.attr.from); - if not ok then - return origin.send(pubsub_error_reply(stanza, ret)); - end - local reply = st.reply(stanza) - :tag("subscriptions", { xmlns = xmlns_pubsub }); - for _, sub in ipairs(ret) do - reply:tag("subscription", { node = sub.node, jid = sub.jid, subscription = 'subscribed' }):up(); - end - return origin.send(reply); -end - -function handlers.set_create(origin, stanza, create) - local node = create.attr.node; - local ok, ret, reply; - if node then - ok, ret = service:create(node, stanza.attr.from); - if ok then - reply = st.reply(stanza); - else - reply = pubsub_error_reply(stanza, ret); - end - else - repeat - node = uuid_generate(); - ok, ret = service:create(node, stanza.attr.from); - until ok or ret ~= "conflict"; - if ok then - reply = st.reply(stanza) - :tag("pubsub", { xmlns = xmlns_pubsub }) - :tag("create", { node = node }); - else - reply = pubsub_error_reply(stanza, ret); - end - end - return origin.send(reply); -end - -function handlers.set_subscribe(origin, stanza, subscribe) - local node, jid = subscribe.attr.node, subscribe.attr.jid; - if jid_bare(jid) ~= jid_bare(stanza.attr.from) then - return origin.send(pubsub_error_reply(stanza, "invalid-jid")); - end - local ok, ret = service:add_subscription(node, stanza.attr.from, jid); - local reply; - if ok then - reply = st.reply(stanza) - :tag("pubsub", { xmlns = xmlns_pubsub }) - :tag("subscription", { - node = node, - jid = jid, - subscription = "subscribed" - }); - else - reply = pubsub_error_reply(stanza, ret); - end - return origin.send(reply); -end - -function handlers.set_unsubscribe(origin, stanza, unsubscribe) - local node, jid = unsubscribe.attr.node, unsubscribe.attr.jid; - if jid_bare(jid) ~= jid_bare(stanza.attr.from) then - return origin.send(pubsub_error_reply(stanza, "invalid-jid")); - end - local ok, ret = service:remove_subscription(node, stanza.attr.from, jid); - local reply; - if ok then - reply = st.reply(stanza); - else - reply = pubsub_error_reply(stanza, ret); - end - return origin.send(reply); -end - -function handlers.set_publish(origin, stanza, publish) - local node = publish.attr.node; - local item = publish:get_child("item"); - local id = (item and item.attr.id) or uuid_generate(); - local ok, ret = service:publish(node, stanza.attr.from, id, item); - local reply; - if ok then - reply = st.reply(stanza) - :tag("pubsub", { xmlns = xmlns_pubsub }) - :tag("publish", { node = node }) - :tag("item", { id = id }); - else - reply = pubsub_error_reply(stanza, ret); - end - return origin.send(reply); -end - -function handlers.set_retract(origin, stanza, retract) - local node, notify = retract.attr.node, retract.attr.notify; - notify = (notify == "1") or (notify == "true"); - local item = retract:get_child("item"); - local id = item and item.attr.id - local reply, notifier; - if notify then - notifier = st.stanza("retract", { id = id }); - end - local ok, ret = service:retract(node, stanza.attr.from, id, notifier); - if ok then - reply = st.reply(stanza); - else - reply = pubsub_error_reply(stanza, ret); - end - return origin.send(reply); -end - -function simple_broadcast(node, jids, item) - item = st.clone(item); - item.attr.xmlns = nil; -- Clear the pubsub namespace - local message = st.message({ from = module.host, type = "headline" }) - :tag("event", { xmlns = xmlns_pubsub_event }) - :tag("items", { node = node }) - :add_child(item); - for jid in pairs(jids) do - module:log("debug", "Sending notification to %s", jid); - message.attr.to = jid; - core_post_stanza(hosts[module.host], message); - end -end - -module:hook("iq/host/http://jabber.org/protocol/pubsub:pubsub", handle_pubsub_iq); - -local disco_info; - -local feature_map = { - create = { "create-nodes", autocreate_on_publish and "instant-nodes", "item-ids" }; - retract = { "delete-items", "retract-items" }; - publish = { "publish" }; - get_items = { "retrieve-items" }; - add_subscription = { "subscribe" }; - get_subscriptions = { "retrieve-subscriptions" }; -}; - -local function add_disco_features_from_service(disco, service) - for method, features in pairs(feature_map) do - if service[method] then - for _, feature in ipairs(features) do - if feature then - disco:tag("feature", { var = xmlns_pubsub.."#"..feature }):up(); - end - end - end - end - for affiliation in pairs(service.config.capabilities) do - if affiliation ~= "none" and affiliation ~= "owner" then - disco:tag("feature", { var = xmlns_pubsub.."#"..affiliation.."-affiliation" }):up(); - end - end -end - -local function build_disco_info(service) - local disco_info = st.stanza("query", { xmlns = "http://jabber.org/protocol/disco#info" }) - :tag("identity", { category = "pubsub", type = "service" }):up() - :tag("feature", { var = "http://jabber.org/protocol/pubsub" }):up(); - add_disco_features_from_service(disco_info, service); - return disco_info; -end - -module:hook("iq-get/host/http://jabber.org/protocol/disco#info:query", function (event) - local origin, stanza = event.origin, event.stanza; - local node = stanza.tags[1].attr.node; - if not node then - return origin.send(st.reply(stanza):add_child(disco_info)); - else - local ok, ret = service:get_nodes(stanza.attr.from); - if ok and not ret[node] then - ok, ret = false, "item-not-found"; - end - if not ok then - return origin.send(pubsub_error_reply(stanza, ret)); - end - local reply = st.reply(stanza) - :tag("query", { xmlns = "http://jabber.org/protocol/disco#info", node = node }) - :tag("identity", { category = "pubsub", type = "leaf" }); - return origin.send(reply); - end -end); - -local function handle_disco_items_on_node(event) - local stanza, origin = event.stanza, event.origin; - local query = stanza.tags[1]; - local node = query.attr.node; - local ok, ret = service:get_items(node, stanza.attr.from); - if not ok then - return origin.send(pubsub_error_reply(stanza, ret)); - end - - local reply = st.reply(stanza) - :tag("query", { xmlns = "http://jabber.org/protocol/disco#items", node = node }); - - for id, item in pairs(ret) do - reply:tag("item", { jid = module.host, name = id }):up(); - end - - return origin.send(reply); -end - - -module:hook("iq-get/host/http://jabber.org/protocol/disco#items:query", function (event) - if event.stanza.tags[1].attr.node then - return handle_disco_items_on_node(event); - end - local ok, ret = service:get_nodes(event.stanza.attr.from); - if not ok then - event.origin.send(pubsub_error_reply(stanza, ret)); - else - local reply = st.reply(event.stanza) - :tag("query", { xmlns = "http://jabber.org/protocol/disco#items" }); - for node, node_obj in pairs(ret) do - reply:tag("item", { jid = module.host, node = node, name = node_obj.config.name }):up(); - end - event.origin.send(reply); - end - return true; -end); - -local admin_aff = module:get_option_string("default_admin_affiliation", "owner"); -local function get_affiliation(jid) - local bare_jid = jid_bare(jid); - if bare_jid == module.host or usermanager.is_admin(bare_jid, module.host) then - return admin_aff; - end -end - -function set_service(new_service) - service = new_service; - module.environment.service = service; - disco_info = build_disco_info(service); -end - -function module.save() - return { service = service }; -end - -function module.restore(data) - set_service(data.service); -end - -set_service(pubsub.new({ - capabilities = { - none = { - create = false; - publish = false; - retract = false; - get_nodes = true; - - subscribe = true; - unsubscribe = true; - get_subscription = true; - get_subscriptions = true; - get_items = true; - - subscribe_other = false; - unsubscribe_other = false; - get_subscription_other = false; - get_subscriptions_other = false; - - be_subscribed = true; - be_unsubscribed = true; - - set_affiliation = false; - }; - owner = { - create = true; - publish = true; - retract = true; - get_nodes = true; - - subscribe = true; - unsubscribe = true; - get_subscription = true; - get_subscriptions = true; - get_items = true; - - - subscribe_other = true; - unsubscribe_other = true; - get_subscription_other = true; - get_subscriptions_other = true; - - be_subscribed = true; - be_unsubscribed = true; - - set_affiliation = true; - }; - }; - - autocreate_on_publish = autocreate_on_publish; - autocreate_on_subscribe = autocreate_on_subscribe; - - broadcaster = simple_broadcast; - get_affiliation = get_affiliation; - - normalize_jid = jid_bare; -})); diff --git a/plugins/mod_pubsub/mod_pubsub.lua b/plugins/mod_pubsub/mod_pubsub.lua new file mode 100644 index 00000000..81a66f8b --- /dev/null +++ b/plugins/mod_pubsub/mod_pubsub.lua @@ -0,0 +1,229 @@ +local pubsub = require "util.pubsub"; +local st = require "util.stanza"; +local jid_bare = require "util.jid".bare; +local usermanager = require "core.usermanager"; + +local xmlns_pubsub = "http://jabber.org/protocol/pubsub"; +local xmlns_pubsub_event = "http://jabber.org/protocol/pubsub#event"; +local xmlns_pubsub_owner = "http://jabber.org/protocol/pubsub#owner"; + +local autocreate_on_publish = module:get_option_boolean("autocreate_on_publish", false); +local autocreate_on_subscribe = module:get_option_boolean("autocreate_on_subscribe", false); +local pubsub_disco_name = module:get_option("name"); +if type(pubsub_disco_name) ~= "string" then pubsub_disco_name = "Prosody PubSub Service"; end + +local service; + +local lib_pubsub = module:require "pubsub"; +local handlers = lib_pubsub.handlers; +local pubsub_error_reply = lib_pubsub.pubsub_error_reply; + +module:depends("disco"); +module:add_identity("pubsub", "service", pubsub_disco_name); +module:add_feature("http://jabber.org/protocol/pubsub"); + +function handle_pubsub_iq(event) + local origin, stanza = event.origin, event.stanza; + local pubsub = stanza.tags[1]; + local action = pubsub.tags[1]; + if not action then + return origin.send(st.error_reply(stanza, "cancel", "bad-request")); + end + local handler = handlers[stanza.attr.type.."_"..action.name]; + if handler then + handler(origin, stanza, action, service); + return true; + end +end + +function simple_broadcast(kind, node, jids, item) + if item then + item = st.clone(item); + item.attr.xmlns = nil; -- Clear the pubsub namespace + end + local message = st.message({ from = module.host, type = "headline" }) + :tag("event", { xmlns = xmlns_pubsub_event }) + :tag(kind, { node = node }) + :add_child(item); + for jid in pairs(jids) do + module:log("debug", "Sending notification to %s", jid); + message.attr.to = jid; + module:send(message); + end +end + +module:hook("iq/host/"..xmlns_pubsub..":pubsub", handle_pubsub_iq); +module:hook("iq/host/"..xmlns_pubsub_owner..":pubsub", handle_pubsub_iq); + +local feature_map = { + create = { "create-nodes", "instant-nodes", "item-ids" }; + retract = { "delete-items", "retract-items" }; + purge = { "purge-nodes" }; + publish = { "publish", autocreate_on_publish and "auto-create" }; + delete = { "delete-nodes" }; + get_items = { "retrieve-items" }; + add_subscription = { "subscribe" }; + get_subscriptions = { "retrieve-subscriptions" }; +}; + +local function add_disco_features_from_service(service) + for method, features in pairs(feature_map) do + if service[method] then + for _, feature in ipairs(features) do + if feature then + module:add_feature(xmlns_pubsub.."#"..feature); + end + end + end + end + for affiliation in pairs(service.config.capabilities) do + if affiliation ~= "none" and affiliation ~= "owner" then + module:add_feature(xmlns_pubsub.."#"..affiliation.."-affiliation"); + end + end +end + +module:hook("host-disco-info-node", function (event) + local stanza, origin, reply, node = event.stanza, event.origin, event.reply, event.node; + local ok, ret = service:get_nodes(stanza.attr.from); + if ok and not ret[node] then + return; + end + if not ok then + return origin.send(pubsub_error_reply(stanza, ret)); + end + event.exists = true; + reply:tag("identity", { category = "pubsub", type = "leaf" }); +end); + +module:hook("host-disco-items-node", function (event) + local stanza, origin, reply, node = event.stanza, event.origin, event.reply, event.node; + local ok, ret = service:get_items(node, stanza.attr.from); + if not ok then + return origin.send(pubsub_error_reply(stanza, ret)); + end + + for id, item in pairs(ret) do + reply:tag("item", { jid = module.host, name = id }):up(); + end + event.exists = true; +end); + + +module:hook("host-disco-items", function (event) + local stanza, origin, reply = event.stanza, event.origin, event.reply; + local ok, ret = service:get_nodes(event.stanza.attr.from); + if not ok then + return origin.send(pubsub_error_reply(event.stanza, ret)); + end + for node, node_obj in pairs(ret) do + reply:tag("item", { jid = module.host, node = node, name = node_obj.config.name }):up(); + end +end); + +local admin_aff = module:get_option_string("default_admin_affiliation", "owner"); +local function get_affiliation(jid) + local bare_jid = jid_bare(jid); + if bare_jid == module.host or usermanager.is_admin(bare_jid, module.host) then + return admin_aff; + end +end + +function set_service(new_service) + service = new_service; + module.environment.service = service; + add_disco_features_from_service(service); +end + +function module.save() + return { service = service }; +end + +function module.restore(data) + set_service(data.service); +end + +function module.load() + if module.reloading then return; end + + set_service(pubsub.new({ + capabilities = { + none = { + create = false; + publish = false; + retract = false; + get_nodes = true; + + subscribe = true; + unsubscribe = true; + get_subscription = true; + get_subscriptions = true; + get_items = true; + + subscribe_other = false; + unsubscribe_other = false; + get_subscription_other = false; + get_subscriptions_other = false; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = false; + }; + publisher = { + create = false; + publish = true; + retract = true; + get_nodes = true; + + subscribe = true; + unsubscribe = true; + get_subscription = true; + get_subscriptions = true; + get_items = true; + + subscribe_other = false; + unsubscribe_other = false; + get_subscription_other = false; + get_subscriptions_other = false; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = false; + }; + owner = { + create = true; + publish = true; + retract = true; + delete = true; + get_nodes = true; + + subscribe = true; + unsubscribe = true; + get_subscription = true; + get_subscriptions = true; + get_items = true; + + + subscribe_other = true; + unsubscribe_other = true; + get_subscription_other = true; + get_subscriptions_other = true; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = true; + }; + }; + + autocreate_on_publish = autocreate_on_publish; + autocreate_on_subscribe = autocreate_on_subscribe; + + broadcaster = simple_broadcast; + get_affiliation = get_affiliation; + + normalize_jid = jid_bare; + })); +end diff --git a/plugins/mod_pubsub/pubsub.lib.lua b/plugins/mod_pubsub/pubsub.lib.lua new file mode 100644 index 00000000..2b015e34 --- /dev/null +++ b/plugins/mod_pubsub/pubsub.lib.lua @@ -0,0 +1,225 @@ +local st = require "util.stanza"; +local uuid_generate = require "util.uuid".generate; + +local xmlns_pubsub = "http://jabber.org/protocol/pubsub"; +local xmlns_pubsub_errors = "http://jabber.org/protocol/pubsub#errors"; + +local _M = {}; + +local handlers = {}; +_M.handlers = handlers; + +local pubsub_errors = { + ["conflict"] = { "cancel", "conflict" }; + ["invalid-jid"] = { "modify", "bad-request", nil, "invalid-jid" }; + ["jid-required"] = { "modify", "bad-request", nil, "jid-required" }; + ["nodeid-required"] = { "modify", "bad-request", nil, "nodeid-required" }; + ["item-not-found"] = { "cancel", "item-not-found" }; + ["not-subscribed"] = { "modify", "unexpected-request", nil, "not-subscribed" }; + ["forbidden"] = { "cancel", "forbidden" }; +}; +local function pubsub_error_reply(stanza, error) + local e = pubsub_errors[error]; + local reply = st.error_reply(stanza, unpack(e, 1, 3)); + if e[4] then + reply:tag(e[4], { xmlns = xmlns_pubsub_errors }):up(); + end + return reply; +end +_M.pubsub_error_reply = pubsub_error_reply; + +function handlers.get_items(origin, stanza, items, service) + local node = items.attr.node; + local item = items:get_child("item"); + local id = item and item.attr.id; + + if not node then + return origin.send(pubsub_error_reply(stanza, "nodeid-required")); + end + local ok, results = service:get_items(node, stanza.attr.from, id); + if not ok then + return origin.send(pubsub_error_reply(stanza, results)); + end + + local data = st.stanza("items", { node = node }); + for _, entry in pairs(results) do + data:add_child(entry); + end + local reply; + if data then + reply = st.reply(stanza) + :tag("pubsub", { xmlns = xmlns_pubsub }) + :add_child(data); + else + reply = pubsub_error_reply(stanza, "item-not-found"); + end + return origin.send(reply); +end + +function handlers.get_subscriptions(origin, stanza, subscriptions, service) + local node = subscriptions.attr.node; + local ok, ret = service:get_subscriptions(node, stanza.attr.from, stanza.attr.from); + if not ok then + return origin.send(pubsub_error_reply(stanza, ret)); + end + local reply = st.reply(stanza) + :tag("pubsub", { xmlns = xmlns_pubsub }) + :tag("subscriptions"); + for _, sub in ipairs(ret) do + reply:tag("subscription", { node = sub.node, jid = sub.jid, subscription = 'subscribed' }):up(); + end + return origin.send(reply); +end + +function handlers.set_create(origin, stanza, create, service) + local node = create.attr.node; + local ok, ret, reply; + if node then + ok, ret = service:create(node, stanza.attr.from); + if ok then + reply = st.reply(stanza); + else + reply = pubsub_error_reply(stanza, ret); + end + else + repeat + node = uuid_generate(); + ok, ret = service:create(node, stanza.attr.from); + until ok or ret ~= "conflict"; + if ok then + reply = st.reply(stanza) + :tag("pubsub", { xmlns = xmlns_pubsub }) + :tag("create", { node = node }); + else + reply = pubsub_error_reply(stanza, ret); + end + end + return origin.send(reply); +end + +function handlers.set_delete(origin, stanza, delete, service) + local node = delete.attr.node; + + local reply, notifier; + if not node then + return origin.send(pubsub_error_reply(stanza, "nodeid-required")); + end + local ok, ret = service:delete(node, stanza.attr.from); + if ok then + reply = st.reply(stanza); + else + reply = pubsub_error_reply(stanza, ret); + end + return origin.send(reply); +end + +function handlers.set_subscribe(origin, stanza, subscribe, service) + local node, jid = subscribe.attr.node, subscribe.attr.jid; + if not (node and jid) then + return origin.send(pubsub_error_reply(stanza, jid and "nodeid-required" or "invalid-jid")); + end + --[[ + local options_tag, options = stanza.tags[1]:get_child("options"), nil; + if options_tag then + options = options_form:data(options_tag.tags[1]); + end + --]] + local options_tag, options; -- FIXME + local ok, ret = service:add_subscription(node, stanza.attr.from, jid, options); + local reply; + if ok then + reply = st.reply(stanza) + :tag("pubsub", { xmlns = xmlns_pubsub }) + :tag("subscription", { + node = node, + jid = jid, + subscription = "subscribed" + }):up(); + if options_tag then + reply:add_child(options_tag); + end + else + reply = pubsub_error_reply(stanza, ret); + end + origin.send(reply); +end + +function handlers.set_unsubscribe(origin, stanza, unsubscribe, service) + local node, jid = unsubscribe.attr.node, unsubscribe.attr.jid; + if not (node and jid) then + return origin.send(pubsub_error_reply(stanza, jid and "nodeid-required" or "invalid-jid")); + end + local ok, ret = service:remove_subscription(node, stanza.attr.from, jid); + local reply; + if ok then + reply = st.reply(stanza); + else + reply = pubsub_error_reply(stanza, ret); + end + return origin.send(reply); +end + +function handlers.set_publish(origin, stanza, publish, service) + local node = publish.attr.node; + if not node then + return origin.send(pubsub_error_reply(stanza, "nodeid-required")); + end + local item = publish:get_child("item"); + local id = (item and item.attr.id); + if not id then + id = uuid_generate(); + if item then + item.attr.id = id; + end + end + local ok, ret = service:publish(node, stanza.attr.from, id, item); + local reply; + if ok then + reply = st.reply(stanza) + :tag("pubsub", { xmlns = xmlns_pubsub }) + :tag("publish", { node = node }) + :tag("item", { id = id }); + else + reply = pubsub_error_reply(stanza, ret); + end + return origin.send(reply); +end + +function handlers.set_retract(origin, stanza, retract, service) + local node, notify = retract.attr.node, retract.attr.notify; + notify = (notify == "1") or (notify == "true"); + local item = retract:get_child("item"); + local id = item and item.attr.id + if not (node and id) then + return origin.send(pubsub_error_reply(stanza, node and "item-not-found" or "nodeid-required")); + end + local reply, notifier; + if notify then + notifier = st.stanza("retract", { id = id }); + end + local ok, ret = service:retract(node, stanza.attr.from, id, notifier); + if ok then + reply = st.reply(stanza); + else + reply = pubsub_error_reply(stanza, ret); + end + return origin.send(reply); +end + +function handlers.set_purge(origin, stanza, purge, service) + local node, notify = purge.attr.node, purge.attr.notify; + notify = (notify == "1") or (notify == "true"); + local reply; + if not node then + return origin.send(pubsub_error_reply(stanza, "nodeid-required")); + end + local ok, ret = service:purge(node, stanza.attr.from, notify); + if ok then + reply = st.reply(stanza); + else + reply = pubsub_error_reply(stanza, ret); + end + return origin.send(reply); +end + +return _M; diff --git a/plugins/mod_register.lua b/plugins/mod_register.lua index 1df73297..e537e903 100644 --- a/plugins/mod_register.lua +++ b/plugins/mod_register.lua @@ -1,24 +1,88 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- -local hosts = _G.hosts; local st = require "util.stanza"; -local datamanager = require "util.datamanager"; +local dataform_new = require "util.dataforms".new; local usermanager_user_exists = require "core.usermanager".user_exists; local usermanager_create_user = require "core.usermanager".create_user; local usermanager_set_password = require "core.usermanager".set_password; +local usermanager_delete_user = require "core.usermanager".delete_user; local os_time = os.time; local nodeprep = require "util.encodings".stringprep.nodeprep; +local jid_bare = require "util.jid".bare; + +local compat = module:get_option_boolean("registration_compat", true); +local allow_registration = module:get_option_boolean("allow_registration", false); +local additional_fields = module:get_option("additional_registration_fields", {}); + +local account_details = module:open_store("account_details"); + +local field_map = { + username = { name = "username", type = "text-single", label = "Username", required = true }; + password = { name = "password", type = "text-private", label = "Password", required = true }; + nick = { name = "nick", type = "text-single", label = "Nickname" }; + name = { name = "name", type = "text-single", label = "Full Name" }; + first = { name = "first", type = "text-single", label = "Given Name" }; + last = { name = "last", type = "text-single", label = "Family Name" }; + email = { name = "email", type = "text-single", label = "Email" }; + address = { name = "address", type = "text-single", label = "Street" }; + city = { name = "city", type = "text-single", label = "City" }; + state = { name = "state", type = "text-single", label = "State" }; + zip = { name = "zip", type = "text-single", label = "Postal code" }; + phone = { name = "phone", type = "text-single", label = "Telephone number" }; + url = { name = "url", type = "text-single", label = "Webpage" }; + date = { name = "date", type = "text-single", label = "Birth date" }; +}; + +local registration_form = dataform_new{ + title = "Creating a new account"; + instructions = "Choose a username and password for use with this service."; + + field_map.username; + field_map.password; +}; + +local registration_query = st.stanza("query", {xmlns = "jabber:iq:register"}) + :tag("instructions"):text("Choose a username and password for use with this service."):up() + :tag("username"):up() + :tag("password"):up(); + +for _, field in ipairs(additional_fields) do + if type(field) == "table" then + registration_form[#registration_form + 1] = field; + else + if field:match("%+$") then + field = field:sub(1, #field - 1); + field_map[field].required = true; + end + + registration_form[#registration_form + 1] = field_map[field]; + registration_query:tag(field):up(); + end +end +registration_query:add_child(registration_form:form()); module:add_feature("jabber:iq:register"); -module:hook("iq/self/jabber:iq:register:query", function(event) +local register_stream_feature = st.stanza("register", {xmlns="http://jabber.org/features/iq-register"}):up(); +module:hook("stream-features", function(event) + local session, features = event.origin, event.features; + + -- Advertise registration to unauthorized clients only. + if not(allow_registration) or session.type ~= "c2s_unauthed" then + return + end + + features:add_child(register_stream_feature); +end); + +local function handle_registration_stanza(event) local session, stanza = event.origin, event.stanza; local query = stanza.tags[1]; @@ -31,45 +95,29 @@ module:hook("iq/self/jabber:iq:register:query", function(event) session.send(reply); else -- stanza.attr.type == "set" if query.tags[1] and query.tags[1].name == "remove" then - -- TODO delete user auth data, send iq response, kick all user resources with a <not-authorized/>, delete all user data local username, host = session.username, session.host; - --session.send(st.error_reply(stanza, "cancel", "not-allowed")); - --return; - usermanager_set_password(username, nil, host); -- Disable account - -- FIXME the disabling currently allows a different user to recreate the account - -- we should add an in-memory account block mode when we have threading - session.send(st.reply(stanza)); - local roster = session.roster; - for _, session in pairs(hosts[host].sessions[username].sessions) do -- disconnect all resources - session:close({condition = "not-authorized", text = "Account deleted"}); + + local old_session_close = session.close; + session.close = function(session, ...) + session.send(st.reply(stanza)); + return old_session_close(session, ...); end - -- TODO datamanager should be able to delete all user data itself - datamanager.store(username, host, "vcard", nil); - datamanager.store(username, host, "private", nil); - datamanager.list_store(username, host, "offline", nil); - local bare = username.."@"..host; - for jid, item in pairs(roster) do - if jid and jid ~= "pending" then - if item.subscription == "both" or item.subscription == "from" or (roster.pending and roster.pending[jid]) then - core_post_stanza(hosts[host], st.presence({type="unsubscribed", from=bare, to=jid})); - end - if item.subscription == "both" or item.subscription == "to" or item.ask then - core_post_stanza(hosts[host], st.presence({type="unsubscribe", from=bare, to=jid})); - end - end + + local ok, err = usermanager_delete_user(username, host); + + if not ok then + module:log("debug", "Removing user account %s@%s failed: %s", username, host, err); + session.close = old_session_close; + session.send(st.error_reply(stanza, "cancel", "service-unavailable", err)); + return true; end - datamanager.store(username, host, "roster", nil); - datamanager.store(username, host, "privacy", nil); - datamanager.store(username, host, "accounts", nil); -- delete accounts datastore at the end + module:log("info", "User removed their account: %s@%s", username, host); module:fire_event("user-deregistered", { username = username, host = host, source = "mod_register", session = session }); else - local username = query:child_with_name("username"); - local password = query:child_with_name("password"); + local username = nodeprep(query:get_child_text("username")); + local password = query:get_child_text("password"); if username and password then - -- FIXME shouldn't use table.concat - username = nodeprep(table.concat(username)); - password = table.concat(password); if username == session.username then if usermanager_set_password(username, password, session.host) then session.send(st.reply(stanza)); @@ -86,38 +134,66 @@ module:hook("iq/self/jabber:iq:register:query", function(event) end end return true; -end); +end -local recent_ips = {}; -local min_seconds_between_registrations = module:get_option("min_seconds_between_registrations"); -local whitelist_only = module:get_option("whitelist_registration_only"); -local whitelisted_ips = module:get_option("registration_whitelist") or { "127.0.0.1" }; -local blacklisted_ips = module:get_option("registration_blacklist") or {}; +module:hook("iq/self/jabber:iq:register:query", handle_registration_stanza); +if compat then + module:hook("iq/host/jabber:iq:register:query", function (event) + local session, stanza = event.origin, event.stanza; + if session.type == "c2s" and jid_bare(stanza.attr.to) == session.host then + return handle_registration_stanza(event); + end + end); +end + +local function parse_response(query) + local form = query:get_child("x", "jabber:x:data"); + if form then + return registration_form:data(form); + else + local data = {}; + local errors = {}; + for _, field in ipairs(registration_form) do + local name, required = field.name, field.required; + if field_map[name] then + data[name] = query:get_child_text(name); + if (not data[name] or #data[name] == 0) and required then + errors[name] = "Required value missing"; + end + end + end + if next(errors) then + return data, errors; + end + return data; + end +end -for _, ip in ipairs(whitelisted_ips) do whitelisted_ips[ip] = true; end -for _, ip in ipairs(blacklisted_ips) do blacklisted_ips[ip] = true; end +local recent_ips = {}; +local min_seconds_between_registrations = module:get_option_number("min_seconds_between_registrations"); +local whitelist_only = module:get_option_boolean("whitelist_registration_only"); +local whitelisted_ips = module:get_option_set("registration_whitelist", { "127.0.0.1" })._items; +local blacklisted_ips = module:get_option_set("registration_blacklist", {})._items; module:hook("stanza/iq/jabber:iq:register:query", function(event) local session, stanza = event.origin, event.stanza; - if module:get_option("allow_registration") == false or session.type ~= "c2s_unauthed" then + if not(allow_registration) or session.type ~= "c2s_unauthed" then session.send(st.error_reply(stanza, "cancel", "service-unavailable")); else local query = stanza.tags[1]; if stanza.attr.type == "get" then local reply = st.reply(stanza); - reply:tag("query", {xmlns = "jabber:iq:register"}) - :tag("instructions"):text("Choose a username and password for use with this service."):up() - :tag("username"):up() - :tag("password"):up(); + reply:add_child(registration_query); session.send(reply); elseif stanza.attr.type == "set" then if query.tags[1] and query.tags[1].name == "remove" then session.send(st.error_reply(stanza, "auth", "registration-required")); else - local username = query:child_with_name("username"); - local password = query:child_with_name("password"); - if username and password then + local data, errors = parse_response(query); + if errors then + session.send(st.error_reply(stanza, "modify", "not-acceptable")); + else -- Check that the user is not blacklisted or registering too often if not session.ip then module:log("debug", "User's IP not known; can't apply blacklist/whitelist"); @@ -130,7 +206,7 @@ module:hook("stanza/iq/jabber:iq:register:query", function(event) else local ip = recent_ips[session.ip]; ip.count = ip.count + 1; - + if os_time() - ip.time < min_seconds_between_registrations then ip.time = os_time(); session.send(st.error_reply(stanza, "wait", "not-acceptable")); @@ -139,32 +215,40 @@ module:hook("stanza/iq/jabber:iq:register:query", function(event) ip.time = os_time(); end end - -- FIXME shouldn't use table.concat - username = nodeprep(table.concat(username)); - password = table.concat(password); + local username, password = nodeprep(data.username), data.password; + data.username, data.password = nil, nil; local host = module.host; if not username or username == "" then session.send(st.error_reply(stanza, "modify", "not-acceptable", "The requested username is invalid.")); + return true; + end + local user = { username = username , host = host, allowed = true } + module:fire_event("user-registering", user); + if not user.allowed then + session.send(st.error_reply(stanza, "modify", "not-acceptable", "The requested username is forbidden.")); elseif usermanager_user_exists(username, host) then session.send(st.error_reply(stanza, "cancel", "conflict", "The requested username already exists.")); else + -- TODO unable to write file, file may be locked, etc, what's the correct error? + local error_reply = st.error_reply(stanza, "wait", "internal-server-error", "Failed to write data to disk."); if usermanager_create_user(username, password, host) then + if next(data) and not account_details:set(username, data) then + usermanager_delete_user(username, host); + session.send(error_reply); + return true; + end session.send(st.reply(stanza)); -- user created! module:log("info", "User account created: %s@%s", username, host); module:fire_event("user-registered", { username = username, host = host, source = "mod_register", session = session }); else - -- TODO unable to write file, file may be locked, etc, what's the correct error? - session.send(st.error_reply(stanza, "wait", "internal-server-error", "Failed to write data to disk.")); + session.send(error_reply); end end - else - session.send(st.error_reply(stanza, "modify", "not-acceptable")); end end end end return true; end); - diff --git a/plugins/mod_roster.lua b/plugins/mod_roster.lua index fe2eea71..56af5368 100644 --- a/plugins/mod_roster.lua +++ b/plugins/mod_roster.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -15,14 +15,15 @@ local t_concat = table.concat; local tonumber = tonumber; local pairs, ipairs = pairs, ipairs; +local rm_load_roster = require "core.rostermanager".load_roster; local rm_remove_from_roster = require "core.rostermanager".remove_from_roster; local rm_add_to_roster = require "core.rostermanager".add_to_roster; local rm_roster_push = require "core.rostermanager".roster_push; -local core_post_stanza = core_post_stanza; +local core_post_stanza = prosody.core_post_stanza; module:add_feature("jabber:iq:roster"); -local rosterver_stream_feature = st.stanza("ver", {xmlns="urn:xmpp:features:rosterver"}):tag("optional"):up(); +local rosterver_stream_feature = st.stanza("ver", {xmlns="urn:xmpp:features:rosterver"}); module:hook("stream-features", function(event) local origin, features = event.origin, event.features; if origin.username then @@ -35,10 +36,10 @@ module:hook("iq/self/jabber:iq:roster:query", function(event) if stanza.attr.type == "get" then local roster = st.reply(stanza); - + local client_ver = tonumber(stanza.tags[1].attr.ver); local server_ver = tonumber(session.roster[false].version or 1); - + if not (client_ver and server_ver) or client_ver ~= server_ver then roster:query("jabber:iq:roster"); -- Client does not support versioning, or has stale roster @@ -68,7 +69,6 @@ module:hook("iq/self/jabber:iq:roster:query", function(event) and query.tags[1].attr.jid ~= "pending" then local item = query.tags[1]; local from_node, from_host = jid_split(stanza.attr.from); - local from_bare = from_node and (from_node.."@"..from_host) or from_host; -- bare JID local jid = jid_prep(item.attr.jid); local node, host, resource = jid_split(jid); if not resource and host then @@ -137,3 +137,20 @@ module:hook("iq/self/jabber:iq:roster:query", function(event) end return true; end); + +module:hook_global("user-deleted", function(event) + local username, host = event.username, event.host; + if host ~= module.host then return end + local bare = username .. "@" .. host; + local roster = rm_load_roster(username, host); + for jid, item in pairs(roster) do + if jid and jid ~= "pending" then + if item.subscription == "both" or item.subscription == "from" or (roster.pending and roster.pending[jid]) then + module:send(st.presence({type="unsubscribed", from=bare, to=jid})); + end + if item.subscription == "both" or item.subscription == "to" or item.ask then + module:send(st.presence({type="unsubscribe", from=bare, to=jid})); + end + end + end +end, 300); diff --git a/plugins/mod_s2s/mod_s2s.lua b/plugins/mod_s2s/mod_s2s.lua new file mode 100644 index 00000000..1d03f3e4 --- /dev/null +++ b/plugins/mod_s2s/mod_s2s.lua @@ -0,0 +1,694 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +module:set_global(); + +local prosody = prosody; +local hosts = prosody.hosts; +local core_process_stanza = prosody.core_process_stanza; + +local tostring, type = tostring, type; +local t_insert = table.insert; +local xpcall, traceback = xpcall, debug.traceback; +local NULL = {}; + +local add_task = require "util.timer".add_task; +local st = require "util.stanza"; +local initialize_filters = require "util.filters".initialize; +local nameprep = require "util.encodings".stringprep.nameprep; +local new_xmpp_stream = require "util.xmppstream".new; +local s2s_new_incoming = require "core.s2smanager".new_incoming; +local s2s_new_outgoing = require "core.s2smanager".new_outgoing; +local s2s_destroy_session = require "core.s2smanager".destroy_session; +local uuid_gen = require "util.uuid".generate; +local cert_verify_identity = require "util.x509".verify_identity; +local fire_global_event = prosody.events.fire_event; + +local s2sout = module:require("s2sout"); + +local connect_timeout = module:get_option_number("s2s_timeout", 90); +local stream_close_timeout = module:get_option_number("s2s_close_timeout", 5); +local opt_keepalives = module:get_option_boolean("s2s_tcp_keepalives", module:get_option_boolean("tcp_keepalives", true)); +local secure_auth = module:get_option_boolean("s2s_secure_auth", false); -- One day... +local secure_domains, insecure_domains = + module:get_option_set("s2s_secure_domains", {})._items, module:get_option_set("s2s_insecure_domains", {})._items; +local require_encryption = module:get_option_boolean("s2s_require_encryption", false); + +local sessions = module:shared("sessions"); + +local log = module._log; + +--- Handle stanzas to remote domains + +local bouncy_stanzas = { message = true, presence = true, iq = true }; +local function bounce_sendq(session, reason) + local sendq = session.sendq; + if not sendq then return; end + session.log("info", "sending error replies for "..#sendq.." queued stanzas because of failed outgoing connection to "..tostring(session.to_host)); + local dummy = { + type = "s2sin"; + send = function(s) + (session.log or log)("error", "Replying to to an s2s error reply, please report this! Traceback: %s", traceback()); + end; + dummy = true; + }; + for i, data in ipairs(sendq) do + local reply = data[2]; + if reply and not(reply.attr.xmlns) and bouncy_stanzas[reply.name] then + reply.attr.type = "error"; + reply:tag("error", {type = "cancel"}) + :tag("remote-server-not-found", {xmlns = "urn:ietf:params:xml:ns:xmpp-stanzas"}):up(); + if reason then + reply:tag("text", {xmlns = "urn:ietf:params:xml:ns:xmpp-stanzas"}) + :text("Server-to-server connection failed: "..reason):up(); + end + core_process_stanza(dummy, reply); + end + sendq[i] = nil; + end + session.sendq = nil; +end + +-- Handles stanzas to existing s2s sessions +function route_to_existing_session(event) + local from_host, to_host, stanza = event.from_host, event.to_host, event.stanza; + if not hosts[from_host] then + log("warn", "Attempt to send stanza from %s - a host we don't serve", from_host); + return false; + end + if hosts[to_host] then + log("warn", "Attempt to route stanza to a remote %s - a host we do serve?!", from_host); + return false; + end + local host = hosts[from_host].s2sout[to_host]; + if host then + -- We have a connection to this host already + if host.type == "s2sout_unauthed" and (stanza.name ~= "db:verify" or not host.dialback_key) then + (host.log or log)("debug", "trying to send over unauthed s2sout to "..to_host); + + -- Queue stanza until we are able to send it + if host.sendq then t_insert(host.sendq, {tostring(stanza), stanza.attr.type ~= "error" and stanza.attr.type ~= "result" and st.reply(stanza)}); + else host.sendq = { {tostring(stanza), stanza.attr.type ~= "error" and stanza.attr.type ~= "result" and st.reply(stanza)} }; end + host.log("debug", "stanza [%s] queued ", stanza.name); + return true; + elseif host.type == "local" or host.type == "component" then + log("error", "Trying to send a stanza to ourselves??") + log("error", "Traceback: %s", traceback()); + log("error", "Stanza: %s", tostring(stanza)); + return false; + else + (host.log or log)("debug", "going to send stanza to "..to_host.." from "..from_host); + -- FIXME + if host.from_host ~= from_host then + log("error", "WARNING! This might, possibly, be a bug, but it might not..."); + log("error", "We are going to send from %s instead of %s", tostring(host.from_host), tostring(from_host)); + end + if host.sends2s(stanza) then + host.log("debug", "stanza sent over %s", host.type); + return true; + end + end + end +end + +-- Create a new outgoing session for a stanza +function route_to_new_session(event) + local from_host, to_host, stanza = event.from_host, event.to_host, event.stanza; + log("debug", "opening a new outgoing connection for this stanza"); + local host_session = s2s_new_outgoing(from_host, to_host); + + -- Store in buffer + host_session.bounce_sendq = bounce_sendq; + host_session.sendq = { {tostring(stanza), stanza.attr.type ~= "error" and stanza.attr.type ~= "result" and st.reply(stanza)} }; + log("debug", "stanza [%s] queued until connection complete", tostring(stanza.name)); + s2sout.initiate_connection(host_session); + if (not host_session.connecting) and (not host_session.conn) then + log("warn", "Connection to %s failed already, destroying session...", to_host); + s2s_destroy_session(host_session, "Connection failed"); + return false; + end + return true; +end + +local function keepalive(event) + return event.session.sends2s(' '); +end + +module:hook("s2s-read-timeout", keepalive, -1); + +function module.add_host(module) + if module:get_option_boolean("disallow_s2s", false) then + module:log("warn", "The 'disallow_s2s' config option is deprecated, please see http://prosody.im/doc/s2s#disabling"); + return nil, "This host has disallow_s2s set"; + end + module:hook("route/remote", route_to_existing_session, -1); + module:hook("route/remote", route_to_new_session, -10); + module:hook("s2s-authenticated", make_authenticated, -1); + module:hook("s2s-read-timeout", keepalive, -1); +end + +-- Stream is authorised, and ready for normal stanzas +function mark_connected(session) + local sendq, send = session.sendq, session.sends2s; + + local from, to = session.from_host, session.to_host; + + session.log("info", "%s s2s connection %s->%s complete", session.direction:gsub("^.", string.upper), from, to); + + local event_data = { session = session }; + if session.type == "s2sout" then + fire_global_event("s2sout-established", event_data); + hosts[from].events.fire_event("s2sout-established", event_data); + else + local host_session = hosts[to]; + session.send = function(stanza) + return host_session.events.fire_event("route/remote", { from_host = to, to_host = from, stanza = stanza }); + end; + + fire_global_event("s2sin-established", event_data); + hosts[to].events.fire_event("s2sin-established", event_data); + end + + if session.direction == "outgoing" then + if sendq then + session.log("debug", "sending %d queued stanzas across new outgoing connection to %s", #sendq, session.to_host); + for i, data in ipairs(sendq) do + send(data[1]); + sendq[i] = nil; + end + session.sendq = nil; + end + + session.ip_hosts = nil; + session.srv_hosts = nil; + end +end + +function make_authenticated(event) + local session, host = event.session, event.host; + if not session.secure then + if require_encryption or (secure_auth and not(insecure_domains[host])) or secure_domains[host] then + session:close({ + condition = "policy-violation", + text = "Encrypted server-to-server communication is required but was not " + ..((session.direction == "outgoing" and "offered") or "used") + }); + end + end + if hosts[host] then + session:close({ condition = "undefined-condition", text = "Attempt to authenticate as a host we serve" }); + end + if session.type == "s2sout_unauthed" then + session.type = "s2sout"; + elseif session.type == "s2sin_unauthed" then + session.type = "s2sin"; + if host then + if not session.hosts[host] then session.hosts[host] = {}; end + session.hosts[host].authed = true; + end + elseif session.type == "s2sin" and host then + if not session.hosts[host] then session.hosts[host] = {}; end + session.hosts[host].authed = true; + else + return false; + end + session.log("debug", "connection %s->%s is now authenticated for %s", session.from_host, session.to_host, host); + + mark_connected(session); + + return true; +end + +--- Helper to check that a session peer's certificate is valid +local function check_cert_status(session) + local host = session.direction == "outgoing" and session.to_host or session.from_host + local conn = session.conn:socket() + local cert + if conn.getpeercertificate then + cert = conn:getpeercertificate() + end + + if cert then + local chain_valid, errors; + if conn.getpeerverification then + chain_valid, errors = conn:getpeerverification(); + elseif conn.getpeerchainvalid then -- COMPAT mw/luasec-hg + chain_valid, errors = conn:getpeerchainvalid(); + errors = (not chain_valid) and { { errors } } or nil; + else + chain_valid, errors = false, { { "Chain verification not supported by this version of LuaSec" } }; + end + -- Is there any interest in printing out all/the number of errors here? + if not chain_valid then + (session.log or log)("debug", "certificate chain validation result: invalid"); + for depth, t in pairs(errors or NULL) do + (session.log or log)("debug", "certificate error(s) at depth %d: %s", depth-1, table.concat(t, ", ")) + end + session.cert_chain_status = "invalid"; + else + (session.log or log)("debug", "certificate chain validation result: valid"); + session.cert_chain_status = "valid"; + + -- We'll go ahead and verify the asserted identity if the + -- connecting server specified one. + if host then + if cert_verify_identity(host, "xmpp-server", cert) then + session.cert_identity_status = "valid" + else + session.cert_identity_status = "invalid" + end + (session.log or log)("debug", "certificate identity validation result: %s", session.cert_identity_status); + end + end + end + return module:fire_event("s2s-check-certificate", { host = host, session = session, cert = cert }); +end + +--- XMPP stream event handlers + +local stream_callbacks = { default_ns = "jabber:server", handlestanza = core_process_stanza }; + +local xmlns_xmpp_streams = "urn:ietf:params:xml:ns:xmpp-streams"; + +function stream_callbacks.streamopened(session, attr) + local send = session.sends2s; + + session.version = tonumber(attr.version) or 0; + + -- TODO: Rename session.secure to session.encrypted + if session.secure == false then + session.secure = true; + + local sock = session.conn:socket(); + if sock.info then + local info = sock:info(); + (session.log or log)("info", "Stream encrypted (%s with %s)", info.protocol, info.cipher); + session.compressed = info.compression; + else + (session.log or log)("info", "Stream encrypted"); + session.compressed = sock.compression and sock:compression(); --COMPAT mw/luasec-hg + end + end + + if session.direction == "incoming" then + -- Send a reply stream header + + -- Validate to/from + local to, from = nameprep(attr.to), nameprep(attr.from); + if not to and attr.to then -- COMPAT: Some servers do not reliably set 'to' (especially on stream restarts) + session:close({ condition = "improper-addressing", text = "Invalid 'to' address" }); + return; + end + if not from and attr.from then -- COMPAT: Some servers do not reliably set 'from' (especially on stream restarts) + session:close({ condition = "improper-addressing", text = "Invalid 'from' address" }); + return; + end + + -- Set session.[from/to]_host if they have not been set already and if + -- this session isn't already authenticated + if session.type == "s2sin_unauthed" and from and not session.from_host then + session.from_host = from; + elseif from ~= session.from_host then + session:close({ condition = "improper-addressing", text = "New stream 'from' attribute does not match original" }); + return; + end + if session.type == "s2sin_unauthed" and to and not session.to_host then + session.to_host = to; + elseif to ~= session.to_host then + session:close({ condition = "improper-addressing", text = "New stream 'to' attribute does not match original" }); + return; + end + + -- For convenience we'll put the sanitised values into these variables + to, from = session.to_host, session.from_host; + + session.streamid = uuid_gen(); + (session.log or log)("debug", "Incoming s2s received %s", st.stanza("stream:stream", attr):top_tag()); + if to then + if not hosts[to] then + -- Attempting to connect to a host we don't serve + session:close({ + condition = "host-unknown"; + text = "This host does not serve "..to + }); + return; + elseif not hosts[to].modules.s2s then + -- Attempting to connect to a host that disallows s2s + session:close({ + condition = "policy-violation"; + text = "Server-to-server communication is disabled for this host"; + }); + return; + end + end + + if hosts[from] then + session:close({ condition = "undefined-condition", text = "Attempt to connect from a host we serve" }); + return; + end + + if session.secure and not session.cert_chain_status then + if check_cert_status(session) == false then + return; + end + end + + session:open_stream(session.to_host, session.from_host) + if session.version >= 1.0 then + local features = st.stanza("stream:features"); + + if to then + hosts[to].events.fire_event("s2s-stream-features", { origin = session, features = features }); + else + (session.log or log)("warn", "No 'to' on stream header from %s means we can't offer any features", from or "unknown host"); + end + + log("debug", "Sending stream features: %s", tostring(features)); + send(features); + end + elseif session.direction == "outgoing" then + -- If we are just using the connection for verifying dialback keys, we won't try and auth it + if not attr.id then error("stream response did not give us a streamid!!!"); end + session.streamid = attr.id; + + if session.secure and not session.cert_chain_status then + if check_cert_status(session) == false then + return; + end + end + + -- Send unauthed buffer + -- (stanzas which are fine to send before dialback) + -- Note that this is *not* the stanza queue (which + -- we can only send if auth succeeds) :) + local send_buffer = session.send_buffer; + if send_buffer and #send_buffer > 0 then + log("debug", "Sending s2s send_buffer now..."); + for i, data in ipairs(send_buffer) do + session.sends2s(tostring(data)); + send_buffer[i] = nil; + end + end + session.send_buffer = nil; + + -- If server is pre-1.0, don't wait for features, just do dialback + if session.version < 1.0 then + if not session.dialback_verifying then + hosts[session.from_host].events.fire_event("s2sout-authenticate-legacy", { origin = session }); + else + mark_connected(session); + end + end + end + session.notopen = nil; +end + +function stream_callbacks.streamclosed(session) + (session.log or log)("debug", "Received </stream:stream>"); + session:close(false); +end + +function stream_callbacks.error(session, error, data) + if error == "no-stream" then + session:close("invalid-namespace"); + elseif error == "parse-error" then + session.log("debug", "Server-to-server XML parse error: %s", tostring(error)); + session:close("not-well-formed"); + elseif error == "stream-error" then + local condition, text = "undefined-condition"; + for child in data:children() do + if child.attr.xmlns == xmlns_xmpp_streams then + if child.name ~= "text" then + condition = child.name; + else + text = child:get_text(); + end + if condition ~= "undefined-condition" and text then + break; + end + end + end + text = condition .. (text and (" ("..text..")") or ""); + session.log("info", "Session closed by remote with error: %s", text); + session:close(nil, text); + end +end + +local function handleerr(err) log("error", "Traceback[s2s]: %s", traceback(tostring(err), 2)); end +function stream_callbacks.handlestanza(session, stanza) + if stanza.attr.xmlns == "jabber:client" then --COMPAT: Prosody pre-0.6.2 may send jabber:client + stanza.attr.xmlns = nil; + end + stanza = session.filter("stanzas/in", stanza); + if stanza then + return xpcall(function () return core_process_stanza(session, stanza) end, handleerr); + end +end + +local listener = {}; + +--- Session methods +local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; +local function session_close(session, reason, remote_reason) + local log = session.log or log; + if session.conn then + if session.notopen then + if session.direction == "incoming" then + session:open_stream(session.to_host, session.from_host); + else + session:open_stream(session.from_host, session.to_host); + end + end + if reason then -- nil == no err, initiated by us, false == initiated by remote + if type(reason) == "string" then -- assume stream error + log("debug", "Disconnecting %s[%s], <stream:error> is: %s", session.host or "(unknown host)", session.type, reason); + session.sends2s(st.stanza("stream:error"):tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' })); + elseif type(reason) == "table" then + if reason.condition then + local stanza = st.stanza("stream:error"):tag(reason.condition, stream_xmlns_attr):up(); + if reason.text then + stanza:tag("text", stream_xmlns_attr):text(reason.text):up(); + end + if reason.extra then + stanza:add_child(reason.extra); + end + log("debug", "Disconnecting %s[%s], <stream:error> is: %s", session.host or "(unknown host)", session.type, tostring(stanza)); + session.sends2s(stanza); + elseif reason.name then -- a stanza + log("debug", "Disconnecting %s->%s[%s], <stream:error> is: %s", session.from_host or "(unknown host)", session.to_host or "(unknown host)", session.type, tostring(reason)); + session.sends2s(reason); + end + end + end + + session.sends2s("</stream:stream>"); + function session.sends2s() return false; end + + local reason = remote_reason or (reason and (reason.text or reason.condition)) or reason; + session.log("info", "%s s2s stream %s->%s closed: %s", session.direction:gsub("^.", string.upper), session.from_host or "(unknown host)", session.to_host or "(unknown host)", reason or "stream closed"); + + -- Authenticated incoming stream may still be sending us stanzas, so wait for </stream:stream> from remote + local conn = session.conn; + if reason == nil and not session.notopen and session.type == "s2sin" then + add_task(stream_close_timeout, function () + if not session.destroyed then + session.log("warn", "Failed to receive a stream close response, closing connection anyway..."); + s2s_destroy_session(session, reason); + conn:close(); + end + end); + else + s2s_destroy_session(session, reason); + conn:close(); -- Close immediately, as this is an outgoing connection or is not authed + end + end +end + +function session_open_stream(session, from, to) + local attr = { + ["xmlns:stream"] = 'http://etherx.jabber.org/streams', + xmlns = 'jabber:server', + version = session.version and (session.version > 0 and "1.0" or nil), + ["xml:lang"] = 'en', + id = session.streamid, + from = from, to = to, + } + if not from or (hosts[from] and hosts[from].modules.dialback) then + attr["xmlns:db"] = 'jabber:server:dialback'; + end + + session.sends2s("<?xml version='1.0'?>"); + session.sends2s(st.stanza("stream:stream", attr):top_tag()); + return true; +end + +-- Session initialization logic shared by incoming and outgoing +local function initialize_session(session) + local stream = new_xmpp_stream(session, stream_callbacks); + session.stream = stream; + + session.notopen = true; + + function session.reset_stream() + session.notopen = true; + session.stream:reset(); + end + + session.open_stream = session_open_stream; + + local filter = session.filter; + function session.data(data) + data = filter("bytes/in", data); + if data then + local ok, err = stream:feed(data); + if ok then return; end + (session.log or log)("warn", "Received invalid XML: %s", data); + (session.log or log)("warn", "Problem was: %s", err); + session:close("not-well-formed"); + end + end + + session.close = session_close; + + local handlestanza = stream_callbacks.handlestanza; + function session.dispatch_stanza(session, stanza) + return handlestanza(session, stanza); + end + + add_task(connect_timeout, function () + if session.type == "s2sin" or session.type == "s2sout" then + return; -- Ok, we're connected + elseif session.type == "s2s_destroyed" then + return; -- Session already destroyed + end + -- Not connected, need to close session and clean up + (session.log or log)("debug", "Destroying incomplete session %s->%s due to inactivity", + session.from_host or "(unknown)", session.to_host or "(unknown)"); + session:close("connection-timeout"); + end); +end + +function listener.onconnect(conn) + conn:setoption("keepalive", opt_keepalives); + local session = sessions[conn]; + if not session then -- New incoming connection + session = s2s_new_incoming(conn); + sessions[conn] = session; + session.log("debug", "Incoming s2s connection"); + + local filter = initialize_filters(session); + local w = conn.write; + session.sends2s = function (t) + log("debug", "sending: %s", t.top_tag and t:top_tag() or t:match("^([^>]*>?)")); + if t.name then + t = filter("stanzas/out", t); + end + if t then + t = filter("bytes/out", tostring(t)); + if t then + return w(conn, t); + end + end + end + + initialize_session(session); + else -- Outgoing session connected + session:open_stream(session.from_host, session.to_host); + end + session.ip = conn:ip(); +end + +function listener.onincoming(conn, data) + local session = sessions[conn]; + if session then + session.data(data); + end +end + +function listener.onstatus(conn, status) + if status == "ssl-handshake-complete" then + local session = sessions[conn]; + if session and session.direction == "outgoing" then + session.log("debug", "Sending stream header..."); + session:open_stream(session.from_host, session.to_host); + end + end +end + +function listener.ondisconnect(conn, err) + local session = sessions[conn]; + if session then + sessions[conn] = nil; + if err and session.direction == "outgoing" and session.notopen then + (session.log or log)("debug", "s2s connection attempt failed: %s", err); + if s2sout.attempt_connection(session, err) then + return; -- Session lives for now + end + end + (session.log or log)("debug", "s2s disconnected: %s->%s (%s)", tostring(session.from_host), tostring(session.to_host), tostring(err or "connection closed")); + s2s_destroy_session(session, err); + end +end + +function listener.onreadtimeout(conn) + local session = sessions[conn]; + if session then + return (hosts[session.host] or prosody).events.fire_event("s2s-read-timeout", { session = session }); + end +end + +function listener.register_outgoing(conn, session) + session.direction = "outgoing"; + sessions[conn] = session; + initialize_session(session); +end + +function check_auth_policy(event) + local host, session = event.host, event.session; + local must_secure = secure_auth; + + if not must_secure and secure_domains[host] then + must_secure = true; + elseif must_secure and insecure_domains[host] then + must_secure = false; + end + + if must_secure and (session.cert_chain_status ~= "valid" or session.cert_identity_status ~= "valid") then + module:log("warn", "Forbidding insecure connection to/from %s", host); + if session.direction == "incoming" then + session:close({ condition = "not-authorized", text = "Your server's certificate is invalid, expired, or not trusted by "..session.to_host }); + else -- Close outgoing connections without warning + session:close(false); + end + return false; + end +end + +module:hook("s2s-check-certificate", check_auth_policy, -1); + +s2sout.set_listener(listener); + +module:hook("server-stopping", function(event) + local reason = event.reason; + for _, session in pairs(sessions) do + session:close{ condition = "system-shutdown", text = reason }; + end +end,500); + + + +module:provides("net", { + name = "s2s"; + listener = listener; + default_port = 5269; + encryption = "starttls"; + multiplex = { + pattern = "^<.*:stream.*%sxmlns%s*=%s*(['\"])jabber:server%1.*>"; + }; +}); + diff --git a/plugins/mod_s2s/s2sout.lib.lua b/plugins/mod_s2s/s2sout.lib.lua new file mode 100644 index 00000000..ec8ea4d4 --- /dev/null +++ b/plugins/mod_s2s/s2sout.lib.lua @@ -0,0 +1,352 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +--- Module containing all the logic for connecting to a remote server + +local portmanager = require "core.portmanager"; +local wrapclient = require "net.server".wrapclient; +local initialize_filters = require "util.filters".initialize; +local idna_to_ascii = require "util.encodings".idna.to_ascii; +local new_ip = require "util.ip".new_ip; +local rfc6724_dest = require "util.rfc6724".destination; +local socket = require "socket"; +local adns = require "net.adns"; +local dns = require "net.dns"; +local t_insert, t_sort, ipairs = table.insert, table.sort, ipairs; +local local_addresses = require "util.net".local_addresses; + +local s2s_destroy_session = require "core.s2smanager".destroy_session; + +local log = module._log; + +local sources = {}; +local has_ipv4, has_ipv6; + +local dns_timeout = module:get_option_number("dns_timeout", 15); +dns.settimeout(dns_timeout); +local max_dns_depth = module:get_option_number("dns_max_depth", 3); + +local s2sout = {}; + +local s2s_listener; + + +function s2sout.set_listener(listener) + s2s_listener = listener; +end + +local function compare_srv_priorities(a,b) + return a.priority < b.priority or (a.priority == b.priority and a.weight > b.weight); +end + +function s2sout.initiate_connection(host_session) + initialize_filters(host_session); + host_session.version = 1; + + -- Kick the connection attempting machine into life + if not s2sout.attempt_connection(host_session) then + -- Intentionally not returning here, the + -- session is needed, connected or not + s2s_destroy_session(host_session); + end + + if not host_session.sends2s then + -- A sends2s which buffers data (until the stream is opened) + -- note that data in this buffer will be sent before the stream is authed + -- and will not be ack'd in any way, successful or otherwise + local buffer; + function host_session.sends2s(data) + if not buffer then + buffer = {}; + host_session.send_buffer = buffer; + end + log("debug", "Buffering data on unconnected s2sout to %s", tostring(host_session.to_host)); + buffer[#buffer+1] = data; + log("debug", "Buffered item %d: %s", #buffer, tostring(data)); + end + end +end + +function s2sout.attempt_connection(host_session, err) + local to_host = host_session.to_host; + local connect_host, connect_port = to_host and idna_to_ascii(to_host), 5269; + + if not connect_host then + return false; + end + + if not err then -- This is our first attempt + log("debug", "First attempt to connect to %s, starting with SRV lookup...", to_host); + host_session.connecting = true; + local handle; + handle = adns.lookup(function (answer) + handle = nil; + host_session.connecting = nil; + if answer and #answer > 0 then + log("debug", "%s has SRV records, handling...", to_host); + local srv_hosts = { answer = answer }; + host_session.srv_hosts = srv_hosts; + for _, record in ipairs(answer) do + t_insert(srv_hosts, record.srv); + end + if #srv_hosts == 1 and srv_hosts[1].target == "." then + log("debug", "%s does not provide a XMPP service", to_host); + s2s_destroy_session(host_session, err); -- Nothing to see here + return; + end + t_sort(srv_hosts, compare_srv_priorities); + + local srv_choice = srv_hosts[1]; + host_session.srv_choice = 1; + if srv_choice then + connect_host, connect_port = srv_choice.target or to_host, srv_choice.port or connect_port; + log("debug", "Best record found, will connect to %s:%d", connect_host, connect_port); + end + else + log("debug", "%s has no SRV records, falling back to A/AAAA", to_host); + end + -- Try with SRV, or just the plain hostname if no SRV + local ok, err = s2sout.try_connect(host_session, connect_host, connect_port); + if not ok then + if not s2sout.attempt_connection(host_session, err) then + -- No more attempts will be made + s2s_destroy_session(host_session, err); + end + end + end, "_xmpp-server._tcp."..connect_host..".", "SRV"); + + return true; -- Attempt in progress + elseif host_session.ip_hosts then + return s2sout.try_connect(host_session, connect_host, connect_port, err); + elseif host_session.srv_hosts and #host_session.srv_hosts > host_session.srv_choice then -- Not our first attempt, and we also have SRV + host_session.srv_choice = host_session.srv_choice + 1; + local srv_choice = host_session.srv_hosts[host_session.srv_choice]; + connect_host, connect_port = srv_choice.target or to_host, srv_choice.port or connect_port; + host_session.log("info", "Connection failed (%s). Attempt #%d: This time to %s:%d", tostring(err), host_session.srv_choice, connect_host, connect_port); + else + host_session.log("info", "Failed in all attempts to connect to %s", tostring(host_session.to_host)); + -- We're out of options + return false; + end + + if not (connect_host and connect_port) then + -- Likely we couldn't resolve DNS + log("warn", "Hmm, we're without a host (%s) and port (%s) to connect to for %s, giving up :(", tostring(connect_host), tostring(connect_port), tostring(to_host)); + return false; + end + + return s2sout.try_connect(host_session, connect_host, connect_port); +end + +function s2sout.try_next_ip(host_session) + host_session.connecting = nil; + host_session.ip_choice = host_session.ip_choice + 1; + local ip = host_session.ip_hosts[host_session.ip_choice]; + local ok, err= s2sout.make_connect(host_session, ip.ip, ip.port); + if not ok then + if not s2sout.attempt_connection(host_session, err or "closed") then + err = err and (": "..err) or ""; + s2s_destroy_session(host_session, "Connection failed"..err); + end + end +end + +function s2sout.try_connect(host_session, connect_host, connect_port, err) + host_session.connecting = true; + + if not err then + local IPs = {}; + host_session.ip_hosts = IPs; + local handle4, handle6; + local have_other_result = not(has_ipv4) or not(has_ipv6) or false; + + if has_ipv4 then + handle4 = adns.lookup(function (reply, err) + handle4 = nil; + + -- COMPAT: This is a compromise for all you CNAME-(ab)users :) + if not (reply and reply[#reply] and reply[#reply].a) then + local count = max_dns_depth; + reply = dns.peek(connect_host, "CNAME", "IN"); + while count > 0 and reply and reply[#reply] and not reply[#reply].a and reply[#reply].cname do + log("debug", "Looking up %s (DNS depth is %d)", tostring(reply[#reply].cname), count); + reply = dns.peek(reply[#reply].cname, "A", "IN") or dns.peek(reply[#reply].cname, "CNAME", "IN"); + count = count - 1; + end + end + -- end of CNAME resolving + + if reply and reply[#reply] and reply[#reply].a then + for _, ip in ipairs(reply) do + log("debug", "DNS reply for %s gives us %s", connect_host, ip.a); + IPs[#IPs+1] = new_ip(ip.a, "IPv4"); + end + end + + if have_other_result then + if #IPs > 0 then + rfc6724_dest(host_session.ip_hosts, sources); + for i = 1, #IPs do + IPs[i] = {ip = IPs[i], port = connect_port}; + end + host_session.ip_choice = 0; + s2sout.try_next_ip(host_session); + else + log("debug", "DNS lookup failed to get a response for %s", connect_host); + host_session.ip_hosts = nil; + if not s2sout.attempt_connection(host_session, "name resolution failed") then -- Retry if we can + log("debug", "No other records to try for %s - destroying", host_session.to_host); + err = err and (": "..err) or ""; + s2s_destroy_session(host_session, "DNS resolution failed"..err); -- End of the line, we can't + end + end + else + have_other_result = true; + end + end, connect_host, "A", "IN"); + else + have_other_result = true; + end + + if has_ipv6 then + handle6 = adns.lookup(function (reply, err) + handle6 = nil; + + if reply and reply[#reply] and reply[#reply].aaaa then + for _, ip in ipairs(reply) do + log("debug", "DNS reply for %s gives us %s", connect_host, ip.aaaa); + IPs[#IPs+1] = new_ip(ip.aaaa, "IPv6"); + end + end + + if have_other_result then + if #IPs > 0 then + rfc6724_dest(host_session.ip_hosts, sources); + for i = 1, #IPs do + IPs[i] = {ip = IPs[i], port = connect_port}; + end + host_session.ip_choice = 0; + s2sout.try_next_ip(host_session); + else + log("debug", "DNS lookup failed to get a response for %s", connect_host); + host_session.ip_hosts = nil; + if not s2sout.attempt_connection(host_session, "name resolution failed") then -- Retry if we can + log("debug", "No other records to try for %s - destroying", host_session.to_host); + err = err and (": "..err) or ""; + s2s_destroy_session(host_session, "DNS resolution failed"..err); -- End of the line, we can't + end + end + else + have_other_result = true; + end + end, connect_host, "AAAA", "IN"); + else + have_other_result = true; + end + return true; + elseif host_session.ip_hosts and #host_session.ip_hosts > host_session.ip_choice then -- Not our first attempt, and we also have IPs left to try + s2sout.try_next_ip(host_session); + else + host_session.ip_hosts = nil; + if not s2sout.attempt_connection(host_session, "out of IP addresses") then -- Retry if we can + log("debug", "No other records to try for %s - destroying", host_session.to_host); + err = err and (": "..err) or ""; + s2s_destroy_session(host_session, "Connecting failed"..err); -- End of the line, we can't + return false; + end + end + + return true; +end + +function s2sout.make_connect(host_session, connect_host, connect_port) + (host_session.log or log)("debug", "Beginning new connection attempt to %s ([%s]:%d)", host_session.to_host, connect_host.addr, connect_port); + + -- Reset secure flag in case this is another + -- connection attempt after a failed STARTTLS + host_session.secure = nil; + + local conn, handler; + local proto = connect_host.proto; + if proto == "IPv4" then + conn, handler = socket.tcp(); + elseif proto == "IPv6" and socket.tcp6 then + conn, handler = socket.tcp6(); + else + handler = "Unsupported protocol: "..tostring(proto); + end + + if not conn then + log("warn", "Failed to create outgoing connection, system error: %s", handler); + return false, handler; + end + + conn:settimeout(0); + local success, err = conn:connect(connect_host.addr, connect_port); + if not success and err ~= "timeout" then + log("warn", "s2s connect() to %s (%s:%d) failed: %s", host_session.to_host, connect_host.addr, connect_port, err); + return false, err; + end + + conn = wrapclient(conn, connect_host.addr, connect_port, s2s_listener, "*a"); + host_session.conn = conn; + + local filter = initialize_filters(host_session); + local w, log = conn.write, host_session.log; + host_session.sends2s = function (t) + log("debug", "sending: %s", (t.top_tag and t:top_tag()) or t:match("^[^>]*>?")); + if t.name then + t = filter("stanzas/out", t); + end + if t then + t = filter("bytes/out", tostring(t)); + if t then + return w(conn, tostring(t)); + end + end + end + + -- Register this outgoing connection so that xmppserver_listener knows about it + -- otherwise it will assume it is a new incoming connection + s2s_listener.register_outgoing(conn, host_session); + + log("debug", "Connection attempt in progress..."); + return true; +end + +module:hook_global("service-added", function (event) + if event.name ~= "s2s" then return end + + local s2s_sources = portmanager.get_active_services():get("s2s"); + if not s2s_sources then + module:log("warn", "s2s not listening on any ports, outgoing connections may fail"); + return; + end + for source, _ in pairs(s2s_sources) do + if source == "*" or source == "0.0.0.0" then + for _, addr in ipairs(local_addresses("ipv4", true)) do + sources[#sources + 1] = new_ip(addr, "IPv4"); + end + elseif source == "::" then + for _, addr in ipairs(local_addresses("ipv6", true)) do + sources[#sources + 1] = new_ip(addr, "IPv6"); + end + else + sources[#sources + 1] = new_ip(source, (source:find(":") and "IPv6") or "IPv4"); + end + end + for i = 1,#sources do + if sources[i].proto == "IPv6" then + has_ipv6 = true; + elseif sources[i].proto == "IPv4" then + has_ipv4 = true; + end + end +end); + +return s2sout; diff --git a/plugins/mod_saslauth.lua b/plugins/mod_saslauth.lua index 422bc187..f24eacf8 100644 --- a/plugins/mod_saslauth.lua +++ b/plugins/mod_saslauth.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -11,12 +11,10 @@ local st = require "util.stanza"; local sm_bind_resource = require "core.sessionmanager".bind_resource; local sm_make_authenticated = require "core.sessionmanager".make_authenticated; -local s2s_make_authenticated = require "core.s2smanager".make_authenticated; local base64 = require "util.encodings".base64; local cert_verify_identity = require "util.x509".verify_identity; -local nodeprep = require "util.encodings".stringprep.nodeprep; local usermanager_get_sasl_handler = require "core.usermanager".get_sasl_handler; local tostring = tostring; @@ -27,7 +25,6 @@ local log = module._log; local xmlns_sasl ='urn:ietf:params:xml:ns:xmpp-sasl'; local xmlns_bind ='urn:ietf:params:xml:ns:xmpp-bind'; -local xmlns_stanzas ='urn:ietf:params:xml:ns:xmpp-stanzas'; local function build_reply(status, ret, err_msg) local reply = st.stanza(status, {xmlns = xmlns_sasl}); @@ -48,16 +45,17 @@ end local function handle_status(session, status, ret, err_msg) if status == "failure" then + module:fire_event("authentication-failure", { session = session, condition = ret, text = err_msg }); session.sasl_handler = session.sasl_handler:clean_clone(); elseif status == "success" then - local username = nodeprep(session.sasl_handler.username); - local ok, err = sm_make_authenticated(session, session.sasl_handler.username); if ok then + module:fire_event("authentication-success", { session = session }); session.sasl_handler = nil; session:reset_stream(); else module:log("warn", "SASL succeeded but username was invalid"); + module:fire_event("authentication-failure", { session = session, condition = "not-authorized", text = err }); session.sasl_handler = session.sasl_handler:clean_clone(); return "failure", "not-authorized", "User authenticated successfully, but username was invalid"; end @@ -89,13 +87,9 @@ module:hook_stanza(xmlns_sasl, "success", function (session, stanza) module:log("debug", "SASL EXTERNAL with %s succeeded", session.to_host); session.external_auth = "succeeded" session:reset_stream(); + session:open_stream(session.from_host, session.to_host); - local default_stream_attr = {xmlns = "jabber:server", ["xmlns:stream"] = "http://etherx.jabber.org/streams", - ["xmlns:db"] = 'jabber:server:dialback', version = "1.0", to = session.to_host, from = session.from_host}; - session.sends2s("<?xml version='1.0'?>"); - session.sends2s(st.stanza("stream:stream", default_stream_attr):top_tag()); - - s2s_make_authenticated(session, session.to_host); + module:fire_event("s2s-authenticated", { session = session, host = session.to_host }); return true; end) @@ -189,8 +183,10 @@ local function s2s_external_auth(session, stanza) session.from_host = text; end session.sends2s(build_reply("success")) - module:log("info", "Accepting SASL EXTERNAL identity from %s", text or session.from_host); - s2s_make_authenticated(session, text or session.from_host) + + local domain = text ~= "" and text or session.from_host; + module:log("info", "Accepting SASL EXTERNAL identity from %s", domain); + module:fire_event("s2s-authenticated", { session = session, host = domain }); session:reset_stream(); return true end @@ -207,7 +203,7 @@ module:hook("stanza/urn:ietf:params:xml:ns:xmpp-sasl:auth", function(event) session.sasl_handler = nil; -- allow starting a new SASL negotiation before completing an old one end if not session.sasl_handler then - session.sasl_handler = usermanager_get_sasl_handler(module.host); + session.sasl_handler = usermanager_get_sasl_handler(module.host, session); end local mechanism = stanza.attr.mechanism; if not session.secure and (secure_auth_only or (mechanism == "PLAIN" and not allow_unencrypted_plain_auth)) then @@ -245,7 +241,7 @@ module:hook("stream-features", function(event) if secure_auth_only and not origin.secure then return; end - origin.sasl_handler = usermanager_get_sasl_handler(module.host); + origin.sasl_handler = usermanager_get_sasl_handler(module.host, origin); if origin.secure then -- check wether LuaSec has the nifty binding to the function needed for tls-unique -- FIXME: would be nice to have this check only once and not for every socket @@ -256,13 +252,13 @@ module:hook("stream-features", function(event) origin.sasl_handler["userdata"] = origin.conn:socket(); end end - features:tag("mechanisms", mechanisms_attr); + local mechanisms = st.stanza("mechanisms", mechanisms_attr); for mechanism in pairs(origin.sasl_handler:mechanisms()) do if mechanism ~= "PLAIN" or origin.secure or allow_unencrypted_plain_auth then - features:tag("mechanism"):text(mechanism):up(); + mechanisms:tag("mechanism"):text(mechanism):up(); end end - features:up(); + if mechanisms[1] then features:add_child(mechanisms); end else features:tag("bind", bind_attr):tag("required"):up():up(); features:tag("session", xmpp_session_attr):tag("optional"):up():up(); diff --git a/plugins/mod_storage_internal.lua b/plugins/mod_storage_internal.lua new file mode 100644 index 00000000..972ecbee --- /dev/null +++ b/plugins/mod_storage_internal.lua @@ -0,0 +1,31 @@ +local datamanager = require "core.storagemanager".olddm; + +local host = module.host; + +local driver = {}; +local driver_mt = { __index = driver }; + +function driver:open(store, typ) + return setmetatable({ store = store, type = typ }, driver_mt); +end +function driver:get(user) + return datamanager.load(user, host, self.store); +end + +function driver:set(user, data) + return datamanager.store(user, host, self.store, data); +end + +function driver:stores(username) + return datamanager.stores(username, host); +end + +function driver:users() + return datamanager.users(host, self.store, self.type); +end + +function driver:purge(user) + return datamanager.purge(user, host); +end + +module:provides("storage", driver); diff --git a/plugins/mod_storage_none.lua b/plugins/mod_storage_none.lua new file mode 100644 index 00000000..8f2d2f56 --- /dev/null +++ b/plugins/mod_storage_none.lua @@ -0,0 +1,23 @@ +local driver = {}; +local driver_mt = { __index = driver }; + +function driver:open(store) + return setmetatable({ store = store }, driver_mt); +end +function driver:get(user) + return {}; +end + +function driver:set(user, data) + return nil, "Storage disabled"; +end + +function driver:stores(username) + return { "roster" }; +end + +function driver:purge(user) + return true; +end + +module:provides("storage", driver); diff --git a/plugins/mod_storage_sql.lua b/plugins/mod_storage_sql.lua index 53f1ea0a..1f453d42 100644 --- a/plugins/mod_storage_sql.lua +++ b/plugins/mod_storage_sql.lua @@ -25,42 +25,154 @@ local tonumber = tonumber; local pairs = pairs; local next = next; local setmetatable = setmetatable; +local xpcall = xpcall; local json = require "util.json"; +local build_url = require"socket.url".build; -local connection = ...; +local DBI; +local connection; local host,user,store = module.host; local params = module:get_option("sql"); -do -- process options to get a db connection - local DBI = require "DBI"; +local dburi; +local connections = module:shared "/*/sql/connection-cache"; + +local function db2uri(params) + return build_url{ + scheme = params.driver, + user = params.username, + password = params.password, + host = params.host, + port = params.port, + path = params.database, + }; +end + + +local resolve_relative_path = require "core.configmanager".resolve_relative_path; + +local function test_connection() + if not connection then return nil; end + if connection:ping() then + return true; + else + module:log("debug", "Database connection closed"); + connection = nil; + connections[dburi] = nil; + end +end +local function connect() + if not test_connection() then + prosody.unlock_globals(); + local dbh, err = DBI.Connect( + params.driver, params.database, + params.username, params.password, + params.host, params.port + ); + prosody.lock_globals(); + if not dbh then + module:log("debug", "Database connection failed: %s", tostring(err)); + return nil, err; + end + module:log("debug", "Successfully connected to database"); + dbh:autocommit(false); -- don't commit automatically + connection = dbh; + + connections[dburi] = dbh; + end + return connection; +end + +local function create_table() + if not module:get_option("sql_manage_tables", true) then + return; + end + local create_sql = "CREATE TABLE `prosody` (`host` TEXT, `user` TEXT, `store` TEXT, `key` TEXT, `type` TEXT, `value` TEXT);"; + if params.driver == "PostgreSQL" then + create_sql = create_sql:gsub("`", "\""); + elseif params.driver == "MySQL" then + create_sql = create_sql:gsub("`value` TEXT", "`value` MEDIUMTEXT"); + end + + local stmt, err = connection:prepare(create_sql); + if stmt then + local ok = stmt:execute(); + local commit_ok = connection:commit(); + if ok and commit_ok then + module:log("info", "Initialized new %s database with prosody table", params.driver); + local index_sql = "CREATE INDEX `prosody_index` ON `prosody` (`host`, `user`, `store`, `key`)"; + if params.driver == "PostgreSQL" then + index_sql = index_sql:gsub("`", "\""); + elseif params.driver == "MySQL" then + index_sql = index_sql:gsub("`([,)])", "`(20)%1"); + end + local stmt, err = connection:prepare(index_sql); + local ok, commit_ok, commit_err; + if stmt then + ok, err = stmt:execute(); + commit_ok, commit_err = connection:commit(); + end + if not(ok and commit_ok) then + module:log("warn", "Failed to create index (%s), lookups may not be optimised", err or commit_err); + end + elseif params.driver == "MySQL" then -- COMPAT: Upgrade tables from 0.8.0 + -- Failed to create, but check existing MySQL table here + local stmt = connection:prepare("SHOW COLUMNS FROM prosody WHERE Field='value' and Type='text'"); + local ok = stmt:execute(); + local commit_ok = connection:commit(); + if ok and commit_ok then + if stmt:rowcount() > 0 then + module:log("info", "Upgrading database schema..."); + local stmt = connection:prepare("ALTER TABLE prosody MODIFY COLUMN `value` MEDIUMTEXT"); + local ok, err = stmt:execute(); + local commit_ok = connection:commit(); + if ok and commit_ok then + module:log("info", "Database table automatically upgraded"); + else + module:log("error", "Failed to upgrade database schema (%s), please see " + .."http://prosody.im/doc/mysql for help", + err or "unknown error"); + end + end + repeat until not stmt:fetch(); + end + end + elseif params.driver ~= "SQLite3" then -- SQLite normally fails to prepare for existing table + module:log("warn", "Prosody was not able to automatically check/create the database table (%s), " + .."see http://prosody.im/doc/modules/mod_storage_sql#table_management for help.", + err or "unknown error"); + end +end - params = params or { driver = "SQLite3", database = "prosody.sqlite" }; - assert(params.driver and params.database, "invalid params"); - +do -- process options to get a db connection + local ok; prosody.unlock_globals(); - local dbh, err = DBI.Connect( - params.driver, params.database, - params.username, params.password, - params.host, params.port - ); + ok, DBI = pcall(require, "DBI"); + if not ok then + package.loaded["DBI"] = {}; + module:log("error", "Failed to load the LuaDBI library for accessing SQL databases: %s", DBI); + module:log("error", "More information on installing LuaDBI can be found at http://prosody.im/doc/depends#luadbi"); + end prosody.lock_globals(); - assert(dbh, err); - - dbh:autocommit(false); -- don't commit automatically - connection = dbh; - - if params.driver == "SQLite3" then -- auto initialize - local stmt = assert(connection:prepare("SELECT COUNT(*) FROM `sqlite_master` WHERE `type`='table' AND `name`='Prosody';")); - local ok = assert(stmt:execute()); - local count = stmt:fetch()[1]; - if count == 0 then - local stmt = assert(connection:prepare("CREATE TABLE `Prosody` (`host` TEXT, `user` TEXT, `store` TEXT, `key` TEXT, `type` TEXT, `value` TEXT);")); - assert(stmt:execute()); - module:log("debug", "Initialized new SQLite3 database"); - end - assert(connection:commit()); - --print("===", json.encode()) + if not ok or not DBI.Connect then + return; -- Halt loading of this module end + + params = params or { driver = "SQLite3" }; + + if params.driver == "SQLite3" then + params.database = resolve_relative_path(prosody.paths.data or ".", params.database or "prosody.sqlite"); + end + + assert(params.driver and params.database, "Both the SQL driver and the database need to be specified"); + + dburi = db2uri(params); + connection = connections[dburi]; + + assert(connect()); + + -- Automatically create table, ignore failure (table probably already exists) + create_table(); end local function serialize(value) @@ -85,19 +197,24 @@ local function deserialize(t, value) end end -local function getsql(sql, ...) +local function dosql(sql, ...) if params.driver == "PostgreSQL" then sql = sql:gsub("`", "\""); end -- do prepared statement stuff local stmt, err = connection:prepare(sql); + if not stmt and not test_connection() then error("connection failed"); end if not stmt then module:log("error", "QUERY FAILED: %s %s", err, debug.traceback()); return nil, err; end -- run query - local ok, err = stmt:execute(host or "", user or "", store or "", ...); + local ok, err = stmt:execute(...); + if not ok and not test_connection() then error("connection failed"); end if not ok then return nil, err; end - + return stmt; end +local function getsql(sql, ...) + return dosql(sql, host or "", user or "", store or "", ...); +end local function setsql(sql, ...) local stmt, err = getsql(sql, ...); if not stmt then return stmt, err; end @@ -107,21 +224,19 @@ local function transact(...) -- ... end local function rollback(...) - connection:rollback(); -- FIXME check for rollback error? + if connection then connection:rollback(); end -- FIXME check for rollback error? return ...; end local function commit(...) - if not connection:commit() then return nil, "SQL commit failed"; end + local success,err = connection:commit(); + if not success then return nil, "SQL commit failed: "..tostring(err); end return ...; end -local keyval_store = {}; -keyval_store.__index = keyval_store; -function keyval_store:get(username) - user,store = username,self.store; - local stmt, err = getsql("SELECT * FROM `Prosody` WHERE `host`=? AND `user`=? AND `store`=?"); - if not stmt then return nil, err; end - +local function keyval_store_get() + local stmt, err = getsql("SELECT * FROM `prosody` WHERE `host`=? AND `user`=? AND `store`=?"); + if not stmt then return rollback(nil, err); end + local haveany; local result = {}; for row in stmt:rows(true) do @@ -138,18 +253,17 @@ function keyval_store:get(username) end return commit(haveany and result or nil); end -function keyval_store:set(username, data) - user,store = username,self.store; - -- start transaction - local affected, err = setsql("DELETE FROM `Prosody` WHERE `host`=? AND `user`=? AND `store`=?"); - +local function keyval_store_set(data) + local affected, err = setsql("DELETE FROM `prosody` WHERE `host`=? AND `user`=? AND `store`=?"); + if not affected then return rollback(affected, err); end + if data and next(data) ~= nil then local extradata = {}; for key, value in pairs(data) do if type(key) == "string" and key ~= "" then local t, value = serialize(value); if not t then return rollback(t, value); end - local ok, err = setsql("INSERT INTO `Prosody` (`host`,`user`,`store`,`key`,`type`,`value`) VALUES (?,?,?,?,?,?)", key, t, value); + local ok, err = setsql("INSERT INTO `prosody` (`host`,`user`,`store`,`key`,`type`,`value`) VALUES (?,?,?,?,?,?)", key, t, value); if not ok then return rollback(ok, err); end else extradata[key] = value; @@ -158,20 +272,49 @@ function keyval_store:set(username, data) if next(extradata) ~= nil then local t, extradata = serialize(extradata); if not t then return rollback(t, extradata); end - local ok, err = setsql("INSERT INTO `Prosody` (`host`,`user`,`store`,`key`,`type`,`value`) VALUES (?,?,?,?,?,?)", "", t, extradata); + local ok, err = setsql("INSERT INTO `prosody` (`host`,`user`,`store`,`key`,`type`,`value`) VALUES (?,?,?,?,?,?)", "", t, extradata); if not ok then return rollback(ok, err); end end end return commit(true); end -local map_store = {}; -map_store.__index = map_store; -function map_store:get(username, key) +local keyval_store = {}; +keyval_store.__index = keyval_store; +function keyval_store:get(username) user,store = username,self.store; - local stmt, err = getsql("SELECT * FROM `Prosody` WHERE `host`=? AND `user`=? AND `store`=? AND `key`=?", key or ""); - if not stmt then return nil, err; end - + if not connection and not connect() then return nil, "Unable to connect to database"; end + local success, ret, err = xpcall(keyval_store_get, debug.traceback); + if not connection and connect() then + success, ret, err = xpcall(keyval_store_get, debug.traceback); + end + if success then return ret, err; else return rollback(nil, ret); end +end +function keyval_store:set(username, data) + user,store = username,self.store; + if not connection and not connect() then return nil, "Unable to connect to database"; end + local success, ret, err = xpcall(function() return keyval_store_set(data); end, debug.traceback); + if not connection and connect() then + success, ret, err = xpcall(function() return keyval_store_set(data); end, debug.traceback); + end + if success then return ret, err; else return rollback(nil, ret); end +end +function keyval_store:users() + local stmt, err = dosql("SELECT DISTINCT `user` FROM `prosody` WHERE `host`=? AND `store`=?", host, self.store); + if not stmt then + return rollback(nil, err); + end + local next = stmt:rows(); + return commit(function() + local row = next(); + return row and row[1]; + end); +end + +local function map_store_get(key) + local stmt, err = getsql("SELECT * FROM `prosody` WHERE `host`=? AND `user`=? AND `store`=? AND `key`=?", key or ""); + if not stmt then return rollback(nil, err); end + local haveany; local result = {}; for row in stmt:rows(true) do @@ -188,16 +331,15 @@ function map_store:get(username, key) end return commit(haveany and result[key] or nil); end -function map_store:set(username, key, data) - user,store = username,self.store; - -- start transaction - local affected, err = setsql("DELETE FROM `Prosody` WHERE `host`=? AND `user`=? AND `store`=? AND `key`=?", key or ""); - +local function map_store_set(key, data) + local affected, err = setsql("DELETE FROM `prosody` WHERE `host`=? AND `user`=? AND `store`=? AND `key`=?", key or ""); + if not affected then return rollback(affected, err); end + if data and next(data) ~= nil then if type(key) == "string" and key ~= "" then local t, value = serialize(data); if not t then return rollback(t, value); end - local ok, err = setsql("INSERT INTO `Prosody` (`host`,`user`,`store`,`key`,`type`,`value`) VALUES (?,?,?,?,?,?)", key, t, value); + local ok, err = setsql("INSERT INTO `prosody` (`host`,`user`,`store`,`key`,`type`,`value`) VALUES (?,?,?,?,?,?)", key, t, value); if not ok then return rollback(ok, err); end else -- TODO non-string keys @@ -206,23 +348,36 @@ function map_store:set(username, key, data) return commit(true); end +local map_store = {}; +map_store.__index = map_store; +function map_store:get(username, key) + user,store = username,self.store; + local success, ret, err = xpcall(function() return map_store_get(key); end, debug.traceback); + if success then return ret, err; else return rollback(nil, ret); end +end +function map_store:set(username, key, data) + user,store = username,self.store; + local success, ret, err = xpcall(function() return map_store_set(key, data); end, debug.traceback); + if success then return ret, err; else return rollback(nil, ret); end +end + local list_store = {}; list_store.__index = list_store; function list_store:scan(username, from, to, jid, typ) user,store = username,self.store; - + local cols = {"from", "to", "jid", "typ"}; local vals = { from , to , jid , typ }; local stmt, err; - local query = "SELECT * FROM `ProsodyArchive` WHERE `host`=? AND `user`=? AND `store`=?"; - + local query = "SELECT * FROM `prosodyarchive` WHERE `host`=? AND `user`=? AND `store`=?"; + query = query.." ORDER BY time"; - --local stmt, err = getsql("SELECT * FROM `Prosody` WHERE `host`=? AND `user`=? AND `store`=? AND `key`=?", key or ""); - + --local stmt, err = getsql("SELECT * FROM `prosody` WHERE `host`=? AND `user`=? AND `store`=? AND `key`=?", key or ""); + return nil, "not-implemented" end -local driver = { name = "sql" }; +local driver = {}; function driver:open(store, typ) if not typ then -- default key-value store @@ -231,4 +386,29 @@ function driver:open(store, typ) return nil, "unsupported-store"; end -module:add_item("data-driver", driver); +function driver:stores(username) + local sql = "SELECT DISTINCT `store` FROM `prosody` WHERE `host`=? AND `user`" .. + (username == true and "!=?" or "=?"); + if username == true or not username then + username = ""; + end + local stmt, err = dosql(sql, host, username); + if not stmt then + return rollback(nil, err); + end + local next = stmt:rows(); + return commit(function() + local row = next(); + return row and row[1]; + end); +end + +function driver:purge(username) + local stmt, err = dosql("DELETE FROM `prosody` WHERE `host`=? AND `user`=?", host, username); + if not stmt then return rollback(stmt, err); end + local changed, err = stmt:affected(); + if not changed then return rollback(changed, err); end + return commit(true, changed); +end + +module:provides("storage", driver); diff --git a/plugins/mod_storage_sql2.lua b/plugins/mod_storage_sql2.lua new file mode 100644 index 00000000..3c5f9d20 --- /dev/null +++ b/plugins/mod_storage_sql2.lua @@ -0,0 +1,370 @@ + +local json = require "util.json"; +local xml_parse = require "util.xml".parse; +local uuid = require "util.uuid"; +local resolve_relative_path = require "core.configmanager".resolve_relative_path; + +local stanza_mt = require"util.stanza".stanza_mt; +local getmetatable = getmetatable; +local t_concat = table.concat; +local function is_stanza(x) return getmetatable(x) == stanza_mt; end + +local noop = function() end +local unpack = unpack +local function iterator(result) + return function(result) + local row = result(); + if row ~= nil then + return unpack(row); + end + end, result, nil; +end + +local mod_sql = module:require("sql"); +local params = module:get_option("sql"); + +local engine; -- TODO create engine + +local function create_table() + local Table,Column,Index = mod_sql.Table,mod_sql.Column,mod_sql.Index; + --[[ + local ProsodyTable = Table { + name="prosody"; + Column { name="host", type="TEXT", nullable=false }; + Column { name="user", type="TEXT", nullable=false }; + Column { name="store", type="TEXT", nullable=false }; + Column { name="key", type="TEXT", nullable=false }; + Column { name="type", type="TEXT", nullable=false }; + Column { name="value", type="TEXT", nullable=false }; + Index { name="prosody_index", "host", "user", "store", "key" }; + }; + engine:transaction(function() + ProsodyTable:create(engine); + end);]] + if not module:get_option("sql_manage_tables", true) then + return; + end + + local create_sql = "CREATE TABLE `prosody` (`host` TEXT, `user` TEXT, `store` TEXT, `key` TEXT, `type` TEXT, `value` TEXT);"; + if params.driver == "PostgreSQL" then + create_sql = create_sql:gsub("`", "\""); + elseif params.driver == "MySQL" then + create_sql = create_sql:gsub("`value` TEXT", "`value` MEDIUMTEXT") + :gsub(";$", " CHARACTER SET 'utf8' COLLATE 'utf8_bin';"); + end + + local index_sql = "CREATE INDEX `prosody_index` ON `prosody` (`host`, `user`, `store`, `key`)"; + if params.driver == "PostgreSQL" then + index_sql = index_sql:gsub("`", "\""); + elseif params.driver == "MySQL" then + index_sql = index_sql:gsub("`([,)])", "`(20)%1"); + end + + local success,err = engine:transaction(function() + engine:execute(create_sql); + engine:execute(index_sql); + end); + if not success then -- so we failed to create + if params.driver == "MySQL" then + success,err = engine:transaction(function() + local result = engine:execute("SHOW COLUMNS FROM prosody WHERE Field='value' and Type='text'"); + if result:rowcount() > 0 then + module:log("info", "Upgrading database schema..."); + engine:execute("ALTER TABLE prosody MODIFY COLUMN `value` MEDIUMTEXT"); + module:log("info", "Database table automatically upgraded"); + end + return true; + end); + if not success then + module:log("error", "Failed to check/upgrade database schema (%s), please see " + .."http://prosody.im/doc/mysql for help", + err or "unknown error"); + end + end + end + local ProsodyArchiveTable = Table { + name="prosodyarchive"; + Column { name="sort_id", type="INTEGER PRIMARY KEY AUTOINCREMENT", nullable=false }; + Column { name="host", type="TEXT", nullable=false }; + Column { name="user", type="TEXT", nullable=false }; + Column { name="store", type="TEXT", nullable=false }; + Column { name="key", type="TEXT", nullable=false }; -- item id + Column { name="when", type="INTEGER", nullable=false }; -- timestamp + Column { name="with", type="TEXT", nullable=false }; -- related id + Column { name="type", type="TEXT", nullable=false }; + Column { name="value", type=params.driver == "MySQL" and "MEDIUMTEXT" or "TEXT", nullable=false }; + Index { name="prosodyarchive_index", "host", "user", "store", "key" }; + }; + engine:transaction(function() + ProsodyArchiveTable:create(engine); + end); +end +local function set_encoding() + if params.driver == "SQLite3" then return end + local set_names_query = "SET NAMES 'utf8';"; + if params.driver == "MySQL" then + set_names_query = set_names_query:gsub(";$", " COLLATE 'utf8_bin';"); + end + local success,err = engine:transaction(function() return engine:execute(set_names_query); end); + if not success then + module:log("error", "Failed to set database connection encoding to UTF8: %s", err); + return; + end + if params.driver == "MySQL" then + -- COMPAT w/pre-0.9: Upgrade tables to UTF-8 if not already + local check_encoding_query = "SELECT `COLUMN_NAME`,`COLUMN_TYPE` FROM `information_schema`.`columns` WHERE `TABLE_NAME`='prosody' AND ( `CHARACTER_SET_NAME`!='utf8' OR `COLLATION_NAME`!='utf8_bin' );"; + local success,err = engine:transaction(function() + local result = engine:execute(check_encoding_query); + local n_bad_columns = result:rowcount(); + if n_bad_columns > 0 then + module:log("warn", "Found %d columns in prosody table requiring encoding change, updating now...", n_bad_columns); + local fix_column_query1 = "ALTER TABLE `prosody` CHANGE `%s` `%s` BLOB;"; + local fix_column_query2 = "ALTER TABLE `prosody` CHANGE `%s` `%s` %s CHARACTER SET 'utf8' COLLATE 'utf8_bin';"; + for row in result:rows() do + local column_name, column_type = unpack(row); + engine:execute(fix_column_query1:format(column_name, column_name)); + engine:execute(fix_column_query2:format(column_name, column_name, column_type)); + end + module:log("info", "Database encoding upgrade complete!"); + end + end); + local success,err = engine:transaction(function() return engine:execute(check_encoding_query); end); + if not success then + module:log("error", "Failed to check/upgrade database encoding: %s", err or "unknown error"); + end + end +end + +do -- process options to get a db connection + params = params or { driver = "SQLite3" }; + + if params.driver == "SQLite3" then + params.database = resolve_relative_path(prosody.paths.data or ".", params.database or "prosody.sqlite"); + end + + assert(params.driver and params.database, "Both the SQL driver and the database need to be specified"); + + --local dburi = db2uri(params); + engine = mod_sql:create_engine(params); + + -- Encoding mess + set_encoding(); + + -- Automatically create table, ignore failure (table probably already exists) + create_table(); +end + +local function serialize(value) + local t = type(value); + if t == "string" or t == "boolean" or t == "number" then + return t, tostring(value); + elseif is_stanza(value) then + return "xml", tostring(value); + elseif t == "table" then + local value,err = json.encode(value); + if value then return "json", value; end + return nil, err; + end + return nil, "Unhandled value type: "..t; +end +local function deserialize(t, value) + if t == "string" then return value; + elseif t == "boolean" then + if value == "true" then return true; + elseif value == "false" then return false; end + elseif t == "number" then return tonumber(value); + elseif t == "json" then + return json.decode(value); + elseif t == "xml" then + return xml_parse(value); + end +end + +local host = module.host; +local user, store; + +local function keyval_store_get() + local haveany; + local result = {}; + for row in engine:select("SELECT `key`,`type`,`value` FROM `prosody` WHERE `host`=? AND `user`=? AND `store`=?", host, user or "", store) do + haveany = true; + local k = row[1]; + local v = deserialize(row[2], row[3]); + if k and v then + if k ~= "" then result[k] = v; elseif type(v) == "table" then + for a,b in pairs(v) do + result[a] = b; + end + end + end + end + if haveany then + return result; + end +end +local function keyval_store_set(data) + engine:delete("DELETE FROM `prosody` WHERE `host`=? AND `user`=? AND `store`=?", host, user or "", store); + + if data and next(data) ~= nil then + local extradata = {}; + for key, value in pairs(data) do + if type(key) == "string" and key ~= "" then + local t, value = serialize(value); + assert(t, value); + engine:insert("INSERT INTO `prosody` (`host`,`user`,`store`,`key`,`type`,`value`) VALUES (?,?,?,?,?,?)", host, user or "", store, key, t, value); + else + extradata[key] = value; + end + end + if next(extradata) ~= nil then + local t, extradata = serialize(extradata); + assert(t, extradata); + engine:insert("INSERT INTO `prosody` (`host`,`user`,`store`,`key`,`type`,`value`) VALUES (?,?,?,?,?,?)", host, user or "", store, "", t, extradata); + end + end + return true; +end + +local keyval_store = {}; +keyval_store.__index = keyval_store; +function keyval_store:get(username) + user,store = username,self.store; + return select(2, engine:transaction(keyval_store_get)); +end +function keyval_store:set(username, data) + user,store = username,self.store; + return engine:transaction(function() + return keyval_store_set(data); + end); +end +function keyval_store:users() + local ok, result = engine:transaction(function() + return engine:select("SELECT DISTINCT `user` FROM `prosody` WHERE `host`=? AND `store`=?", host, self.store); + end); + if not ok then return ok, result end + return iterator(result); +end + +local archive_store = {} +archive_store.__index = archive_store +function archive_store:append(username, when, with, value) + local user,store = username,self.store; + return engine:transaction(function() + local key = uuid.generate(); + local t, value = serialize(value); + engine:insert("INSERT INTO `prosodyarchive` (`host`, `user`, `store`, `when`, `with`, `key`, `type`, `value`) VALUES (?,?,?,?,?,?,?,?)", host, user or "", store, when, with, key, t, value); + return key; + end); +end +function archive_store:find(username, query) + query = query or {}; + local user,store = username,self.store; + local total; + local ok, result = engine:transaction(function() + local sql_query = "SELECT `key`, `type`, `value`, `when` FROM `prosodyarchive` WHERE %s ORDER BY `sort_id` %s%s;"; + local args = { host, user or "", store, }; + local where = { "`host` = ?", "`user` = ?", "`store` = ?", }; + + -- Time range, inclusive + if query.start then + args[#args+1] = query.start + where[#where+1] = "`when` >= ?" + end + if query["end"] then + args[#args+1] = query["end"]; + if query.start then + where[#where] = "`when` BETWEEN ? AND ?" -- is this inclusive? + else + where[#where+1] = "`when` >= ?" + end + end + + -- Related name + if query.with then + where[#where+1] = "`with` = ?"; + args[#args+1] = query.with + end + + -- Unique id + if query.key then + where[#where+1] = "`key` = ?"; + args[#args+1] = query.key + end + + -- Total matching + if query.total then + local stats = engine:select(sql_query:gsub("^(SELECT).-(FROM)", "%1 COUNT(*) %2"):format(t_concat(where, " AND "), "DESC", ""), unpack(args)); + if stats then + local _total = stats() + total = _total and _total[1]; + end + if query.limit == 0 then -- Skip the real query + return noop, total; + end + end + + -- Before or after specific item, exclusive + if query.after then + where[#where+1] = "`sort_id` > (SELECT `sort_id` FROM `prosodyarchive` WHERE `key` = ? LIMIT 1)" + args[#args+1] = query.after + end + if query.before then + where[#where+1] = "`sort_id` < (SELECT `sort_id` FROM `prosodyarchive` WHERE `key` = ? LIMIT 1)" + args[#args+1] = query.before + end + + if query.limit then + args[#args+1] = query.limit; + end + + sql_query = sql_query:format(t_concat(where, " AND "), query.reverse and "DESC" or "ASC", query.limit and " LIMIT ?" or ""); + module:log("debug", sql_query); + return engine:select(sql_query, unpack(args)); + end); + if not ok then return ok, result end + return function() + local row = result(); + if row ~= nil then + return row[1], deserialize(row[2], row[3]), row[4]; + end + end, total; +end + +local stores = { + keyval = keyval_store; + archive = archive_store; +}; + +local driver = {}; + +function driver:open(store, typ) + local store_mt = stores[typ or "keyval"]; + if store_mt then + return setmetatable({ store = store }, store_mt); + end + return nil, "unsupported-store"; +end + +function driver:stores(username) + local sql = "SELECT DISTINCT `store` FROM `prosody` WHERE `host`=? AND `user`" .. + (username == true and "!=?" or "=?"); + if username == true or not username then + username = ""; + end + local ok, result = engine:transaction(function() + return engine:select(sql, host, username); + end); + if not ok then return ok, result end + return iterator(result); +end + +function driver:purge(username) + return engine:transaction(function() + local stmt,err = engine:delete("DELETE FROM `prosody` WHERE `host`=? AND `user`=?", host, username); + return true,err; + end); +end + +module:provides("storage", driver); + + diff --git a/plugins/mod_storage_sql_ejabberd.lua b/plugins/mod_storage_sql_ejabberd.lua deleted file mode 100644 index 74763c92..00000000 --- a/plugins/mod_storage_sql_ejabberd.lua +++ /dev/null @@ -1,232 +0,0 @@ - -local setmetatable = setmetatable; -local error = error; -local unpack = unpack; -local module = module; -local tostring = tostring; -local pairs, next = pairs, next; -local prosody = prosody; -local assert = assert; -local require = require; -local st = require "util.stanza"; -local DBI = require "DBI"; - --- connect to db -local params = module:get_option("sql_ejabberd") or error("No sql_ejabberd config option"); -local database; -do - module:log("debug", "Opening database: %s", "dbi:"..params.driver..":"..params.database); - prosody.unlock_globals(); - local dbh, err = DBI.Connect( - params.driver, params.database, - params.username, params.password, - params.host, params.port - ); - prosody.lock_globals(); - assert(dbh, err); - dbh:autocommit(true); - database = dbh; -end - --- initialize db -local ejabberd_init = module:require("ejabberd_init"); -ejabberd_init.init(database); - -local sqlcache = {}; -local function prepare(sql) - module:log("debug", "query: %s", sql); - local err; - local r = sqlcache[sql]; - if not r then - r, err = database:prepare(sql); - if not r then error("Unable to prepare SQL statement: "..err); end - sqlcache[sql] = r; - end - return r; -end - -local _parse_xml = module:require("xmlparse"); -local function parse_xml(str) - local s = _parse_xml(str); - if s and not s.gsub then - return st.preserialize(s); - end -end -local function unparse_xml(s) - return tostring(st.deserialize(s)); -end - - -local handlers = {}; - -handlers.accounts = { - get = function(self, user) - local select = self:query("select password from users where username=? and host=?", user, self.host); - local row = select and select:fetch(); - if row then return { password = row[1] }; end - end; - set = function(self, user, data) - if data and data.password then - return self:modify("update users set password=? where username=? and host=?", data.password, user, self.host) - or self:modify("insert into users (username, host, password) values (?, ?, ?)", user, self.host, data.password); - else - return self:modify("delete from users where username=? and host=?", user, self.host); - end - end; -}; -handlers.vcard = { - get = function(self, user) - local select = self:query("select vcard from vcard where username=? and host=?", user, self.host); - local row = select and select:fetch(); - if row then return parse_xml(row[1]); end - end; - set = function(self, user, data) - if data then - data = unparse_xml(data); - return self:modify("update vcard set vcard=? where username=? and host=?", data, user, self.host) - or self:modify("insert into vcard (username, host, vcard) values (?, ?, ?)", user, self.host, data); - else - return self:modify("delete from vcard where username=? and host=?", user, self.host); - end - end; -}; -handlers.private = { - get = function(self, user) - local select = self:query("select namespace,data from private_storage where username=? and host=?", user, self.host); - if select then - local data = {}; - for row in select:rows() do - data[row[1]] = parse_xml(row[2]); - end - return data; - end - end; - set = function(self, user, data) - if data then - self:modify("delete from private_storage where username=? and host=?", user, self.host); - for namespace,text in pairs(data) do - self:modify("insert into private_storage (username, host, namespace, data) values (?, ?, ?, ?)", user, self.host, namespace, unparse_xml(text)); - end - return true; - else - return self:modify("delete from private_storage where username=? and host=?", user, self.host); - end - end; - -- TODO map_set, map_get -}; -local subscription_map = { N = "none", B = "both", F = "from", T = "to" }; -local subscription_map_reverse = { none = "N", both = "B", from = "F", to = "T" }; -handlers.roster = { - get = function(self, user) - local select = self:query("select jid,nick,subscription,ask,server,subscribe,type from rosterusers where username=?", user); - if select then - local roster = { pending = {} }; - for row in select:rows() do - local jid,nick,subscription,ask,server,subscribe,typ = unpack(row); - local item = { groups = {} }; - if nick == "" then nick = nil; end - item.nick = nick; - item.subscription = subscription_map[subscription]; - if ask == "N" then ask = nil; - elseif ask == "O" then ask = "subscribe" - elseif ask == "I" then roster.pending[jid] = true; ask = nil; - elseif ask == "B" then roster.pending[jid] = true; ask = "subscribe"; - else module:log("debug", "bad roster_item.ask: %s", ask); ask = nil; end - item.ask = ask; - roster[jid] = item; - end - - select = self:query("select jid,grp from rostergroups where username=?", user); - if select then - for row in select:rows() do - local jid,grp = unpack(row); - if roster[jid] then roster[jid].groups[grp] = true; end - end - end - select = self:query("select version from roster_version where username=?", user); - local row = select and select:fetch(); - if row then - roster[false] = { version = row[1]; }; - end - return roster; - end - end; - set = function(self, user, data) - if data and next(data) ~= nil then - self:modify("delete from rosterusers where username=?", user); - self:modify("delete from rostergroups where username=?", user); - self:modify("delete from roster_version where username=?", user); - local done = {}; - local pending = data.pending or {}; - for jid,item in pairs(data) do - if jid and jid ~= "pending" then - local subscription = subscription_map_reverse[item.subscription]; - local ask; - if pending[jid] then - if item.ask then ask = "B"; else ask = "I"; end - else - if item.ask then ask = "O"; else ask = "N"; end - end - local r = self:modify("insert into rosterusers (username,jid,nick,subscription,ask,askmessage,server,subscribe) values (?, ?, ?, ?, ?, '', '', '')", user, jid, item.nick or "", subscription, ask); - if not r then module:log("debug", "--- :( %s", tostring(r)); end - done[jid] = true; - for group in pairs(item.groups) do - self:modify("insert into rostergroups (username,jid,grp) values (?, ?, ?)", user, jid, group); - end - end - end - for jid in pairs(pending) do - if not done[jid] then - self:modify("insert into rosterusers (username,jid,nick,subscription,ask,askmessage,server,subscribe) values (?, ?, ?, ?, ?. ''. ''. '')", user, jid, "", "N", "I"); - end - end - local version = data[false] and data[false].version; - if version then - self:modify("insert into roster_version (username,version) values (?, ?)", user, version); - end - return true; - else - self:modify("delete from rosterusers where username=?", user); - self:modify("delete from rostergroups where username=?", user); - self:modify("delete from roster_version where username=?", user); - end - end; -}; - ------------------------------ -local driver = {}; -driver.__index = driver; - -function driver:query(sql, ...) - local stmt,err = prepare(sql); - if not stmt then - module:log("error", "Failed to prepare SQL [[%s]], error: %s", sql, err); - return nil, err; - end - local ok, err = stmt:execute(...); - if not ok then - module:log("error", "Failed to execute SQL [[%s]], error: %s", sql, err); - return nil, err; - end - return stmt; -end -function driver:modify(sql, ...) - local stmt, err = self:query(sql, ...); - if stmt and stmt:affected() > 0 then return stmt; end - return nil, err; -end - -function driver:open(datastore, typ) - local instance = setmetatable({ host = module.host, datastore = datastore }, self); - local handler = handlers[datastore]; - if not handler then return nil; end - for key,val in pairs(handler) do - instance[key] = val; - end - if instance.init then instance:init(); end - return instance; -end - ------------------------------ - -module:add_item("data-driver", driver); diff --git a/plugins/mod_time.lua b/plugins/mod_time.lua index cb69ebe7..ae7da916 100644 --- a/plugins/mod_time.lua +++ b/plugins/mod_time.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/plugins/mod_tls.lua b/plugins/mod_tls.lua index c79227e1..bab2202e 100644 --- a/plugins/mod_tls.lua +++ b/plugins/mod_tls.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -23,19 +23,47 @@ local s2s_feature = st.stanza("starttls", starttls_attr); if secure_auth_only then c2s_feature:tag("required"):up(); end if secure_s2s_only then s2s_feature:tag("required"):up(); end -local global_ssl_ctx = prosody.global_ssl_ctx; - +local hosts = prosody.hosts; local host = hosts[module.host]; +local ssl_ctx_c2s, ssl_ctx_s2sout, ssl_ctx_s2sin; +do + local function get_ssl_cfg(typ) + local cfg_key = (typ and typ.."_" or "").."ssl"; + local ssl_config = config.rawget(module.host, cfg_key); + if not ssl_config then + local base_host = module.host:match("%.(.*)"); + ssl_config = config.get(base_host, cfg_key); + end + return ssl_config or typ and get_ssl_cfg(); + end + + local ssl_config, err = get_ssl_cfg("c2s"); + ssl_ctx_c2s, err = create_context(host.host, "server", ssl_config); -- for incoming client connections + if err then module:log("error", "Error creating context for c2s: %s", err); end + + ssl_config = get_ssl_cfg("s2s"); + ssl_ctx_s2sin, err = create_context(host.host, "server", ssl_config); -- for incoming server connections + ssl_ctx_s2sout = create_context(host.host, "client", ssl_config); -- for outgoing server connections + if err then module:log("error", "Error creating context for s2s: %s", err); end -- Both would have the same issue +end + local function can_do_tls(session) + if not session.conn.starttls then + return false; + elseif session.ssl_ctx then + return true; + end if session.type == "c2s_unauthed" then - return session.conn.starttls and host.ssl_ctx_in; + session.ssl_ctx = ssl_ctx_c2s; elseif session.type == "s2sin_unauthed" and allow_s2s_tls then - return session.conn.starttls and host.ssl_ctx_in; + session.ssl_ctx = ssl_ctx_s2sin; elseif session.direction == "outgoing" and allow_s2s_tls then - return session.conn.starttls and host.ssl_ctx; + session.ssl_ctx = ssl_ctx_s2sout; + else + return false; end - return false; + return session.ssl_ctx; end -- Hook <starttls/> @@ -44,10 +72,8 @@ module:hook("stanza/urn:ietf:params:xml:ns:xmpp-tls:starttls", function(event) if can_do_tls(origin) then (origin.sends2s or origin.send)(starttls_proceed); origin:reset_stream(); - local host = origin.to_host or origin.host; - local ssl_ctx = host and hosts[host].ssl_ctx_in or global_ssl_ctx; - origin.conn:starttls(ssl_ctx); - origin.log("info", "TLS negotiation started for %s...", origin.type); + origin.conn:starttls(origin.ssl_ctx); + origin.log("debug", "TLS negotiation started for %s...", origin.type); origin.secure = false; else origin.log("warn", "Attempt to start TLS, but TLS is not available on this %s connection", origin.type); @@ -75,7 +101,7 @@ end); module:hook_stanza("http://etherx.jabber.org/streams", "features", function (session, stanza) module:log("debug", "Received features element"); if can_do_tls(session) and stanza:child_with_ns(xmlns_starttls) then - module:log("%s is offering TLS, taking up the offer...", session.to_host); + module:log("debug", "%s is offering TLS, taking up the offer...", session.to_host); session.sends2s("<starttls xmlns='"..xmlns_starttls.."'/>"); return true; end @@ -84,24 +110,7 @@ end, 500); module:hook_stanza(xmlns_starttls, "proceed", function (session, stanza) module:log("debug", "Proceeding with TLS on s2sout..."); session:reset_stream(); - local ssl_ctx = session.from_host and hosts[session.from_host].ssl_ctx or global_ssl_ctx; - session.conn:starttls(ssl_ctx); + session.conn:starttls(session.ssl_ctx); session.secure = false; return true; end); - -function module.load() - local global_ssl_config = config.get("*", "core", "ssl"); - local ssl_config = config.get(module.host, "core", "ssl"); - local base_host = module.host:match("%.(.*)"); - if ssl_config == global_ssl_config and hosts[base_host] then - ssl_config = config.get(base_host, "core", "ssl"); - end - host.ssl_ctx = create_context(host.host, "client", ssl_config); -- for outgoing connections - host.ssl_ctx_in = create_context(host.host, "server", ssl_config); -- for incoming connections -end - -function module.unload() - host.ssl_ctx = nil; - host.ssl_ctx_in = nil; -end diff --git a/plugins/mod_uptime.lua b/plugins/mod_uptime.lua index 52b33c74..2e369b16 100644 --- a/plugins/mod_uptime.lua +++ b/plugins/mod_uptime.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -9,7 +9,7 @@ local st = require "util.stanza"; local start_time = prosody.start_time; -prosody.events.add_handler("server-started", function() start_time = prosody.start_time end); +module:hook_global("server-started", function() start_time = prosody.start_time end); -- XEP-0012: Last activity module:add_feature("jabber:iq:last"); diff --git a/plugins/mod_vcard.lua b/plugins/mod_vcard.lua index e2f1dfb8..72f92ef7 100644 --- a/plugins/mod_vcard.lua +++ b/plugins/mod_vcard.lua @@ -1,14 +1,15 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- local st = require "util.stanza" local jid_split = require "util.jid".split; -local datamanager = require "util.datamanager" + +local vcards = module:open_store(); module:add_feature("vcard-temp"); @@ -19,9 +20,9 @@ local function handle_vcard(event) local vCard; if to then local node, host = jid_split(to); - vCard = st.deserialize(datamanager.load(node, host, "vcard")); -- load vCard for user or server + vCard = st.deserialize(vcards:get(node)); -- load vCard for user or server else - vCard = st.deserialize(datamanager.load(session.username, session.host, "vcard"));-- load user's own vCard + vCard = st.deserialize(vcards:get(session.username));-- load user's own vCard end if vCard then session.send(st.reply(stanza):add_child(vCard)); -- send vCard! @@ -30,7 +31,7 @@ local function handle_vcard(event) end else if not to then - if datamanager.store(session.username, session.host, "vcard", st.preserialize(stanza.tags[1])) then + if vcards:set(session.username, st.preserialize(stanza.tags[1])) then session.send(st.reply(stanza)); else -- TODO unable to write file, file may be locked, etc, what's the correct error? @@ -46,13 +47,8 @@ end module:hook("iq/bare/vcard-temp:vCard", handle_vcard); module:hook("iq/host/vcard-temp:vCard", handle_vcard); --- COMPAT: https://support.process-one.net/browse/EJAB-1045 -if module:get_option("vcard_compatibility") then - module:hook("iq/full", function(data) - local stanza = data.stanza; - local payload = stanza.tags[1]; - if stanza.attr.type == "get" and payload.name == "vCard" and payload.attr.xmlns == "vcard-temp" then - return handle_vcard(data); - end - end, 1); +-- COMPAT w/0.8 +if module:get_option("vcard_compatibility") ~= nil then + module:log("error", "The vcard_compatibility option has been removed, see".. + "mod_compat_vcard in prosody-modules if you still need this."); end diff --git a/plugins/mod_version.lua b/plugins/mod_version.lua index 52d8d290..be244beb 100644 --- a/plugins/mod_version.lua +++ b/plugins/mod_version.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -21,7 +21,7 @@ if not module:get_option("hide_os_type") then version = "Windows"; else local os_version_command = module:get_option("os_version_command"); - local ok pposix = pcall(require, "pposix"); + local ok, pposix = pcall(require, "util.pposix"); if not os_version_command and (ok and pposix and pposix.uname) then version = pposix.uname().sysname; end diff --git a/plugins/mod_watchregistrations.lua b/plugins/mod_watchregistrations.lua index ac1e6302..b7be5daf 100644 --- a/plugins/mod_watchregistrations.lua +++ b/plugins/mod_watchregistrations.lua @@ -1,32 +1,30 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- local host = module:get_host(); +local jid_prep = require "util.jid".prep; -local registration_watchers = module:get_option("registration_watchers") - or module:get_option("admins") or {}; - -local registration_alert = module:get_option("registration_notification") or "User $username just registered on $host from $ip"; +local registration_watchers = module:get_option_set("registration_watchers", module:get_option("admins", {})) / jid_prep; +local registration_notification = module:get_option("registration_notification", "User $username just registered on $host from $ip"); local st = require "util.stanza"; -module:hook("user-registered", - function (user) - module:log("debug", "Notifying of new registration"); - local message = st.message{ type = "chat", from = host } - :tag("body") - :text(registration_alert:gsub("%$(%w+)", - function (v) return user[v] or user.session and user.session[v] or nil; end)); - - for _, jid in ipairs(registration_watchers) do - module:log("debug", "Notifying %s", jid); - message.attr.to = jid; - core_route_stanza(hosts[host], message); - end - end); +module:hook("user-registered", function (user) + module:log("debug", "Notifying of new registration"); + local message = st.message{ type = "chat", from = host } + :tag("body") + :text(registration_notification:gsub("%$(%w+)", function (v) + return user[v] or user.session and user.session[v] or nil; + end)); + for jid in registration_watchers do + module:log("debug", "Notifying %s", jid); + message.attr.to = jid; + module:send(message); + end +end); diff --git a/plugins/mod_welcome.lua b/plugins/mod_welcome.lua index 8f9cca2a..9c0c821b 100644 --- a/plugins/mod_welcome.lua +++ b/plugins/mod_welcome.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -16,6 +16,6 @@ module:hook("user-registered", local welcome_stanza = st.message({ to = user.username.."@"..user.host, from = host }) :tag("body"):text(welcome_text:gsub("$(%w+)", user)); - core_route_stanza(hosts[host], welcome_stanza); + module:send(welcome_stanza); module:log("debug", "Welcomed user %s@%s", user.username, user.host); end); diff --git a/plugins/muc/mod_muc.lua b/plugins/muc/mod_muc.lua index 329b9270..cb967c90 100644 --- a/plugins/muc/mod_muc.lua +++ b/plugins/muc/mod_muc.lua @@ -1,11 +1,12 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- +local array = require "util.array"; if module:get_host_type() ~= "component" then error("MUC should be loaded as a component, please see http://prosody.im/doc/components", 0); @@ -16,33 +17,53 @@ local muc_name = module:get_option("name"); if type(muc_name) ~= "string" then muc_name = "Prosody Chatrooms"; end local restrict_room_creation = module:get_option("restrict_room_creation"); if restrict_room_creation then - if restrict_room_creation == true then + if restrict_room_creation == true then restrict_room_creation = "admin"; elseif restrict_room_creation ~= "admin" and restrict_room_creation ~= "local" then restrict_room_creation = nil; end end -local muc_new_room = module:require "muc".new_room; +local lock_rooms = module:get_option_boolean("muc_room_locking", false); +local lock_room_timeout = module:get_option_number("muc_room_lock_timeout", 300); + +local muclib = module:require "muc"; +local muc_new_room = muclib.new_room; local jid_split = require "util.jid".split; local jid_bare = require "util.jid".bare; local st = require "util.stanza"; local uuid_gen = require "util.uuid".generate; -local datamanager = require "util.datamanager"; local um_is_admin = require "core.usermanager".is_admin; +local hosts = prosody.hosts; rooms = {}; local rooms = rooms; -local persistent_rooms = datamanager.load(nil, muc_host, "persistent") or {}; -local component = hosts[module.host]; +local persistent_rooms_storage = module:open_store("persistent"); +local persistent_rooms = persistent_rooms_storage:get() or {}; +local room_configs = module:open_store("config"); -- Configurable options -local max_history_messages = module:get_option_number("max_history_messages"); +muclib.set_max_history_length(module:get_option_number("max_history_messages")); + +module:depends("disco"); +module:add_identity("conference", "text", muc_name); +module:add_feature("http://jabber.org/protocol/muc"); local function is_admin(jid) return um_is_admin(jid, module.host); end -local function room_route_stanza(room, stanza) core_post_stanza(component, stanza); end +local _set_affiliation = muc_new_room.room_mt.set_affiliation; +local _get_affiliation = muc_new_room.room_mt.get_affiliation; +function muclib.room_mt:get_affiliation(jid) + if is_admin(jid) then return "owner"; end + return _get_affiliation(self, jid); +end +function muclib.room_mt:set_affiliation(actor, jid, affiliation, callback, reason) + if is_admin(jid) then return nil, "modify", "not-acceptable"; end + return _set_affiliation(self, actor, jid, affiliation, callback, reason); +end + +local function room_route_stanza(room, stanza) module:send(stanza); end local function room_save(room, forced) local node = jid_split(room.jid); persistent_rooms[room.jid] = room._data.persistent; @@ -54,59 +75,74 @@ local function room_save(room, forced) _data = room._data; _affiliations = room._affiliations; }; - datamanager.store(node, muc_host, "config", data); + room_configs:set(node, data); room._data.history = history; elseif forced then - datamanager.store(node, muc_host, "config", nil); + room_configs:set(node, nil); + if not next(room._occupants) then -- Room empty + rooms[room.jid] = nil; + end end - if forced then datamanager.store(nil, muc_host, "persistent", persistent_rooms); end + if forced then persistent_rooms_storage:set(nil, persistent_rooms); end end -for jid in pairs(persistent_rooms) do - local node = jid_split(jid); - local data = datamanager.load(node, muc_host, "config") or {}; - local room = muc_new_room(jid, { - history_length = max_history_messages; - }); - room._data = data._data; - room._data.history_length = max_history_messages; --TODO: Need to allow per-room with a global limit - room._affiliations = data._affiliations; +function create_room(jid) + local room = muc_new_room(jid); room.route_stanza = room_route_stanza; room.save = room_save; rooms[jid] = room; + if lock_rooms then + room.locked = true; + if lock_room_timeout and lock_room_timeout > 0 then + module:add_timer(lock_room_timeout, function () + if room.locked then + room:destroy(); -- Not unlocked in time + end + end); + end + end + module:fire_event("muc-room-created", { room = room }); + return room; +end + +local persistent_errors = false; +for jid in pairs(persistent_rooms) do + local node = jid_split(jid); + local data = room_configs:get(node); + if data then + local room = create_room(jid); + room._data = data._data; + room._affiliations = data._affiliations; + else -- missing room data + persistent_rooms[jid] = nil; + module:log("error", "Missing data for room '%s', removing from persistent room list", jid); + persistent_errors = true; + end end +if persistent_errors then persistent_rooms_storage:set(nil, persistent_rooms); end -local host_room = muc_new_room(muc_host, { - history_length = max_history_messages; -}); +local host_room = muc_new_room(muc_host); host_room.route_stanza = room_route_stanza; host_room.save = room_save; -local function get_disco_info(stanza) - return st.iq({type='result', id=stanza.attr.id, from=muc_host, to=stanza.attr.from}):query("http://jabber.org/protocol/disco#info") - :tag("identity", {category='conference', type='text', name=muc_name}):up() - :tag("feature", {var="http://jabber.org/protocol/muc"}); -- TODO cache disco reply -end -local function get_disco_items(stanza) - local reply = st.iq({type='result', id=stanza.attr.id, from=muc_host, to=stanza.attr.from}):query("http://jabber.org/protocol/disco#items"); +module:hook("host-disco-items", function(event) + local reply = event.reply; + module:log("debug", "host-disco-items called"); for jid, room in pairs(rooms) do - if not room:is_hidden() then + if not room:get_hidden() then reply:tag("item", {jid=jid, name=room:get_name()}):up(); end end - return reply; -- TODO cache disco reply -end +end); -local function handle_to_domain(origin, stanza) +local function handle_to_domain(event) + local origin, stanza = event.origin, event.stanza; local type = stanza.attr.type; if type == "error" or type == "result" then return; end if stanza.name == "iq" and type == "get" then local xmlns = stanza.tags[1].attr.xmlns; - if xmlns == "http://jabber.org/protocol/disco#info" then - origin.send(get_disco_info(stanza)); - elseif xmlns == "http://jabber.org/protocol/disco#items" then - origin.send(get_disco_items(stanza)); - elseif xmlns == "http://jabber.org/protocol/muc#unique" then + local node = stanza.tags[1].attr.node; + if xmlns == "http://jabber.org/protocol/muc#unique" then origin.send(st.reply(stanza):tag("unique", {xmlns = xmlns}):text(uuid_gen())); -- FIXME Random UUIDs can theoretically have collisions else origin.send(st.error_reply(stanza, "cancel", "service-unavailable")); -- TODO disco/etc @@ -115,40 +151,32 @@ local function handle_to_domain(origin, stanza) host_room:handle_stanza(origin, stanza); --origin.send(st.error_reply(stanza, "cancel", "service-unavailable", "The muc server doesn't deal with messages and presence directed at it")); end + return true; end function stanza_handler(event) local origin, stanza = event.origin, event.stanza; - local to_node, to_host, to_resource = jid_split(stanza.attr.to); - if to_node then - local bare = to_node.."@"..to_host; - if to_host == muc_host or bare == muc_host then - local room = rooms[bare]; - if not room then - if not(restrict_room_creation) or - (restrict_room_creation == "admin" and is_admin(stanza.attr.from)) or - (restrict_room_creation == "local" and select(2, jid_split(stanza.attr.from)) == module.host:gsub("^[^%.]+%.", "")) then - room = muc_new_room(bare, { - history_length = max_history_messages; - }); - room.route_stanza = room_route_stanza; - room.save = room_save; - rooms[bare] = room; - end - end - if room then - room:handle_stanza(origin, stanza); - if not next(room._occupants) and not persistent_rooms[room.jid] then -- empty, non-persistent room - rooms[bare] = nil; -- discard room - end - else - origin.send(st.error_reply(stanza, "cancel", "not-allowed")); - end - else --[[not for us?]] end - return true; + local bare = jid_bare(stanza.attr.to); + local room = rooms[bare]; + if not room then + if stanza.name ~= "presence" then + origin.send(st.error_reply(stanza, "cancel", "item-not-found")); + return true; + end + if not(restrict_room_creation) or + (restrict_room_creation == "admin" and is_admin(stanza.attr.from)) or + (restrict_room_creation == "local" and select(2, jid_split(stanza.attr.from)) == module.host:gsub("^[^%.]+%.", "")) then + room = create_room(bare); + end + end + if room then + room:handle_stanza(origin, stanza); + if not next(room._occupants) and not persistent_rooms[room.jid] then -- empty, non-persistent room + rooms[bare] = nil; -- discard room + end + else + origin.send(st.error_reply(stanza, "cancel", "not-allowed")); end - -- to the main muc domain - handle_to_domain(origin, stanza); return true; end module:hook("iq/bare", stanza_handler, -1); @@ -157,31 +185,92 @@ module:hook("presence/bare", stanza_handler, -1); module:hook("iq/full", stanza_handler, -1); module:hook("message/full", stanza_handler, -1); module:hook("presence/full", stanza_handler, -1); -module:hook("iq/host", stanza_handler, -1); -module:hook("message/host", stanza_handler, -1); -module:hook("presence/host", stanza_handler, -1); +module:hook("iq/host", handle_to_domain, -1); +module:hook("message/host", handle_to_domain, -1); +module:hook("presence/host", handle_to_domain, -1); hosts[module.host].send = function(stanza) -- FIXME do a generic fix if stanza.attr.type == "result" or stanza.attr.type == "error" then - core_post_stanza(component, stanza); + module:send(stanza); else error("component.send only supports result and error stanzas at the moment"); end end -prosody.hosts[module:get_host()].muc = { rooms = rooms }; +hosts[module:get_host()].muc = { rooms = rooms }; +local saved = false; module.save = function() + saved = true; return {rooms = rooms}; end module.restore = function(data) for jid, oldroom in pairs(data.rooms or {}) do - local room = muc_new_room(jid); + local room = create_room(jid); room._jid_nick = oldroom._jid_nick; room._occupants = oldroom._occupants; room._data = oldroom._data; room._affiliations = oldroom._affiliations; - room.route_stanza = room_route_stanza; - room.save = room_save; - rooms[jid] = room; end - prosody.hosts[module:get_host()].muc = { rooms = rooms }; + hosts[module:get_host()].muc = { rooms = rooms }; end + +function shutdown_room(room, stanza) + for nick, occupant in pairs(room._occupants) do + stanza.attr.from = nick; + for jid in pairs(occupant.sessions) do + stanza.attr.to = jid; + room:_route_stanza(stanza); + room._jid_nick[jid] = nil; + end + room._occupants[nick] = nil; + end +end +function shutdown_component() + if not saved then + local stanza = st.presence({type = "unavailable"}) + :tag("x", {xmlns = "http://jabber.org/protocol/muc#user"}) + :tag("item", { affiliation='none', role='none' }):up() + :tag("status", { code = "332"}):up(); + for roomjid, room in pairs(rooms) do + shutdown_room(room, stanza); + end + shutdown_room(host_room, stanza); + end +end +module.unload = shutdown_component; +module:hook_global("server-stopping", shutdown_component); + +-- Ad-hoc commands +module:depends("adhoc") +local t_concat = table.concat; +local keys = require "util.iterators".keys; +local adhoc_new = module:require "adhoc".new; +local adhoc_initial = require "util.adhoc".new_initial_data_form; +local dataforms_new = require "util.dataforms".new; + +local destroy_rooms_layout = dataforms_new { + title = "Destroy rooms"; + instructions = "Select the rooms to destroy"; + + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/muc#destroy" }; + { name = "rooms", type = "list-multi", required = true, label = "Rooms to destroy:"}; +}; + +local destroy_rooms_handler = adhoc_initial(destroy_rooms_layout, function() + return { rooms = array.collect(keys(rooms)):sort() }; +end, function(fields, errors) + if errors then + local errmsg = {}; + for name, err in pairs(errors) do + errmsg[#errmsg + 1] = name .. ": " .. err; + end + return { status = "completed", error = { message = t_concat(errmsg, "\n") } }; + end + for _, room in ipairs(fields.rooms) do + rooms[room]:destroy(); + rooms[room] = nil; + end + return { status = "completed", info = "The following rooms were destroyed:\n"..t_concat(fields.rooms, "\n") }; +end); +local destroy_rooms_desc = adhoc_new("Destroy Rooms", "http://prosody.im/protocol/muc#destroy", destroy_rooms_handler, "admin"); + +module:provides("adhoc", destroy_rooms_desc); diff --git a/plugins/muc/muc.lib.lua b/plugins/muc/muc.lib.lua index f3e2dd52..0565d692 100644 --- a/plugins/muc/muc.lib.lua +++ b/plugins/muc/muc.lib.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -9,7 +9,6 @@ local select = select; local pairs, ipairs = pairs, ipairs; -local datamanager = require "util.datamanager"; local datetime = require "util.datetime"; local dataform = require "util.dataforms"; @@ -19,38 +18,25 @@ local jid_bare = require "util.jid".bare; local jid_prep = require "util.jid".prep; local st = require "util.stanza"; local log = require "util.logger".init("mod_muc"); -local multitable_new = require "util.multitable".new; local t_insert, t_remove = table.insert, table.remove; local setmetatable = setmetatable; local base64 = require "util.encodings".base64; local md5 = require "util.hashes".md5; local muc_domain = nil; --module:get_host(); -local default_history_length = 20; +local default_history_length, max_history_length = 20, math.huge; ------------ -local function filter_xmlns_from_array(array, filters) - local count = 0; - for i=#array,1,-1 do - local attr = array[i].attr; - if filters[attr and attr.xmlns] then - t_remove(array, i); - count = count + 1; - end - end - return count; -end -local function filter_xmlns_from_stanza(stanza, filters) - if filters then - if filter_xmlns_from_array(stanza.tags, filters) ~= 0 then - return stanza, filter_xmlns_from_array(stanza, filters); - end +local presence_filters = {["http://jabber.org/protocol/muc"]=true;["http://jabber.org/protocol/muc#user"]=true}; +local function presence_filter(tag) + if presence_filters[tag.attr.xmlns] then + return nil; end - return stanza, 0; + return tag; end -local presence_filters = {["http://jabber.org/protocol/muc"]=true;["http://jabber.org/protocol/muc#user"]=true}; + local function get_filtered_presence(stanza) - return filter_xmlns_from_stanza(st.clone(stanza):reset(), presence_filters); + return st.clone(stanza):maptags(presence_filter); end local kickable_error_conditions = { ["gone"] = true; @@ -74,30 +60,23 @@ local function is_kickable_error(stanza) local cond = get_error_condition(stanza); return kickable_error_conditions[cond] and cond; end -local function getUsingPath(stanza, path, getText) - local tag = stanza; - for _, name in ipairs(path) do - if type(tag) ~= 'table' then return; end - tag = tag:child_with_name(name); - end - if tag and getText then tag = table.concat(tag); end - return tag; -end -local function getTag(stanza, path) return getUsingPath(stanza, path); end -local function getText(stanza, path) return getUsingPath(stanza, path, true); end ----------- local room_mt = {}; room_mt.__index = room_mt; +function room_mt:__tostring() + return "MUC room ("..self.jid..")"; +end + function room_mt:get_default_role(affiliation) if affiliation == "owner" or affiliation == "admin" then return "moderator"; elseif affiliation == "member" then return "participant"; elseif not affiliation then - if not self:is_members_only() then - return self:is_moderated() and "visitor" or "participant"; + if not self:get_members_only() then + return self:get_moderated() and "visitor" or "participant"; end end end @@ -133,7 +112,6 @@ function room_mt:broadcast_message(stanza, historic) stanza = st.clone(stanza); stanza.attr.to = ""; local stamp = datetime.datetime(); - local chars = #tostring(stanza); stanza:tag("delay", {xmlns = "urn:xmpp:delay", from = muc_domain, stamp = stamp}):up(); -- XEP-0203 stanza:tag("x", {xmlns = "jabber:x:delay", from = muc_domain, stamp = datetime.legacy()}):up(); -- XEP-0091 (deprecated) local entry = { stanza = stanza, stamp = stamp }; @@ -169,10 +147,10 @@ function room_mt:send_history(to, stanza) if history then local x_tag = stanza and stanza:get_child("x", "http://jabber.org/protocol/muc"); local history_tag = x_tag and x_tag:get_child("history", "http://jabber.org/protocol/muc"); - + local maxchars = history_tag and tonumber(history_tag.attr.maxchars); if maxchars then maxchars = math.floor(maxchars); end - + local maxstanzas = math.floor(history_tag and tonumber(history_tag.attr.maxstanzas) or #history); if not history_tag then maxstanzas = 20; end @@ -185,8 +163,7 @@ function room_mt:send_history(to, stanza) local n = 0; local charcount = 0; - local stanzacount = 0; - + for i=#history,1,-1 do local entry = history[i]; if maxchars then @@ -213,18 +190,20 @@ function room_mt:send_history(to, stanza) end function room_mt:get_disco_info(stanza) + local count = 0; for _ in pairs(self._occupants) do count = count + 1; end return st.reply(stanza):query("http://jabber.org/protocol/disco#info") :tag("identity", {category="conference", type="text", name=self:get_name()}):up() :tag("feature", {var="http://jabber.org/protocol/muc"}):up() :tag("feature", {var=self:get_password() and "muc_passwordprotected" or "muc_unsecured"}):up() - :tag("feature", {var=self:is_moderated() and "muc_moderated" or "muc_unmoderated"}):up() - :tag("feature", {var=self:is_members_only() and "muc_membersonly" or "muc_open"}):up() - :tag("feature", {var=self:is_persistent() and "muc_persistent" or "muc_temporary"}):up() - :tag("feature", {var=self:is_hidden() and "muc_hidden" or "muc_public"}):up() + :tag("feature", {var=self:get_moderated() and "muc_moderated" or "muc_unmoderated"}):up() + :tag("feature", {var=self:get_members_only() and "muc_membersonly" or "muc_open"}):up() + :tag("feature", {var=self:get_persistent() and "muc_persistent" or "muc_temporary"}):up() + :tag("feature", {var=self:get_hidden() and "muc_hidden" or "muc_public"}):up() :tag("feature", {var=self._data.whois ~= "anyone" and "muc_semianonymous" or "muc_nonanonymous"}):up() :add_child(dataform.new({ { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/muc#roominfo" }, - { name = "muc#roominfo_description", label = "Description"} + { name = "muc#roominfo_description", label = "Description"}, + { name = "muc#roominfo_occupants", label = "Number of occupants", value = tostring(count) } }):form({["muc#roominfo_description"] = self:get_description()}, 'result')) ; end @@ -236,7 +215,6 @@ function room_mt:get_disco_items(stanza) return reply; end function room_mt:set_subject(current_nick, subject) - -- TODO check nick's authority if subject == "" then subject = nil; end self._data['subject'] = subject; self._data['subject_from'] = current_nick; @@ -294,7 +272,7 @@ function room_mt:set_moderated(moderated) if self.save then self:save(true); end end end -function room_mt:is_moderated() +function room_mt:get_moderated() return self._data.moderated; end function room_mt:set_members_only(members_only) @@ -304,7 +282,7 @@ function room_mt:set_members_only(members_only) if self.save then self:save(true); end end end -function room_mt:is_members_only() +function room_mt:get_members_only() return self._data.members_only; end function room_mt:set_persistent(persistent) @@ -314,7 +292,7 @@ function room_mt:set_persistent(persistent) if self.save then self:save(true); end end end -function room_mt:is_persistent() +function room_mt:get_persistent() return self._data.persistent; end function room_mt:set_hidden(hidden) @@ -324,9 +302,74 @@ function room_mt:set_hidden(hidden) if self.save then self:save(true); end end end -function room_mt:is_hidden() +function room_mt:get_hidden() return self._data.hidden; end +function room_mt:get_public() + return not self:get_hidden(); +end +function room_mt:set_public(public) + return self:set_hidden(not public); +end +function room_mt:set_changesubject(changesubject) + changesubject = changesubject and true or nil; + if self._data.changesubject ~= changesubject then + self._data.changesubject = changesubject; + if self.save then self:save(true); end + end +end +function room_mt:get_changesubject() + return self._data.changesubject; +end +function room_mt:get_historylength() + return self._data.history_length or default_history_length; +end +function room_mt:set_historylength(length) + length = math.min(tonumber(length) or default_history_length, max_history_length or math.huge); + if length == default_history_length then + length = nil; + end + self._data.history_length = length; +end + + +local valid_whois = { moderators = true, anyone = true }; + +function room_mt:set_whois(whois) + if valid_whois[whois] and self._data.whois ~= whois then + self._data.whois = whois; + if self.save then self:save(true); end + end +end + +function room_mt:get_whois() + return self._data.whois; +end + +local function construct_stanza_id(room, stanza) + local from_jid, to_nick = stanza.attr.from, stanza.attr.to; + local from_nick = room._jid_nick[from_jid]; + local occupant = room._occupants[to_nick]; + local to_jid = occupant.jid; + + return from_nick, to_jid, base64.encode(to_jid.."\0"..stanza.attr.id.."\0"..md5(from_jid)); +end +local function deconstruct_stanza_id(room, stanza) + local from_jid_possiblybare, to_nick = stanza.attr.from, stanza.attr.to; + local from_jid, id, to_jid_hash = (base64.decode(stanza.attr.id) or ""):match("^(.+)%z(.*)%z(.+)$"); + local from_nick = room._jid_nick[from_jid]; + + if not(from_nick) then return; end + if not(from_jid_possiblybare == from_jid or from_jid_possiblybare == jid_bare(from_jid)) then return; end + + local occupant = room._occupants[to_nick]; + for to_jid in pairs(occupant and occupant.sessions or {}) do + if md5(to_jid) == to_jid_hash then + return from_nick, to_jid, id; + end + end +end + function room_mt:handle_to_occupant(origin, stanza) -- PM, vCards, etc local from, to = stanza.attr.from, stanza.attr.to; @@ -346,6 +389,7 @@ function room_mt:handle_to_occupant(origin, stanza) -- PM, vCards, etc elseif type == "unavailable" then -- unavailable if current_nick then log("debug", "%s leaving %s", current_nick, room); + self._jid_nick[from] = nil; local occupant = self._occupants[current_nick]; local new_jid = next(occupant.sessions); if new_jid == from then new_jid = next(occupant.sessions, new_jid); end @@ -356,7 +400,7 @@ function room_mt:handle_to_occupant(origin, stanza) -- PM, vCards, etc pr.attr.to = from; pr:tag("x", {xmlns='http://jabber.org/protocol/muc#user'}) :tag("item", {affiliation=occupant.affiliation or "none", role='none'}):up() - :tag("status", {code='110'}); + :tag("status", {code='110'}):up(); self:_route_stanza(pr); if jid ~= new_jid then pr = st.clone(occupant.sessions[new_jid]) @@ -370,7 +414,6 @@ function room_mt:handle_to_occupant(origin, stanza) -- PM, vCards, etc self:broadcast_presence(pr, from); self._occupants[current_nick] = nil; end - self._jid_nick[from] = nil; end elseif not type then -- available if current_nick then @@ -437,6 +480,12 @@ function room_mt:handle_to_occupant(origin, stanza) -- PM, vCards, etc log("debug", "%s joining as %s", from, to); if not next(self._affiliations) then -- new room, no owners self._affiliations[jid_bare(from)] = "owner"; + if self.locked and not stanza:get_child("x", "http://jabber.org/protocol/muc") then + self.locked = nil; -- Older groupchat protocol doesn't lock + end + elseif self.locked then -- Deny entry + origin.send(st.error_reply(stanza, "cancel", "item-not-found")); + return; end local affiliation = self:get_affiliation(from); local role = self:get_default_role(affiliation) @@ -454,10 +503,13 @@ function room_mt:handle_to_occupant(origin, stanza) -- PM, vCards, etc if not is_merge then self:broadcast_except_nick(pr, to); end - pr:tag("status", {code='110'}); + pr:tag("status", {code='110'}):up(); if self._data.whois == 'anyone' then pr:tag("status", {code='100'}):up(); end + if self.locked then + pr:tag("status", {code='201'}):up(); + end pr.attr.to = from; self:_route_stanza(pr); self:send_history(from, stanza); @@ -478,25 +530,14 @@ function room_mt:handle_to_occupant(origin, stanza) -- PM, vCards, etc end end elseif not current_nick then -- not in room - if type == "error" or type == "result" then - local id = stanza.name == "iq" and stanza.attr.id and base64.decode(stanza.attr.id); - local _nick, _id, _hash = (id or ""):match("^(.+)%z(.*)%z(.+)$"); - local occupant = self._occupants[stanza.attr.to]; - if occupant and _nick and self._jid_nick[_nick] and _id and _hash then - local id, _to = stanza.attr.id; - for jid in pairs(occupant.sessions) do - if md5(jid) == _hash then - _to = jid; - break; - end - end - if _to then - stanza.attr.to, stanza.attr.from, stanza.attr.id = _to, self._jid_nick[_nick], _id; - self:_route_stanza(stanza); - stanza.attr.to, stanza.attr.from, stanza.attr.id = to, from, id; - end + if (type == "error" or type == "result") and stanza.name == "iq" then + local id = stanza.attr.id; + stanza.attr.from, stanza.attr.to, stanza.attr.id = deconstruct_stanza_id(self, stanza); + if stanza.attr.id then + self:_route_stanza(stanza); end - else + stanza.attr.from, stanza.attr.to, stanza.attr.id = from, to, id; + elseif type ~= "error" then origin.send(st.error_reply(stanza, "cancel", "not-acceptable")); end elseif stanza.name == "message" and type == "groupchat" then -- groupchat messages not allowed in PM @@ -508,16 +549,28 @@ function room_mt:handle_to_occupant(origin, stanza) -- PM, vCards, etc local o_data = self._occupants[to]; if o_data then log("debug", "%s sent private stanza to %s (%s)", from, to, o_data.jid); - local jid = o_data.jid; - local bare = jid_bare(jid); - stanza.attr.to, stanza.attr.from = jid, current_nick; - local id = stanza.attr.id; - if stanza.name=='iq' and type=='get' and stanza.tags[1].attr.xmlns == 'vcard-temp' and bare ~= jid then - stanza.attr.to = bare; - stanza.attr.id = base64.encode(jid.."\0"..id.."\0"..md5(from)); + if stanza.name == "iq" then + local id = stanza.attr.id; + if stanza.attr.type == "get" or stanza.attr.type == "set" then + stanza.attr.from, stanza.attr.to, stanza.attr.id = construct_stanza_id(self, stanza); + else + stanza.attr.from, stanza.attr.to, stanza.attr.id = deconstruct_stanza_id(self, stanza); + end + if type == 'get' and stanza.tags[1].attr.xmlns == 'vcard-temp' then + stanza.attr.to = jid_bare(stanza.attr.to); + end + if stanza.attr.id then + self:_route_stanza(stanza); + end + stanza.attr.from, stanza.attr.to, stanza.attr.id = from, to, id; + else -- message + stanza.attr.from = current_nick; + for jid in pairs(o_data.sessions) do + stanza.attr.to = jid; + self:_route_stanza(stanza); + end + stanza.attr.from, stanza.attr.to = from, to; end - self:_route_stanza(stanza); - stanza.attr.to, stanza.attr.from, stanza.attr.id = to, from, id; elseif type ~= "error" and type ~= "result" then -- recipient not in room origin.send(st.error_reply(stanza, "cancel", "item-not-found", "Recipient not in room")); end @@ -526,15 +579,14 @@ end function room_mt:send_form(origin, stanza) origin.send(st.reply(stanza):query("http://jabber.org/protocol/muc#owner") - :add_child(self:get_form_layout():form()) + :add_child(self:get_form_layout(stanza.attr.from):form()) ); end -function room_mt:get_form_layout() - local title = "Configuration for "..self.jid; - return dataform.new({ - title = title, - instructions = title, +function room_mt:get_form_layout(actor) + local form = dataform.new({ + title = "Configuration for "..self.jid, + instructions = "Complete and submit this form to configure the room.", { name = 'FORM_TYPE', type = 'hidden', @@ -556,13 +608,19 @@ function room_mt:get_form_layout() name = 'muc#roomconfig_persistentroom', type = 'boolean', label = 'Make Room Persistent?', - value = self:is_persistent() + value = self:get_persistent() }, { name = 'muc#roomconfig_publicroom', type = 'boolean', label = 'Make Room Publicly Searchable?', - value = not self:is_hidden() + value = not self:get_hidden() + }, + { + name = 'muc#roomconfig_changesubject', + type = 'boolean', + label = 'Allow Occupants to Change Subject?', + value = self:get_changesubject() }, { name = 'muc#roomconfig_whois', @@ -583,22 +641,24 @@ function room_mt:get_form_layout() name = 'muc#roomconfig_moderatedroom', type = 'boolean', label = 'Make Room Moderated?', - value = self:is_moderated() + value = self:get_moderated() }, { name = 'muc#roomconfig_membersonly', type = 'boolean', label = 'Make Room Members-Only?', - value = self:is_members_only() + value = self:get_members_only() + }, + { + name = 'muc#roomconfig_historylength', + type = 'text-single', + label = 'Maximum Number of History Messages Returned by Room', + value = tostring(self:get_historylength()) } }); + return module:fire_event("muc-config-form", { room = self, actor = actor, form = form }) or form; end -local valid_whois = { - moderators = true, - anyone = true, -} - function room_mt:process_form(origin, stanza) local query = stanza.tags[1]; local form; @@ -607,69 +667,50 @@ function room_mt:process_form(origin, stanza) if form.attr.type == "cancel" then origin.send(st.reply(stanza)); return; end if form.attr.type ~= "submit" then origin.send(st.error_reply(stanza, "cancel", "bad-request", "Not a submitted form")); return; end - local fields = self:get_form_layout():data(form); + local fields = self:get_form_layout(stanza.attr.from):data(form); if fields.FORM_TYPE ~= "http://jabber.org/protocol/muc#roomconfig" then origin.send(st.error_reply(stanza, "cancel", "bad-request", "Form is not of type room configuration")); return; end - local dirty = false - local name = fields['muc#roomconfig_roomname']; - if name ~= self:get_name() then - self:set_name(name); - end + local changed = {}; - local description = fields['muc#roomconfig_roomdesc']; - if description ~= self:get_description() then - self:set_description(description); + local function handle_option(name, field, allowed) + local new = fields[field]; + if new == nil then return; end + if allowed and not allowed[new] then return; end + if new == self["get_"..name](self) then return; end + changed[name] = true; + self["set_"..name](self, new); end - local persistent = fields['muc#roomconfig_persistentroom']; - dirty = dirty or (self:is_persistent() ~= persistent) - module:log("debug", "persistent=%s", tostring(persistent)); - - local moderated = fields['muc#roomconfig_moderatedroom']; - dirty = dirty or (self:is_moderated() ~= moderated) - module:log("debug", "moderated=%s", tostring(moderated)); - - local membersonly = fields['muc#roomconfig_membersonly']; - dirty = dirty or (self:is_members_only() ~= membersonly) - module:log("debug", "membersonly=%s", tostring(membersonly)); - - local public = fields['muc#roomconfig_publicroom']; - dirty = dirty or (self:is_hidden() ~= (not public and true or nil)) - - local whois = fields['muc#roomconfig_whois']; - if not valid_whois[whois] then - origin.send(st.error_reply(stanza, 'cancel', 'bad-request', "Invalid value for 'whois'")); - return; - end - local whois_changed = self._data.whois ~= whois - self._data.whois = whois - module:log('debug', 'whois=%s', whois) + local event = { room = self, fields = fields, changed = changed, stanza = stanza, origin = origin, update_option = handle_option }; + module:fire_event("muc-config-submitted", event); - local password = fields['muc#roomconfig_roomsecret']; - if self:get_password() ~= password then - self:set_password(password); - end - self:set_moderated(moderated); - self:set_members_only(membersonly); - self:set_persistent(persistent); - self:set_hidden(not public); + handle_option("name", "muc#roomconfig_roomname"); + handle_option("description", "muc#roomconfig_roomdesc"); + handle_option("persistent", "muc#roomconfig_persistentroom"); + handle_option("moderated", "muc#roomconfig_moderatedroom"); + handle_option("members_only", "muc#roomconfig_membersonly"); + handle_option("public", "muc#roomconfig_publicroom"); + handle_option("changesubject", "muc#roomconfig_changesubject"); + handle_option("historylength", "muc#roomconfig_historylength"); + handle_option("whois", "muc#roomconfig_whois", valid_whois); + handle_option("password", "muc#roomconfig_roomsecret"); if self.save then self:save(true); end + if self.locked then + module:fire_event("muc-room-unlocked", { room = self }); + self.locked = nil; + end origin.send(st.reply(stanza)); - if dirty or whois_changed then + if next(changed) then local msg = st.message({type='groupchat', from=self.jid}) :tag('x', {xmlns='http://jabber.org/protocol/muc#user'}):up() - - if dirty then - msg.tags[1]:tag('status', {code = '104'}):up(); - end - if whois_changed then - local code = (whois == 'moderators') and "173" or "172"; + :tag('status', {code = '104'}):up(); + if changed.whois then + local code = (self:get_whois() == 'moderators') and "173" or "172"; msg.tags[1]:tag('status', {code = code}):up(); end - self:broadcast_message(msg, false) end end @@ -691,15 +732,16 @@ function room_mt:destroy(newjid, reason, password) self._occupants[nick] = nil; end self:set_persistent(false); + module:fire_event("muc-room-destroyed", { room = self }); end function room_mt:handle_to_room(origin, stanza) -- presence changes and groupchat messages, along with disco/etc local type = stanza.attr.type; local xmlns = stanza.tags[1] and stanza.tags[1].attr.xmlns; if stanza.name == "iq" then - if xmlns == "http://jabber.org/protocol/disco#info" and type == "get" then + if xmlns == "http://jabber.org/protocol/disco#info" and type == "get" and not stanza.tags[1].attr.node then origin.send(self:get_disco_info(stanza)); - elseif xmlns == "http://jabber.org/protocol/disco#items" and type == "get" then + elseif xmlns == "http://jabber.org/protocol/disco#items" and type == "get" and not stanza.tags[1].attr.node then origin.send(self:get_disco_items(stanza)); elseif xmlns == "http://jabber.org/protocol/muc#admin" then local actor = stanza.attr.from; @@ -777,13 +819,13 @@ function room_mt:handle_to_room(origin, stanza) -- presence changes and groupcha end elseif xmlns == "http://jabber.org/protocol/muc#owner" and (type == "get" or type == "set") and stanza.tags[1].name == "query" then if self:get_affiliation(stanza.attr.from) ~= "owner" then - origin.send(st.error_reply(stanza, "auth", "forbidden")); + origin.send(st.error_reply(stanza, "auth", "forbidden", "Only owners can configure rooms")); elseif stanza.attr.type == "get" then self:send_form(origin, stanza); elseif stanza.attr.type == "set" then local child = stanza.tags[1].tags[1]; if not child then - origin.send(st.error_reply(stanza, "auth", "bad-request")); + origin.send(st.error_reply(stanza, "modify", "bad-request")); elseif child.name == "destroy" then local newjid = child.attr.jid; local reason, password; @@ -804,27 +846,27 @@ function room_mt:handle_to_room(origin, stanza) -- presence changes and groupcha origin.send(st.error_reply(stanza, "cancel", "service-unavailable")); end elseif stanza.name == "message" and type == "groupchat" then - local from, to = stanza.attr.from, stanza.attr.to; - local room = jid_bare(to); + local from = stanza.attr.from; local current_nick = self._jid_nick[from]; local occupant = self._occupants[current_nick]; if not occupant then -- not in room origin.send(st.error_reply(stanza, "cancel", "not-acceptable")); elseif occupant.role == "visitor" then - origin.send(st.error_reply(stanza, "cancel", "forbidden")); + origin.send(st.error_reply(stanza, "auth", "forbidden")); else local from = stanza.attr.from; stanza.attr.from = current_nick; - local subject = getText(stanza, {"subject"}); + local subject = stanza:get_child_text("subject"); if subject then - if occupant.role == "moderator" then - self:set_subject(current_nick, subject); -- TODO use broadcast_message_stanza + if occupant.role == "moderator" or + ( self._data.changesubject and occupant.role == "participant" ) then -- and participant + self:set_subject(current_nick, subject); else stanza.attr.from = from; - origin.send(st.error_reply(stanza, "cancel", "forbidden")); + origin.send(st.error_reply(stanza, "auth", "forbidden")); end else - self:broadcast_message(stanza, true); + self:broadcast_message(stanza, self:get_historylength() > 0 and stanza:get_child("body")); end stanza.attr.from = from; end @@ -842,8 +884,8 @@ function room_mt:handle_to_room(origin, stanza) -- presence changes and groupcha elseif type ~= "error" and type ~= "result" then origin.send(st.error_reply(stanza, "cancel", "service-unavailable")); end - elseif stanza.name == "message" and not stanza.attr.type and #stanza.tags == 1 and self._jid_nick[stanza.attr.from] - and stanza.tags[1].name == "x" and stanza.tags[1].attr.xmlns == "http://jabber.org/protocol/muc#user" then + elseif stanza.name == "message" and not(type == "chat" or type == "error" or type == "groupchat" or type == "headline") and #stanza.tags == 1 + and self._jid_nick[stanza.attr.from] and stanza.tags[1].name == "x" and stanza.tags[1].attr.xmlns == "http://jabber.org/protocol/muc#user" then local x = stanza.tags[1]; local payload = (#x.tags == 1 and x.tags[1]); if payload and payload.name == "invite" and payload.attr.to then @@ -866,7 +908,7 @@ function room_mt:handle_to_room(origin, stanza) -- presence changes and groupcha :tag('body') -- Add a plain message for clients which don't support invites :text(_from..' invited you to the room '.._to..(_reason and (' ('.._reason..')') or "")) :up(); - if self:is_members_only() and not self:get_affiliation(_invitee) then + if self:get_members_only() and not self:get_affiliation(_invitee) then log("debug", "%s invited %s into members only room %s, granting membership", _from, _invitee, _to); self:set_affiliation(_from, _invitee, "member", nil, "Invited by " .. self._jid_nick[_from]) end @@ -907,8 +949,25 @@ function room_mt:set_affiliation(actor, jid, affiliation, callback, reason) if affiliation and affiliation ~= "outcast" and affiliation ~= "owner" and affiliation ~= "admin" and affiliation ~= "member" then return nil, "modify", "not-acceptable"; end - if self:get_affiliation(actor) ~= "owner" then return nil, "cancel", "not-allowed"; end - if jid_bare(actor) == jid then return nil, "cancel", "not-allowed"; end + if actor ~= true then + local actor_affiliation = self:get_affiliation(actor); + local target_affiliation = self:get_affiliation(jid); + if target_affiliation == affiliation then -- no change, shortcut + if callback then callback(); end + return true; + end + if actor_affiliation ~= "owner" then + if affiliation == "owner" or affiliation == "admin" or actor_affiliation ~= "admin" or target_affiliation == "owner" or target_affiliation == "admin" then + return nil, "cancel", "not-allowed"; + end + elseif target_affiliation == "owner" and jid_bare(actor) == jid then -- self change + local is_last = true; + for j, aff in pairs(self._affiliations) do if j ~= jid and aff == "owner" then is_last = false; break; end end + if is_last then + return nil, "cancel", "conflict"; + end + end + end self._affiliations[jid] = affiliation; local role = self:get_default_role(affiliation); local x = st.stanza("x", {xmlns = "http://jabber.org/protocol/muc#user"}) @@ -960,11 +1019,12 @@ function room_mt:get_role(nick) return session and session.role or nil; end function room_mt:can_set_role(actor_jid, occupant_jid, role) - local actor = self._occupants[self._jid_nick[actor_jid]]; local occupant = self._occupants[occupant_jid]; - - if not occupant or not actor then return nil, "modify", "not-acceptable"; end + if not occupant or not actor_jid then return nil, "modify", "not-acceptable"; end + if actor_jid == true then return true; end + + local actor = self._occupants[self._jid_nick[actor_jid]]; if actor.role == "moderator" then if occupant.affiliation ~= "owner" and occupant.affiliation ~= "admin" then if actor.affiliation == "owner" or actor.affiliation == "admin" then @@ -1061,10 +1121,17 @@ function _M.new_room(jid, config) _occupants = {}; _data = { whois = 'moderators'; - history_length = (config and config.history_length); + history_length = math.min((config and config.history_length) + or default_history_length, max_history_length); }; _affiliations = {}; }, room_mt); end +function _M.set_max_history_length(_max_history_length) + max_history_length = _max_history_length or math.huge; +end + +_M.room_mt = room_mt; + return _M; diff --git a/plugins/sql.lib.lua b/plugins/sql.lib.lua new file mode 100644 index 00000000..005ee45d --- /dev/null +++ b/plugins/sql.lib.lua @@ -0,0 +1,9 @@ +local cache = module:shared("/*/sql.lib/util.sql"); + +if not cache._M then + prosody.unlock_globals(); + cache._M = require "util.sql"; + prosody.lock_globals(); +end + +return cache._M; diff --git a/plugins/storage/ejabberd_init.lib.lua b/plugins/storage/ejabberd_init.lib.lua deleted file mode 100644 index 91f8563b..00000000 --- a/plugins/storage/ejabberd_init.lib.lua +++ /dev/null @@ -1,252 +0,0 @@ - -local t_concat = table.concat; -local t_insert = table.insert; -local pairs = pairs; -local DBI = require "DBI"; - -local sqlite = true; -local q = {}; - -local function set(key, val) --- t_insert(q, "SET "..key.."="..val..";\n") -end -local function create_table(name, fields) - t_insert(q, "CREATE TABLE ".."IF NOT EXISTS "..name.." (\n"); - for _, field in pairs(fields) do - t_insert(q, "\t"); - field = t_concat(field, " "); - if sqlite then - if field:lower():match("^primary key *%(") then field = field:gsub("%(%d+%)", ""); end - end - t_insert(q, field); - if _ ~= #fields then t_insert(q, ",\n"); end - t_insert(q, "\n"); - end - if sqlite then - t_insert(q, ");\n"); - else - t_insert(q, ") CHARACTER SET utf8;\n"); - end -end -local function create_index(name, index) - --t_insert(q, "CREATE INDEX "..name.." ON "..index..";\n"); -end -local function create_unique_index(name, index) - --t_insert(q, "CREATE UNIQUE INDEX "..name.." ON "..index..";\n"); -end -local function insert(target, value) - t_insert(q, "INSERT INTO "..target.."\nVALUES "..value..";\n"); -end -local function foreign_key(name, fkey, fname, fcol) - t_insert(q, "ALTER TABLE `"..name.."` ADD FOREIGN KEY (`"..fkey.."`) REFERENCES `"..fname.."` (`"..fcol.."`) ON DELETE CASCADE;\n"); -end - -function build_query() - q = {}; - set('table_type', 'InnoDB'); - create_table('hosts', { - {'clusterid','integer','NOT','NULL'}; - {'host','varchar(250)','NOT','NULL','PRIMARY','KEY'}; - {'config','text','NOT','NULL'}; - }); - insert("hosts (clusterid, host, config)", "(1, 'localhost', '')"); - create_table('users', { - {'host','varchar(250)','NOT','NULL'}; - {'username','varchar(250)','NOT','NULL'}; - {'password','text','NOT','NULL'}; - {'created_at','timestamp','NOT','NULL','DEFAULT','CURRENT_TIMESTAMP'}; - {'PRIMARY','KEY','(host, username)'}; - }); - create_table('last', { - {'host','varchar(250)','NOT','NULL'}; - {'username','varchar(250)','NOT','NULL'}; - {'seconds','text','NOT','NULL'}; - {'state','text','NOT','NULL'}; - {'PRIMARY','KEY','(host, username)'}; - }); - create_table('rosterusers', { - {'host','varchar(250)','NOT','NULL'}; - {'username','varchar(250)','NOT','NULL'}; - {'jid','varchar(250)','NOT','NULL'}; - {'nick','text','NOT','NULL'}; - {'subscription','character(1)','NOT','NULL'}; - {'ask','character(1)','NOT','NULL'}; - {'askmessage','text','NOT','NULL'}; - {'server','character(1)','NOT','NULL'}; - {'subscribe','text','NOT','NULL'}; - {'type','text'}; - {'created_at','timestamp','NOT','NULL','DEFAULT','CURRENT_TIMESTAMP'}; - {'PRIMARY','KEY','(host(75), username(75), jid(75))'}; - }); - create_index('i_rosteru_username', 'rosterusers(username)'); - create_index('i_rosteru_jid', 'rosterusers(jid)'); - create_table('rostergroups', { - {'host','varchar(250)','NOT','NULL'}; - {'username','varchar(250)','NOT','NULL'}; - {'jid','varchar(250)','NOT','NULL'}; - {'grp','text','NOT','NULL'}; - {'PRIMARY','KEY','(host(75), username(75), jid(75))'}; - }); - --[[create_table('spool', { - {'host','varchar(250)','NOT','NULL'}; - {'username','varchar(250)','NOT','NULL'}; - {'xml','text','NOT','NULL'}; - {'seq','BIGINT','UNSIGNED','NOT','NULL','AUTO_INCREMENT','UNIQUE'}; - {'created_at','timestamp','NOT','NULL','DEFAULT','CURRENT_TIMESTAMP'}; - {'PRIMARY','KEY','(host, username, seq)'}; - });]] - create_table('vcard', { - {'host','varchar(250)','NOT','NULL'}; - {'username','varchar(250)','NOT','NULL'}; - {'vcard','text','NOT','NULL'}; - {'created_at','timestamp','NOT','NULL','DEFAULT','CURRENT_TIMESTAMP'}; - {'PRIMARY','KEY','(host, username)'}; - }); - create_table('vcard_search', { - {'host','varchar(250)','NOT','NULL'}; - {'username','varchar(250)','NOT','NULL'}; - {'lusername','varchar(250)','NOT','NULL'}; - {'fn','text','NOT','NULL'}; - {'lfn','varchar(250)','NOT','NULL'}; - {'family','text','NOT','NULL'}; - {'lfamily','varchar(250)','NOT','NULL'}; - {'given','text','NOT','NULL'}; - {'lgiven','varchar(250)','NOT','NULL'}; - {'middle','text','NOT','NULL'}; - {'lmiddle','varchar(250)','NOT','NULL'}; - {'nickname','text','NOT','NULL'}; - {'lnickname','varchar(250)','NOT','NULL'}; - {'bday','text','NOT','NULL'}; - {'lbday','varchar(250)','NOT','NULL'}; - {'ctry','text','NOT','NULL'}; - {'lctry','varchar(250)','NOT','NULL'}; - {'locality','text','NOT','NULL'}; - {'llocality','varchar(250)','NOT','NULL'}; - {'email','text','NOT','NULL'}; - {'lemail','varchar(250)','NOT','NULL'}; - {'orgname','text','NOT','NULL'}; - {'lorgname','varchar(250)','NOT','NULL'}; - {'orgunit','text','NOT','NULL'}; - {'lorgunit','varchar(250)','NOT','NULL'}; - {'PRIMARY','KEY','(host, lusername)'}; - }); - create_index('i_vcard_search_lfn ', 'vcard_search(lfn)'); - create_index('i_vcard_search_lfamily ', 'vcard_search(lfamily)'); - create_index('i_vcard_search_lgiven ', 'vcard_search(lgiven)'); - create_index('i_vcard_search_lmiddle ', 'vcard_search(lmiddle)'); - create_index('i_vcard_search_lnickname', 'vcard_search(lnickname)'); - create_index('i_vcard_search_lbday ', 'vcard_search(lbday)'); - create_index('i_vcard_search_lctry ', 'vcard_search(lctry)'); - create_index('i_vcard_search_llocality', 'vcard_search(llocality)'); - create_index('i_vcard_search_lemail ', 'vcard_search(lemail)'); - create_index('i_vcard_search_lorgname ', 'vcard_search(lorgname)'); - create_index('i_vcard_search_lorgunit ', 'vcard_search(lorgunit)'); - create_table('privacy_default_list', { - {'host','varchar(250)','NOT','NULL'}; - {'username','varchar(250)'}; - {'name','varchar(250)','NOT','NULL'}; - {'PRIMARY','KEY','(host, username)'}; - }); - --[[create_table('privacy_list', { - {'host','varchar(250)','NOT','NULL'}; - {'username','varchar(250)','NOT','NULL'}; - {'name','varchar(250)','NOT','NULL'}; - {'id','BIGINT','UNSIGNED','NOT','NULL','AUTO_INCREMENT','UNIQUE'}; - {'created_at','timestamp','NOT','NULL','DEFAULT','CURRENT_TIMESTAMP'}; - {'PRIMARY','KEY','(host, username, name)'}; - });]] - create_table('privacy_list_data', { - {'id','bigint'}; - {'t','character(1)','NOT','NULL'}; - {'value','text','NOT','NULL'}; - {'action','character(1)','NOT','NULL'}; - {'ord','NUMERIC','NOT','NULL'}; - {'match_all','boolean','NOT','NULL'}; - {'match_iq','boolean','NOT','NULL'}; - {'match_message','boolean','NOT','NULL'}; - {'match_presence_in','boolean','NOT','NULL'}; - {'match_presence_out','boolean','NOT','NULL'}; - }); - create_table('private_storage', { - {'host','varchar(250)','NOT','NULL'}; - {'username','varchar(250)','NOT','NULL'}; - {'namespace','varchar(250)','NOT','NULL'}; - {'data','text','NOT','NULL'}; - {'created_at','timestamp','NOT','NULL','DEFAULT','CURRENT_TIMESTAMP'}; - {'PRIMARY','KEY','(host(75), username(75), namespace(75))'}; - }); - create_index('i_private_storage_username USING BTREE', 'private_storage(username)'); - create_table('roster_version', { - {'username','varchar(250)','PRIMARY','KEY'}; - {'version','text','NOT','NULL'}; - }); - --[[create_table('pubsub_node', { - {'host','text'}; - {'node','text'}; - {'parent','text'}; - {'type','text'}; - {'nodeid','bigint','auto_increment','primary','key'}; - }); - create_index('i_pubsub_node_parent', 'pubsub_node(parent(120))'); - create_unique_index('i_pubsub_node_tuple', 'pubsub_node(host(20), node(120))'); - create_table('pubsub_node_option', { - {'nodeid','bigint'}; - {'name','text'}; - {'val','text'}; - }); - create_index('i_pubsub_node_option_nodeid', 'pubsub_node_option(nodeid)'); - foreign_key('pubsub_node_option', 'nodeid', 'pubsub_node', 'nodeid'); - create_table('pubsub_node_owner', { - {'nodeid','bigint'}; - {'owner','text'}; - }); - create_index('i_pubsub_node_owner_nodeid', 'pubsub_node_owner(nodeid)'); - foreign_key('pubsub_node_owner', 'nodeid', 'pubsub_node', 'nodeid'); - create_table('pubsub_state', { - {'nodeid','bigint'}; - {'jid','text'}; - {'affiliation','character(1)'}; - {'subscriptions','text'}; - {'stateid','bigint','auto_increment','primary','key'}; - }); - create_index('i_pubsub_state_jid', 'pubsub_state(jid(60))'); - create_unique_index('i_pubsub_state_tuple', 'pubsub_state(nodeid, jid(60))'); - foreign_key('pubsub_state', 'nodeid', 'pubsub_node', 'nodeid'); - create_table('pubsub_item', { - {'nodeid','bigint'}; - {'itemid','text'}; - {'publisher','text'}; - {'creation','text'}; - {'modification','text'}; - {'payload','text'}; - }); - create_index('i_pubsub_item_itemid', 'pubsub_item(itemid(36))'); - create_unique_index('i_pubsub_item_tuple', 'pubsub_item(nodeid, itemid(36))'); - foreign_key('pubsub_item', 'nodeid', 'pubsub_node', 'nodeid'); - create_table('pubsub_subscription_opt', { - {'subid','text'}; - {'opt_name','varchar(32)'}; - {'opt_value','text'}; - }); - create_unique_index('i_pubsub_subscription_opt', 'pubsub_subscription_opt(subid(32), opt_name(32))');]] - return t_concat(q); -end - -local function init(dbh) - local q = build_query(); - for statement in q:gmatch("[^;]*;") do - statement = statement:gsub("\n", ""):gsub("\t", " "); - if sqlite then - statement = statement:gsub("AUTO_INCREMENT", "AUTOINCREMENT"); - statement = statement:gsub("auto_increment", "autoincrement"); - end - local result, err = DBI.Do(dbh, statement); - if not result then - print("X", result, err); - print("Y", statement); - end - end -end - -local _M = { init = init }; -return _M; diff --git a/plugins/storage/ejabberdstore.lib.lua b/plugins/storage/ejabberdstore.lib.lua deleted file mode 100644 index 7e8592a8..00000000 --- a/plugins/storage/ejabberdstore.lib.lua +++ /dev/null @@ -1,190 +0,0 @@ -
-local handlers = {};
-
-handlers.accounts = {
- get = function(self, user)
- local select = self:query("select password from users where username=?", user);
- local row = select and select:fetch();
- if row then return { password = row[1] }; end
- end;
- set = function(self, user, data)
- if data and data.password then
- return self:modify("update users set password=? where username=?", data.password, user)
- or self:modify("insert into users (username, password) values (?, ?)", user, data.password);
- else
- return self:modify("delete from users where username=?", user);
- end
- end;
-};
-handlers.vcard = {
- get = function(self, user)
- local select = self:query("select vcard from vcard where username=?", user);
- local row = select and select:fetch();
- if row then return parse_xml(row[1]); end
- end;
- set = function(self, user, data)
- if data then
- data = unparse_xml(data);
- return self:modify("update vcard set vcard=? where username=?", data, user)
- or self:modify("insert into vcard (username, vcard) values (?, ?)", user, data);
- else
- return self:modify("delete from vcard where username=?", user);
- end
- end;
-};
-handlers.private = {
- get = function(self, user)
- local select = self:query("select namespace,data from private_storage where username=?", user);
- if select then
- local data = {};
- for row in select:rows() do
- data[row[1]] = parse_xml(row[2]);
- end
- return data;
- end
- end;
- set = function(self, user, data)
- if data then
- self:modify("delete from private_storage where username=?", user);
- for namespace,text in pairs(data) do
- self:modify("insert into private_storage (username, namespace, data) values (?, ?, ?)", user, namespace, unparse_xml(text));
- end
- return true;
- else
- return self:modify("delete from private_storage where username=?", user);
- end
- end;
- -- TODO map_set, map_get
-};
-local subscription_map = { N = "none", B = "both", F = "from", T = "to" };
-local subscription_map_reverse = { none = "N", both = "B", from = "F", to = "T" };
-handlers.roster = {
- get = function(self, user)
- local select = self:query("select jid,nick,subscription,ask,server,subscribe,type from rosterusers where username=?", user);
- if select then
- local roster = { pending = {} };
- for row in select:rows() do
- local jid,nick,subscription,ask,server,subscribe,typ = unpack(row);
- local item = { groups = {} };
- if nick == "" then nick = nil; end
- item.nick = nick;
- item.subscription = subscription_map[subscription];
- if ask == "N" then ask = nil;
- elseif ask == "O" then ask = "subscribe"
- elseif ask == "I" then roster.pending[jid] = true; ask = nil;
- elseif ask == "B" then roster.pending[jid] = true; ask = "subscribe";
- else module:log("debug", "bad roster_item.ask: %s", ask); ask = nil; end
- item.ask = ask;
- roster[jid] = item;
- end
-
- select = self:query("select jid,grp from rostergroups where username=?", user);
- if select then
- for row in select:rows() do
- local jid,grp = unpack(rows);
- if roster[jid] then roster[jid].groups[grp] = true; end
- end
- end
- select = self:query("select version from roster_version where username=?", user);
- local row = select and select:fetch();
- if row then
- roster[false] = { version = row[1]; };
- end
- return roster;
- end
- end;
- set = function(self, user, data)
- if data and next(data) ~= nil then
- self:modify("delete from rosterusers where username=?", user);
- self:modify("delete from rostergroups where username=?", user);
- self:modify("delete from roster_version where username=?", user);
- local done = {};
- local pending = data.pending or {};
- for jid,item in pairs(data) do
- if jid and jid ~= "pending" then
- local subscription = subscription_map_reverse[item.subscription];
- local ask;
- if pending[jid] then
- if item.ask then ask = "B"; else ask = "I"; end
- else
- if item.ask then ask = "O"; else ask = "N"; end
- end
- local r = self:modify("insert into rosterusers (username,jid,nick,subscription,ask,askmessage,server,subscribe) values (?, ?, ?, ?, ?, '', '', '')", user, jid, item.nick or "", subscription, ask);
- if not r then module:log("debug", "--- :( %s", tostring(r)); end
- done[jid] = true;
- for group in pairs(item.groups) do
- self:modify("insert into rostergroups (username,jid,grp) values (?, ?, ?)", user, jid, group);
- end
- end
- end
- for jid in pairs(pending) do
- if not done[jid] then
- self:modify("insert into rosterusers (username,jid,nick,subscription,ask,askmessage,server,subscribe) values (?, ?, ?, ?, ?. ''. ''. '')", user, jid, "", "N", "I");
- end
- end
- local version = data[false] and data[false].version;
- if version then
- self:modify("insert into roster_version (username,version) values (?, ?)", user, version);
- end
- return true;
- else
- self:modify("delete from rosterusers where username=?", user);
- self:modify("delete from rostergroups where username=?", user);
- self:modify("delete from roster_version where username=?", user);
- end
- end;
-};
-
------------------------------
-local driver = {};
-driver.__index = driver;
-
-function driver:prepare(sql)
- module:log("debug", "query: %s", sql);
- local err;
- if not self.sqlcache then self.sqlcache = {}; end
- local r = self.sqlcache[sql];
- if r then return r; end
- r, err = self.database:prepare(sql);
- if not r then error("Unable to prepare SQL statement: "..err); end
- self.sqlcache[sql] = r;
- return r;
-end
-
-function driver:query(sql, ...)
- local stmt = self:prepare(sql);
- if stmt:execute(...) then return stmt; end
-end
-function driver:modify(sql, ...)
- local stmt = self:query(sql, ...);
- if stmt and stmt:affected() > 0 then return stmt; end
-end
-
-function driver:open(host, datastore, typ)
- local cache_key = host.." "..datastore;
- if self.ds_cache[cache_key] then return self.ds_cache[cache_key]; end
- local instance = setmetatable({}, self);
- instance.host = host;
- instance.datastore = datastore;
- local handler = handlers[datastore];
- if not handler then return nil; end
- for key,val in pairs(handler) do
- instance[key] = val;
- end
- if instance.init then instance:init(); end
- self.ds_cache[cache_key] = instance;
- return instance;
-end
-
------------------------------
-local _M = {};
-
-function _M.new(dbtype, dbname, ...)
- local instance = setmetatable({}, driver);
- instance.__index = instance;
- instance.database = get_database(dbtype, dbname, ...);
- instance.ds_cache = {};
- return instance;
-end
-
-return _M;
diff --git a/plugins/storage/mod_xep0227.lua b/plugins/storage/mod_xep0227.lua index b6d2e627..5d07a2ea 100644 --- a/plugins/storage/mod_xep0227.lua +++ b/plugins/storage/mod_xep0227.lua @@ -8,7 +8,7 @@ local os_remove = os.remove; local io_open = io.open; local st = require "util.stanza"; -local parse_xml_real = module:require("xmlparse"); +local parse_xml_real = require "util.xml".parse; local function getXml(user, host) local jid = user.."@"..host; @@ -160,4 +160,4 @@ function driver:open(host, datastore, typ) return instance; end -module:add_item("data-driver", driver); +module:provides("storage", driver); diff --git a/plugins/storage/sqlbasic.lib.lua b/plugins/storage/sqlbasic.lib.lua deleted file mode 100644 index f1202287..00000000 --- a/plugins/storage/sqlbasic.lib.lua +++ /dev/null @@ -1,97 +0,0 @@ - --- Basic SQL driver --- This driver stores data as simple key-values - -local ser = require "util.serialization".serialize; -local deser = function(data) - module:log("debug", "deser: %s", tostring(data)); - if not data then return nil; end - local f = loadstring("return "..data); - if not f then return nil; end - setfenv(f, {}); - local s, d = pcall(f); - if not s then return nil; end - return d; -end; - -local driver = {}; -driver.__index = driver; - -driver.item_table = "item"; -driver.list_table = "list"; - -function driver:prepare(sql) - module:log("debug", "query: %s", sql); - local err; - if not self.sqlcache then self.sqlcache = {}; end - local r = self.sqlcache[sql]; - if r then return r; end - r, err = self.connection:prepare(sql); - if not r then error("Unable to prepare SQL statement: "..err); end - self.sqlcache[sql] = r; - return r; -end - -function driver:load(username, host, datastore) - local select = self:prepare("select data from "..self.item_table.." where username=? and host=? and datastore=?"); - select:execute(username, host, datastore); - local row = select:fetch(); - return row and deser(row[1]) or nil; -end - -function driver:store(username, host, datastore, data) - if not data or next(data) == nil then - local delete = self:prepare("delete from "..self.item_table.." where username=? and host=? and datastore=?"); - delete:execute(username, host, datastore); - return true; - else - local d = self:load(username, host, datastore); - if d then -- update - local update = self:prepare("update "..self.item_table.." set data=? where username=? and host=? and datastore=?"); - return update:execute(ser(data), username, host, datastore); - else -- insert - local insert = self:prepare("insert into "..self.item_table.." values (?, ?, ?, ?)"); - return insert:execute(username, host, datastore, ser(data)); - end - end -end - -function driver:list_append(username, host, datastore, data) - if not data then return; end - local insert = self:prepare("insert into "..self.list_table.." values (?, ?, ?, ?)"); - return insert:execute(username, host, datastore, ser(data)); -end - -function driver:list_store(username, host, datastore, data) - -- remove existing data - local delete = self:prepare("delete from "..self.list_table.." where username=? and host=? and datastore=?"); - delete:execute(username, host, datastore); - if data and next(data) ~= nil then - -- add data - for _, d in ipairs(data) do - self:list_append(username, host, datastore, ser(d)); - end - end - return true; -end - -function driver:list_load(username, host, datastore) - local select = self:prepare("select data from "..self.list_table.." where username=? and host=? and datastore=?"); - select:execute(username, host, datastore); - local r = {}; - for row in select:rows() do - table.insert(r, deser(row[1])); - end - return r; -end - -local _M = {}; -function _M.new(dbtype, dbname, ...) - local d = {}; - setmetatable(d, driver); - local dbh = get_database(dbtype, dbname, ...); - --d:set_connection(dbh); - d.connection = dbh; - return d; -end -return _M; diff --git a/plugins/storage/xmlparse.lib.lua b/plugins/storage/xmlparse.lib.lua deleted file mode 100644 index 91063995..00000000 --- a/plugins/storage/xmlparse.lib.lua +++ /dev/null @@ -1,56 +0,0 @@ -
-local st = require "util.stanza";
-
--- XML parser
-local parse_xml = (function()
- local entity_map = setmetatable({
- ["amp"] = "&";
- ["gt"] = ">";
- ["lt"] = "<";
- ["apos"] = "'";
- ["quot"] = "\"";
- }, {__index = function(_, s)
- if s:sub(1,1) == "#" then
- if s:sub(2,2) == "x" then
- return string.char(tonumber(s:sub(3), 16));
- else
- return string.char(tonumber(s:sub(2)));
- end
- end
- end
- });
- local function xml_unescape(str)
- return (str:gsub("&(.-);", entity_map));
- end
- local function parse_tag(s)
- local name,sattr=(s):gmatch("([^%s]+)(.*)")();
- local attr = {};
- for a,b in (sattr):gmatch("([^=%s]+)=['\"]([^'\"]*)['\"]") do attr[a] = xml_unescape(b); end
- return name, attr;
- end
- return function(xml)
- local stanza = st.stanza("root");
- local regexp = "<([^>]*)>([^<]*)";
- for elem, text in xml:gmatch(regexp) do
- if elem:sub(1,1) == "!" or elem:sub(1,1) == "?" then -- neglect comments and processing-instructions
- elseif elem:sub(1,1) == "/" then -- end tag
- elem = elem:sub(2);
- stanza:up(); -- TODO check for start-end tag name match
- elseif elem:sub(-1,-1) == "/" then -- empty tag
- elem = elem:sub(1,-2);
- local name,attr = parse_tag(elem);
- stanza:tag(name, attr):up();
- else -- start tag
- local name,attr = parse_tag(elem);
- stanza:tag(name, attr);
- end
- if #text ~= 0 then -- text
- stanza:text(xml_unescape(text));
- end
- end
- return stanza.tags[1];
- end
-end)();
--- end of XML parser
-
-return parse_xml;
@@ -18,10 +18,22 @@ CFG_DATADIR=os.getenv("PROSODY_DATADIR"); -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- +local function is_relative(path) + local path_sep = package.config:sub(1,1); + return ((path_sep == "/" and path:sub(1,1) ~= "/") + or (path_sep == "\\" and (path:sub(1,1) ~= "/" and path:sub(2,3) ~= ":\\"))) +end + -- Tell Lua where to find our libraries if CFG_SOURCEDIR then - package.path = CFG_SOURCEDIR.."/?.lua;"..package.path; - package.cpath = CFG_SOURCEDIR.."/?.so;"..package.cpath; + local function filter_relative_paths(path) + if is_relative(path) then return ""; end + end + local function sanitise_paths(paths) + return (paths:gsub("[^;]+;?", filter_relative_paths):gsub(";;+", ";")); + end + package.path = sanitise_paths(CFG_SOURCEDIR.."/?.lua;"..package.path); + package.cpath = sanitise_paths(CFG_SOURCEDIR.."/?.so;"..package.cpath); end -- Substitute ~ with path to home directory in data path @@ -32,8 +44,8 @@ if CFG_DATADIR then end -- Global 'prosody' object -prosody = { events = require "util.events".new(); }; -local prosody = prosody; +local prosody = { events = require "util.events".new(); }; +_G.prosody = prosody; -- Check dependencies local dependencies = require "util.dependencies"; @@ -58,6 +70,8 @@ function read_config() if CFG_CONFIGDIR then table.insert(filenames, CFG_CONFIGDIR.."/"..arg[2]); end + elseif os.getenv("PROSODY_CONFIG") then -- Passed by prosodyctl + table.insert(filenames, os.getenv("PROSODY_CONFIG")); else for _, format in ipairs(config.parsers()) do table.insert(filenames, (CFG_CONFIGDIR or ".").."/prosody.cfg."..format); @@ -115,20 +129,41 @@ function log_dependency_warnings() dependencies.log_warnings(); end +function sanity_check() + for host, host_config in pairs(config.getconfig()) do + if host ~= "*" + and host_config.enabled ~= false + and not host_config.component_module then + return; + end + end + log("error", "No enabled VirtualHost entries found in the config file."); + log("error", "At least one active host is required for Prosody to function. Exiting..."); + os.exit(1); +end + function sandbox_require() -- Replace require() with one that doesn't pollute _G, required -- for neat sandboxing of modules local _realG = _G; local _real_require = require; + if not getfenv then + -- FIXME: This is a hack to replace getfenv() in Lua 5.2 + function getfenv(f) return debug.getupvalue(debug.getinfo(f or 1).func, 1); end + end function require(...) local curr_env = getfenv(2); - local curr_env_mt = getmetatable(getfenv(2)); + local curr_env_mt = getmetatable(curr_env); local _realG_mt = getmetatable(_realG); if curr_env_mt and curr_env_mt.__index and not curr_env_mt.__newindex and _realG_mt then - local old_newindex + local old_newindex, old_index; old_newindex, _realG_mt.__newindex = _realG_mt.__newindex, curr_env; + old_index, _realG_mt.__index = _realG_mt.__index, function (_G, k) + return rawget(curr_env, k); + end; local ret = _real_require(...); _realG_mt.__newindex = old_newindex; + _realG_mt.__index = old_index; return ret; end return _real_require(...); @@ -163,6 +198,7 @@ function set_function_metatable() end function init_global_state() + -- COMPAT: These globals are deprecated bare_sessions = {}; full_sessions = {}; hosts = {}; @@ -171,9 +207,16 @@ function init_global_state() prosody.full_sessions = full_sessions; prosody.hosts = hosts; - prosody.paths = { source = CFG_SOURCEDIR, config = CFG_CONFIGDIR, - plugins = CFG_PLUGINDIR, data = CFG_DATADIR }; - + local data_path = config.get("*", "data_path") or CFG_DATADIR or "data"; + local custom_plugin_paths = config.get("*", "plugin_paths"); + if custom_plugin_paths then + local path_sep = package.config:sub(3,3); + -- path1;path2;path3;defaultpath... + CFG_PLUGINDIR = table.concat(custom_plugin_paths, path_sep)..path_sep..(CFG_PLUGINDIR or "plugins"); + end + prosody.paths = { source = CFG_SOURCEDIR, config = CFG_CONFIGDIR or ".", + plugins = CFG_PLUGINDIR or "plugins", data = data_path }; + prosody.arg = _G.arg; prosody.platform = "unknown"; @@ -188,6 +231,11 @@ function init_global_state() prosody.installed = true; end + if prosody.installed then + -- Change working directory to data path. + require "lfs".chdir(data_path); + end + -- Function to reload the config file function prosody.reload_config() log("info", "Reloading configuration file"); @@ -216,67 +264,6 @@ function init_global_state() prosody.events.fire_event("server-stopping", {reason = reason}); server.setquitting(true); end - - -- Load SSL settings from config, and create a ctx table - local certmanager = require "core.certmanager"; - local global_ssl_ctx = certmanager.create_context("*", "server"); - prosody.global_ssl_ctx = global_ssl_ctx; - - local cl = require "net.connlisteners"; - function prosody.net_activate_ports(option, listener, default, conntype) - conntype = conntype or (global_ssl_ctx and "tls") or "tcp"; - local ports_option = option and option.."_ports" or "ports"; - if not cl.get(listener) then return; end - local ports = config.get("*", "core", ports_option) or default; - if type(ports) == "number" then ports = {ports} end; - - if type(ports) ~= "table" then - log("error", "core."..ports_option.." is not a table"); - else - for _, port in ipairs(ports) do - port = tonumber(port); - if type(port) ~= "number" then - log("error", "Non-numeric "..ports_option..": "..tostring(port)); - else - local ok, err = cl.start(listener, { - ssl = conntype == "ssl" and global_ssl_ctx, - port = port, - interface = (option and config.get("*", "core", option.."_interface")) - or cl.get(listener).default_interface - or config.get("*", "core", "interface"), - type = conntype - }); - if not ok then - local friendly_message = err; - if err:match(" in use") then - if port == 5222 or port == 5223 or port == 5269 then - friendly_message = "check that Prosody or another XMPP server is " - .."not already running and using this port"; - elseif port == 80 or port == 81 then - friendly_message = "check that a HTTP server is not already using " - .."this port"; - elseif port == 5280 then - friendly_message = "check that Prosody or a BOSH connection manager " - .."is not already running"; - else - friendly_message = "this port is in use by another application"; - end - elseif err:match("permission") then - friendly_message = "Prosody does not have sufficient privileges to use this port"; - elseif err == "no ssl context" then - if not config.get("*", "core", "ssl") then - friendly_message = "there is no 'ssl' config under Host \"*\" which is " - .."require for legacy SSL ports"; - else - friendly_message = "initializing SSL support failed, see previous log entries"; - end - end - log("error", "Failed to open server port %d, %s", port, friendly_message); - end - end - end - end - end end function read_version() @@ -297,14 +284,15 @@ function load_secondary_libraries() --- Load and initialise core modules require "util.import" require "util.xmppstream" - require "core.rostermanager" + require "core.stanza_router" require "core.hostmanager" + require "core.portmanager" require "core.modulemanager" require "core.usermanager" + require "core.rostermanager" require "core.sessionmanager" - require "core.stanza_router" package.loaded['core.componentmanager'] = setmetatable({},{__index=function() - log("warn", "componentmanager is deprecated: %s", debug.traceback():match("\n[^\n]*\n[\s\t]*([^\n]*)")); + log("warn", "componentmanager is deprecated: %s", debug.traceback():match("\n[^\n]*\n[ \t]*([^\n]*)")); return function() end end}); @@ -325,15 +313,11 @@ function load_secondary_libraries() if remdebug then remdebug.engine.start() end ]] - require "net.connlisteners"; - require "util.stanza" require "util.jid" end function init_data_store() - local data_path = config.get("*", "core", "data_path") or CFG_DATADIR or "data"; - require "util.datamanager".set_data_path(data_path); require "core.storagemanager"; end @@ -341,20 +325,6 @@ function prepare_to_start() log("info", "Prosody is using the %s backend for connection handling", server.get_backend()); -- Signal to modules that we are ready to start prosody.events.fire_event("server-starting"); - - -- start listening on sockets - if config.get("*", "core", "ports") then - prosody.net_activate_ports(nil, "multiplex", {5222, 5269}); - if config.get("*", "core", "ssl_ports") then - prosody.net_activate_ports("ssl", "multiplex", {5223}, "ssl"); - end - else - prosody.net_activate_ports("c2s", "xmppclient", {5222}); - prosody.net_activate_ports("s2s", "xmppserver", {5269}); - prosody.net_activate_ports("component", "xmppcomponent", {5347}, "tcp"); - prosody.net_activate_ports("legacy_ssl", "xmppclient", {}, "ssl"); - end - prosody.start_time = os.time(); end @@ -401,43 +371,6 @@ end function cleanup() log("info", "Shutdown status: Cleaning up"); prosody.events.fire_event("server-cleanup"); - - -- Ok, we're quitting I know, but we - -- need to do some tidying before we go :) - server.setquitting(false); - - log("info", "Shutdown status: Closing all active sessions"); - for hostname, host in pairs(hosts) do - log("debug", "Shutdown status: Closing client connections for %s", hostname) - if host.sessions then - local reason = { condition = "system-shutdown", text = "Server is shutting down" }; - if prosody.shutdown_reason then - reason.text = reason.text..": "..prosody.shutdown_reason; - end - for username, user in pairs(host.sessions) do - for resource, session in pairs(user.sessions) do - log("debug", "Closing connection for %s@%s/%s", username, hostname, resource); - session:close(reason); - end - end - end - - log("debug", "Shutdown status: Closing outgoing s2s connections from %s", hostname); - if host.s2sout then - for remotehost, session in pairs(host.s2sout) do - if session.close then - session:close("system-shutdown"); - else - log("warn", "Unable to close outgoing s2s session to %s, no session:close()?!", remotehost); - end - end - end - end - - log("info", "Shutdown status: Closing all server connections"); - server.closeall(); - - server.setquitting(true); end -- Are you ready? :) @@ -445,6 +378,7 @@ end -- previous steps to have already been performed read_config(); init_logging(); +sanity_check(); sandbox_require(); set_function_metatable(); load_libraries(); diff --git a/prosody.cfg.lua.dist b/prosody.cfg.lua.dist index 3134acb6..1d11a658 100644 --- a/prosody.cfg.lua.dist +++ b/prosody.cfg.lua.dist @@ -4,7 +4,7 @@ -- website at http://prosody.im/doc/configure -- -- Tip: You can check that the syntax of this file is correct --- when you have finished by running: luac -p prosody.cfg.lua +-- when you have finished by running: prosodyctl check config -- If there are any errors, it will let you know what and where -- they are, otherwise it will keep quiet. -- @@ -24,7 +24,7 @@ admins = { } -- Enable use of libevent for better performance under high load -- For more information see: http://prosody.im/doc/libevent ---use_libevent = true; +--use_libevent = true -- This is the list of modules Prosody will load on startup. -- It looks for mod_modulename.lua in the plugins folder, so make sure that exists too. @@ -41,63 +41,108 @@ modules_enabled = { -- Not essential, but recommended "private"; -- Private XML storage (for room bookmarks, etc.) "vcard"; -- Allow users to set vCards + + -- These are commented by default as they have a performance impact --"privacy"; -- Support privacy lists --"compression"; -- Stream compression -- Nice to have - "legacyauth"; -- Legacy authentication. Only used by some old clients and bots. "version"; -- Replies to server version requests "uptime"; -- Report how long server has been running "time"; -- Let others know the time here on this server "ping"; -- Replies to XMPP pings with pongs "pep"; -- Enables users to publish their mood, activity, playing music and more "register"; -- Allow users to register on this server using a client and change passwords - "adhoc"; -- Support for "ad-hoc commands" that can be executed with an XMPP client -- Admin interfaces "admin_adhoc"; -- Allows administration via an XMPP client that supports ad-hoc commands --"admin_telnet"; -- Opens telnet console interface on localhost port 5582 + + -- HTTP modules + --"bosh"; -- Enable BOSH clients, aka "Jabber over HTTP" + --"http_files"; -- Serve static files from a directory over HTTP -- Other specific functionality --"posix"; -- POSIX functionality, sends server to background, enables syslog, etc. - --"bosh"; -- Enable BOSH clients, aka "Jabber over HTTP" - --"httpserver"; -- Serve static files from a directory over HTTP --"groups"; -- Shared roster support --"announce"; -- Send announcement to all online users --"welcome"; -- Welcome users who register accounts --"watchregistrations"; -- Alert admins of registrations -}; + --"motd"; -- Send a message to users when they log in + --"legacyauth"; -- Legacy authentication. Only used by some old clients and bots. +} --- These modules are auto-loaded, should you --- (for some mad reason) want to disable --- them then uncomment them below +-- These modules are auto-loaded, but should you want +-- to disable them then uncomment them here: modules_disabled = { - -- "presence"; - -- "message"; - -- "iq"; -}; + -- "offline"; -- Store offline messages + -- "c2s"; -- Handle client connections + -- "s2s"; -- Handle server-to-server connections +} -- Disable account creation by default, for security -- For more information see http://prosody.im/doc/creating_accounts -allow_registration = false; +allow_registration = false -- These are the SSL/TLS-related settings. If you don't want -- to use SSL/TLS, you may comment or remove this ssl = { key = "certs/localhost.key"; - certificate = "certs/localhost.cert"; + certificate = "certs/localhost.crt"; } --- Require encryption on client/server connections? ---c2s_require_encryption = false ---s2s_require_encryption = false +-- Force clients to use encrypted connections? This option will +-- prevent clients from authenticating unless they are using encryption. + +c2s_require_encryption = true + +-- Force certificate authentication for server-to-server connections? +-- This provides ideal security, but requires servers you communicate +-- with to support encryption AND present valid, trusted certificates. +-- NOTE: Your version of LuaSec must support certificate verification! +-- For more information see http://prosody.im/doc/s2s#security + +s2s_secure_auth = false + +-- Many servers don't support encryption or have invalid or self-signed +-- certificates. You can list domains here that will not be required to +-- authenticate using certificates. They will be authenticated using DNS. + +--s2s_insecure_domains = { "gmail.com" } + +-- Even if you leave s2s_secure_auth disabled, you can still require valid +-- certificates for some domains by specifying a list here. + +--s2s_secure_domains = { "jabber.org" } + +-- Select the authentication backend to use. The 'internal' providers +-- use Prosody's configured data storage to store the authentication data. +-- To allow Prosody to offer secure authentication mechanisms to clients, the +-- default provider stores passwords in plaintext. If you do not trust your +-- server please see http://prosody.im/doc/modules/mod_auth_internal_hashed +-- for information about using the hashed backend. + +authentication = "internal_plain" + +-- Select the storage backend to use. By default Prosody uses flat files +-- in its configured data directory, but it also supports more backends +-- through modules. An "sql" backend is included by default, but requires +-- additional dependencies. See http://prosody.im/doc/storage for more info. + +--storage = "sql" -- Default is "internal" + +-- For the "sql" backend, you can uncomment *one* of the below to configure: +--sql = { driver = "SQLite3", database = "prosody.sqlite" } -- Default. 'database' is the filename. +--sql = { driver = "MySQL", database = "prosody", username = "prosody", password = "secret", host = "localhost" } +--sql = { driver = "PostgreSQL", database = "prosody", username = "prosody", password = "secret", host = "localhost" } -- Logging configuration -- For advanced logging see http://prosody.im/doc/logging log = { - info = "prosody.log"; -- Change info to debug for verbose logging + info = "prosody.log"; -- Change 'info' to 'debug' for verbose logging error = "prosody.err"; -- "*syslog"; -- Uncomment this for logging to syslog + -- "*console"; -- Log to the console, useful for debugging with daemonize=false } ----------- Virtual hosts ----------- @@ -18,10 +18,22 @@ CFG_DATADIR=os.getenv("PROSODY_DATADIR"); -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- +local function is_relative(path) + local path_sep = package.config:sub(1,1); + return ((path_sep == "/" and path:sub(1,1) ~= "/") + or (path_sep == "\\" and (path:sub(1,1) ~= "/" and path:sub(2,3) ~= ":\\"))) +end + -- Tell Lua where to find our libraries if CFG_SOURCEDIR then - package.path = CFG_SOURCEDIR.."/?.lua;"..package.path; - package.cpath = CFG_SOURCEDIR.."/?.so;"..package.cpath; + local function filter_relative_paths(path) + if is_relative(path) then return ""; end + end + local function sanitise_paths(paths) + return (paths:gsub("[^;]+;?", filter_relative_paths):gsub(";;+", ";")); + end + package.path = sanitise_paths(CFG_SOURCEDIR.."/?.lua;"..package.path); + package.cpath = sanitise_paths(CFG_SOURCEDIR.."/?.so;"..package.cpath); end -- Substitute ~ with path to home directory in data path @@ -32,14 +44,16 @@ if CFG_DATADIR then end -- Global 'prosody' object -prosody = { +local prosody = { hosts = {}; events = require "util.events".new(); platform = "posix"; lock_globals = function () end; unlock_globals = function () end; + installed = CFG_SOURCEDIR ~= nil; + core_post_stanza = function () end; -- TODO: mod_router! }; -local prosody = prosody; +_G.prosody = prosody; local dependencies = require "util.dependencies"; if not dependencies.check_dependencies() then @@ -48,16 +62,17 @@ end config = require "core.configmanager" +local ENV_CONFIG; do local filenames = {}; local filename; if arg[1] == "--config" and arg[2] then table.insert(filenames, arg[2]); - table.remove(arg, 1); table.remove(arg, 1); if CFG_CONFIGDIR then table.insert(filenames, CFG_CONFIGDIR.."/"..arg[2]); end + table.remove(arg, 1); table.remove(arg, 1); else for _, format in ipairs(config.parsers()) do table.insert(filenames, (CFG_CONFIGDIR or ".").."/prosody.cfg."..format); @@ -68,6 +83,7 @@ do local file = io.open(filename); if file then file:close(); + ENV_CONFIG = filename; CFG_CONFIGDIR = filename:match("^(.*)[\\/][^\\/]*$"); break; end @@ -94,20 +110,32 @@ do os.exit(1); end end -local original_logging_config = config.get("*", "core", "log"); -config.set("*", "core", "log", { { levels = { min="info" }, to = "console" } }); +local original_logging_config = config.get("*", "log"); +config.set("*", "log", { { levels = { min="info" }, to = "console" } }); + +local data_path = config.get("*", "data_path") or CFG_DATADIR or "data"; +local custom_plugin_paths = config.get("*", "plugin_paths"); +if custom_plugin_paths then + local path_sep = package.config:sub(3,3); + -- path1;path2;path3;defaultpath... + CFG_PLUGINDIR = table.concat(custom_plugin_paths, path_sep)..path_sep..(CFG_PLUGINDIR or "plugins"); +end +prosody.paths = { source = CFG_SOURCEDIR, config = CFG_CONFIGDIR, + plugins = CFG_PLUGINDIR or "plugins", data = data_path }; + +if prosody.installed then + -- Change working directory to data path. + require "lfs".chdir(data_path); +end require "core.loggingmanager" dependencies.log_warnings(); -local data_path = config.get("*", "core", "data_path") or CFG_DATADIR or "data"; -require "util.datamanager".set_data_path(data_path); - -- Switch away from root and into the prosody user -- local switched_user, current_uid; -local want_pposix_version = "0.3.5"; +local want_pposix_version = "0.3.6"; local ok, pposix = pcall(require, "util.pposix"); if ok and pposix then @@ -115,8 +143,8 @@ if ok and pposix then current_uid = pposix.getuid(); if current_uid == 0 then -- We haz root! - local desired_user = config.get("*", "core", "prosody_user") or "prosody"; - local desired_group = config.get("*", "core", "prosody_group") or desired_user; + local desired_user = config.get("*", "prosody_user") or "prosody"; + local desired_group = config.get("*", "prosody_group") or desired_user; local ok, err = pposix.setgid(desired_group); if ok then ok, err = pposix.initgroups(desired_user); @@ -135,11 +163,14 @@ if ok and pposix then end -- Set our umask to protect data files - pposix.umask(config.get("*", "core", "umask") or "027"); + pposix.umask(config.get("*", "umask") or "027"); + pposix.setenv("HOME", data_path); + pposix.setenv("PROSODY_CONFIG", ENV_CONFIG); else print("Error: Unable to load pposix module. Check that Prosody is installed correctly.") print("For more help send the below error to us through http://prosody.im/discuss"); print(tostring(pposix)) + os.exit(1); end local function test_writeable(filename) @@ -186,6 +217,7 @@ local error_messages = setmetatable({ ["invalid-hostname"] = "The given hostname is invalid"; ["no-password"] = "No password was supplied"; ["no-such-user"] = "The given user does not exist on the server"; + ["no-such-host"] = "The given hostname does not exist in the config"; ["unable-to-save-data"] = "Unable to store, perhaps you don't have permission?"; ["no-pidfile"] = "There is no 'pidfile' option in the configuration file, see http://prosody.im/doc/prosodyctl#pidfile for help"; ["no-posix"] = "The mod_posix module is not enabled in the Prosody config file, see http://prosody.im/doc/prosodyctl for more info"; @@ -199,6 +231,7 @@ local function make_host(hostname) return { type = "local", events = prosody.events, + modules = {}, users = require "core.usermanager".new_null_provider(hostname) }; end @@ -207,104 +240,46 @@ for hostname, config in pairs(config.getconfig()) do hosts[hostname] = make_host(hostname); end -require "core.modulemanager" +local modulemanager = require "core.modulemanager" -require "util.prosodyctl" +local prosodyctl = require "util.prosodyctl" require "socket" ----------------------- -function show_message(msg, ...) - print(msg:format(...)); -end - -function show_warning(msg, ...) - print(msg:format(...)); -end - -function show_usage(usage, desc) - print("Usage: "..arg[0].." "..usage); - if desc then - print(" "..desc); - end -end - -local function getchar(n) - local stty_ret = os.execute("stty raw -echo 2>/dev/null"); - local ok, char; - if stty_ret == 0 then - ok, char = pcall(io.read, n or 1); - os.execute("stty sane"); - else - ok, char = pcall(io.read, "*l"); - if ok then - char = char:sub(1, n or 1); + -- FIXME: Duplicate code waiting for util.startup +function read_version() + -- Try to determine version + local version_file = io.open((CFG_SOURCEDIR or ".").."/prosody.version"); + if version_file then + prosody.version = version_file:read("*a"):gsub("%s*$", ""); + version_file:close(); + if #prosody.version == 12 and prosody.version:match("^[a-f0-9]+$") then + prosody.version = "hg:"..prosody.version; end - end - if ok then - return char; - end -end - -local function getpass() - local stty_ret = os.execute("stty -echo 2>/dev/null"); - if stty_ret ~= 0 then - io.write("\027[08m"); -- ANSI 'hidden' text attribute - end - local ok, pass = pcall(io.read, "*l"); - if stty_ret == 0 then - os.execute("stty sane"); else - io.write("\027[00m"); - end - io.write("\n"); - if ok then - return pass; + prosody.version = "unknown"; end end -function show_yesno(prompt) - io.write(prompt, " "); - local choice = getchar():lower(); - io.write("\n"); - if not choice:match("%a") then - choice = prompt:match("%[.-(%U).-%]$"); - if not choice then return nil; end - end - return (choice == "y"); -end - -local function read_password() - local password; - while true do - io.write("Enter new password: "); - password = getpass(); - if not password then - show_message("No password - cancelled"); - return; - end - io.write("Retype new password: "); - if getpass() ~= password then - if not show_yesno [=[Passwords did not match, try again? [Y/n]]=] then - return; - end - else - break; - end - end - return password; -end +local show_message, show_warning = prosodyctl.show_message, prosodyctl.show_warning; +local show_usage = prosodyctl.show_usage; +local getchar, getpass = prosodyctl.getchar, prosodyctl.getpass; +local show_yesno = prosodyctl.show_yesno; +local show_prompt = prosodyctl.show_prompt; +local read_password = prosodyctl.read_password; -local prosodyctl_timeout = (config.get("*", "core", "prosodyctl_timeout") or 5) * 2; +local prosodyctl_timeout = (config.get("*", "prosodyctl_timeout") or 5) * 2; ----------------------- local commands = {}; local command = arg[1]; function commands.adduser(arg) + local jid_split = require "util.jid".split; if not arg[1] or arg[1] == "--help" then show_usage([[adduser JID]], [[Create the specified user account in Prosody]]); return 1; end - local user, host = arg[1]:match("([^@]+)@(.+)"); + local user, host = jid_split(arg[1]); if not user and host then show_message [[Failed to understand JID, please supply the JID you want to create]] show_usage [[adduser user@host]] @@ -339,11 +314,12 @@ function commands.adduser(arg) end function commands.passwd(arg) + local jid_split = require "util.jid".split; if not arg[1] or arg[1] == "--help" then show_usage([[passwd JID]], [[Set the password for the specified user account in Prosody]]); return 1; end - local user, host = arg[1]:match("([^@]+)@(.+)"); + local user, host = jid_split(arg[1]); if not user and host then show_message [[Failed to understand JID, please supply the JID you want to set the password for]] show_usage [[passwd user@host]] @@ -378,11 +354,12 @@ function commands.passwd(arg) end function commands.deluser(arg) + local jid_split = require "util.jid".split; if not arg[1] or arg[1] == "--help" then show_usage([[deluser JID]], [[Permanently remove the specified user account from Prosody]]); return 1; end - local user, host = arg[1]:match("([^@]+)@(.+)"); + local user, host = jid_split(arg[1]); if not user and host then show_message [[Failed to understand JID, please supply the JID you want to set the password for]] show_usage [[passwd user@host]] @@ -405,7 +382,7 @@ function commands.deluser(arg) return 1; end - local ok, msg = prosodyctl.passwd { user = user, host = host }; + local ok, msg = prosodyctl.deluser { user = user, host = host }; if ok then return 0; end @@ -437,7 +414,7 @@ function commands.start(arg) local ok, ret = prosodyctl.start(); if ok then - if config.get("*", "core", "daemonize") ~= false then + if config.get("*", "daemonize") ~= false then local i=1; while true do local ok, running = prosodyctl.isrunning(); @@ -541,6 +518,81 @@ function commands.restart(arg) return commands.start(arg); end +function commands.about(arg) + read_version(); + if arg[1] == "--help" then + show_usage([[about]], [[Show information about this Prosody installation]]); + return 1; + end + + local array = require "util.array"; + local keys = require "util.iterators".keys; + + print("Prosody "..(prosody.version or "(unknown version)")); + print(""); + print("# Prosody directories"); + print("Data directory: ", CFG_DATADIR or "./"); + print("Plugin directory:", CFG_PLUGINDIR or "./"); + print("Config directory:", CFG_CONFIGDIR or "./"); + print("Source directory:", CFG_SOURCEDIR or "./"); + print(""); + print("# Lua environment"); + print("Lua version: ", _G._VERSION); + print(""); + print("Lua module search paths:"); + for path in package.path:gmatch("[^;]+") do + print(" "..path); + end + print(""); + print("Lua C module search paths:"); + for path in package.cpath:gmatch("[^;]+") do + print(" "..path); + end + print(""); + local luarocks_status = (pcall(require, "luarocks.loader") and "Installed ("..(luarocks.cfg.program_version or "2.x+")..")") + or (pcall(require, "luarocks.require") and "Installed (1.x)") + or "Not installed"; + print("LuaRocks: ", luarocks_status); + print(""); + print("# Lua module versions"); + local module_versions, longest_name = {}, 8; + for name, module in pairs(package.loaded) do + if type(module) == "table" and rawget(module, "_VERSION") + and name ~= "_G" and not name:match("%.") then + if #name > longest_name then + longest_name = #name; + end + module_versions[name] = module._VERSION; + end + end + local sorted_keys = array.collect(keys(module_versions)):sort(); + for _, name in ipairs(array.collect(keys(module_versions)):sort()) do + print(name..":"..string.rep(" ", longest_name-#name), module_versions[name]); + end + print(""); +end + +function commands.reload(arg) + if arg[1] == "--help" then + show_usage([[reload]], [[Reload Prosody's configuration and re-open log files]]); + return 1; + end + + if not prosodyctl.isrunning() then + show_message("Prosody is not running"); + return 1; + end + + local ok, ret = prosodyctl.reload(); + if ok then + + show_message("Prosody log files re-opened and config file reloaded. You may need to reload modules for some changes to take effect."); + return 0; + end + + show_message(error_messages[ret]); + return 1; +end -- ejabberdctl compatibility function commands.register(arg) @@ -594,6 +646,466 @@ function commands.unregister(arg) return 1; end +local openssl; +local lfs; + +local cert_commands = {}; + +local function ask_overwrite(filename) + return lfs.attributes(filename) and not show_yesno("Overwrite "..filename .. "?"); +end + +function cert_commands.config(arg) + if #arg >= 1 and arg[1] ~= "--help" then + local conf_filename = (CFG_DATADIR or "./certs") .. "/" .. arg[1] .. ".cnf"; + if ask_overwrite(conf_filename) then + return nil, conf_filename; + end + local conf = openssl.config.new(); + conf:from_prosody(hosts, config, arg); + show_message("Please provide details to include in the certificate config file."); + show_message("Leave the field empty to use the default value or '.' to exclude the field.") + for i, k in ipairs(openssl._DN_order) do + local v = conf.distinguished_name[k]; + if v then + local nv; + if k == "commonName" then + v = arg[1] + elseif k == "emailAddress" then + v = "xmpp@" .. arg[1]; + elseif k == "countryName" then + local tld = arg[1]:match"%.([a-z]+)$"; + if tld and #tld == 2 and tld ~= "uk" then + v = tld:upper(); + end + end + nv = show_prompt(("%s (%s):"):format(k, nv or v)); + nv = (not nv or nv == "") and v or nv; + if nv:find"[\192-\252][\128-\191]+" then + conf.req.string_mask = "utf8only" + end + conf.distinguished_name[k] = nv ~= "." and nv or nil; + end + end + local conf_file = io.open(conf_filename, "w"); + conf_file:write(conf:serialize()); + conf_file:close(); + print(""); + show_message("Config written to " .. conf_filename); + return nil, conf_filename; + else + show_usage("cert config HOSTNAME [HOSTNAME+]", "Builds a certificate config file covering the supplied hostname(s)") + end +end + +function cert_commands.key(arg) + if #arg >= 1 and arg[1] ~= "--help" then + local key_filename = (CFG_DATADIR or "./certs") .. "/" .. arg[1] .. ".key"; + if ask_overwrite(key_filename) then + return nil, key_filename; + end + os.remove(key_filename); -- This file, if it exists is unlikely to have write permissions + local key_size = tonumber(arg[2] or show_prompt("Choose key size (2048):") or 2048); + local old_umask = pposix.umask("0377"); + if openssl.genrsa{out=key_filename, key_size} then + os.execute(("chmod 400 '%s'"):format(key_filename)); + show_message("Key written to ".. key_filename); + pposix.umask(old_umask); + return nil, key_filename; + end + show_message("There was a problem, see OpenSSL output"); + else + show_usage("cert key HOSTNAME <bits>", "Generates a RSA key named HOSTNAME.key\n " + .."Prompts for a key size if none given") + end +end + +function cert_commands.request(arg) + if #arg >= 1 and arg[1] ~= "--help" then + local req_filename = (CFG_DATADIR or "./certs") .. "/" .. arg[1] .. ".req"; + if ask_overwrite(req_filename) then + return nil, req_filename; + end + local _, key_filename = cert_commands.key({arg[1]}); + local _, conf_filename = cert_commands.config(arg); + if openssl.req{new=true, key=key_filename, utf8=true, config=conf_filename, out=req_filename} then + show_message("Certificate request written to ".. req_filename); + else + show_message("There was a problem, see OpenSSL output"); + end + else + show_usage("cert request HOSTNAME [HOSTNAME+]", "Generates a certificate request for the supplied hostname(s)") + end +end + +function cert_commands.generate(arg) + if #arg >= 1 and arg[1] ~= "--help" then + local cert_filename = (CFG_DATADIR or "./certs") .. "/" .. arg[1] .. ".crt"; + if ask_overwrite(cert_filename) then + return nil, cert_filename; + end + local _, key_filename = cert_commands.key({arg[1]}); + local _, conf_filename = cert_commands.config(arg); + local ret; + if key_filename and conf_filename and cert_filename + and openssl.req{new=true, x509=true, nodes=true, key=key_filename, + days=365, sha1=true, utf8=true, config=conf_filename, out=cert_filename} then + show_message("Certificate written to ".. cert_filename); + else + show_message("There was a problem, see OpenSSL output"); + end + else + show_usage("cert generate HOSTNAME [HOSTNAME+]", "Generates a self-signed certificate for the current hostname(s)") + end +end + +function commands.cert(arg) + if #arg >= 1 and arg[1] ~= "--help" then + openssl = require "util.openssl"; + lfs = require "lfs"; + local subcmd = table.remove(arg, 1); + if type(cert_commands[subcmd]) == "function" then + if not arg[1] then + show_message"You need to supply at least one hostname" + arg = { "--help" }; + end + if arg[1] ~= "--help" and not hosts[arg[1]] then + show_message(error_messages["no-such-host"]); + return + end + return cert_commands[subcmd](arg); + end + end + show_usage("cert config|request|generate|key", "Helpers for generating X.509 certificates and keys.") +end + +function commands.check(arg) + if arg[1] == "--help" then + show_usage([[check]], [[Perform basic checks on your Prosody installation]]); + return 1; + end + local what = table.remove(arg, 1); + local array, set = require "util.array", require "util.set"; + local it = require "util.iterators"; + local ok = true; + if not what or what == "config" then + print("Checking config..."); + local known_global_options = set.new({ + "pidfile", "log", "plugin_paths", "prosody_user", "prosody_group", "daemonize", + "umask", "prosodyctl_timeout", "use_ipv6", "use_libevent", "network_settings" + }); + local config = config.getconfig(); + -- Check that we have any global options (caused by putting a host at the top) + if it.count(it.filter("log", pairs(config["*"]))) == 0 then + ok = false; + print(""); + print(" No global options defined. Perhaps you have put a host definition at the top") + print(" of the config file? They should be at the bottom, see http://prosody.im/doc/configure#overview"); + end + -- Check for global options under hosts + local global_options = set.new(it.to_array(it.keys(config["*"]))); + for host, options in it.filter("*", pairs(config)) do + local host_options = set.new(it.to_array(it.keys(options))); + local misplaced_options = set.intersection(host_options, known_global_options); + for name in pairs(options) do + if name:match("^interfaces?") + or name:match("_ports?$") or name:match("_interfaces?$") + or name:match("_ssl$") then + misplaced_options:add(name); + end + end + if not misplaced_options:empty() then + ok = false; + print(""); + local n = it.count(misplaced_options); + print(" You have "..n.." option"..(n>1 and "s " or " ").."set under "..host.." that should be"); + print(" in the global section of the config file, above any VirtualHost or Component definitions,") + print(" see http://prosody.im/doc/configure#overview for more information.") + print(""); + print(" You need to move the following option"..(n>1 and "s" or "")..": "..table.concat(it.to_array(misplaced_options), ", ")); + end + local subdomain = host:match("^[^.]+"); + if not(host_options:contains("component_module")) and (subdomain == "jabber" or subdomain == "xmpp" + or subdomain == "chat" or subdomain == "im") then + print(""); + print(" Suggestion: If "..host.. " is a new host with no real users yet, consider renaming it now to"); + print(" "..host:gsub("^[^.]+%.", "")..". You can use SRV records to redirect XMPP clients and servers to "..host.."."); + print(" For more information see: http://prosody.im/doc/dns"); + end + end + + print("Done.\n"); + end + if not what or what == "dns" then + local dns = require "net.dns"; + local idna = require "util.encodings".idna; + local ip = require "util.ip"; + local c2s_ports = set.new(config.get("*", "c2s_ports") or {5222}); + local s2s_ports = set.new(config.get("*", "s2s_ports") or {5269}); + + local c2s_srv_required, s2s_srv_required; + if not c2s_ports:contains(5222) then + c2s_srv_required = true; + end + if not s2s_ports:contains(5269) then + s2s_srv_required = true; + end + + local problem_hosts = set.new(); + + local external_addresses, internal_addresses = set.new(), set.new(); + + local fqdn = socket.dns.tohostname(socket.dns.gethostname()); + if fqdn then + local res = dns.lookup(idna.to_ascii(fqdn), "A"); + if res then + for _, record in ipairs(res) do + external_addresses:add(record.a); + end + end + local res = dns.lookup(idna.to_ascii(fqdn), "AAAA"); + if res then + for _, record in ipairs(res) do + external_addresses:add(record.aaaa); + end + end + end + + local local_addresses = require"util.net".local_addresses() or {}; + + for addr in it.values(local_addresses) do + if not ip.new_ip(addr).private then + external_addresses:add(addr); + else + internal_addresses:add(addr); + end + end + + if external_addresses:empty() then + print(""); + print(" Failed to determine the external addresses of this server. Checks may be inaccurate."); + c2s_srv_required, s2s_srv_required = true, true; + end + + local v6_supported = not not socket.tcp6; + + for host, host_options in it.filter("*", pairs(config.getconfig())) do + local all_targets_ok, some_targets_ok = true, false; + + local is_component = not not host_options.component_module; + print("Checking DNS for "..(is_component and "component" or "host").." "..host.."..."); + local target_hosts = set.new(); + if not is_component then + local res = dns.lookup("_xmpp-client._tcp."..idna.to_ascii(host)..".", "SRV"); + if res then + for _, record in ipairs(res) do + target_hosts:add(record.srv.target); + if not c2s_ports:contains(record.srv.port) then + print(" SRV target "..record.srv.target.." contains unknown client port: "..record.srv.port); + end + end + else + if c2s_srv_required then + print(" No _xmpp-client SRV record found for "..host..", but it looks like you need one."); + all_targst_ok = false; + else + target_hosts:add(host); + end + end + end + local res = dns.lookup("_xmpp-server._tcp."..idna.to_ascii(host)..".", "SRV"); + if res then + for _, record in ipairs(res) do + target_hosts:add(record.srv.target); + if not s2s_ports:contains(record.srv.port) then + print(" SRV target "..record.srv.target.." contains unknown server port: "..record.srv.port); + end + end + else + if s2s_srv_required then + print(" No _xmpp-server SRV record found for "..host..", but it looks like you need one."); + all_targets_ok = false; + else + target_hosts:add(host); + end + end + if target_hosts:empty() then + target_hosts:add(host); + end + + if target_hosts:contains("localhost") then + print(" Target 'localhost' cannot be accessed from other servers"); + target_hosts:remove("localhost"); + end + + local modules = set.new(it.to_array(it.values(host_options.modules_enabled))) + + set.new(it.to_array(it.values(config.get("*", "modules_enabled")))) + + set.new({ config.get(host, "component_module") }); + + if modules:contains("proxy65") then + local proxy65_target = config.get(host, "proxy65_address") or host; + local A, AAAA = dns.lookup(idna.to_ascii(proxy65_target), "A"), dns.lookup(idna.to_ascii(proxy65_target), "AAAA"); + local prob = {}; + if not A then + table.insert(prob, "A"); + end + if v6_supported and not AAAA then + table.insert(prob, "AAAA"); + end + if #prob > 0 then + print(" File transfer proxy "..proxy65_target.." has no "..table.concat(prob, "/").." record. Create one or set 'proxy65_address' to the correct host/IP."); + end + end + + for host in target_hosts do + local host_ok_v4, host_ok_v6; + local res = dns.lookup(idna.to_ascii(host), "A"); + if res then + for _, record in ipairs(res) do + if external_addresses:contains(record.a) then + some_targets_ok = true; + host_ok_v4 = true; + elseif internal_addresses:contains(record.a) then + host_ok_v4 = true; + some_targets_ok = true; + print(" "..host.." A record points to internal address, external connections might fail"); + else + print(" "..host.." A record points to unknown address "..record.a); + all_targets_ok = false; + end + end + end + local res = dns.lookup(idna.to_ascii(host), "AAAA"); + if res then + for _, record in ipairs(res) do + if external_addresses:contains(record.aaaa) then + some_targets_ok = true; + host_ok_v6 = true; + elseif internal_addresses:contains(record.aaaa) then + host_ok_v6 = true; + some_targets_ok = true; + print(" "..host.." AAAA record points to internal address, external connections might fail"); + else + print(" "..host.." AAAA record points to unknown address "..record.aaaa); + all_targets_ok = false; + end + end + end + + local bad_protos = {} + if not host_ok_v4 then + table.insert(bad_protos, "IPv4"); + end + if not host_ok_v6 then + table.insert(bad_protos, "IPv6"); + end + if #bad_protos > 0 then + print(" Host "..host.." does not seem to resolve to this server ("..table.concat(bad_protos, "/")..")"); + end + if host_ok_v6 and not v6_supported then + print(" Host "..host.." has AAAA records, but your version of LuaSocket does not support IPv6."); + print(" Please see http://prosody.im/doc/ipv6 for more information."); + end + end + if not all_targets_ok then + print(" "..(some_targets_ok and "Only some" or "No").." targets for "..host.." appear to resolve to this server."); + if is_component then + print(" DNS records are necessary if you want users on other servers to access this component."); + end + problem_hosts:add(host); + end + print(""); + end + if not problem_hosts:empty() then + print(""); + print("For more information about DNS configuration please see http://prosody.im/doc/dns"); + print(""); + ok = false; + end + end + if not what or what == "certs" then + local cert_ok; + print"Checking certificates..." + local x509_verify_identity = require"util.x509".verify_identity; + local ssl = dependencies.softreq"ssl"; + -- local datetime_parse = require"util.datetime".parse_x509; + local load_cert = ssl and ssl.x509 and ssl.x509.load; + -- or ssl.cert_from_pem + if not ssl then + print("LuaSec not available, can't perform certificate checks") + if what == "certs" then cert_ok = false end + elseif not load_cert then + print("This version of LuaSec (" .. ssl._VERSION .. ") does not support certificate checking"); + cert_ok = false + else + for host in pairs(hosts) do + if host ~= "*" then -- Should check global certs too. + print("Checking certificate for "..host); + -- First, let's find out what certificate this host uses. + local ssl_config = config.rawget(host, "ssl"); + if not ssl_config then + local base_host = host:match("%.(.*)"); + ssl_config = config.get(base_host, "ssl"); + end + if not ssl_config then + print(" No 'ssl' option defined for "..host) + cert_ok = false + elseif not ssl_config.certificate then + print(" No 'certificate' set in ssl option for "..host) + cert_ok = false + elseif not ssl_config.key then + print(" No 'key' set in ssl option for "..host) + cert_ok = false + else + local key, err = io.open(ssl_config.key); -- Permissions check only + if not key then + print(" Could not open "..ssl_config.key..": "..err); + cert_ok = false + else + key:close(); + end + local cert_fh, err = io.open(ssl_config.certificate); -- Load the file. + if not cert_fh then + print(" Could not open "..ssl_config.certificate..": "..err); + cert_ok = false + else + print(" Certificate: "..ssl_config.certificate) + local cert = load_cert(cert_fh:read"*a"); cert_fh = cert_fh:close(); + if not cert:validat(os.time()) then + print(" Certificate has expired.") + cert_ok = false + end + if config.get(host, "component_module") == nil + and not x509_verify_identity(host, "_xmpp-client", cert) then + print(" Not vaild for client connections to "..host..".") + cert_ok = false + end + if (not (config.get(name, "anonymous_login") + or config.get(name, "authentication") == "anonymous")) + and not x509_verify_identity(host, "_xmpp-client", cert) then + print(" Not vaild for server-to-server connections to "..host..".") + cert_ok = false + end + end + end + end + end + if cert_ok == false then + print("") + print("For more information about certificates please see http://prosody.im/doc/certificates"); + ok = false + end + end + print("") + end + if not ok then + print("Problems found, see above."); + else + print("All checks passed, congratulations!"); + end + return ok and 0 or 2; +end + --------------------- if command and command:match("^mod_") then -- Is a command in a module @@ -644,7 +1156,7 @@ if not commands[command] then -- Show help for all commands print("Where COMMAND may be one of:\n"); local hidden_commands = require "util.set".new{ "register", "unregister", "addplugin" }; - local commands_order = { "adduser", "passwd", "deluser", "start", "stop", "restart" }; + local commands_order = { "adduser", "passwd", "deluser", "start", "stop", "restart", "reload", "about" }; local done = {}; diff --git a/tests/modulemanager_option_conversion.lua b/tests/modulemanager_option_conversion.lua index 7dceeaed..100dbe83 100644 --- a/tests/modulemanager_option_conversion.lua +++ b/tests/modulemanager_option_conversion.lua @@ -18,7 +18,7 @@ function test_value(value, returns) assert(module:get_option_number("opt") == returns.number, "number doesn't match"); assert(module:get_option_string("opt") == returns.string, "string doesn't match"); assert(module:get_option_boolean("opt") == returns.boolean, "boolean doesn't match"); - + if type(returns.array) == "table" then local target_array, returned_array = returns.array, module:get_option_array("opt"); assert(#target_array == #returned_array, "array length doesn't match"); @@ -28,7 +28,7 @@ function test_value(value, returns) else assert(module:get_option_array("opt") == returns.array, "array is returned (not nil)"); end - + if type(returns.set) == "table" then local target_items, returned_items = set.new(returns.set), module:get_option_set("opt"); assert(target_items == returned_items, "set doesn't match"); diff --git a/tests/test.lua b/tests/test.lua index ae5b24f0..f7475a80 100644 --- a/tests/test.lua +++ b/tests/test.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -9,15 +9,18 @@ function run_all_tests() + package.loaded["net.connlisteners"] = { get = function () return {} end }; dotest "util.jid" dotest "util.multitable" - dotest "core.modulemanager" + dotest "util.rfc6724" + dotest "util.http" dotest "core.stanza_router" dotest "core.s2smanager" dotest "core.configmanager" + dotest "util.ip" dotest "util.stanza" dotest "util.sasl.scram" - + dosingletest("test_sasl.lua", "latin1toutf8"); end @@ -84,12 +87,12 @@ function dosingletest(testname, fname) print("WARNING: ", "Failed to initialise tests for "..testname, err); return; end - + if type(tests[fname]) ~= "function" then error(testname.." has no test '"..fname.."'", 0); end - - + + local line_hook, line_info = new_line_coverage_monitor(testname); debug.sethook(line_hook, "l") local success, ret = pcall(tests[fname]); @@ -131,17 +134,23 @@ function dotest(unitname) print("WARNING: ", "Failed to load module: "..unitname, err); return; end - + local oldmodule, old_M = _fakeG.module, _fakeG._M; - _fakeG.module = function () _M = _G end + _fakeG.module = function () _M = unit end setfenv(chunk, unit); - local success, err = pcall(chunk); + local success, ret = pcall(chunk); _fakeG.module, _fakeG._M = oldmodule, old_M; if not success then print("WARNING: ", "Failed to initialise module: "..unitname, err); return; end - + + if type(ret) == "table" then + for k,v in pairs(ret) do + unit[k] = v; + end + end + for name, f in pairs(unit) do local test = rawget(tests, name); if type(f) ~= "function" then @@ -188,11 +197,11 @@ end function new_line_coverage_monitor(file) local lines_hit, funcs_hit = {}, {}; local total_lines, covered_lines = 0, 0; - + for line in io.lines(file) do total_lines = total_lines + 1; end - + return function (event, line) -- Line hook if not lines_hit[line] then local info = debug.getinfo(2, "fSL") diff --git a/tests/test_core_configmanager.lua b/tests/test_core_configmanager.lua index 132dfc74..5bd469c6 100644 --- a/tests/test_core_configmanager.lua +++ b/tests/test_core_configmanager.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -9,27 +9,23 @@ function get(get, config) - config.set("example.com", "test", "testkey", 123); - assert_equal(get("example.com", "test", "testkey"), 123, "Retrieving a set key"); - - config.set("*", "test", "testkey1", 321); - assert_equal(get("*", "test", "testkey1"), 321, "Retrieving a set global key"); - assert_equal(get("example.com", "test", "testkey1"), 321, "Retrieving a set key of undefined host, of which only a globally set one exists"); - - config.set("example.com", "test", ""); -- Creates example.com host in config - assert_equal(get("example.com", "test", "testkey1"), 321, "Retrieving a set key, of which only a globally set one exists"); - + config.set("example.com", "testkey", 123); + assert_equal(get("example.com", "testkey"), 123, "Retrieving a set key"); + + config.set("*", "testkey1", 321); + assert_equal(get("*", "testkey1"), 321, "Retrieving a set global key"); + assert_equal(get("example.com", "testkey1"), 321, "Retrieving a set key of undefined host, of which only a globally set one exists"); + + config.set("example.com", ""); -- Creates example.com host in config + assert_equal(get("example.com", "testkey1"), 321, "Retrieving a set key, of which only a globally set one exists"); + assert_equal(get(), nil, "No parameters to get()"); assert_equal(get("undefined host"), nil, "Getting for undefined host"); - assert_equal(get("undefined host", "undefined section"), nil, "Getting for undefined host & section"); - assert_equal(get("undefined host", "undefined section", "undefined key"), nil, "Getting for undefined host & section & key"); - - assert_equal(get("example.com", "undefined section", "testkey"), nil, "Defined host, undefined section"); + assert_equal(get("undefined host", "undefined key"), nil, "Getting for undefined host & key"); end function set(set, u) - assert_equal(set("*"), false, "Set with no section/key"); - assert_equal(set("*", "set_test"), false, "Set with no key"); + assert_equal(set("*"), false, "Set with no key"); assert_equal(set("*", "set_test", "testkey"), true, "Setting a nil global value"); assert_equal(set("*", "set_test", "testkey", 123), true, "Setting a global value"); diff --git a/tests/test_core_modulemanager.lua b/tests/test_core_modulemanager.lua deleted file mode 100644 index 9498875a..00000000 --- a/tests/test_core_modulemanager.lua +++ /dev/null @@ -1,48 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - -local config = require "core.configmanager"; -local helpers = require "util.helpers"; -local set = require "util.set"; - -function load_modules_for_host(load_modules_for_host, mm) - local test_num = 0; - local function test_load(global_modules_enabled, global_modules_disabled, host_modules_enabled, host_modules_disabled, expected_modules) - test_num = test_num + 1; - -- Prepare - hosts = { ["example.com"] = {} }; - config.set("*", "core", "modules_enabled", global_modules_enabled); - config.set("*", "core", "modules_disabled", global_modules_disabled); - config.set("example.com", "core", "modules_enabled", host_modules_enabled); - config.set("example.com", "core", "modules_disabled", host_modules_disabled); - - expected_modules = set.new(expected_modules); - expected_modules:add_list(helpers.get_upvalue(load_modules_for_host, "autoload_modules")); - - local loaded_modules = set.new(); - function mm.load(host, module) - assert_equal(host, "example.com", test_num..": Host isn't example.com but "..tostring(host)); - assert_equal(expected_modules:contains(module), true, test_num..": Loading unexpected module '"..tostring(module).."'"); - loaded_modules:add(module); - end - load_modules_for_host("example.com"); - assert_equal((expected_modules - loaded_modules):empty(), true, test_num..": Not all modules loaded: "..tostring(expected_modules - loaded_modules)); - end - - test_load({ "one", "two", "three" }, nil, nil, nil, { "one", "two", "three" }); - test_load({ "one", "two", "three" }, {}, nil, nil, { "one", "two", "three" }); - test_load({ "one", "two", "three" }, { "two" }, nil, nil, { "one", "three" }); - test_load({ "one", "two", "three" }, { "three" }, nil, nil, { "one", "two" }); - test_load({ "one", "two", "three" }, nil, nil, { "three" }, { "one", "two" }); - test_load({ "one", "two", "three" }, nil, { "three" }, { "three" }, { "one", "two", "three" }); - - test_load({ "one", "two" }, nil, { "three" }, nil, { "one", "two", "three" }); - test_load({ "one", "two", "three" }, nil, { "three" }, nil, { "one", "two", "three" }); - test_load({ "one", "two", "three" }, { "three" }, { "three" }, nil, { "one", "two", "three" }); - test_load({ "one", "two" }, { "three" }, { "three" }, nil, { "one", "two", "three" }); -end diff --git a/tests/test_core_s2smanager.lua b/tests/test_core_s2smanager.lua index b49c7da6..d2dbf830 100644 --- a/tests/test_core_s2smanager.lua +++ b/tests/test_core_s2smanager.lua @@ -1,11 +1,14 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- +env = { + prosody = { events = require "util.events".new() }; +}; function compare_srv_priorities(csp) local r1 = { priority = 10, weight = 0 } @@ -13,7 +16,7 @@ function compare_srv_priorities(csp) local r3 = { priority = 1000, weight = 2 } local r4 = { priority = 1000, weight = 2 } local r5 = { priority = 1000, weight = 5 } - + assert_equal(csp(r1, r1), false); assert_equal(csp(r1, r2), true); assert_equal(csp(r1, r3), true); diff --git a/tests/test_core_stanza_router.lua b/tests/test_core_stanza_router.lua index 0a93694f..ca6b78fc 100644 --- a/tests/test_core_stanza_router.lua +++ b/tests/test_core_stanza_router.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -14,7 +14,7 @@ function core_process_stanza(core_process_stanza, u) local s2sin_session = { from_host = "remotehost", to_host = "localhost", type = "s2sin", hosts = { ["remotehost"] = { authed = true } } } local local_host_session = { host = "localhost", type = "local", s2sout = { ["remotehost"] = s2sout_session } } local local_user_session = { username = "user", host = "localhost", resource = "resource", full_jid = "user@localhost/resource", type = "c2s" } - + _G.prosody.hosts["localhost"] = local_host_session; _G.prosody.full_sessions["user@localhost/resource"] = local_user_session; _G.prosody.bare_sessions["user@localhost"] = { sessions = { resource = local_user_session } }; @@ -23,15 +23,15 @@ function core_process_stanza(core_process_stanza, u) local function test_message_full_jid() local env = testlib_new_env(); local msg = stanza.stanza("message", { to = "user@localhost/resource", type = "chat" }):tag("body"):text("Hello world"); - + local target_routed; - + function env.core_post_stanza(p_origin, p_stanza) assert_equal(p_origin, local_user_session, "origin of routed stanza is not correct"); assert_equal(p_stanza, msg, "routed stanza is not correct one: "..p_stanza:pretty_print()); target_routed = true; end - + env.hosts = hosts; env.prosody = { hosts = hosts }; setfenv(core_process_stanza, env); @@ -42,9 +42,9 @@ function core_process_stanza(core_process_stanza, u) local function test_message_bare_jid() local env = testlib_new_env(); local msg = stanza.stanza("message", { to = "user@localhost", type = "chat" }):tag("body"):text("Hello world"); - + local target_routed; - + function env.core_post_stanza(p_origin, p_stanza) assert_equal(p_origin, local_user_session, "origin of routed stanza is not correct"); assert_equal(p_stanza, msg, "routed stanza is not correct one: "..p_stanza:pretty_print()); @@ -60,9 +60,9 @@ function core_process_stanza(core_process_stanza, u) local function test_message_no_to() local env = testlib_new_env(); local msg = stanza.stanza("message", { type = "chat" }):tag("body"):text("Hello world"); - + local target_handled; - + function env.core_post_stanza(p_origin, p_stanza) assert_equal(p_origin, local_user_session, "origin of handled stanza is not correct"); assert_equal(p_stanza, msg, "handled stanza is not correct one: "..p_stanza:pretty_print()); @@ -78,9 +78,9 @@ function core_process_stanza(core_process_stanza, u) local function test_message_to_remote_bare() local env = testlib_new_env(); local msg = stanza.stanza("message", { to = "user@remotehost", type = "chat" }):tag("body"):text("Hello world"); - + local target_routed; - + function env.core_route_stanza(p_origin, p_stanza) assert_equal(p_origin, local_user_session, "origin of handled stanza is not correct"); assert_equal(p_stanza, msg, "handled stanza is not correct one: "..p_stanza:pretty_print()); @@ -88,7 +88,7 @@ function core_process_stanza(core_process_stanza, u) end function env.core_post_stanza(...) env.core_route_stanza(...); end - + env.hosts = hosts; setfenv(core_process_stanza, env); assert_equal(core_process_stanza(local_user_session, msg), nil, "core_process_stanza returned incorrect value"); @@ -98,9 +98,9 @@ function core_process_stanza(core_process_stanza, u) local function test_message_to_remote_server() local env = testlib_new_env(); local msg = stanza.stanza("message", { to = "remotehost", type = "chat" }):tag("body"):text("Hello world"); - + local target_routed; - + function env.core_route_stanza(p_origin, p_stanza) assert_equal(p_origin, local_user_session, "origin of handled stanza is not correct"); assert_equal(p_stanza, msg, "handled stanza is not correct one: "..p_stanza:pretty_print()); @@ -123,9 +123,9 @@ function core_process_stanza(core_process_stanza, u) local function test_iq_to_remote_server() local env = testlib_new_env(); local msg = stanza.stanza("iq", { to = "remotehost", type = "get", id = "id" }):tag("body"):text("Hello world"); - + local target_routed; - + function env.core_route_stanza(p_origin, p_stanza) assert_equal(p_origin, local_user_session, "origin of handled stanza is not correct"); assert_equal(p_stanza, msg, "handled stanza is not correct one: "..p_stanza:pretty_print()); @@ -145,9 +145,9 @@ function core_process_stanza(core_process_stanza, u) local function test_iq_error_to_local_user() local env = testlib_new_env(); local msg = stanza.stanza("iq", { to = "user@localhost/resource", from = "user@remotehost", type = "error", id = "id" }):tag("error", { type = 'cancel' }):tag("item-not-found", { xmlns='urn:ietf:params:xml:ns:xmpp-stanzas' }); - + local target_routed; - + function env.core_route_stanza(p_origin, p_stanza) assert_equal(p_origin, s2sin_session, "origin of handled stanza is not correct"); assert_equal(p_stanza, msg, "handled stanza is not correct one: "..p_stanza:pretty_print()); @@ -167,9 +167,9 @@ function core_process_stanza(core_process_stanza, u) local function test_iq_to_local_bare() local env = testlib_new_env(); local msg = stanza.stanza("iq", { to = "user@localhost", from = "user@localhost", type = "get", id = "id" }):tag("ping", { xmlns = "urn:xmpp:ping:0" }); - + local target_handled; - + function env.core_post_stanza(p_origin, p_stanza) assert_equal(p_origin, local_user_session, "origin of handled stanza is not correct"); assert_equal(p_stanza, msg, "handled stanza is not correct one: "..p_stanza:pretty_print()); @@ -209,11 +209,11 @@ function core_route_stanza(core_route_stanza) local msg2 = stanza.stanza("iq", { to = "user@localhost/foo", from = "user@localhost", type = "error" }):tag("ping", { xmlns = "urn:xmpp:ping:0" }); --package.loaded["core.usermanager"] = { user_exists = function (user, host) print("RAR!") return true or user == "user" and host == "localhost" and true; end }; local target_handled, target_replied; - + function env.core_post_stanza(p_origin, p_stanza) target_handled = true; end - + function local_user_session.send(data) --print("Replying with: ", tostring(data)); --print(debug.traceback()) diff --git a/tests/test_sasl.lua b/tests/test_sasl.lua index 8bd1a2ff..dd63c5a0 100644 --- a/tests/test_sasl.lua +++ b/tests/test_sasl.lua @@ -1,18 +1,11 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- - ---- WARNING! --- --- This file contains a mix of encodings below. --- Many editors will unquestioningly convert these for you. --- Please be careful :( (I recommend Scite) ---------------------------------- - local gmatch = string.gmatch; local t_concat, t_insert = table.concat, table.insert; local to_byte, to_char = string.byte, string.char; @@ -37,9 +30,9 @@ function latin1toutf8() local function assert_utf8(latin, utf8) assert_equal(_latin1toutf8(latin), utf8, "Incorrect UTF8 from Latin1: "..tostring(latin)); end - + assert_utf8("", "") assert_utf8("test", "test") assert_utf8(nil, nil) - assert_utf8("foobar.råkat.se", "foobar.rÃ¥kat.se") + assert_utf8("foobar.r\229kat.se", "foobar.r\195\165kat.se") end diff --git a/tests/test_util_http.lua b/tests/test_util_http.lua new file mode 100644 index 00000000..a195df6b --- /dev/null +++ b/tests/test_util_http.lua @@ -0,0 +1,37 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +function urlencode(urlencode) + assert_equal(urlencode("helloworld123"), "helloworld123", "Normal characters not escaped"); + assert_equal(urlencode("hello world"), "hello%20world", "Spaces escaped"); + assert_equal(urlencode("This & that = something"), "This%20%26%20that%20%3d%20something", "Important URL chars escaped"); +end + +function urldecode(urldecode) + assert_equal("helloworld123", urldecode("helloworld123"), "Normal characters not escaped"); + assert_equal("hello world", urldecode("hello%20world"), "Spaces escaped"); + assert_equal("This & that = something", urldecode("This%20%26%20that%20%3d%20something"), "Important URL chars escaped"); + assert_equal("This & that = something", urldecode("This%20%26%20that%20%3D%20something"), "Important URL chars escaped"); +end + +function formencode(formencode) + assert_equal(formencode({ { name = "one", value = "1"}, { name = "two", value = "2" } }), "one=1&two=2", "Form encoded"); + assert_equal(formencode({ { name = "one two", value = "1"}, { name = "two one&", value = "2" } }), "one+two=1&two+one%26=2", "Form encoded"); +end + +function formdecode(formdecode) + local t = formdecode("one=1&two=2"); + assert_table(t[1]); + assert_equal(t[1].name, "one"); assert_equal(t[1].value, "1"); + assert_table(t[2]); + assert_equal(t[2].name, "two"); assert_equal(t[2].value, "2"); + + local t = formdecode("one+two=1&two+one%26=2"); + assert_equal(t[1].name, "one two"); assert_equal(t[1].value, "1"); + assert_equal(t[2].name, "two one&"); assert_equal(t[2].value, "2"); +end diff --git a/tests/test_util_ip.lua b/tests/test_util_ip.lua new file mode 100644 index 00000000..0ded1123 --- /dev/null +++ b/tests/test_util_ip.lua @@ -0,0 +1,89 @@ + +function match(match, _M) + local _ = _M.new_ip; + local ip = _"10.20.30.40"; + assert_equal(match(ip, _"10.0.0.0", 8), true); + assert_equal(match(ip, _"10.0.0.0", 16), false); + assert_equal(match(ip, _"10.0.0.0", 24), false); + assert_equal(match(ip, _"10.0.0.0", 32), false); + + assert_equal(match(ip, _"10.20.0.0", 8), true); + assert_equal(match(ip, _"10.20.0.0", 16), true); + assert_equal(match(ip, _"10.20.0.0", 24), false); + assert_equal(match(ip, _"10.20.0.0", 32), false); + + assert_equal(match(ip, _"0.0.0.0", 32), false); + assert_equal(match(ip, _"0.0.0.0", 0), true); + assert_equal(match(ip, _"0.0.0.0"), false); + + assert_equal(match(ip, _"10.0.0.0", 255), false, "excessive number of bits"); + assert_equal(match(ip, _"10.0.0.0", -8), true, "negative number of bits"); + assert_equal(match(ip, _"10.0.0.0", -32), true, "negative number of bits"); + assert_equal(match(ip, _"10.0.0.0", 0), true, "zero bits"); + assert_equal(match(ip, _"10.0.0.0"), false, "no specified number of bits (differing ip)"); + assert_equal(match(ip, _"10.20.30.40"), true, "no specified number of bits (same ip)"); + + assert_equal(match(_"127.0.0.1", _"127.0.0.1"), true, "simple ip"); + + assert_equal(match(_"8.8.8.8", _"8.8.0.0", 16), true); + assert_equal(match(_"8.8.4.4", _"8.8.0.0", 16), true); +end + +function parse_cidr(parse_cidr, _M) + local new_ip = _M.new_ip; + + assert_equal(new_ip"0.0.0.0", new_ip"0.0.0.0") + + local function assert_cidr(cidr, ip, bits) + local parsed_ip, parsed_bits = parse_cidr(cidr); + assert_equal(new_ip(ip), parsed_ip, cidr.." parsed ip is "..ip); + assert_equal(bits, parsed_bits, cidr.." parsed bits is "..tostring(bits)); + end + assert_cidr("0.0.0.0", "0.0.0.0", nil); + assert_cidr("127.0.0.1", "127.0.0.1", nil); + assert_cidr("127.0.0.1/0", "127.0.0.1", 0); + assert_cidr("127.0.0.1/8", "127.0.0.1", 8); + assert_cidr("127.0.0.1/32", "127.0.0.1", 32); + assert_cidr("127.0.0.1/256", "127.0.0.1", 256); + assert_cidr("::/48", "::", 48); +end + +function new_ip(new_ip) + local v4, v6 = "IPv4", "IPv6"; + local function assert_proto(s, proto) + local ip = new_ip(s); + if proto then + assert_equal(ip and ip.proto, proto, "protocol is correct for "..("%q"):format(s)); + else + assert_equal(ip, nil, "address is invalid"); + end + end + assert_proto("127.0.0.1", v4); + assert_proto("::1", v6); + assert_proto("", nil); + assert_proto("abc", nil); + assert_proto(" ", nil); +end + +function commonPrefixLength(cpl, _M) + local new_ip = _M.new_ip; + local function assert_cpl6(a, b, len, v4) + local ipa, ipb = new_ip(a), new_ip(b); + if v4 then len = len+96; end + assert_equal(cpl(ipa, ipb), len, "common prefix length of "..a.." and "..b.." is "..len); + assert_equal(cpl(ipb, ipa), len, "common prefix length of "..b.." and "..a.." is "..len); + end + local function assert_cpl4(a, b, len) + return assert_cpl6(a, b, len, "IPv4"); + end + assert_cpl4("0.0.0.0", "0.0.0.0", 32); + assert_cpl4("255.255.255.255", "0.0.0.0", 0); + assert_cpl4("255.255.255.255", "255.255.0.0", 16); + assert_cpl4("255.255.255.255", "255.255.255.255", 32); + assert_cpl4("255.255.255.255", "255.255.255.255", 32); + + assert_cpl6("::1", "::1", 128); + assert_cpl6("abcd::1", "abcd::1", 128); + assert_cpl6("abcd::abcd", "abcd::", 112); + assert_cpl6("abcd::abcd", "abcd::abcd:abcd", 96); +end diff --git a/tests/test_util_jid.lua b/tests/test_util_jid.lua index a817e644..02a90c3b 100644 --- a/tests/test_util_jid.lua +++ b/tests/test_util_jid.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/tests/test_util_multitable.lua b/tests/test_util_multitable.lua index ed10b128..71a83450 100644 --- a/tests/test_util_multitable.lua +++ b/tests/test_util_multitable.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -41,7 +41,7 @@ function get(get, multitable) end mt = multitable.new(); - + local trigger1, trigger2, trigger3 = {}, {}, {}; local item1, item2, item3 = {}, {}, {}; @@ -51,12 +51,12 @@ function get(get, multitable) mt:add(1, 2, 3, item1); assert_has_all("Has item1 for 1, 2, 3", mt:get(1, 2, 3), item1); - + -- Doesn't support nil --[[ mt:add(nil, item1); mt:add(nil, item2); mt:add(nil, item3); - + assert_has_all("Has all items with (nil)", mt:get(nil), item1, item2, item3); ]] end diff --git a/tests/test_util_rfc6724.lua b/tests/test_util_rfc6724.lua new file mode 100644 index 00000000..bb73e921 --- /dev/null +++ b/tests/test_util_rfc6724.lua @@ -0,0 +1,97 @@ +-- Prosody IM +-- Copyright (C) 2011-2013 Florian Zeitz +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +function source(source) + local new_ip = require"util.ip".new_ip; + assert_equal(source(new_ip("2001:db8:1::1", "IPv6"), + {new_ip("2001:db8:3::1", "IPv6"), new_ip("fe80::1", "IPv6")}).addr, + "2001:db8:3::1", + "prefer appropriate scope"); + assert_equal(source(new_ip("ff05::1", "IPv6"), + {new_ip("2001:db8:3::1", "IPv6"), new_ip("fe80::1", "IPv6")}).addr, + "2001:db8:3::1", + "prefer appropriate scope"); + assert_equal(source(new_ip("2001:db8:1::1", "IPv6"), + {new_ip("2001:db8:1::1", "IPv6"), new_ip("2001:db8:2::1", "IPv6")}).addr, + "2001:db8:1::1", + "prefer same address"); -- "2001:db8:1::1" should be marked "deprecated" here, we don't handle that right now + assert_equal(source(new_ip("fe80::1", "IPv6"), + {new_ip("fe80::2", "IPv6"), new_ip("2001:db8:1::1", "IPv6")}).addr, + "fe80::2", + "prefer appropriate scope"); -- "fe80::2" should be marked "deprecated" here, we don't handle that right now + assert_equal(source(new_ip("2001:db8:1::1", "IPv6"), + {new_ip("2001:db8:1::2", "IPv6"), new_ip("2001:db8:3::2", "IPv6")}).addr, + "2001:db8:1::2", + "longest matching prefix"); +--[[ "2001:db8:1::2" should be a care-of address and "2001:db8:3::2" a home address, we can't handle this and would fail + assert_equal(source(new_ip("2001:db8:1::1", "IPv6"), + {new_ip("2001:db8:1::2", "IPv6"), new_ip("2001:db8:3::2", "IPv6")}).addr, + "2001:db8:3::2", + "prefer home address"); +]] + assert_equal(source(new_ip("2002:c633:6401::1", "IPv6"), + {new_ip("2002:c633:6401::d5e3:7953:13eb:22e8", "IPv6"), new_ip("2001:db8:1::2", "IPv6")}).addr, + "2002:c633:6401::d5e3:7953:13eb:22e8", + "prefer matching label"); -- "2002:c633:6401::d5e3:7953:13eb:22e8" should be marked "temporary" here, we don't handle that right now + assert_equal(source(new_ip("2001:db8:1::d5e3:0:0:1", "IPv6"), + {new_ip("2001:db8:1::2", "IPv6"), new_ip("2001:db8:1::d5e3:7953:13eb:22e8", "IPv6")}).addr, + "2001:db8:1::d5e3:7953:13eb:22e8", + "prefer temporary address") -- "2001:db8:1::2" should be marked "public" and "2001:db8:1::d5e3:7953:13eb:22e8" should be marked "temporary" here, we don't handle that right now +end + +function destination(dest) + local order; + local new_ip = require"util.ip".new_ip; + order = dest({new_ip("2001:db8:1::1", "IPv6"), new_ip("198.51.100.121", "IPv4")}, + {new_ip("2001:db8:1::2", "IPv6"), new_ip("fe80::1", "IPv6"), new_ip("169.254.13.78", "IPv4")}) + assert_equal(order[1].addr, "2001:db8:1::1", "prefer matching scope"); + assert_equal(order[2].addr, "198.51.100.121", "prefer matching scope"); + + order = dest({new_ip("2001:db8:1::1", "IPv6"), new_ip("198.51.100.121", "IPv4")}, + {new_ip("fe80::1", "IPv6"), new_ip("198.51.100.117", "IPv4")}) + assert_equal(order[1].addr, "198.51.100.121", "prefer matching scope"); + assert_equal(order[2].addr, "2001:db8:1::1", "prefer matching scope"); + + order = dest({new_ip("2001:db8:1::1", "IPv6"), new_ip("10.1.2.3", "IPv4")}, + {new_ip("2001:db8:1::2", "IPv6"), new_ip("fe80::1", "IPv6"), new_ip("10.1.2.4", "IPv4")}) + assert_equal(order[1].addr, "2001:db8:1::1", "prefer higher precedence"); + assert_equal(order[2].addr, "10.1.2.3", "prefer higher precedence"); + + order = dest({new_ip("2001:db8:1::1", "IPv6"), new_ip("fe80::1", "IPv6")}, + {new_ip("2001:db8:1::2", "IPv6"), new_ip("fe80::2", "IPv6")}) + assert_equal(order[1].addr, "fe80::1", "prefer smaller scope"); + assert_equal(order[2].addr, "2001:db8:1::1", "prefer smaller scope"); + +--[[ "2001:db8:1::2" and "fe80::2" should be marked "care-of address", while "2001:db8:3::1" should be marked "home address", we can't currently handle this and would fail the test + order = dest({new_ip("2001:db8:1::1", "IPv6"), new_ip("fe80::1", "IPv6")}, + {new_ip("2001:db8:1::2", "IPv6"), new_ip("2001:db8:3::1", "IPv6"), new_ip("fe80::2", "IPv6")}) + assert_equal(order[1].addr, "2001:db8:1::1", "prefer home address"); + assert_equal(order[2].addr, "fe80::1", "prefer home address"); +]] + +--[[ "fe80::2" should be marked "deprecated", we can't currently handle this and would fail the test + order = dest({new_ip("2001:db8:1::1", "IPv6"), new_ip("fe80::1", "IPv6")}, + {new_ip("2001:db8:1::2", "IPv6"), new_ip("fe80::2", "IPv6")}) + assert_equal(order[1].addr, "2001:db8:1::1", "avoid deprecated addresses"); + assert_equal(order[2].addr, "fe80::1", "avoid deprecated addresses"); +]] + + order = dest({new_ip("2001:db8:1::1", "IPv6"), new_ip("2001:db8:3ffe::1", "IPv6")}, + {new_ip("2001:db8:1::2", "IPv6"), new_ip("2001:db8:3f44::2", "IPv6"), new_ip("fe80::2", "IPv6")}) + assert_equal(order[1].addr, "2001:db8:1::1", "longest matching prefix"); + assert_equal(order[2].addr, "2001:db8:3ffe::1", "longest matching prefix"); + + order = dest({new_ip("2002:c633:6401::1", "IPv6"), new_ip("2001:db8:1::1", "IPv6")}, + {new_ip("2002:c633:6401::2", "IPv6"), new_ip("fe80::2", "IPv6")}) + assert_equal(order[1].addr, "2002:c633:6401::1", "prefer matching label"); + assert_equal(order[2].addr, "2001:db8:1::1", "prefer matching label"); + + order = dest({new_ip("2002:c633:6401::1", "IPv6"), new_ip("2001:db8:1::1", "IPv6")}, + {new_ip("2002:c633:6401::2", "IPv6"), new_ip("2001:db8:1::2", "IPv6"), new_ip("fe80::2", "IPv6")}) + assert_equal(order[1].addr, "2001:db8:1::1", "prefer higher precedence"); + assert_equal(order[2].addr, "2002:c633:6401::1", "prefer higher precedence"); +end diff --git a/tests/test_util_sasl_scram.lua b/tests/test_util_sasl_scram.lua index aeae8748..bc89829f 100644 --- a/tests/test_util_sasl_scram.lua +++ b/tests/test_util_sasl_scram.lua @@ -1,6 +1,6 @@ -local hmac_sha1 = require "util.hmac".sha1; +local hmac_sha1 = require "util.hashes".hmac_sha1; local function toHex(s) return s and (s:gsub(".", function (c) return ("%02x"):format(c:byte()); end)); end diff --git a/tests/test_util_stanza.lua b/tests/test_util_stanza.lua index fce26f3a..20cc8274 100644 --- a/tests/test_util_stanza.lua +++ b/tests/test_util_stanza.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -18,7 +18,7 @@ end function deserialize(deserialize, st) local stanza = st.stanza("message", { a = "a" }); - + local stanza2 = deserialize(st.preserialize(stanza)); assert_is(stanza2 and stanza.name, "deserialize returns a stanza"); assert_table(stanza2.attr, "Deserialized stanza has attributes"); diff --git a/tests/util/logger.lua b/tests/util/logger.lua index 35facd4e..c133e332 100644 --- a/tests/util/logger.lua +++ b/tests/util/logger.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/tools/ejabberd2prosody.lua b/tools/ejabberd2prosody.lua index a231abd8..0a6736d7 100755 --- a/tools/ejabberd2prosody.lua +++ b/tools/ejabberd2prosody.lua @@ -2,7 +2,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -11,30 +11,40 @@ package.path = package.path ..";../?.lua"; -if arg[0]:match("^./") then - package.path = package.path .. ";"..arg[0]:gsub("/ejabberd2prosody.lua$", "/?.lua"); +local my_name = arg[0]; +if my_name:match("[/\\]") then + package.path = package.path..";"..my_name:gsub("[^/\\]+$", "../?.lua"); + package.path = package.path..";"..my_name:gsub("[^/\\]+$", "?.lua"); + package.cpath = package.cpath..";"..my_name:gsub("[^/\\]+$", "../?.so"); end -require "erlparse"; +local erlparse = require "erlparse"; prosody = {}; +package.loaded["util.logger"] = {init = function() return function() end; end} local serialize = require "util.serialization".serialize; local st = require "util.stanza"; -package.loaded["util.logger"] = {init = function() return function() end; end} local dm = require "util.datamanager" dm.set_data_path("data"); function build_stanza(tuple, stanza) + assert(type(tuple) == "table", "XML node is of unexpected type: "..type(tuple)); if tuple[1] == "xmlelement" then + assert(type(tuple[2]) == "string", "element name has type: "..type(tuple[2])); + assert(type(tuple[3]) == "table", "element attribute array has type: "..type(tuple[3])); + assert(type(tuple[4]) == "table", "element children array has type: "..type(tuple[4])); local name = tuple[2]; local attr = {}; - for _, a in ipairs(tuple[3]) do attr[a[1]] = a[2]; end + for _, a in ipairs(tuple[3]) do + if type(a[1]) == "string" and type(a[2]) == "string" then attr[a[1]] = a[2]; end + end local up; if stanza then stanza:tag(name, attr); up = true; else stanza = st.stanza(name, attr); end for _, a in ipairs(tuple[4]) do build_stanza(a, stanza); end if up then stanza:up(); else return stanza end elseif tuple[1] == "xmlcdata" then + assert(type(tuple[2]) == "string", "XML CDATA has unexpected type: "..type(tuple[2])); stanza:text(tuple[2]); else error("unknown element type: "..serialize(tuple)); @@ -78,6 +88,70 @@ function offline_msg(node, host, t, stanza) local ret, err = dm.list_append(node, host, "offline", st.preserialize(stanza)); print("["..(err or "success").."] offline: " ..node.."@"..host.." - "..os.date("!%Y-%m-%dT%H:%M:%SZ", t)); end +function privacy(node, host, default, lists) + local privacy = { lists = {} }; + local count = 0; + if default then privacy.default = default; end + for _, inlist in ipairs(lists) do + local name, items = inlist[1], inlist[2]; + local list = { name = name; items = {}; }; + local orders = {}; + for _, item in pairs(items) do + repeat + if item[1] ~= "listitem" then print("[error] privacy: unhandled item: "..tostring(item[1])); break; end + local _type, value = item[2], item[3]; + if _type == "jid" then + if type(value) ~= "table" then print("[error] privacy: jid value is not valid: "..tostring(value)); break; end + local _node, _host, _resource = value[1], value[2], value[3]; + if (type(_node) == "table") then _node = nil; end + if (type(_host) == "table") then _host = nil; end + if (type(_resource) == "table") then _resource = nil; end + value = (_node and _node.."@".._host or _host)..(_resource and "/".._resource or ""); + elseif _type == "none" then + _type = nil; + value = nil; + elseif _type == "group" then + if type(value) ~= "string" then print("[error] privacy: group value is not string: "..tostring(value)); break; end + elseif _type == "subscription" then + if value~="both" and value~="from" and value~="to" and value~="none" then + print("[error] privacy: subscription value is invalid: "..tostring(value)); break; + end + else print("[error] privacy: invalid item type: "..tostring(_type)); break; end + local action = item[4]; + if action ~= "allow" and action ~= "deny" then print("[error] privacy: unhandled action: "..tostring(action)); break; end + local order = item[5]; + if type(order) ~= "number" or order<0 then print("[error] privacy: order is not numeric: "..tostring(order)); break; end + if orders[order] then print("[error] privacy: duplicate order value: "..tostring(order)); break; end + orders[order] = true; + local match_all = item[6]; + local match_iq = item[7]; + local match_message = item[8]; + local match_presence_in = item[9]; + local match_presence_out = item[10]; + list.items[#list.items+1] = { + type = _type; + value = value; + action = action; + order = order; + message = match_message == "true"; + iq = match_iq == "true"; + ["presence-in"] = match_presence_in == "true"; + ["presence-out"] = match_presence_out == "true"; + }; + until true; + end + table.sort(list.items, function(a, b) return a.order < b.order; end); + if privacy.lists[list.name] then print("[warn] duplicate privacy list: "..tostring(list.name)); end + privacy.lists[list.name] = list; + count = count + 1; + end + if default and not privacy.lists[default] then + if default == "none" then privacy.default = nil; + else print("[warn] default privacy list doesn't exist: "..tostring(default)); end + end + local ret, err = dm.store(node, host, "privacy", privacy); + print("["..(err or "success").."] privacy: " ..node.."@"..host.." - "..count.." list(s)"); +end local filters = { @@ -119,6 +193,9 @@ local filters = { offline_msg = function(tuple) offline_msg(tuple[2][1], tuple[2][2], build_time(tuple[3]), build_stanza(tuple[7])); end; + privacy = function(tuple) + privacy(tuple[2][1], tuple[2][2], tuple[3], tuple[4]); + end; config = function(tuple) if tuple[2] == "hosts" then local output = io.output(); io.output("prosody.cfg.lua"); @@ -155,10 +232,10 @@ local help = "/? -? ? /h -h /help -help --help"; if not arg or help:find(arg, 1, true) then print([[ejabberd db dump importer for Prosody - Usage: ejabberd2prosody.lua filename.txt + Usage: ]]..my_name..[[ filename.txt The file can be generated from ejabberd using: - sudo ./bin/ejabberdctl dump filename.txt + sudo ejabberdctl dump filename.txt Note: The path of ejabberdctl depends on your ejabberd installation, and ejabberd needs to be running for ejabberdctl to work.]]); os.exit(1); diff --git a/tools/ejabberdsql2prosody.lua b/tools/ejabberdsql2prosody.lua index 958cf0e2..930d5a6b 100644 --- a/tools/ejabberdsql2prosody.lua +++ b/tools/ejabberdsql2prosody.lua @@ -2,7 +2,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -10,8 +10,17 @@ prosody = {}; package.path = package.path ..";../?.lua"; + +local my_name = arg[0]; +if my_name:match("[/\\]") then + package.path = package.path..";"..my_name:gsub("[^/\\]+$", "../?.lua"); + package.cpath = package.cpath..";"..my_name:gsub("[^/\\]+$", "../?.so"); +end + + local serialize = require "util.serialization".serialize; local st = require "util.stanza"; +local parse_xml = require "util.xml".parse; package.loaded["util.logger"] = {init = function() return function() end; end} local dm = require "util.datamanager" dm.set_data_path("data"); @@ -21,12 +30,16 @@ function parseFile(filename) local file = nil; local last = nil; +local line = 1; local function read(expected) local ch; if last then ch = last; last = nil; - else ch = file:read(1); end - if expected and ch ~= expected then error("expected: "..expected.."; got: "..(ch or "nil")); end + else + ch = file:read(1); + if ch == "\n" then line = line + 1; end + end + if expected and ch ~= expected then error("expected: "..expected.."; got: "..(ch or "nil").." on line "..line); end return ch; end local function pushback(ch) @@ -125,7 +138,12 @@ local function readInsert() end end local tname = readTableName(); - for ch in ("` VALUES "):gmatch(".") do read(ch); end -- expect this + read("`"); read(" ") -- expect this + if peek() == "(" then -- skip column list + repeat until read() == ")"; + read(" "); + end + for ch in ("VALUES "):gmatch(".") do read(ch); end -- expect this local tuples = readTuples(); read(";"); read("\n"); return tname, tuples; @@ -158,58 +176,6 @@ return readFile(filename); ------ end --- XML parser -local parse_xml = (function() - local entity_map = setmetatable({ - ["amp"] = "&"; - ["gt"] = ">"; - ["lt"] = "<"; - ["apos"] = "'"; - ["quot"] = "\""; - }, {__index = function(_, s) - if s:sub(1,1) == "#" then - if s:sub(2,2) == "x" then - return string.char(tonumber(s:sub(3), 16)); - else - return string.char(tonumber(s:sub(2))); - end - end - end - }); - local function xml_unescape(str) - return (str:gsub("&(.-);", entity_map)); - end - local function parse_tag(s) - local name,sattr=(s):gmatch("([^%s]+)(.*)")(); - local attr = {}; - for a,b in (sattr):gmatch("([^=%s]+)=['\"]([^'\"]*)['\"]") do attr[a] = xml_unescape(b); end - return name, attr; - end - return function(xml) - local stanza = st.stanza("root"); - local regexp = "<([^>]*)>([^<]*)"; - for elem, text in xml:gmatch(regexp) do - if elem:sub(1,1) == "!" or elem:sub(1,1) == "?" then -- neglect comments and processing-instructions - elseif elem:sub(1,1) == "/" then -- end tag - elem = elem:sub(2); - stanza:up(); -- TODO check for start-end tag name match - elseif elem:sub(-1,-1) == "/" then -- empty tag - elem = elem:sub(1,-2); - local name,attr = parse_tag(elem); - stanza:tag(name, attr):up(); - else -- start tag - local name,attr = parse_tag(elem); - stanza:tag(name, attr); - end - if #text ~= 0 then -- text - stanza:text(xml_unescape(text)); - end - end - return stanza.tags[1]; - end -end)(); --- end of XML parser - local arg, host = ...; local help = "/? -? ? /h -h /help -help --help"; if not(arg and host) or help:find(arg, 1, true) then diff --git a/tools/erlparse.lua b/tools/erlparse.lua index dc3a2f94..25c38bcf 100644 --- a/tools/erlparse.lua +++ b/tools/erlparse.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -51,7 +51,7 @@ local function isSpace(ch) return ch <= _space; end -local escapes = {["\\b"]="\b", ["\\d"]="\d", ["\\e"]="\e", ["\\f"]="\f", ["\\n"]="\n", ["\\r"]="\r", ["\\s"]="\s", ["\\t"]="\t", ["\\v"]="\v", ["\\\""]="\"", ["\\'"]="'", ["\\\\"]="\\"}; +local escapes = {["\\b"]="\b", ["\\d"]="\127", ["\\e"]="\27", ["\\f"]="\f", ["\\n"]="\n", ["\\r"]="\r", ["\\s"]=" ", ["\\t"]="\t", ["\\v"]="\v", ["\\\""]="\"", ["\\'"]="'", ["\\\\"]="\\"}; local function readString() read("\""); -- skip quote local slash = nil; @@ -95,6 +95,12 @@ local function readNumber() while isNumeric(peek()) do num[#num+1] = read(); end + if peek() == "." then + num[#num+1] = read(); + while isNumeric(peek()) do + num[#num+1] = read(); + end + end return tonumber(t_concat(num)); end local readItem = nil; diff --git a/tools/jabberd14sql2prosody.lua b/tools/jabberd14sql2prosody.lua new file mode 100644 index 00000000..03376b30 --- /dev/null +++ b/tools/jabberd14sql2prosody.lua @@ -0,0 +1,654 @@ +#!/usr/bin/env lua + + +do + + +local _parse_sql_actions = { [0] = + 0, 1, 0, 1, 1, 2, 0, 2, 2, 0, 9, 2, 0, 10, 2, 0, 11, 2, 0, 13, + 2, 1, 2, 2, 1, 6, 3, 0, 3, 4, 3, 0, 3, 5, 3, 0, 3, 7, 3, 0, + 3, 8, 3, 0, 3, 12, 4, 0, 2, 3, 7, 4, 0, 3, 8, 11 +}; + +local _parse_sql_trans_keys = { [0] = + 0, 0, 45, 45, 10, 10, 42, 42, 10, 42, 10, 47, 82, 82, + 69, 69, 65, 65, 84, 84, 69, 69, 32, 32, 68, 84, 65, + 65, 84, 84, 65, 65, 66, 66, 65, 65, 83, 83, 69, 69, + 9, 47, 9, 96, 45, 45, 10, 10, 42, 42, 10, 42, 10, 47, + 10, 96, 10, 96, 9, 47, 9, 59, 45, 45, 10, 10, 42, + 42, 10, 42, 10, 47, 65, 65, 66, 66, 76, 76, 69, 69, + 32, 32, 73, 96, 70, 70, 32, 32, 78, 78, 79, 79, 84, 84, + 32, 32, 69, 69, 88, 88, 73, 73, 83, 83, 84, 84, 83, + 83, 32, 32, 96, 96, 10, 96, 10, 96, 32, 32, 40, 40, + 10, 10, 32, 41, 32, 32, 75, 96, 69, 69, 89, 89, 32, 32, + 96, 96, 10, 96, 10, 96, 10, 10, 82, 82, 73, 73, 77, + 77, 65, 65, 82, 82, 89, 89, 32, 32, 75, 75, 69, 69, + 89, 89, 32, 32, 78, 78, 73, 73, 81, 81, 85, 85, 69, 69, + 32, 32, 75, 75, 10, 96, 10, 96, 10, 10, 10, 59, 10, + 59, 82, 82, 79, 79, 80, 80, 32, 32, 84, 84, 65, 65, + 66, 66, 76, 76, 69, 69, 32, 32, 73, 73, 70, 70, 32, 32, + 69, 69, 88, 88, 73, 73, 83, 83, 84, 84, 83, 83, 32, + 32, 96, 96, 10, 96, 10, 96, 59, 59, 78, 78, 83, 83, + 69, 69, 82, 82, 84, 84, 32, 32, 73, 73, 78, 78, 84, 84, + 79, 79, 32, 32, 96, 96, 10, 96, 10, 96, 32, 32, 40, + 86, 10, 41, 32, 32, 86, 86, 65, 65, 76, 76, 85, 85, + 69, 69, 83, 83, 32, 32, 40, 40, 39, 78, 10, 92, 10, 92, + 41, 44, 44, 59, 32, 78, 48, 57, 41, 57, 48, 57, 41, + 57, 85, 85, 76, 76, 76, 76, 34, 116, 79, 79, 67, 67, + 75, 75, 32, 32, 84, 84, 65, 65, 66, 66, 76, 76, 69, 69, + 83, 83, 32, 32, 96, 96, 10, 96, 10, 96, 32, 32, 87, + 87, 82, 82, 73, 73, 84, 84, 69, 69, 69, 69, 84, 84, + 32, 32, 10, 59, 10, 59, 78, 83, 76, 76, 79, 79, 67, 67, + 75, 75, 32, 32, 84, 84, 65, 65, 66, 66, 76, 76, 69, + 69, 83, 83, 69, 69, 9, 85, 0 +}; + +local _parse_sql_key_spans = { [0] = + 0, 1, 1, 1, 33, 38, 1, 1, 1, 1, 1, 1, 17, 1, 1, 1, 1, 1, 1, 1, + 39, 88, 1, 1, 1, 33, 38, 87, 87, 39, 51, 1, 1, 1, 33, 38, 1, 1, 1, 1, + 1, 24, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 87, 87, 1, 1, + 1, 10, 1, 22, 1, 1, 1, 1, 87, 87, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 87, 87, 1, 50, 50, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 87, 87, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 87, 87, 1, 47, 32, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 40, 83, 83, 4, 16, 47, 10, 17, 10, 17, 1, 1, 1, 83, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 87, 87, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 50, 50, 6, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 77 +}; + +local _parse_sql_index_offsets = { [0] = + 0, 0, 2, 4, 6, 40, 79, 81, 83, 85, 87, 89, 91, 109, 111, 113, 115, 117, 119, 121, + 123, 163, 252, 254, 256, 258, 292, 331, 419, 507, 547, 599, 601, 603, 605, 639, 678, 680, 682, 684, + 686, 688, 713, 715, 717, 719, 721, 723, 725, 727, 729, 731, 733, 735, 737, 739, 741, 829, 917, 919, + 921, 923, 934, 936, 959, 961, 963, 965, 967, 1055, 1143, 1145, 1147, 1149, 1151, 1153, 1155, 1157, 1159, 1161, + 1163, 1165, 1167, 1169, 1171, 1173, 1175, 1177, 1179, 1181, 1269, 1357, 1359, 1410, 1461, 1463, 1465, 1467, 1469, 1471, + 1473, 1475, 1477, 1479, 1481, 1483, 1485, 1487, 1489, 1491, 1493, 1495, 1497, 1499, 1501, 1503, 1591, 1679, 1681, 1683, + 1685, 1687, 1689, 1691, 1693, 1695, 1697, 1699, 1701, 1703, 1705, 1793, 1881, 1883, 1931, 1964, 1966, 1968, 1970, 1972, + 1974, 1976, 1978, 1980, 1982, 2023, 2107, 2191, 2196, 2213, 2261, 2272, 2290, 2301, 2319, 2321, 2323, 2325, 2409, 2411, + 2413, 2415, 2417, 2419, 2421, 2423, 2425, 2427, 2429, 2431, 2433, 2521, 2609, 2611, 2613, 2615, 2617, 2619, 2621, 2623, + 2625, 2627, 2678, 2729, 2736, 2738, 2740, 2742, 2744, 2746, 2748, 2750, 2752, 2754, 2756, 2758, 2760 +}; + +local _parse_sql_indicies = { [0] = + 0, 1, 2, 0, 3, 1, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 5, 3, + 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 5, 3, 3, 3, 3, 6, 3, 7, + 1, 8, 1, 9, 1, 10, 1, 11, 1, 12, 1, 13, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 14, 1, 15, 1, 16, 1, 17, 1, 18, 1, 19, 1, 20, + 1, 21, 1, 22, 23, 22, 22, 22, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 22, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 24, + 1, 25, 1, 22, 23, 22, 22, 22, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 22, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 24, + 1, 25, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 26, 1, 27, 1, 23, 27, 28, 1, 29, 28, + 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, + 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 30, 28, 29, 28, 28, 28, 28, 28, 28, 28, + 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, + 28, 28, 28, 28, 30, 28, 28, 28, 28, 22, 28, 32, 31, 31, 31, 31, 31, 31, 31, 31, + 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, + 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, + 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, + 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 1, 31, 32, + 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, + 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, + 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, + 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, + 31, 31, 31, 31, 31, 33, 31, 34, 35, 34, 34, 34, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 34, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 36, 1, 37, 1, 34, 35, 34, 34, 34, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 34, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 36, 1, 37, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 6, 1, 38, + 1, 35, 38, 39, 1, 40, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, + 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 41, 39, 40, + 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, + 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 41, 39, 39, 39, 39, 34, 39, 42, 1, + 43, 1, 44, 1, 45, 1, 46, 1, 47, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 48, 1, 49, 1, 50, 1, 51, 1, 52, + 1, 53, 1, 54, 1, 55, 1, 56, 1, 57, 1, 58, 1, 59, 1, 60, 1, 61, 1, 48, + 1, 63, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, + 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, + 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, + 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, 62, + 62, 62, 62, 62, 62, 62, 62, 1, 62, 65, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 66, 64, 67, 1, 68, + 1, 69, 1, 70, 1, 1, 1, 1, 1, 1, 1, 1, 71, 1, 72, 1, 73, 1, 1, 1, + 1, 74, 1, 1, 1, 1, 75, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 76, 1, 77, + 1, 78, 1, 79, 1, 80, 1, 82, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, + 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, + 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, + 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, + 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 1, 81, 82, 81, 81, 81, 81, + 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, + 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, + 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, + 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, + 81, 83, 81, 69, 83, 84, 1, 85, 1, 86, 1, 87, 1, 88, 1, 89, 1, 90, 1, 91, + 1, 92, 1, 93, 1, 83, 1, 94, 1, 95, 1, 96, 1, 97, 1, 98, 1, 99, 1, 73, + 1, 101, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 1, 100, 103, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, + 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, + 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, + 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, + 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 102, 104, 102, 105, 83, 106, + 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, + 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, + 71, 71, 71, 71, 71, 71, 71, 71, 107, 71, 108, 71, 71, 71, 71, 71, 71, 71, 71, 71, + 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, + 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 71, 107, + 71, 109, 1, 110, 1, 111, 1, 112, 1, 113, 1, 114, 1, 115, 1, 116, 1, 117, 1, 118, + 1, 119, 1, 120, 1, 121, 1, 122, 1, 123, 1, 124, 1, 125, 1, 126, 1, 127, 1, 128, + 1, 129, 1, 131, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, + 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, + 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, + 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, + 130, 130, 130, 130, 130, 130, 130, 130, 130, 1, 130, 131, 130, 130, 130, 130, 130, 130, 130, 130, + 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, + 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, + 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, + 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 130, 132, 130, 6, + 1, 133, 1, 134, 1, 135, 1, 136, 1, 137, 1, 138, 1, 139, 1, 140, 1, 141, 1, 142, + 1, 143, 1, 144, 1, 146, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, + 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, + 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, + 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, + 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 145, 1, 145, 148, 147, 147, 147, 147, 147, 147, + 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, + 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, + 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, + 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 149, + 147, 150, 1, 151, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 152, 1, 153, 151, 151, 151, 151, 151, 151, 151, 151, + 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, 151, + 151, 151, 154, 151, 155, 1, 152, 1, 156, 1, 157, 1, 158, 1, 159, 1, 160, 1, 161, 1, + 162, 1, 163, 1, 1, 1, 1, 1, 164, 1, 1, 165, 165, 165, 165, 165, 165, 165, 165, 165, + 165, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 166, 1, 168, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, + 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 169, 167, 167, 167, 167, 167, 167, 167, + 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, + 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, 167, + 167, 167, 167, 167, 167, 170, 167, 172, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, + 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 173, 171, 171, 171, + 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, + 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, 171, + 171, 171, 171, 171, 171, 171, 171, 171, 171, 174, 171, 175, 1, 1, 176, 1, 161, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 177, 1, 178, 1, 1, 1, 1, 1, 1, + 163, 1, 1, 1, 1, 1, 164, 1, 1, 165, 165, 165, 165, 165, 165, 165, 165, 165, 165, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 166, + 1, 179, 179, 179, 179, 179, 179, 179, 179, 179, 179, 1, 180, 1, 1, 181, 1, 182, 1, 179, + 179, 179, 179, 179, 179, 179, 179, 179, 179, 1, 183, 183, 183, 183, 183, 183, 183, 183, 183, 183, + 1, 180, 1, 1, 181, 1, 1, 1, 183, 183, 183, 183, 183, 183, 183, 183, 183, 183, 1, 184, + 1, 185, 1, 186, 1, 171, 1, 1, 171, 1, 171, 1, 1, 1, 1, 1, 1, 1, 1, 171, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 171, 1, 171, 1, 1, 171, 1, 1, 171, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 171, 1, 1, 1, 171, 1, 171, 1, 187, 1, 188, 1, 189, 1, 190, 1, 191, 1, 192, + 1, 193, 1, 194, 1, 195, 1, 196, 1, 197, 1, 198, 1, 200, 199, 199, 199, 199, 199, 199, + 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, + 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, + 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, + 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 1, + 199, 200, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, + 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, + 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, + 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, 199, + 199, 199, 199, 199, 199, 199, 199, 201, 199, 202, 1, 203, 1, 204, 1, 205, 1, 206, 1, 132, + 1, 207, 1, 208, 1, 209, 1, 210, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 211, 209, 2, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, + 209, 209, 209, 209, 209, 209, 209, 211, 209, 212, 1, 1, 1, 1, 213, 1, 214, 1, 215, 1, + 216, 1, 217, 1, 218, 1, 219, 1, 220, 1, 221, 1, 222, 1, 223, 1, 132, 1, 127, 1, + 6, 2, 6, 6, 6, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 6, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 224, 1, 225, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 6, 1, 1, 1, 1, 1, 1, 1, 226, 227, + 1, 1, 1, 1, 228, 1, 1, 229, 1, 1, 1, 1, 1, 1, 230, 1, 231, 1, 0 +}; + +local _parse_sql_trans_targs = { [0] = + 2, 0, 196, 4, 4, 5, 196, 7, 8, 9, 10, 11, 12, 13, 36, 14, 15, 16, 17, 18, + 19, 20, 21, 21, 22, 24, 27, 23, 25, 25, 26, 28, 28, 29, 30, 30, 31, 33, 32, 34, + 34, 35, 37, 38, 39, 40, 41, 42, 56, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, + 54, 55, 57, 57, 57, 57, 58, 59, 60, 61, 62, 92, 63, 64, 71, 82, 89, 65, 66, 67, + 68, 69, 69, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 83, 84, 85, 86, 87, 88, + 90, 90, 90, 90, 91, 70, 92, 93, 196, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 116, 117, 119, 120, 121, 122, 123, 124, 125, + 126, 127, 128, 129, 130, 131, 131, 131, 131, 132, 133, 134, 137, 134, 135, 136, 138, 139, 140, 141, + 142, 143, 144, 145, 150, 151, 154, 146, 146, 147, 157, 146, 146, 147, 157, 148, 149, 196, 144, 151, + 148, 149, 152, 153, 155, 156, 147, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, + 171, 172, 173, 174, 175, 176, 177, 179, 180, 181, 181, 182, 184, 195, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 1, 3, 6, 94, 118, 158, 178, 183 +}; + +local _parse_sql_trans_actions = { [0] = + 1, 0, 3, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 3, 1, 1, 1, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 1, 1, + 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 5, 20, 1, 3, 30, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 5, 20, 1, 3, 26, 3, 3, 1, 23, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 5, 20, 1, 3, 42, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, + 1, 1, 11, 1, 5, 5, 1, 5, 20, 46, 5, 1, 3, 34, 1, 14, 1, 17, 1, 1, + 51, 38, 1, 1, 1, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 +}; + +local parse_sql_start = 196; +local parse_sql_first_final = 196; +local parse_sql_error = 0; + +local parse_sql_en_main = 196; + + + + +local _sql_unescapes = setmetatable({ + ["\\0"] = "\0"; + ["\\'"] = "'"; + ["\\\""] = "\""; + ["\\b"] = "\b"; + ["\\n"] = "\n"; + ["\\r"] = "\r"; + ["\\t"] = "\t"; + ["\\Z"] = "\26"; + ["\\\\"] = "\\"; + ["\\%"] = "%"; + ["\\_"] = "_"; +},{ __index = function(t, s) assert(false, "Unknown escape sequences: "..s); end }); + +function parse_sql(data, h) + local p = 1; + local pe = #data + 1; + local cs; + + local pos_char, pos_line = 1, 1; + + local mark, token; + local table_name, columns, value_lists, value_list, value_count; + + + cs = parse_sql_start; + +-- ragel flat exec + + local testEof = false; + local _slen = 0; + local _trans = 0; + local _keys = 0; + local _inds = 0; + local _acts = 0; + local _nacts = 0; + local _tempval = 0; + local _goto_level = 0; + local _resume = 10; + local _eof_trans = 15; + local _again = 20; + local _test_eof = 30; + local _out = 40; + + while true do -- goto loop + local _continue = false; + repeat + local _trigger_goto = false; + if _goto_level <= 0 then + +-- noEnd + if p == pe then + _goto_level = _test_eof; + _continue = true; break; + end + + +-- errState != 0 + if cs == 0 then + _goto_level = _out; + _continue = true; break; + end + end -- _goto_level <= 0 + + if _goto_level <= _resume then + _keys = cs * 2; -- LOCATE_TRANS + _inds = _parse_sql_index_offsets[cs]; + _slen = _parse_sql_key_spans[cs]; + + if _slen > 0 and + _parse_sql_trans_keys[_keys] <= data:byte(p) and + data:byte(p) <= _parse_sql_trans_keys[_keys + 1] then + _trans = _parse_sql_indicies[ _inds + data:byte(p) - _parse_sql_trans_keys[_keys] ]; + else _trans =_parse_sql_indicies[ _inds + _slen ]; end + + cs = _parse_sql_trans_targs[_trans]; + + if _parse_sql_trans_actions[_trans] ~= 0 then + _acts = _parse_sql_trans_actions[_trans]; + _nacts = _parse_sql_actions[_acts]; + _acts = _acts + 1; + + while _nacts > 0 do + _nacts = _nacts - 1; + _acts = _acts + 1; + _tempval = _parse_sql_actions[_acts - 1]; + + -- start action switch + if _tempval == 0 then --4 FROM_STATE_ACTION_SWITCH +-- line 34 "sql.rl" -- end of line directive + pos_char = pos_char + 1; -- ACTION + elseif _tempval == 1 then --4 FROM_STATE_ACTION_SWITCH +-- line 35 "sql.rl" -- end of line directive + pos_line = pos_line + 1; pos_char = 1; -- ACTION + elseif _tempval == 2 then --4 FROM_STATE_ACTION_SWITCH +-- line 38 "sql.rl" -- end of line directive + mark = p; -- ACTION + elseif _tempval == 3 then --4 FROM_STATE_ACTION_SWITCH +-- line 39 "sql.rl" -- end of line directive + token = data:sub(mark, p-1); -- ACTION + elseif _tempval == 4 then --4 FROM_STATE_ACTION_SWITCH +-- line 52 "sql.rl" -- end of line directive + table.insert(columns, token); columns[#columns] = token; -- ACTION + elseif _tempval == 5 then --4 FROM_STATE_ACTION_SWITCH +-- line 58 "sql.rl" -- end of line directive + table_name,columns = token,{}; -- ACTION + elseif _tempval == 6 then --4 FROM_STATE_ACTION_SWITCH +-- line 59 "sql.rl" -- end of line directive + h.create(table_name, columns); -- ACTION + elseif _tempval == 7 then --4 FROM_STATE_ACTION_SWITCH +-- line 65 "sql.rl" -- end of line directive + + value_count = value_count + 1; value_list[value_count] = token:gsub("\\.", _sql_unescapes); + -- ACTION + elseif _tempval == 8 then --4 FROM_STATE_ACTION_SWITCH +-- line 68 "sql.rl" -- end of line directive + value_count = value_count + 1; value_list[value_count] = tonumber(token); -- ACTION + elseif _tempval == 9 then --4 FROM_STATE_ACTION_SWITCH +-- line 69 "sql.rl" -- end of line directive + value_count = value_count + 1; -- ACTION + elseif _tempval == 10 then --4 FROM_STATE_ACTION_SWITCH +-- line 71 "sql.rl" -- end of line directive + value_list,value_count = {},0; -- ACTION + elseif _tempval == 11 then --4 FROM_STATE_ACTION_SWITCH +-- line 71 "sql.rl" -- end of line directive + table.insert(value_lists, value_list); -- ACTION + elseif _tempval == 12 then --4 FROM_STATE_ACTION_SWITCH +-- line 74 "sql.rl" -- end of line directive + table_name,value_lists = token,{}; -- ACTION + elseif _tempval == 13 then --4 FROM_STATE_ACTION_SWITCH +-- line 75 "sql.rl" -- end of line directive + h.insert(table_name, value_lists); -- ACTION + end +-- line 355 "sql.lua" -- end of line directive + -- end action switch + end -- while _nacts + end + + if _trigger_goto then _continue = true; break; end + end -- endif + + if _goto_level <= _again then + if cs == 0 then + _goto_level = _out; + _continue = true; break; + end + p = p + 1; + if p ~= pe then + _goto_level = _resume; + _continue = true; break; + end + end -- _goto_level <= _again + + if _goto_level <= _test_eof then + end -- _goto_level <= _test_eof + + if _goto_level <= _out then break; end + _continue = true; + until true; + if not _continue then break; end + end -- endif _goto_level <= out + + -- end of execute block + + + if cs < parse_sql_first_final then + print("parse_sql: there was an error, line "..pos_line.." column "..pos_char); + else + print("Success. EOF at line "..pos_line.." column "..pos_char) + end +end + +end + +-- import modules +package.path = package.path.."..\?.lua;"; + +local my_name = arg[0]; +if my_name:match("[/\\]") then + package.path = package.path..";"..my_name:gsub("[^/\\]+$", "../?.lua"); + package.cpath = package.cpath..";"..my_name:gsub("[^/\\]+$", "../?.so"); +end + + +-- ugly workaround for getting datamanager to work outside of prosody :( +prosody = { }; +prosody.platform = "unknown"; +if os.getenv("WINDIR") then + prosody.platform = "windows"; +elseif package.config:sub(1,1) == "/" then + prosody.platform = "_posix"; +end +package.loaded["util.logger"] = {init = function() return function() end; end} + +local dm = require "util.datamanager"; +dm.set_data_path("data"); + +local datetime = require "util.datetime"; + +local st = require "util.stanza"; +local parse_xml = require "util.xml".parse; + +function store_password(username, host, password) + -- create or update account for username@host + local ret, err = dm.store(username, host, "accounts", {password = password}); + print("["..(err or "success").."] stored account: "..username.."@"..host.." = "..password); +end + +function store_vcard(username, host, stanza) + -- create or update vCard for username@host + local ret, err = dm.store(username, host, "vcard", st.preserialize(stanza)); + print("["..(err or "success").."] stored vCard: "..username.."@"..host); +end + +function store_roster(username, host, roster_items) + -- fetch current roster-table for username@host if he already has one + local roster = dm.load(username, host, "roster") or {}; + -- merge imported roster-items with loaded roster + for item_tag in roster_items:childtags() do + -- jid for this roster-item + local item_jid = item_tag.attr.jid + -- validate item stanzas + if (item_tag.name == "item") and (item_jid ~= "") then + -- prepare roster item + -- TODO: is the subscription attribute optional? + local item = {subscription = item_tag.attr.subscription, groups = {}}; + -- optional: give roster item a real name + if item_tag.attr.name then + item.name = item_tag.attr.name; + end + -- optional: iterate over group stanzas inside item stanza + for group_tag in item_tag:childtags() do + local group_name = group_tag:get_text(); + if (group_tag.name == "group") and (group_name ~= "") then + item.groups[group_name] = true; + else + print("[error] invalid group stanza: "..group_tag:pretty_print()); + end + end + -- store item in roster + roster[item_jid] = item; + print("[success] roster entry: " ..username.."@"..host.." - "..item_jid); + else + print("[error] invalid roster stanza: " ..item_tag:pretty_print()); + end + + end + -- store merged roster-table + local ret, err = dm.store(username, host, "roster", roster); + print("["..(err or "success").."] stored roster: " ..username.."@"..host); +end + +function store_subscription_request(username, host, presence_stanza) + local from_bare = presence_stanza.attr.from; + + -- fetch current roster-table for username@host if he already has one + local roster = dm.load(username, host, "roster") or {}; + + local item = roster[from_bare]; + if item and (item.subscription == "from" or item.subscription == "both") then + return; -- already subscribed, do nothing + end + + -- add to table of pending subscriptions + if not roster.pending then roster.pending = {}; end + roster.pending[from_bare] = true; + + -- store updated roster-table + local ret, err = dm.store(username, host, "roster", roster); + print("["..(err or "success").."] stored subscription request: " ..username.."@"..host.." - "..from_bare); +end + +local os_date = os.date; +local os_time = os.time; +local os_difftime = os.difftime; +function datetime_parse(s) + if s then + local year, month, day, hour, min, sec, tzd; + year, month, day, hour, min, sec, tzd = s:match("^(%d%d%d%d)%-?(%d%d)%-?(%d%d)T(%d%d):(%d%d):(%d%d)%.?%d*([Z+%-]?.*)$"); + if year then + local time_offset = os_difftime(os_time(os_date("*t")), os_time(os_date("!*t"))); -- to deal with local timezone + local tzd_offset = 0; + if tzd ~= "" and tzd ~= "Z" then + local sign, h, m = tzd:match("([+%-])(%d%d):?(%d*)"); + if not sign then return; end + if #m ~= 2 then m = "0"; end + h, m = tonumber(h), tonumber(m); + tzd_offset = h * 60 * 60 + m * 60; + if sign == "-" then tzd_offset = -tzd_offset; end + end + sec = (sec + time_offset) - tzd_offset; + return os_time({year=year, month=month, day=day, hour=hour, min=min, sec=sec, isdst=false}); + end + end +end + +function store_offline_messages(username, host, stanza) + -- TODO: maybe use list_load(), append and list_store() instead + -- of constantly reopening the file with list_append()? + --for ch in offline_messages:childtags() do + --print("message :"..ch:pretty_print()); + stanza.attr.node = nil; + + local stamp = stanza:get_child("x", "jabber:x:delay"); + if not stamp or not stamp.attr.stamp then print(2) return; end + + for i=1,#stanza do if stanza[i] == stamp then table.remove(stanza, i); break; end end + for i=1,#stanza.tags do if stanza.tags[i] == stamp then table.remove(stanza.tags, i); break; end end + + local parsed_stamp = datetime_parse(stamp.attr.stamp); + if not parsed_stamp then print(1, stamp.attr.stamp) return; end + + stanza.attr.stamp, stanza.attr.stamp_legacy = datetime.datetime(parsed_stamp), datetime.legacy(parsed_stamp); + local ret, err = dm.list_append(username, host, "offline", st.preserialize(stanza)); + print("["..(err or "success").."] stored offline message: " ..username.."@"..host.." - "..stanza.attr.from); + --end +end + +-- load data +local arg = ...; +local help = "/? -? ? /h -h /help -help --help"; +if not arg or help:find(arg, 1, true) then + print([[XEP-227 importer for Prosody + + Usage: jabberd14sql2prosody.lua filename.sql +]]); + os.exit(1); +end +local f = io.open(arg); +local s = f:read("*a"); +f:close(); + +local table_count = 0; +local insert_count = 0; +local row_count = 0; +-- parse +parse_sql(s, { + create = function(table_name, columns) + --[[print(table_name);]] + table_count = table_count + 1; + end; + insert = function(table_name, value_lists) + --[[print(table_name, #value_lists);]] + insert_count = insert_count + 1; + row_count = row_count + #value_lists; + + for _,value_list in ipairs(value_lists) do + if table_name == "users" then + local user, realm, password = unpack(value_list); + store_password(user, realm, password); + elseif table_name == "roster" then + local user, realm, xml = unpack(value_list); + local stanza,err = parse_xml(xml); + if stanza then + store_roster(user, realm, stanza); + else + print("[error] roster: XML parsing failed for "..user.."@"..realm..": "..err); + end + elseif table_name == "vcard" then + local user, realm, name, email, nickname, birthday, photo, xml = unpack(value_list); + if xml then + local stanza,err = parse_xml(xml); + if stanza then + store_vcard(user, realm, stanza); + else + print("[error] vcard: XML parsing failed for "..user.."@"..realm..": "..err); + end + else + --print("[warn] vcard: NULL vCard for "..user.."@"..realm..": "..err); + end + elseif table_name == "storedsubscriptionrequests" then + local user, realm, fromjid, xml = unpack(value_list); + local stanza,err = parse_xml(xml); + if stanza then + store_subscription_request(user, realm, stanza); + else + print("[error] storedsubscriptionrequests: XML parsing failed for "..user.."@"..realm..": "..err); + end + elseif table_name == "messages" then + --local user, realm, node, correspondent, type, storetime, delivertime, subject, body, xml = unpack(value_list); + local user, realm, type, xml = value_list[1], value_list[2], value_list[5], value_list[10]; + if type == "offline" and xml ~= "" then + local stanza,err = parse_xml(xml); + if stanza then + store_offline_messages(user, realm, stanza); + else + print("[error] offline messages: XML parsing failed for "..user.."@"..realm..": "..err); + print(unpack(value_list)); + end + end + end + end + end; +}); + +print("table_count", table_count); +print("insert_count", insert_count); +print("row_count", row_count); + diff --git a/tools/migration/Makefile b/tools/migration/Makefile new file mode 100644 index 00000000..ae402bd2 --- /dev/null +++ b/tools/migration/Makefile @@ -0,0 +1,39 @@ + +include ../../config.unix + +BIN = $(DESTDIR)$(PREFIX)/bin +CONFIG = $(DESTDIR)$(SYSCONFDIR) +SOURCE = $(DESTDIR)$(PREFIX)/lib/prosody +DATA = $(DESTDIR)$(DATADIR) +MAN = $(DESTDIR)$(PREFIX)/share/man + +INSTALLEDSOURCE = $(PREFIX)/lib/prosody +INSTALLEDCONFIG = $(SYSCONFDIR) +INSTALLEDMODULES = $(PREFIX)/lib/prosody/modules +INSTALLEDDATA = $(DATADIR) + +SOURCE_FILES = migrator/*.lua + +all: prosody-migrator.install migrator.cfg.lua.install prosody-migrator.lua $(SOURCE_FILES) + +install: prosody-migrator.install migrator.cfg.lua.install + install -d $(BIN) $(CONFIG) $(SOURCE) $(SOURCE)/migrator + install -d $(MAN)/man1 + install -d $(SOURCE)/migrator + install -m755 ./prosody-migrator.install $(BIN)/prosody-migrator + install -m644 $(SOURCE_FILES) $(SOURCE)/migrator + test -e $(CONFIG)/migrator.cfg.lua || install -m644 migrator.cfg.lua.install $(CONFIG)/migrator.cfg.lua + +clean: + rm -f prosody-migrator.install + rm -f migrator.cfg.lua.install + +prosody-migrator.install: prosody-migrator.lua + sed "1s/\blua\b/$(RUNWITH)/; \ + s|^CFG_SOURCEDIR=.*;$$|CFG_SOURCEDIR='$(INSTALLEDSOURCE)';|; \ + s|^CFG_CONFIGDIR=.*;$$|CFG_CONFIGDIR='$(INSTALLEDCONFIG)';|;" \ + < prosody-migrator.lua > prosody-migrator.install + +migrator.cfg.lua.install: migrator.cfg.lua + sed "s|^local data_path = .*;$$|local data_path = '$(INSTALLEDDATA)';|;" \ + < migrator.cfg.lua > migrator.cfg.lua.install diff --git a/tools/migration/migrator.cfg.lua b/tools/migration/migrator.cfg.lua new file mode 100644 index 00000000..fa37f2a3 --- /dev/null +++ b/tools/migration/migrator.cfg.lua @@ -0,0 +1,26 @@ +local data_path = "../../data"; + +input { + type = "prosody_files"; + path = data_path; +} + +output { + type = "prosody_sql"; + driver = "SQLite3"; + database = data_path.."/prosody.sqlite"; +} + +--[[ + +input { + type = "prosody_files"; + path = data_path; +} +output { + type = "prosody_sql"; + driver = "SQLite3"; + database = data_path.."/prosody.sqlite"; +} + +]] diff --git a/tools/migration/migrator/jabberd14.lua b/tools/migration/migrator/jabberd14.lua new file mode 100644 index 00000000..2f0b0b78 --- /dev/null +++ b/tools/migration/migrator/jabberd14.lua @@ -0,0 +1,142 @@ + +local lfs = require "lfs"; +local st = require "util.stanza"; +local parse_xml = require "util.xml".parse; +local os_getenv = os.getenv; +local io_open = io.open; +local assert = assert; +local ipairs = ipairs; +local coroutine = coroutine; +local print = print; + +module "jabberd14" + +local function is_dir(path) return lfs.attributes(path, "mode") == "directory"; end +local function is_file(path) return lfs.attributes(path, "mode") == "file"; end +local function clean_path(path) + return path:gsub("\\", "/"):gsub("//+", "/"):gsub("^~", os_getenv("HOME") or "~"); +end + +local function load_xml(path) + local f, err = io_open(path); + if not f then return f, err; end + local data = f:read("*a"); + f:close(); + if not data then return; end + return parse_xml(data); +end + +local function load_spool_file(host, filename, path) + local xml = load_xml(path); + if not xml then return; end + + local register_element = xml:get_child("query", "jabber:iq:register"); + local username_element = register_element and register_element:get_child("username", "jabber:iq:register"); + local password_element = register_element and register_element:get_child("password", "jabber:iq:auth"); + local username = username_element and username_element:get_text(); + local password = password_element and password_element:get_text(); + if not username then + print("[warn] Missing /xdb/{jabber:iq:register}register/username> in file "..filename) + return; + elseif username..".xml" ~= filename then + print("[warn] Missing /xdb/{jabber:iq:register}register/username does not match filename "..filename); + return; + end + + local userdata = { + user = username; + host = host; + stores = {}; + }; + local stores = userdata.stores; + stores.accounts = { password = password }; + + for i=1,#xml.tags do + local tag = xml.tags[i]; + local xname = (tag.attr.xmlns or "")..":"..tag.name; + if tag.attr.j_private_flag == "1" and tag.attr.xmlns then + -- Private XML + stores.private = stores.private or {}; + tag.attr.j_private_flag = nil; + stores.private[tag.attr.xmlns] = st.preserialize(tag); + elseif xname == "jabber:iq:auth:password" then + if stores.accounts.password ~= tag:get_text() then + if password then + print("[warn] conflicting passwords") + else + stores.accounts.password = tag:get_text(); + end + end + elseif xname == "jabber:iq:register:query" then + -- already processed + elseif xname == "jabber:xdb:nslist:foo" then + -- ignore + elseif xname == "jabber:iq:auth:0k:zerok" then + -- ignore + elseif xname == "jabber:iq:roster:query" then + -- Roster + local roster = {}; + local subscription_types = { from = true, to = true, both = true, none = true }; + for _,item_element in ipairs(tag.tags) do + assert(item_element.name == "item"); + assert(item_element.attr.jid); + assert(subscription_types[item_element.attr.subscription]); + assert((item_element.attr.ask or "subscribe") == "subscribe") + if item_element.name == "item" then + local groups = {}; + for _,group_element in ipairs(item_element.tags) do + assert(group_element.name == "group"); + groups[group_element:get_text()] = true; + end + local item = { + name = item_element.attr.name; + subscription = item_element.attr.subscription; + ask = item_element.attr.ask; + groups = groups; + }; + roster[item_element.attr.jid] = item; + end + end + stores.roster = roster; + elseif xname == "jabber:iq:last:query" then + -- Last activity + elseif xname == "jabber:x:offline:foo" then + -- Offline messages + elseif xname == "vcard-temp:vCard" then + -- vCards + stores.vcard = st.preserialize(tag); + else + print("[warn] Unknown tag: "..xname); + end + end + return userdata; +end + +local function loop_over_users(path, host, cb) + for file in lfs.dir(path) do + if file:match("%.xml$") then + local user = load_spool_file(host, file, path.."/"..file); + if user then cb(user); end + end + end +end +local function loop_over_hosts(path, cb) + for host in lfs.dir(path) do + if host ~= "." and host ~= ".." and is_dir(path.."/"..host) then + loop_over_users(path.."/"..host, host, cb); + end + end +end + +function reader(input) + local path = clean_path(assert(input.path, "no input.path specified")); + assert(is_dir(path), "input.path is not a directory"); + + if input.host then + return coroutine.wrap(function() loop_over_users(input.path, input.host, coroutine.yield) end); + else + return coroutine.wrap(function() loop_over_hosts(input.path, coroutine.yield) end); + end +end + +return _M; diff --git a/tools/migration/migrator/mtools.lua b/tools/migration/migrator/mtools.lua new file mode 100644 index 00000000..e7b774bb --- /dev/null +++ b/tools/migration/migrator/mtools.lua @@ -0,0 +1,56 @@ + + +local print = print; +local t_insert = table.insert; +local t_sort = table.sort; + +module "mtools" + +function sorted(params) + + local reader = params.reader; -- iterator to get items from + local sorter = params.sorter; -- sorting function + local filter = params.filter; -- filter function + + local cache = {}; + for item in reader do + if filter then item = filter(item); end + if item then t_insert(cache, item); end + end + if sorter then + t_sort(cache, sorter); + end + local i = 0; + return function() + i = i + 1; + return cache[i]; + end; + +end + +function merged(reader, merger) + + local item1 = reader(); + local merged = { item1 }; + return function() + while true do + if not item1 then return nil; end + local item2 = reader(); + if not item2 then item1 = nil; return merged; end + if merger(item1, item2) then + --print("merged") + item1 = item2; + t_insert(merged, item1); + else + --print("unmerged", merged) + item1 = item2; + local tmp = merged; + merged = { item1 }; + return tmp; + end + end + end; + +end + +return _M; diff --git a/tools/migration/migrator/prosody_files.lua b/tools/migration/migrator/prosody_files.lua new file mode 100644 index 00000000..4462fb3e --- /dev/null +++ b/tools/migration/migrator/prosody_files.lua @@ -0,0 +1,138 @@ + +local print = print; +local assert = assert; +local setmetatable = setmetatable; +local tonumber = tonumber; +local char = string.char; +local coroutine = coroutine; +local lfs = require "lfs"; +local loadfile = loadfile; +local pcall = pcall; +local mtools = require "migrator.mtools"; +local next = next; +local pairs = pairs; +local json = require "util.json"; +local os_getenv = os.getenv; + +prosody = {}; +local dm = require "util.datamanager" + +module "prosody_files" + +local function is_dir(path) return lfs.attributes(path, "mode") == "directory"; end +local function is_file(path) return lfs.attributes(path, "mode") == "file"; end +local function clean_path(path) + return path:gsub("\\", "/"):gsub("//+", "/"):gsub("^~", os_getenv("HOME") or "~"); +end +local encode, decode; do + local urlcodes = setmetatable({}, { __index = function (t, k) t[k] = char(tonumber("0x"..k)); return t[k]; end }); + decode = function (s) return s and (s:gsub("+", " "):gsub("%%([a-fA-F0-9][a-fA-F0-9])", urlcodes)); end + encode = function (s) return s and (s:gsub("%W", function (c) return format("%%%02x", c:byte()); end)); end +end +local function decode_dir(x) + if x:gsub("%%%x%x", ""):gsub("[a-zA-Z0-9]", "") == "" then + return decode(x); + end +end +local function decode_file(x) + if x:match(".%.dat$") and x:gsub("%.dat$", ""):gsub("%%%x%x", ""):gsub("[a-zA-Z0-9]", "") == "" then + return decode(x:gsub("%.dat$", "")); + end +end +local function prosody_dir(path, ondir, onfile, ...) + for x in lfs.dir(path) do + local xpath = path.."/"..x; + if decode_dir(x) and is_dir(xpath) then + ondir(xpath, x, ...); + elseif decode_file(x) and is_file(xpath) then + onfile(xpath, x, ...); + end + end +end + +local function handle_root_file(path, name) + --print("root file: ", decode_file(name)) + coroutine.yield { user = nil, host = nil, store = decode_file(name) }; +end +local function handle_host_file(path, name, host) + --print("host file: ", decode_dir(host).."/"..decode_file(name)) + coroutine.yield { user = nil, host = decode_dir(host), store = decode_file(name) }; +end +local function handle_store_file(path, name, store, host) + --print("store file: ", decode_file(name).."@"..decode_dir(host).."/"..decode_dir(store)) + coroutine.yield { user = decode_file(name), host = decode_dir(host), store = decode_dir(store) }; +end +local function handle_host_store(path, name, host) + prosody_dir(path, function() end, handle_store_file, name, host); +end +local function handle_host_dir(path, name) + prosody_dir(path, handle_host_store, handle_host_file, name); +end +local function handle_root_dir(path) + prosody_dir(path, handle_host_dir, handle_root_file); +end + +local function decode_user(item) + local userdata = { + user = item[1].user; + host = item[1].host; + stores = {}; + }; + for i=1,#item do -- loop over stores + local result = {}; + local store = item[i]; + userdata.stores[store.store] = store.data; + store.user = nil; store.host = nil; store.store = nil; + end + return userdata; +end + +function reader(input) + local path = clean_path(assert(input.path, "no input.path specified")); + assert(is_dir(path), "input.path is not a directory"); + local iter = coroutine.wrap(function()handle_root_dir(path);end); + -- get per-user stores, sorted + local iter = mtools.sorted { + reader = function() + local x = iter(); + if x then + dm.set_data_path(path); + local err; + x.data, err = dm.load(x.user, x.host, x.store); + if x.data == nil and err then + error(("Error loading data at path %s for %s@%s (%s store)") + :format(path, x.user or "<nil>", x.host or "<nil>", x.store or "<nil>"), 0); + end + return x; + end + end; + sorter = function(a, b) + local a_host, a_user, a_store = a.host or "", a.user or "", a.store or ""; + local b_host, b_user, b_store = b.host or "", b.user or "", b.store or ""; + return a_host > b_host or (a_host==b_host and a_user > b_user) or (a_host==b_host and a_user==b_user and a_store > b_store); + end; + }; + -- merge stores to get users + iter = mtools.merged(iter, function(a, b) + return (a.host == b.host and a.user == b.user); + end); + + return function() + local x = iter(); + return x and decode_user(x); + end +end + +function writer(output) + local path = clean_path(assert(output.path, "no output.path specified")); + assert(is_dir(path), "output.path is not a directory"); + return function(item) + if not item then return; end -- end of input + dm.set_data_path(path); + for store, data in pairs(item.stores) do + assert(dm.store(item.user, item.host, store, data)); + end + end +end + +return _M; diff --git a/tools/migration/migrator/prosody_sql.lua b/tools/migration/migrator/prosody_sql.lua new file mode 100644 index 00000000..180ae910 --- /dev/null +++ b/tools/migration/migrator/prosody_sql.lua @@ -0,0 +1,200 @@ + +local assert = assert; +local have_DBI, DBI = pcall(require,"DBI"); +local print = print; +local type = type; +local next = next; +local pairs = pairs; +local t_sort = table.sort; +local json = require "util.json"; +local mtools = require "migrator.mtools"; +local tostring = tostring; +local tonumber = tonumber; + +if not have_DBI then + error("LuaDBI (required for SQL support) was not found, please see http://prosody.im/doc/depends#luadbi", 0); +end + +module "prosody_sql" + +local function create_table(connection, params) + local create_sql = "CREATE TABLE `prosody` (`host` TEXT, `user` TEXT, `store` TEXT, `key` TEXT, `type` TEXT, `value` TEXT);"; + if params.driver == "PostgreSQL" then + create_sql = create_sql:gsub("`", "\""); + elseif params.driver == "MySQL" then + create_sql = create_sql:gsub("`value` TEXT", "`value` MEDIUMTEXT"); + end + + local stmt = connection:prepare(create_sql); + if stmt then + local ok = stmt:execute(); + local commit_ok = connection:commit(); + if ok and commit_ok then + local index_sql = "CREATE INDEX `prosody_index` ON `prosody` (`host`, `user`, `store`, `key`)"; + if params.driver == "PostgreSQL" then + index_sql = index_sql:gsub("`", "\""); + elseif params.driver == "MySQL" then + index_sql = index_sql:gsub("`([,)])", "`(20)%1"); + end + local stmt, err = connection:prepare(index_sql); + local ok, commit_ok, commit_err; + if stmt then + ok, err = assert(stmt:execute()); + commit_ok, commit_err = assert(connection:commit()); + end + elseif params.driver == "MySQL" then -- COMPAT: Upgrade tables from 0.8.0 + -- Failed to create, but check existing MySQL table here + local stmt = connection:prepare("SHOW COLUMNS FROM prosody WHERE Field='value' and Type='text'"); + local ok = stmt:execute(); + local commit_ok = connection:commit(); + if ok and commit_ok then + if stmt:rowcount() > 0 then + local stmt = connection:prepare("ALTER TABLE prosody MODIFY COLUMN `value` MEDIUMTEXT"); + local ok = stmt:execute(); + local commit_ok = connection:commit(); + if ok and commit_ok then + print("Database table automatically upgraded"); + end + end + repeat until not stmt:fetch(); + end + end + end +end + +local function serialize(value) + local t = type(value); + if t == "string" or t == "boolean" or t == "number" then + return t, tostring(value); + elseif t == "table" then + local value,err = json.encode(value); + if value then return "json", value; end + return nil, err; + end + return nil, "Unhandled value type: "..t; +end +local function deserialize(t, value) + if t == "string" then return value; + elseif t == "boolean" then + if value == "true" then return true; + elseif value == "false" then return false; end + elseif t == "number" then return tonumber(value); + elseif t == "json" then + return json.decode(value); + end +end + +local function decode_user(item) + local userdata = { + user = item[1][1].user; + host = item[1][1].host; + stores = {}; + }; + for i=1,#item do -- loop over stores + local result = {}; + local store = item[i]; + for i=1,#store do -- loop over store data + local row = store[i]; + local k = row.key; + local v = deserialize(row.type, row.value); + if k and v then + if k ~= "" then result[k] = v; elseif type(v) == "table" then + for a,b in pairs(v) do + result[a] = b; + end + end + end + userdata.stores[store[1].store] = result; + end + end + return userdata; +end + +function reader(input) + local dbh = assert(DBI.Connect( + assert(input.driver, "no input.driver specified"), + assert(input.database, "no input.database specified"), + input.username, input.password, + input.host, input.port + )); + assert(dbh:ping()); + local stmt = assert(dbh:prepare("SELECT * FROM prosody")); + assert(stmt:execute()); + local keys = {"host", "user", "store", "key", "type", "value"}; + local f,s,val = stmt:rows(true); + -- get SQL rows, sorted + local iter = mtools.sorted { + reader = function() val = f(s, val); return val; end; + filter = function(x) + for i=1,#keys do + if not x[keys[i]] then return false; end -- TODO log error, missing field + end + if x.host == "" then x.host = nil; end + if x.user == "" then x.user = nil; end + if x.store == "" then x.store = nil; end + return x; + end; + sorter = function(a, b) + local a_host, a_user, a_store = a.host or "", a.user or "", a.store or ""; + local b_host, b_user, b_store = b.host or "", b.user or "", b.store or ""; + return a_host > b_host or (a_host==b_host and a_user > b_user) or (a_host==b_host and a_user==b_user and a_store > b_store); + end; + }; + -- merge rows to get stores + iter = mtools.merged(iter, function(a, b) + return (a.host == b.host and a.user == b.user and a.store == b.store); + end); + -- merge stores to get users + iter = mtools.merged(iter, function(a, b) + return (a[1].host == b[1].host and a[1].user == b[1].user); + end); + return function() + local x = iter(); + return x and decode_user(x); + end; +end + +function writer(output, iter) + local dbh = assert(DBI.Connect( + assert(output.driver, "no output.driver specified"), + assert(output.database, "no output.database specified"), + output.username, output.password, + output.host, output.port + )); + assert(dbh:ping()); + create_table(dbh, output); + local stmt = assert(dbh:prepare("SELECT * FROM prosody")); + assert(stmt:execute()); + local stmt = assert(dbh:prepare("DELETE FROM prosody")); + assert(stmt:execute()); + local insert_sql = "INSERT INTO `prosody` (`host`,`user`,`store`,`key`,`type`,`value`) VALUES (?,?,?,?,?,?)"; + if output.driver == "PostgreSQL" then + insert_sql = insert_sql:gsub("`", "\""); + end + local insert = assert(dbh:prepare(insert_sql)); + + return function(item) + if not item then assert(dbh:commit()) return dbh:close(); end -- end of input + local host = item.host or ""; + local user = item.user or ""; + for store, data in pairs(item.stores) do + -- TODO transactions + local extradata = {}; + for key, value in pairs(data) do + if type(key) == "string" and key ~= "" then + local t, value = assert(serialize(value)); + local ok, err = assert(insert:execute(host, user, store, key, t, value)); + else + extradata[key] = value; + end + end + if next(extradata) ~= nil then + local t, extradata = assert(serialize(extradata)); + local ok, err = assert(insert:execute(host, user, store, "", t, extradata)); + end + end + end; +end + + +return _M; diff --git a/tools/migration/prosody-migrator.lua b/tools/migration/prosody-migrator.lua new file mode 100644 index 00000000..b86e9892 --- /dev/null +++ b/tools/migration/prosody-migrator.lua @@ -0,0 +1,132 @@ +#!/usr/bin/env lua + +CFG_SOURCEDIR=os.getenv("PROSODY_SRCDIR"); +CFG_CONFIGDIR=os.getenv("PROSODY_CFGDIR"); + +-- Substitute ~ with path to home directory in paths +if CFG_CONFIGDIR then + CFG_CONFIGDIR = CFG_CONFIGDIR:gsub("^~", os.getenv("HOME")); +end + +if CFG_SOURCEDIR then + CFG_SOURCEDIR = CFG_SOURCEDIR:gsub("^~", os.getenv("HOME")); +end + +local default_config = (CFG_CONFIGDIR or ".").."/migrator.cfg.lua"; + +-- Command-line parsing +local options = {}; +local handled_opts = 0; +for i = 1, #arg do + if arg[i]:sub(1,2) == "--" then + local opt, val = arg[i]:match("([%w-]+)=?(.*)"); + if opt then + options[(opt:sub(3):gsub("%-", "_"))] = #val > 0 and val or true; + end + handled_opts = i; + else + break; + end +end +table.remove(arg, handled_opts); + +if CFG_SOURCEDIR then + package.path = CFG_SOURCEDIR.."/?.lua;"..package.path; + package.cpath = CFG_SOURCEDIR.."/?.so;"..package.cpath; +else + package.path = "../../?.lua;"..package.path + package.cpath = "../../?.so;"..package.cpath +end + +local envloadfile = require "util.envload".envloadfile; + +-- Load config file +local function loadfilein(file, env) + if loadin then + return loadin(env, io.open(file):read("*a")); + else + return envloadfile(file, env); + end +end + +local config_file = options.config or default_config; +local from_store = arg[1] or "input"; +local to_store = arg[2] or "output"; + +config = {}; +local config_env = setmetatable({}, { __index = function(t, k) return function(tbl) config[k] = tbl; end; end }); +local config_chunk, err = loadfilein(config_file, config_env); +if not config_chunk then + print("There was an error loading the config file, check the file exists"); + print("and that the syntax is correct:"); + print("", err); + os.exit(1); +end + +config_chunk(); + +local have_err; +if #arg > 0 and #arg ~= 2 then + have_err = true; + print("Error: Incorrect number of parameters supplied."); +end +if not config[from_store] then + have_err = true; + print("Error: Input store '"..from_store.."' not found in the config file."); +end +if not config[to_store] then + have_err = true; + print("Error: Output store '"..to_store.."' not found in the config file."); +end + +function load_store_handler(name) + local store_type = config[name].type; + if not store_type then + print("Error: "..name.." store type not specified in the config file"); + return false; + else + local ok, err = pcall(require, "migrator."..store_type); + if not ok then + if package.loaded["migrator."..store_type] then + print(("Error: Failed to initialize '%s' store:\n\t%s") + :format(name, err)); + else + print(("Error: Unrecognised store type for '%s': %s") + :format(from_store, store_type)); + end + return false; + end + end + return true; +end + +have_err = have_err or not(load_store_handler(from_store, "input") and load_store_handler(to_store, "output")); + +if have_err then + print(""); + print("Usage: "..arg[0].." FROM_STORE TO_STORE"); + print("If no stores are specified, 'input' and 'output' are used."); + print(""); + print("The available stores in your migrator config are:"); + print(""); + for store in pairs(config) do + print("", store); + end + print(""); + os.exit(1); +end + +local itype = config[from_store].type; +local otype = config[to_store].type; +local reader = require("migrator."..itype).reader(config[from_store]); +local writer = require("migrator."..otype).writer(config[to_store]); + +local json = require "util.json"; + +io.stderr:write("Migrating...\n"); +for x in reader do + --print(json.encode(x)) + writer(x); +end +writer(nil); -- close +io.stderr:write("Done!\n"); diff --git a/tools/openfire2prosody.lua b/tools/openfire2prosody.lua new file mode 100644 index 00000000..cd3e62e5 --- /dev/null +++ b/tools/openfire2prosody.lua @@ -0,0 +1,109 @@ +#!/usr/bin/env lua +-- Prosody IM +-- Copyright (C) 2008-2009 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +package.path = package.path..";../?.lua"; +package.cpath = package.cpath..";../?.so"; -- needed for util.pposix used in datamanager + +local my_name = arg[0]; +if my_name:match("[/\\]") then + package.path = package.path..";"..my_name:gsub("[^/\\]+$", "../?.lua"); + package.cpath = package.cpath..";"..my_name:gsub("[^/\\]+$", "../?.so"); +end + +-- ugly workaround for getting datamanager to work outside of prosody :( +prosody = { }; +prosody.platform = "unknown"; +if os.getenv("WINDIR") then + prosody.platform = "windows"; +elseif package.config:sub(1,1) == "/" then + prosody.platform = "posix"; +end + +local parse_xml = require "util.xml".parse; + +----------------------------------------------------------------------- + +package.loaded["util.logger"] = {init = function() return function() end; end} +local dm = require "util.datamanager" +dm.set_data_path("data"); + +local arg = ...; +local help = "/? -? ? /h -h /help -help --help"; +if not arg or help:find(arg, 1, true) then + print([[Openfire importer for Prosody + + Usage: openfire2prosody.lua filename.xml hostname + +]]); + os.exit(1); +end + +local host = select(2, ...) or "localhost"; + +local file = assert(io.open(arg)); +local data = assert(file:read("*a")); +file:close(); + +local xml = assert(parse_xml(data)); + +assert(xml.name == "Openfire", "The input file is not an Openfire XML export"); + +local substatus_mapping = { ["0"] = "none", ["1"] = "to", ["2"] = "from", ["3"] = "both" }; + +for _,tag in ipairs(xml.tags) do + if tag.name == "User" then + local username, password, roster; + + for _,tag in ipairs(tag.tags) do + if tag.name == "Username" then + username = tag:get_text(); + elseif tag.name == "Password" then + password = tag:get_text(); + elseif tag.name == "Roster" then + roster = {}; + local pending = {}; + for _,tag in ipairs(tag.tags) do + if tag.name == "Item" then + local jid = assert(tag.attr.jid, "Roster item has no JID"); + if tag.attr.substatus ~= "-1" then + local item = {}; + item.name = tag.attr.name; + item.subscription = assert(substatus_mapping[tag.attr.substatus], "invalid substatus"); + item.ask = tag.attr.askstatus == "0" and "subscribe" or nil; + + local groups = {}; + for _,tag in ipairs(tag) do + if tag.name == "Group" then + groups[tag:get_text()] = true; + end + end + item.groups = groups; + roster[jid] = item; + end + if tag.attr.recvstatus == "1" then pending[jid] = true; end + end + end + + if next(pending) then + roster[false] = { pending = pending }; + end + end + end + + assert(username and password, "No username or password"); + + local ret, err = dm.store(username, host, "accounts", {password = password}); + print("["..(err or "success").."] stored account: "..username.."@"..host.." = "..password); + + if roster then + local ret, err = dm.store(username, host, "roster", roster); + print("["..(err or "success").."] stored roster: "..username.."@"..host.." = "..password); + end + end +end + diff --git a/tools/xep227toprosody.lua b/tools/xep227toprosody.lua index 23e5948b..e95a9d59 100755 --- a/tools/xep227toprosody.lua +++ b/tools/xep227toprosody.lua @@ -3,7 +3,7 @@ -- Copyright (C) 2008-2009 Matthew Wild -- Copyright (C) 2008-2009 Waqas Hussain -- Copyright (C) 2010 Stefan Gehn --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -25,6 +25,12 @@ package.path = package.path..";../?.lua"; package.cpath = package.cpath..";../?.so"; -- needed for util.pposix used in datamanager +local my_name = arg[0]; +if my_name:match("[/\\]") then + package.path = package.path..";"..my_name:gsub("[^/\\]+$", "../?.lua"); + package.cpath = package.cpath..";"..my_name:gsub("[^/\\]+$", "../?.so"); +end + -- ugly workaround for getting datamanager to work outside of prosody :( prosody = { }; prosody.platform = "unknown"; @@ -64,11 +70,11 @@ function store_roster(username, host, roster_items) -- fetch current roster-table for username@host if he already has one local roster = dm.load(username, host, "roster") or {}; -- merge imported roster-items with loaded roster - for item_tag in roster_items:childtags() do + for item_tag in roster_items:childtags("item") do -- jid for this roster-item local item_jid = item_tag.attr.jid -- validate item stanzas - if (item_tag.name == "item") and (item_jid ~= "") then + if (item_jid ~= "") then -- prepare roster item -- TODO: is the subscription attribute optional? local item = {subscription = item_tag.attr.subscription, groups = {}}; @@ -77,9 +83,9 @@ function store_roster(username, host, roster_items) item.name = item_tag.attr.name; end -- optional: iterate over group stanzas inside item stanza - for group_tag in item_tag:childtags() do + for group_tag in item_tag:childtags("group") do local group_name = group_tag:get_text(); - if (group_tag.name == "group") and (group_name ~= "") then + if (group_name ~= "") then item.groups[group_name] = true; else print("[error] invalid group stanza: "..group_tag:pretty_print()); @@ -100,7 +106,7 @@ end function store_private(username, host, private_items) local private = dm.load(username, host, "private") or {}; - for ch in private_items:childtags() do + for _, ch in ipairs(private_items.tags) do --print("private :"..ch:pretty_print()); private[ch.name..":"..ch.attr.xmlns] = st.preserialize(ch); print("[success] private item: " ..username.."@"..host.." - "..ch.name); @@ -112,7 +118,7 @@ end function store_offline_messages(username, host, offline_messages) -- TODO: maybe use list_load(), append and list_store() instead -- of constantly reopening the file with list_append()? - for ch in offline_messages:childtags() do + for ch in offline_messages:childtags("message", "jabber:client") do --print("message :"..ch:pretty_print()); local ret, err = dm.list_append(username, host, "offline", st.preserialize(ch)); print("["..(err or "success").."] stored offline message: " ..username.."@"..host.." - "..ch.attr.from); diff --git a/util-src/Makefile b/util-src/Makefile index 1ca934ad..3a1ba3f2 100644 --- a/util-src/Makefile +++ b/util-src/Makefile @@ -9,9 +9,21 @@ OPENSSL_LIB?=crypto CC?=gcc CXX?=g++ LD?=gcc +CFLAGS+=-ggdb +.PHONY: all install clean .SUFFIXES: .c .o .so +all: encodings.so hashes.so net.so pposix.so signal.so + +install: encodings.so hashes.so net.so pposix.so signal.so + install *.so ../util/ + +clean: + rm -f *.o + rm -f *.so + rm -f ../util/*.so + encodings.so: encodings.o MACOSX_DEPLOYMENT_TARGET="10.3"; export MACOSX_DEPLOYMENT_TARGET; $(CC) -o $@ $< $(LDFLAGS) $(IDNA_LIBS) @@ -27,12 +39,3 @@ hashes.so: hashes.o MACOSX_DEPLOYMENT_TARGET="10.3"; export MACOSX_DEPLOYMENT_TARGET; $(LD) -o $@ $< $(LDFLAGS) -all: encodings.so hashes.so pposix.so signal.so - -install: encodings.so hashes.so pposix.so signal.so - install *.so ../util/ - -clean: - rm -f *.o - rm -f *.so - rm -f ../util/*.so diff --git a/util-src/encodings.c b/util-src/encodings.c index 2a4653fb..b9b6160a 100644 --- a/util-src/encodings.c +++ b/util-src/encodings.c @@ -117,55 +117,8 @@ static const luaL_Reg Reg_base64[] = }; /***************** STRINGPREP *****************/ -#ifndef USE_STRINGPREP_ICU -/****************** libidn ********************/ - -#include <stringprep.h> - -static int stringprep_prep(lua_State *L, const Stringprep_profile *profile) -{ - size_t len; - const char *s; - char string[1024]; - int ret; - if(!lua_isstring(L, 1)) { - lua_pushnil(L); - return 1; - } - s = lua_tolstring(L, 1, &len); - if (len >= 1024) { - lua_pushnil(L); - return 1; /* TODO return error message */ - } - strcpy(string, s); - ret = stringprep(string, 1024, (Stringprep_profile_flags)0, profile); - if (ret == STRINGPREP_OK) { - lua_pushstring(L, string); - return 1; - } else { - lua_pushnil(L); - return 1; /* TODO return error message */ - } -} - -#define MAKE_PREP_FUNC(myFunc, prep) \ -static int myFunc(lua_State *L) { return stringprep_prep(L, prep); } - -MAKE_PREP_FUNC(Lstringprep_nameprep, stringprep_nameprep) /** stringprep.nameprep(s) */ -MAKE_PREP_FUNC(Lstringprep_nodeprep, stringprep_xmpp_nodeprep) /** stringprep.nodeprep(s) */ -MAKE_PREP_FUNC(Lstringprep_resourceprep, stringprep_xmpp_resourceprep) /** stringprep.resourceprep(s) */ -MAKE_PREP_FUNC(Lstringprep_saslprep, stringprep_saslprep) /** stringprep.saslprep(s) */ - -static const luaL_Reg Reg_stringprep[] = -{ - { "nameprep", Lstringprep_nameprep }, - { "nodeprep", Lstringprep_nodeprep }, - { "resourceprep", Lstringprep_resourceprep }, - { "saslprep", Lstringprep_saslprep }, - { NULL, NULL } -}; +#ifdef USE_STRINGPREP_ICU -#else #include <unicode/usprep.h> #include <unicode/ustring.h> #include <unicode/utrace.h> @@ -192,13 +145,17 @@ static int icu_stringprep_prep(lua_State *L, const UStringPrepProfile *profile) return 1; } u_strFromUTF8(unprepped, 1024, &unprepped_len, input, input_len, &err); + if (U_FAILURE(err)) { + lua_pushnil(L); + return 1; + } prepped_len = usprep_prepare(profile, unprepped, unprepped_len, prepped, 1024, 0, NULL, &err); if (U_FAILURE(err)) { lua_pushnil(L); return 1; } else { u_strToUTF8(output, 1024, &output_len, prepped, prepped_len, &err); - if(output_len < 1024) + if (U_SUCCESS(err) && output_len < 1024) lua_pushlstring(L, output, output_len); else lua_pushnil(L); @@ -239,49 +196,58 @@ static const luaL_Reg Reg_stringprep[] = { "saslprep", Lstringprep_saslprep }, { NULL, NULL } }; -#endif +#else /* USE_STRINGPREP_ICU */ -/***************** IDNA *****************/ -#ifndef USE_STRINGPREP_ICU /****************** libidn ********************/ -#include <idna.h> -#include <idn-free.h> +#include <stringprep.h> -static int Lidna_to_ascii(lua_State *L) /** idna.to_ascii(s) */ +static int stringprep_prep(lua_State *L, const Stringprep_profile *profile) { size_t len; - const char *s = luaL_checklstring(L, 1, &len); - char* output = NULL; - int ret = idna_to_ascii_8z(s, &output, IDNA_USE_STD3_ASCII_RULES); - if (ret == IDNA_SUCCESS) { - lua_pushstring(L, output); - idn_free(output); + const char *s; + char string[1024]; + int ret; + if(!lua_isstring(L, 1)) { + lua_pushnil(L); return 1; - } else { + } + s = lua_tolstring(L, 1, &len); + if (len >= 1024) { lua_pushnil(L); - idn_free(output); return 1; /* TODO return error message */ } -} - -static int Lidna_to_unicode(lua_State *L) /** idna.to_unicode(s) */ -{ - size_t len; - const char *s = luaL_checklstring(L, 1, &len); - char* output = NULL; - int ret = idna_to_unicode_8z8z(s, &output, 0); - if (ret == IDNA_SUCCESS) { - lua_pushstring(L, output); - idn_free(output); + strcpy(string, s); + ret = stringprep(string, 1024, (Stringprep_profile_flags)0, profile); + if (ret == STRINGPREP_OK) { + lua_pushstring(L, string); return 1; } else { lua_pushnil(L); - idn_free(output); return 1; /* TODO return error message */ } } -#else + +#define MAKE_PREP_FUNC(myFunc, prep) \ +static int myFunc(lua_State *L) { return stringprep_prep(L, prep); } + +MAKE_PREP_FUNC(Lstringprep_nameprep, stringprep_nameprep) /** stringprep.nameprep(s) */ +MAKE_PREP_FUNC(Lstringprep_nodeprep, stringprep_xmpp_nodeprep) /** stringprep.nodeprep(s) */ +MAKE_PREP_FUNC(Lstringprep_resourceprep, stringprep_xmpp_resourceprep) /** stringprep.resourceprep(s) */ +MAKE_PREP_FUNC(Lstringprep_saslprep, stringprep_saslprep) /** stringprep.saslprep(s) */ + +static const luaL_Reg Reg_stringprep[] = +{ + { "nameprep", Lstringprep_nameprep }, + { "nodeprep", Lstringprep_nodeprep }, + { "resourceprep", Lstringprep_resourceprep }, + { "saslprep", Lstringprep_saslprep }, + { NULL, NULL } +}; +#endif + +/***************** IDNA *****************/ +#ifdef USE_STRINGPREP_ICU #include <unicode/ustdio.h> #include <unicode/uidna.h> /* IDNA2003 or IDNA2008 ? ? ? */ @@ -296,13 +262,18 @@ static int Lidna_to_ascii(lua_State *L) /** idna.to_ascii(s) */ char output[1024]; u_strFromUTF8(ustr, 1024, &ulen, s, len, &err); + if (U_FAILURE(err)) { + lua_pushnil(L); + return 1; + } + dest_len = uidna_IDNToASCII(ustr, ulen, dest, 1024, UIDNA_USE_STD3_RULES, NULL, &err); if (U_FAILURE(err)) { lua_pushnil(L); return 1; } else { u_strToUTF8(output, 1024, &output_len, dest, dest_len, &err); - if(output_len < 1024) + if (U_SUCCESS(err) && output_len < 1024) lua_pushlstring(L, output, output_len); else lua_pushnil(L); @@ -315,25 +286,70 @@ static int Lidna_to_unicode(lua_State *L) /** idna.to_unicode(s) */ size_t len; int32_t ulen, dest_len, output_len; const char *s = luaL_checklstring(L, 1, &len); - UChar* ustr; + UChar ustr[1024]; UErrorCode err = U_ZERO_ERROR; UChar dest[1024]; char output[1024]; u_strFromUTF8(ustr, 1024, &ulen, s, len, &err); + if (U_FAILURE(err)) { + lua_pushnil(L); + return 1; + } + dest_len = uidna_IDNToUnicode(ustr, ulen, dest, 1024, UIDNA_USE_STD3_RULES, NULL, &err); if (U_FAILURE(err)) { lua_pushnil(L); return 1; } else { u_strToUTF8(output, 1024, &output_len, dest, dest_len, &err); - if(output_len < 1024) + if (U_SUCCESS(err) && output_len < 1024) lua_pushlstring(L, output, output_len); else lua_pushnil(L); return 1; } } + +#else /* USE_STRINGPREP_ICU */ +/****************** libidn ********************/ + +#include <idna.h> +#include <idn-free.h> + +static int Lidna_to_ascii(lua_State *L) /** idna.to_ascii(s) */ +{ + size_t len; + const char *s = luaL_checklstring(L, 1, &len); + char* output = NULL; + int ret = idna_to_ascii_8z(s, &output, IDNA_USE_STD3_ASCII_RULES); + if (ret == IDNA_SUCCESS) { + lua_pushstring(L, output); + idn_free(output); + return 1; + } else { + lua_pushnil(L); + idn_free(output); + return 1; /* TODO return error message */ + } +} + +static int Lidna_to_unicode(lua_State *L) /** idna.to_unicode(s) */ +{ + size_t len; + const char *s = luaL_checklstring(L, 1, &len); + char* output = NULL; + int ret = idna_to_unicode_8z8z(s, &output, 0); + if (ret == IDNA_SUCCESS) { + lua_pushstring(L, output); + idn_free(output); + return 1; + } else { + lua_pushnil(L); + idn_free(output); + return 1; /* TODO return error message */ + } +} #endif static const luaL_Reg Reg_idna[] = diff --git a/util-src/hashes.c b/util-src/hashes.c index 33a9be89..33041e83 100644 --- a/util-src/hashes.c +++ b/util-src/hashes.c @@ -14,14 +14,24 @@ */ #include <string.h> +#include <stdlib.h> + +#ifdef _MSC_VER +typedef unsigned __int32 uint32_t; +#else +#include <inttypes.h> +#endif #include "lua.h" #include "lauxlib.h" #include <openssl/sha.h> #include <openssl/md5.h> -const char* hex_tab = "0123456789abcdef"; -void toHex(const char* in, int length, char* out) { +#define HMAC_IPAD 0x36363636 +#define HMAC_OPAD 0x5c5c5c5c + +const char *hex_tab = "0123456789abcdef"; +void toHex(const unsigned char *in, int length, unsigned char *out) { int i; for (i = 0; i < length; i++) { out[i*2] = hex_tab[(in[i] >> 4) & 0xF]; @@ -34,28 +44,161 @@ static int myFunc(lua_State *L) { \ size_t len; \ const char *s = luaL_checklstring(L, 1, &len); \ int hex_out = lua_toboolean(L, 2); \ - char hash[size]; \ - char result[size*2]; \ - func((const unsigned char*)s, len, (unsigned char*)hash); \ + unsigned char hash[size], result[size*2]; \ + func((const unsigned char*)s, len, hash); \ + if (hex_out) { \ + toHex(hash, size, result); \ + lua_pushlstring(L, (char*)result, size*2); \ + } else { \ + lua_pushlstring(L, (char*)hash, size);\ + } \ + return 1; \ +} + +MAKE_HASH_FUNCTION(Lsha1, SHA1, SHA_DIGEST_LENGTH) +MAKE_HASH_FUNCTION(Lsha224, SHA224, SHA224_DIGEST_LENGTH) +MAKE_HASH_FUNCTION(Lsha256, SHA256, SHA256_DIGEST_LENGTH) +MAKE_HASH_FUNCTION(Lsha384, SHA384, SHA384_DIGEST_LENGTH) +MAKE_HASH_FUNCTION(Lsha512, SHA512, SHA512_DIGEST_LENGTH) +MAKE_HASH_FUNCTION(Lmd5, MD5, MD5_DIGEST_LENGTH) + +struct hash_desc { + int (*Init)(void*); + int (*Update)(void*, const void *, size_t); + int (*Final)(unsigned char*, void*); + size_t digestLength; + void *ctx, *ctxo; +}; + +static void hmac(struct hash_desc *desc, const char *key, size_t key_len, + const char *msg, size_t msg_len, unsigned char *result) +{ + union xory { + unsigned char bytes[64]; + uint32_t quadbytes[16]; + }; + + int i; + unsigned char hashedKey[64]; /* Maximum used digest length */ + union xory k_ipad, k_opad; + + if (key_len > 64) { + desc->Init(desc->ctx); + desc->Update(desc->ctx, key, key_len); + desc->Final(hashedKey, desc->ctx); + key = (const char*)hashedKey; + key_len = desc->digestLength; + } + + memcpy(k_ipad.bytes, key, key_len); + memset(k_ipad.bytes + key_len, 0, 64 - key_len); + memcpy(k_opad.bytes, k_ipad.bytes, 64); + + for (i = 0; i < 16; i++) { + k_ipad.quadbytes[i] ^= HMAC_IPAD; + k_opad.quadbytes[i] ^= HMAC_OPAD; + } + + desc->Init(desc->ctx); + desc->Update(desc->ctx, k_ipad.bytes, 64); + desc->Init(desc->ctxo); + desc->Update(desc->ctxo, k_opad.bytes, 64); + desc->Update(desc->ctx, msg, msg_len); + desc->Final(result, desc->ctx); + desc->Update(desc->ctxo, result, desc->digestLength); + desc->Final(result, desc->ctxo); +} + +#define MAKE_HMAC_FUNCTION(myFunc, func, size, type) \ +static int myFunc(lua_State *L) { \ + type ctx, ctxo; \ + unsigned char hash[size], result[2*size]; \ + size_t key_len, msg_len; \ + const char *key = luaL_checklstring(L, 1, &key_len); \ + const char *msg = luaL_checklstring(L, 2, &msg_len); \ + const int hex_out = lua_toboolean(L, 3); \ + struct hash_desc desc; \ + desc.Init = (int (*)(void*))func##_Init; \ + desc.Update = (int (*)(void*, const void *, size_t))func##_Update; \ + desc.Final = (int (*)(unsigned char*, void*))func##_Final; \ + desc.digestLength = size; \ + desc.ctx = &ctx; \ + desc.ctxo = &ctxo; \ + hmac(&desc, key, key_len, msg, msg_len, hash); \ if (hex_out) { \ toHex(hash, size, result); \ - lua_pushlstring(L, result, size*2); \ + lua_pushlstring(L, (char*)result, size*2); \ } else { \ - lua_pushlstring(L, hash, size);\ + lua_pushlstring(L, (char*)hash, size); \ } \ return 1; \ } -MAKE_HASH_FUNCTION(Lsha1, SHA1, 20) -MAKE_HASH_FUNCTION(Lsha256, SHA256, 32) -MAKE_HASH_FUNCTION(Lmd5, MD5, 16) +MAKE_HMAC_FUNCTION(Lhmac_sha1, SHA1, SHA_DIGEST_LENGTH, SHA_CTX) +MAKE_HMAC_FUNCTION(Lhmac_sha256, SHA256, SHA256_DIGEST_LENGTH, SHA256_CTX) +MAKE_HMAC_FUNCTION(Lhmac_sha512, SHA512, SHA512_DIGEST_LENGTH, SHA512_CTX) +MAKE_HMAC_FUNCTION(Lhmac_md5, MD5, MD5_DIGEST_LENGTH, MD5_CTX) + +static int LscramHi(lua_State *L) { + union xory { + unsigned char bytes[SHA_DIGEST_LENGTH]; + uint32_t quadbytes[SHA_DIGEST_LENGTH/4]; + }; + int i; + SHA_CTX ctx, ctxo; + unsigned char Ust[SHA_DIGEST_LENGTH]; + union xory Und; + union xory res; + size_t str_len, salt_len; + struct hash_desc desc; + const char *str = luaL_checklstring(L, 1, &str_len); + const char *salt = luaL_checklstring(L, 2, &salt_len); + char *salt2; + const int iter = luaL_checkinteger(L, 3); + + desc.Init = (int (*)(void*))SHA1_Init; + desc.Update = (int (*)(void*, const void *, size_t))SHA1_Update; + desc.Final = (int (*)(unsigned char*, void*))SHA1_Final; + desc.digestLength = SHA_DIGEST_LENGTH; + desc.ctx = &ctx; + desc.ctxo = &ctxo; + + salt2 = malloc(salt_len + 4); + if (salt2 == NULL) + luaL_error(L, "Out of memory in scramHi"); + memcpy(salt2, salt, salt_len); + memcpy(salt2 + salt_len, "\0\0\0\1", 4); + hmac(&desc, str, str_len, salt2, salt_len + 4, Ust); + free(salt2); + + memcpy(res.bytes, Ust, sizeof(res)); + for (i = 1; i < iter; i++) { + int j; + hmac(&desc, str, str_len, (char*)Ust, sizeof(Ust), Und.bytes); + for (j = 0; j < SHA_DIGEST_LENGTH/4; j++) + res.quadbytes[j] ^= Und.quadbytes[j]; + memcpy(Ust, Und.bytes, sizeof(Ust)); + } + + lua_pushlstring(L, (char*)res.bytes, SHA_DIGEST_LENGTH); + + return 1; +} static const luaL_Reg Reg[] = { - { "sha1", Lsha1 }, - { "sha256", Lsha256 }, - { "md5", Lmd5 }, - { NULL, NULL } + { "sha1", Lsha1 }, + { "sha224", Lsha224 }, + { "sha256", Lsha256 }, + { "sha384", Lsha384 }, + { "sha512", Lsha512 }, + { "md5", Lmd5 }, + { "hmac_sha1", Lhmac_sha1 }, + { "hmac_sha256", Lhmac_sha256 }, + { "hmac_sha512", Lhmac_sha512 }, + { "hmac_md5", Lhmac_md5 }, + { "scram_Hi_sha1", LscramHi }, + { NULL, NULL } }; LUALIB_API int luaopen_util_hashes(lua_State *L) diff --git a/util-src/net.c b/util-src/net.c new file mode 100644 index 00000000..e307c628 --- /dev/null +++ b/util-src/net.c @@ -0,0 +1,117 @@ +/* Prosody IM +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- +-- Copyright (C) 2012 Paul Aurich +-- Copyright (C) 2013 Matthew Wild +-- Copyright (C) 2013 Florian Zeitz +-- +*/ + +#include <stddef.h> +#include <string.h> +#include <errno.h> + +#ifndef _WIN32 + #include <sys/ioctl.h> + #include <sys/types.h> + #include <sys/socket.h> + #include <net/if.h> + #include <ifaddrs.h> + #include <arpa/inet.h> + #include <netinet/in.h> +#endif + +#include <lua.h> +#include <lauxlib.h> + +/* Enumerate all locally configured IP addresses */ + +const char * const type_strings[] = { + "both", + "ipv4", + "ipv6", + NULL +}; + +static int lc_local_addresses(lua_State *L) +{ +#ifndef _WIN32 + /* Link-local IPv4 addresses; see RFC 3927 and RFC 5735 */ + const long ip4_linklocal = htonl(0xa9fe0000); /* 169.254.0.0 */ + const long ip4_mask = htonl(0xffff0000); + struct ifaddrs *addr = NULL, *a; +#endif + int n = 1; + int type = luaL_checkoption(L, 1, "both", type_strings); + const char link_local = lua_toboolean(L, 2); /* defaults to 0 (false) */ + const char ipv4 = (type == 0 || type == 1); + const char ipv6 = (type == 0 || type == 2); + +#ifndef _WIN32 + if (getifaddrs(&addr) < 0) { + lua_pushnil(L); + lua_pushfstring(L, "getifaddrs failed (%d): %s", errno, + strerror(errno)); + return 2; + } +#endif + lua_newtable(L); + +#ifndef _WIN32 + for (a = addr; a; a = a->ifa_next) { + int family; + char ipaddr[INET6_ADDRSTRLEN]; + const char *tmp = NULL; + + if (a->ifa_addr == NULL || a->ifa_flags & IFF_LOOPBACK) + continue; + + family = a->ifa_addr->sa_family; + + if (ipv4 && family == AF_INET) { + struct sockaddr_in *sa = (struct sockaddr_in *)a->ifa_addr; + if (!link_local &&((sa->sin_addr.s_addr & ip4_mask) == ip4_linklocal)) + continue; + tmp = inet_ntop(family, &sa->sin_addr, ipaddr, sizeof(ipaddr)); + } else if (ipv6 && family == AF_INET6) { + struct sockaddr_in6 *sa = (struct sockaddr_in6 *)a->ifa_addr; + if (!link_local && IN6_IS_ADDR_LINKLOCAL(&sa->sin6_addr)) + continue; + if (IN6_IS_ADDR_V4MAPPED(&sa->sin6_addr) || IN6_IS_ADDR_V4COMPAT(&sa->sin6_addr)) + continue; + tmp = inet_ntop(family, &sa->sin6_addr, ipaddr, sizeof(ipaddr)); + } + + if (tmp != NULL) { + lua_pushstring(L, tmp); + lua_rawseti(L, -2, n++); + } + /* TODO: Error reporting? */ + } + + freeifaddrs(addr); +#else + if (ipv4) { + lua_pushstring(L, "0.0.0.0"); + lua_rawseti(L, -2, n++); + } + if (ipv6) { + lua_pushstring(L, "::"); + lua_rawseti(L, -2, n++); + } +#endif + return 1; +} + +int luaopen_util_net(lua_State* L) +{ + luaL_Reg exports[] = { + { "local_addresses", lc_local_addresses }, + { NULL, NULL } + }; + + luaL_register(L, "net", exports); + return 1; +} diff --git a/util-src/pposix.c b/util-src/pposix.c index ffd21288..8ad167b7 100644 --- a/util-src/pposix.c +++ b/util-src/pposix.c @@ -13,7 +13,7 @@ * POSIX support functions for Lua */ -#define MODULE_VERSION "0.3.5" +#define MODULE_VERSION "0.3.6" #include <stdlib.h> #include <math.h> @@ -32,8 +32,19 @@ #include <string.h> #include <errno.h> #include "lua.h" +#include "lualib.h" #include "lauxlib.h" +#include <fcntl.h> +#if defined(__linux__) && defined(_GNU_SOURCE) +#include <linux/falloc.h> +#endif + +#if (defined(_SVID_SOURCE) && !defined(WITHOUT_MALLINFO)) + #include <malloc.h> + #define WITH_MALLINFO +#endif + /* Daemonization support */ static int lc_daemonize(lua_State *L) @@ -78,6 +89,10 @@ static int lc_daemonize(lua_State *L) close(0); close(1); close(2); + /* Make sure accidental use of FDs 0, 1, 2 don't cause weirdness */ + open("/dev/null", O_RDONLY); + open("/dev/null", O_WRONLY); + open("/dev/null", O_WRONLY); /* Final fork, use it wisely */ if(fork()) @@ -189,12 +204,13 @@ int level_constants[] = { }; int lc_syslog_log(lua_State* L) { - int level = luaL_checkoption(L, 1, "notice", level_strings); - level = level_constants[level]; + int level = level_constants[luaL_checkoption(L, 1, "notice", level_strings)]; - luaL_checkstring(L, 2); + if(lua_gettop(L) == 3) + syslog(level, "%s: %s", luaL_checkstring(L, 2), luaL_checkstring(L, 3)); + else + syslog(level, "%s", lua_tostring(L, 2)); - syslog(level, "%s", lua_tostring(L, 2)); return 0; } @@ -395,23 +411,27 @@ int lc_initgroups(lua_State* L) return 2; } ret = initgroups(lua_tostring(L, 1), gid); - switch(errno) + if(ret) + { + switch(errno) + { + case ENOMEM: + lua_pushnil(L); + lua_pushstring(L, "no-memory"); + break; + case EPERM: + lua_pushnil(L); + lua_pushstring(L, "permission-denied"); + break; + default: + lua_pushnil(L); + lua_pushstring(L, "unknown-error"); + } + } + else { - case 0: lua_pushboolean(L, 1); lua_pushnil(L); - break; - case ENOMEM: - lua_pushnil(L); - lua_pushstring(L, "no-memory"); - break; - case EPERM: - lua_pushnil(L); - lua_pushstring(L, "permission-denied"); - break; - default: - lua_pushnil(L); - lua_pushstring(L, "unknown-error"); } return 2; } @@ -465,51 +485,57 @@ int string2resource(const char *s) { if (!strcmp(s, "NPROC")) return RLIMIT_NPROC; if (!strcmp(s, "RSS")) return RLIMIT_RSS; #endif +#ifdef RLIMIT_NICE + if (!strcmp(s, "NICE")) return RLIMIT_NICE; +#endif return -1; } +unsigned long int arg_to_rlimit(lua_State* L, int idx, rlim_t current) { + switch(lua_type(L, idx)) { + case LUA_TSTRING: + if(strcmp(lua_tostring(L, idx), "unlimited") == 0) + return RLIM_INFINITY; + case LUA_TNUMBER: + return lua_tointeger(L, idx); + case LUA_TNONE: + case LUA_TNIL: + return current; + default: + return luaL_argerror(L, idx, "unexpected type"); + } +} + int lc_setrlimit(lua_State *L) { + struct rlimit lim; int arguments = lua_gettop(L); - int softlimit = -1; - int hardlimit = -1; - const char *resource = NULL; int rid = -1; if(arguments < 1 || arguments > 3) { lua_pushboolean(L, 0); lua_pushstring(L, "incorrect-arguments"); + return 2; } - resource = luaL_checkstring(L, 1); - softlimit = luaL_checkinteger(L, 2); - hardlimit = luaL_checkinteger(L, 3); + rid = string2resource(luaL_checkstring(L, 1)); + if (rid == -1) { + lua_pushboolean(L, 0); + lua_pushstring(L, "invalid-resource"); + return 2; + } - rid = string2resource(resource); - if (rid != -1) { - struct rlimit lim; - struct rlimit lim_current; - - if (softlimit < 0 || hardlimit < 0) { - if (getrlimit(rid, &lim_current)) { - lua_pushboolean(L, 0); - lua_pushstring(L, "getrlimit-failed"); - return 2; - } - } + /* Fetch current values to use as defaults */ + if (getrlimit(rid, &lim)) { + lua_pushboolean(L, 0); + lua_pushstring(L, "getrlimit-failed"); + return 2; + } - if (softlimit < 0) lim.rlim_cur = lim_current.rlim_cur; - else lim.rlim_cur = softlimit; - if (hardlimit < 0) lim.rlim_max = lim_current.rlim_max; - else lim.rlim_max = hardlimit; + lim.rlim_cur = arg_to_rlimit(L, 2, lim.rlim_cur); + lim.rlim_max = arg_to_rlimit(L, 3, lim.rlim_max); - if (setrlimit(rid, &lim)) { - lua_pushboolean(L, 0); - lua_pushstring(L, "setrlimit-failed"); - return 2; - } - } else { - /* Unsupported resoucrce. Sorry I'm pretty limited by POSIX standard. */ + if (setrlimit(rid, &lim)) { lua_pushboolean(L, 0); - lua_pushstring(L, "invalid-resource"); + lua_pushstring(L, "setrlimit-failed"); return 2; } lua_pushboolean(L, 1); @@ -528,6 +554,8 @@ int lc_getrlimit(lua_State *L) { return 2; } + + resource = luaL_checkstring(L, 1); rid = string2resource(resource); if (rid != -1) { @@ -543,8 +571,14 @@ int lc_getrlimit(lua_State *L) { return 2; } lua_pushboolean(L, 1); - lua_pushnumber(L, lim.rlim_cur); - lua_pushnumber(L, lim.rlim_max); + if(lim.rlim_cur == RLIM_INFINITY) + lua_pushstring(L, "unlimited"); + else + lua_pushnumber(L, lim.rlim_cur); + if(lim.rlim_max == RLIM_INFINITY) + lua_pushstring(L, "unlimited"); + else + lua_pushnumber(L, lim.rlim_max); return 3; } @@ -577,6 +611,111 @@ int lc_uname(lua_State* L) return 1; } +int lc_setenv(lua_State* L) +{ + const char *var = luaL_checkstring(L, 1); + const char *value; + + /* If the second argument is nil or nothing, unset the var */ + if(lua_isnoneornil(L, 2)) + { + if(unsetenv(var) != 0) + { + lua_pushnil(L); + lua_pushstring(L, strerror(errno)); + return 2; + } + lua_pushboolean(L, 1); + return 1; + } + + value = luaL_checkstring(L, 2); + + if(setenv(var, value, 1) != 0) + { + lua_pushnil(L); + lua_pushstring(L, strerror(errno)); + return 2; + } + + lua_pushboolean(L, 1); + return 1; +} + +#ifdef WITH_MALLINFO +int lc_meminfo(lua_State* L) +{ + struct mallinfo info = mallinfo(); + lua_newtable(L); + /* This is the total size of memory allocated with sbrk by malloc, in bytes. */ + lua_pushinteger(L, info.arena); + lua_setfield(L, -2, "allocated"); + /* This is the total size of memory allocated with mmap, in bytes. */ + lua_pushinteger(L, info.hblkhd); + lua_setfield(L, -2, "allocated_mmap"); + /* This is the total size of memory occupied by chunks handed out by malloc. */ + lua_pushinteger(L, info.uordblks); + lua_setfield(L, -2, "used"); + /* This is the total size of memory occupied by free (not in use) chunks. */ + lua_pushinteger(L, info.fordblks); + lua_setfield(L, -2, "unused"); + /* This is the size of the top-most releasable chunk that normally borders the + end of the heap (i.e., the high end of the virtual address space's data segment). */ + lua_pushinteger(L, info.keepcost); + lua_setfield(L, -2, "returnable"); + return 1; +} +#endif + +/* File handle extraction blatantly stolen from + * https://github.com/rrthomas/luaposix/blob/master/lposix.c#L631 + * */ + +#if _XOPEN_SOURCE >= 600 || _POSIX_C_SOURCE >= 200112L || defined(_GNU_SOURCE) +int lc_fallocate(lua_State* L) +{ + off_t offset, len; + FILE *f = *(FILE**) luaL_checkudata(L, 1, LUA_FILEHANDLE); + + offset = luaL_checkinteger(L, 2); + len = luaL_checkinteger(L, 3); + +#if defined(__linux__) && defined(_GNU_SOURCE) + if(fallocate(fileno(f), FALLOC_FL_KEEP_SIZE, offset, len) == 0) + { + lua_pushboolean(L, 1); + return 1; + } + + if(errno != ENOSYS && errno != EOPNOTSUPP) + { + lua_pushnil(L); + lua_pushstring(L, strerror(errno)); + return 2; + } +#else +#warning Only using posix_fallocate() fallback. +#warning Linux fallocate() is strongly recommended if available: recompile with -D_GNU_SOURCE +#warning Note that posix_fallocate() will still be used on filesystems that dont support fallocate() +#endif + + if(posix_fallocate(fileno(f), offset, len) == 0) + { + lua_pushboolean(L, 1); + return 1; + } + else + { + lua_pushnil(L); + lua_pushstring(L, strerror(errno)); + /* posix_fallocate() can leave a bunch of NULs at the end, so we cut that + * this assumes that offset == length of the file */ + ftruncate(fileno(f), offset); + return 2; + } +} +#endif + /* Register functions */ int luaopen_util_pposix(lua_State *L) @@ -608,6 +747,16 @@ int luaopen_util_pposix(lua_State *L) { "uname", lc_uname }, + { "setenv", lc_setenv }, + +#ifdef WITH_MALLINFO + { "meminfo", lc_meminfo }, +#endif + +#if _XOPEN_SOURCE >= 600 || _POSIX_C_SOURCE >= 200112L || defined(_GNU_SOURCE) + { "fallocate", lc_fallocate }, +#endif + { NULL, NULL } }; diff --git a/util-src/windows.c b/util-src/windows.c index e1e07608..121cc471 100644 --- a/util-src/windows.c +++ b/util-src/windows.c @@ -38,15 +38,16 @@ static int Lget_nameservers(lua_State *L) { } return 1; } else { - luaL_error(L, "DnsQueryConfig returned %d", status); - return 0; // unreachable, but prevents a compiler warning + lua_pushnil(L); + lua_pushfstring(L, "DnsQueryConfig returned %d", status); + return 2; } } -static void lassert(lua_State *L, BOOL test, char* string) { - if (!test) { - luaL_error(L, "%s: %d", string, GetLastError()); - } +static int lerror(lua_State *L, char* string) { + lua_pushnil(L); + lua_pushfstring(L, "%s: %d", string, GetLastError()); + return 2; } static int Lget_consolecolor(lua_State *L) { @@ -55,9 +56,9 @@ static int Lget_consolecolor(lua_State *L) { CONSOLE_SCREEN_BUFFER_INFO info; - lassert(L, console != INVALID_HANDLE_VALUE, "GetStdHandle"); - lassert(L, GetConsoleScreenBufferInfo(console, &info), "GetConsoleScreenBufferInfo"); - lassert(L, ReadConsoleOutputAttribute(console, &color, sizeof(WORD), info.dwCursorPosition, &read_len), "ReadConsoleOutputAttribute"); + if (console == INVALID_HANDLE_VALUE) return lerror(L, "GetStdHandle"); + if (!GetConsoleScreenBufferInfo(console, &info)) return lerror(L, "GetConsoleScreenBufferInfo"); + if (!ReadConsoleOutputAttribute(console, &color, sizeof(WORD), info.dwCursorPosition, &read_len)) return lerror(L, "ReadConsoleOutputAttribute"); lua_pushnumber(L, color); return 1; @@ -65,9 +66,10 @@ static int Lget_consolecolor(lua_State *L) { static int Lset_consolecolor(lua_State *L) { int color = luaL_checkint(L, 1); HWND console = GetStdHandle(STD_OUTPUT_HANDLE); - lassert(L, console != INVALID_HANDLE_VALUE, "GetStdHandle"); - lassert(L, SetConsoleTextAttribute(console, color), "SetConsoleTextAttribute"); - return 0; + if (console == INVALID_HANDLE_VALUE) return lerror(L, "GetStdHandle"); + if (!SetConsoleTextAttribute(console, color)) return lerror(L, "SetConsoleTextAttribute"); + lua_pushboolean(L, 1); + return 1; } static const luaL_Reg Reg[] = diff --git a/util/adhoc.lua b/util/adhoc.lua new file mode 100644 index 00000000..671e85cf --- /dev/null +++ b/util/adhoc.lua @@ -0,0 +1,31 @@ +local function new_simple_form(form, result_handler) + return function(self, data, state) + if state then + if data.action == "cancel" then + return { status = "canceled" }; + end + local fields, err = form:data(data.form); + return result_handler(fields, err, data); + else + return { status = "executing", actions = {"next", "complete", default = "complete"}, form = form }, "executing"; + end + end +end + +local function new_initial_data_form(form, initial_data, result_handler) + return function(self, data, state) + if state then + if data.action == "cancel" then + return { status = "canceled" }; + end + local fields, err = form:data(data.form); + return result_handler(fields, err, data); + else + return { status = "executing", actions = {"next", "complete", default = "complete"}, + form = { layout = form, values = initial_data() } }, "executing"; + end + end +end + +return { new_simple_form = new_simple_form, + new_initial_data_form = new_initial_data_form }; diff --git a/util/array.lua b/util/array.lua index 6c1f0460..9bf215af 100644 --- a/util/array.lua +++ b/util/array.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -9,12 +9,20 @@ local t_insert, t_sort, t_remove, t_concat = table.insert, table.sort, table.remove, table.concat; +local setmetatable = setmetatable; +local math_random = math.random; +local pairs, ipairs = pairs, ipairs; +local tostring = tostring; + local array = {}; local array_base = {}; local array_methods = {}; -local array_mt = { __index = array_methods, __tostring = function (array) return array:concat(", "); end }; +local array_mt = { __index = array_methods, __tostring = function (array) return "{"..array:concat(", ").."}"; end }; -local function new_array(_, t) +local function new_array(self, t, _s, _var) + if type(t) == "function" then -- Assume iterator + t = self.collect(t, _s, _var); + end return setmetatable(t or {}, array_mt); end @@ -25,6 +33,15 @@ end setmetatable(array, { __call = new_array }); +-- Read-only methods +function array_methods:random() + return self[math_random(1,#self)]; +end + +-- These methods can be called two ways: +-- array.method(existing_array, [params [, ...]]) -- Create new array for result +-- existing_array:method([params, ...]) -- Transform existing array into result +-- function array_base.map(outa, ina, func) for k,v in ipairs(ina) do outa[k] = func(v); @@ -42,13 +59,13 @@ function array_base.filter(outa, ina, func) write = write + 1; end end - + if inplace and write <= start_length then for i=write,start_length do outa[i] = nil; end end - + return outa; end @@ -60,15 +77,18 @@ function array_base.sort(outa, ina, ...) return outa; end ---- These methods only mutate -function array_methods:random() - return self[math.random(1,#self)]; +function array_base.pluck(outa, ina, key) + for i=1,#ina do + outa[i] = ina[i][key]; + end + return outa; end +--- These methods only mutate the array function array_methods:shuffle(outa, ina) local len = #self; for i=1,#self do - local r = math.random(i,len); + local r = math_random(i,len); self[i], self[r] = self[r], self[i]; end return self; @@ -91,18 +111,32 @@ function array_methods:append(array) return self; end -array_methods.push = table.insert; -array_methods.pop = table.remove; -array_methods.concat = table.concat; -array_methods.length = function (t) return #t; end +function array_methods:push(x) + t_insert(self, x); + return self; +end + +function array_methods:pop(x) + local v = self[x]; + t_remove(self, x); + return v; +end + +function array_methods:concat(sep) + return t_concat(array.map(self, tostring), sep); +end + +function array_methods:length() + return #self; +end --- These methods always create a new array function array.collect(f, s, var) - local t, var = {}; + local t = {}; while true do var = f(s, var); if var == nil then break; end - table.insert(t, var); + t_insert(t, var); end return setmetatable(t, array_mt); end diff --git a/util/async.lua b/util/async.lua new file mode 100644 index 00000000..968ec804 --- /dev/null +++ b/util/async.lua @@ -0,0 +1,158 @@ +local log = require "util.logger".init("util.async"); + +local function runner_continue(thread) + -- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure) + if coroutine.status(thread) ~= "suspended" then -- This should suffice + return false; + end + local ok, state, runner = coroutine.resume(thread); + if not ok then + local level = 0; + while debug.getinfo(thread, level, "") do level = level + 1; end + ok, runner = debug.getlocal(thread, level-1, 1); + local error_handler = runner.watchers.error; + if error_handler then error_handler(runner, debug.traceback(thread, state)); end + elseif state == "ready" then + -- If state is 'ready', it is our responsibility to update runner.state from 'waiting'. + -- We also have to :run(), because the queue might have further items that will not be + -- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer). + runner.state = "ready"; + runner:run(); + end + return true; +end + +local function waiter(num) + local thread = coroutine.running(); + if not thread then + error("Not running in an async context, see http://prosody.im/doc/developers/async"); + end + num = num or 1; + local waiting; + return function () + if num == 0 then return; end -- already done + waiting = true; + coroutine.yield("wait"); + end, function () + num = num - 1; + if num == 0 and waiting then + runner_continue(thread); + elseif num < 0 then + error("done() called too many times"); + end + end; +end + +local function guarder() + local guards = {}; + return function (id, func) + local thread = coroutine.running(); + if not thread then + error("Not running in an async context, see http://prosody.im/doc/developers/async"); + end + local guard = guards[id]; + if not guard then + guard = {}; + guards[id] = guard; + log("debug", "New guard!"); + else + table.insert(guard, thread); + log("debug", "Guarded. %d threads waiting.", #guard) + coroutine.yield("wait"); + end + local function exit() + local next_waiting = table.remove(guard, 1); + if next_waiting then + log("debug", "guard: Executing next waiting thread (%d left)", #guard) + runner_continue(next_waiting); + else + log("debug", "Guard off duty.") + guards[id] = nil; + end + end + if func then + func(); + exit(); + return; + end + return exit; + end; +end + +local runner_mt = {}; +runner_mt.__index = runner_mt; + +local function runner_create_thread(func, self) + local thread = coroutine.create(function (self) + while true do + func(coroutine.yield("ready", self)); + end + end); + assert(coroutine.resume(thread, self)); -- Start it up, it will return instantly to wait for the first input + return thread; +end + +local empty_watchers = {}; +local function runner(func, watchers, data) + return setmetatable({ func = func, thread = false, state = "ready", notified_state = "ready", + queue = {}, watchers = watchers or empty_watchers, data = data } + , runner_mt); +end + +function runner_mt:run(input) + if input ~= nil then + table.insert(self.queue, input); + end + if self.state ~= "ready" then + return true, self.state, #self.queue; + end + + local q, thread = self.queue, self.thread; + if not thread or coroutine.status(thread) == "dead" then + thread = runner_create_thread(self.func, self); + self.thread = thread; + end + + local n, state, err = #q, self.state, nil; + self.state = "running"; + while n > 0 and state == "ready" do + local consumed; + for i = 1,n do + local input = q[i]; + local ok, new_state = coroutine.resume(thread, input); + if not ok then + consumed, state, err = i, "ready", debug.traceback(thread, new_state); + self.thread = nil; + break; + elseif new_state == "wait" then + consumed, state = i, "waiting"; + break; + end + end + if not consumed then consumed = n; end + if q[n+1] ~= nil then + n = #q; + end + for i = 1, n do + q[i] = q[consumed+i]; + end + n = #q; + end + self.state = state; + if err or state ~= self.notified_state then + if err then + state = "error" + else + self.notified_state = state; + end + local handler = self.watchers[state]; + if handler then handler(self, err); end + end + return true, state, n; +end + +function runner_mt:enqueue(input) + table.insert(self.queue, input); +end + +return { waiter = waiter, guarder = guarder, runner = runner }; diff --git a/util/broadcast.lua b/util/broadcast.lua deleted file mode 100644 index be17461d..00000000 --- a/util/broadcast.lua +++ /dev/null @@ -1,68 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - -local ipairs, pairs, setmetatable, type = - ipairs, pairs, setmetatable, type; - -module "pubsub" - -local pubsub_node_mt = { __index = _M }; - -function new_node(name) - return setmetatable({ name = name, subscribers = {} }, pubsub_node_mt); -end - -function set_subscribers(node, subscribers_list, list_type) - local subscribers = node.subscribers; - - if list_type == "array" then - for _, jid in ipairs(subscribers_list) do - if not subscribers[jid] then - node:add_subscriber(jid); - end - end - elseif (not list_type) or list_type == "set" then - for jid in pairs(subscribers_list) do - if type(jid) == "string" then - node:add_subscriber(jid); - end - end - end -end - -function get_subscribers(node) - return node.subscribers; -end - -function publish(node, item, dispatcher, data) - local subscribers = node.subscribers; - for i = 1,#subscribers do - item.attr.to = subscribers[i]; - dispatcher(data, item); - end -end - -function add_subscriber(node, jid) - local subscribers = node.subscribers; - if not subscribers[jid] then - local space = #subscribers; - subscribers[space] = jid; - subscribers[jid] = space; - end -end - -function remove_subscriber(node, jid) - local subscribers = node.subscribers; - if subscribers[jid] then - subscribers[subscribers[jid]] = nil; - subscribers[jid] = nil; - end -end - -return _M; diff --git a/util/caps.lua b/util/caps.lua index a61e7403..4723b912 100644 --- a/util/caps.lua +++ b/util/caps.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/dataforms.lua b/util/dataforms.lua index ae745e03..b38d0e27 100644 --- a/util/dataforms.lua +++ b/util/dataforms.lua @@ -1,16 +1,17 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- local setmetatable = setmetatable; local pairs, ipairs = pairs, ipairs; -local tostring, type = tostring, type; +local tostring, type, next = tostring, type, next; local t_concat = table.concat; local st = require "util.stanza"; +local jid_prep = require "util.jid".prep; module "dataforms" @@ -37,7 +38,7 @@ function form_t.form(layout, data, formtype) form:tag("field", { type = field_type, var = field.name, label = field.label }); local value = (data and data[field.name]) or field.value; - + if value then -- Add value, depending on type if field_type == "hidden" then @@ -52,7 +53,7 @@ function form_t.form(layout, data, formtype) elseif field_type == "boolean" then form:tag("value"):text((value and "1") or "0"):up(); elseif field_type == "fixed" then - + form:tag("value"):text(value):up(); elseif field_type == "jid-multi" then for _, jid in ipairs(value) do form:tag("value"):text(jid):up(); @@ -92,11 +93,11 @@ function form_t.form(layout, data, formtype) end end end - + if field.required then form:tag("required"):up(); end - + -- Jump back up to list of fields form:up(); end @@ -107,30 +108,41 @@ local field_readers = {}; function form_t.data(layout, stanza) local data = {}; - - for field_tag in stanza:childtags() do - local field_type; - for n, field in ipairs(layout) do + local errors = {}; + + for _, field in ipairs(layout) do + local tag; + for field_tag in stanza:childtags() do if field.name == field_tag.attr.var then - field_type = field.type; + tag = field_tag; break; end end - - local reader = field_readers[field_type]; - if reader then - data[field_tag.attr.var] = reader(field_tag); + + if not tag then + if field.required then + errors[field.name] = "Required value missing"; + end + else + local reader = field_readers[field.type]; + if reader then + data[field.name], errors[field.name] = reader(tag, field.required); + end end - + end + if next(errors) then + return data, errors; end return data; end field_readers["text-single"] = - function (field_tag) - local value = field_tag:child_with_name("value"); - if value then - return value[1]; + function (field_tag, required) + local data = field_tag:get_child_text("value"); + if data and #data > 0 then + return data + elseif required then + return nil, "Required value missing"; end end @@ -138,64 +150,85 @@ field_readers["text-private"] = field_readers["text-single"]; field_readers["jid-single"] = - field_readers["text-single"]; + function (field_tag, required) + local raw_data = field_tag:get_child_text("value") + local data = jid_prep(raw_data); + if data and #data > 0 then + return data + elseif raw_data then + return nil, "Invalid JID: " .. raw_data; + elseif required then + return nil, "Required value missing"; + end + end field_readers["jid-multi"] = - function (field_tag) + function (field_tag, required) local result = {}; - for value_tag in field_tag:childtags() do - if value_tag.name == "value" then - result[#result+1] = value_tag[1]; + local err = {}; + for value_tag in field_tag:childtags("value") do + local raw_value = value_tag:get_text(); + local value = jid_prep(raw_value); + result[#result+1] = value; + if raw_value and not value then + err[#err+1] = ("Invalid JID: " .. raw_value); end end - return result; + if #result > 0 then + return result, (#err > 0 and t_concat(err, "\n") or nil); + elseif required then + return nil, "Required value missing"; + end end -field_readers["text-multi"] = - function (field_tag) +field_readers["list-multi"] = + function (field_tag, required) local result = {}; - for value_tag in field_tag:childtags() do - if value_tag.name == "value" then - result[#result+1] = value_tag[1]; - end + for value in field_tag:childtags("value") do + result[#result+1] = value:get_text(); + end + if #result > 0 then + return result; + elseif required then + return nil, "Required value missing"; + end + end + +field_readers["text-multi"] = + function (field_tag, required) + local data, err = field_readers["list-multi"](field_tag, required); + if data then + data = t_concat(data, "\n"); end - return t_concat(result, "\n"); + return data, err; end field_readers["list-single"] = field_readers["text-single"]; -field_readers["list-multi"] = - function (field_tag) - local result = {}; - for value_tag in field_tag:childtags() do - if value_tag.name == "value" then - result[#result+1] = value_tag[1]; - end - end - return result; - end +local boolean_values = { + ["1"] = true, ["true"] = true, + ["0"] = false, ["false"] = false, +}; field_readers["boolean"] = - function (field_tag) - local value = field_tag:child_with_name("value"); - if value then - if value[1] == "1" or value[1] == "true" then - return true; - else - return false; - end + function (field_tag, required) + local raw_value = field_tag:get_child_text("value"); + local value = boolean_values[raw_value ~= nil and raw_value]; + if value ~= nil then + return value; + elseif raw_value then + return nil, "Invalid boolean representation"; + elseif required then + return nil, "Required value missing"; end end field_readers["hidden"] = function (field_tag) - local value = field_tag:child_with_name("value"); - if value then - return value[1]; - end + return field_tag:get_child_text("value"); end - + return _M; diff --git a/util/datamanager.lua b/util/datamanager.lua index fbdfb581..4a4d62b3 100644 --- a/util/datamanager.lua +++ b/util/datamanager.lua @@ -1,34 +1,47 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- local format = string.format; -local setmetatable, type = setmetatable, type; -local pairs, ipairs = pairs, ipairs; +local setmetatable = setmetatable; +local ipairs = ipairs; local char = string.char; -local loadfile, setfenv, pcall = loadfile, setfenv, pcall; +local pcall = pcall; local log = require "util.logger".init("datamanager"); local io_open = io.open; local os_remove = os.remove; -local tostring, tonumber = tostring, tonumber; -local error = error; +local os_rename = os.rename; +local tonumber = tonumber; local next = next; local t_insert = table.insert; -local append = require "util.serialization".append; -local path_separator = "/"; if os.getenv("WINDIR") then path_separator = "\\" end +local t_concat = table.concat; +local envloadfile = require"util.envload".envloadfile; +local serialize = require "util.serialization".serialize; +local path_separator = assert ( package.config:match ( "^([^\n]+)" ) , "package.config not in standard form" ) -- Extract directory seperator from package.config (an undocumented string that comes with lua) local lfs = require "lfs"; -local raw_mkdir; +local prosody = prosody; -if prosody.platform == "posix" then - raw_mkdir = require "util.pposix".mkdir; -- Doesn't trample on umask -else - raw_mkdir = lfs.mkdir; -end +local raw_mkdir = lfs.mkdir; +local function fallocate(f, offset, len) + -- This assumes that current position == offset + local fake_data = (" "):rep(len); + local ok, msg = f:write(fake_data); + if not ok then + return ok, msg; + end + f:seek("set", offset); + return true; +end; +pcall(function() + local pposix = require "util.pposix"; + raw_mkdir = pposix.mkdir or raw_mkdir; -- Doesn't trample on umask + fallocate = pposix.fallocate or fallocate; +end); module "datamanager" @@ -56,7 +69,7 @@ local function mkdir(path) return path; end -local data_path = "data"; +local data_path = (prosody and prosody.paths and prosody.paths.data) or "."; local callbacks = {}; ------- API ------------- @@ -71,7 +84,7 @@ local function callback(username, host, datastore, data) username, host, datastore, data = f(username, host, datastore, data); if username == false then break; end end - + return username, host, datastore, data; end function add_callback(func) @@ -100,37 +113,68 @@ function getpath(username, host, datastore, ext, create) if username then if create then mkdir(mkdir(mkdir(data_path).."/"..host).."/"..datastore); end return format("%s/%s/%s/%s.%s", data_path, host, datastore, username, ext); - elseif host then + else if create then mkdir(mkdir(data_path).."/"..host); end return format("%s/%s/%s.%s", data_path, host, datastore, ext); - else - if create then mkdir(data_path); end - return format("%s/%s.%s", data_path, datastore, ext); end end function load(username, host, datastore) - local data, ret = loadfile(getpath(username, host, datastore)); + local data, ret = envloadfile(getpath(username, host, datastore), {}); if not data then local mode = lfs.attributes(getpath(username, host, datastore), "mode"); if not mode then - log("debug", "Failed to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil")); + log("debug", "Assuming empty %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil"); return nil; else -- file exists, but can't be read -- TODO more detailed error checking and logging? - log("error", "Failed to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil")); + log("error", "Failed to load %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil"); return nil, "Error reading storage"; end end - setfenv(data, {}); + local success, ret = pcall(data); if not success then - log("error", "Unable to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil")); + log("error", "Unable to load %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil"); return nil, "Error reading storage"; end return ret; end +local function atomic_store(filename, data) + local scratch = filename.."~"; + local f, ok, msg; + repeat + f, msg = io_open(scratch, "w"); + if not f then break end + + ok, msg = f:write(data); + if not ok then break end + + ok, msg = f:close(); + if not ok then break end + + return os_rename(scratch, filename); + until false; + + -- Cleanup + if f then f:close(); end + os_remove(scratch); + return nil, msg; +end + +if prosody.platform ~= "posix" then + -- os.rename does not overwrite existing files on Windows + -- TODO We could use Transactional NTFS on Vista and above + function atomic_store(filename, data) + local f, err = io_open(filename, "w"); + if not f then return f, err; end + local ok, msg = f:write(data); + if not ok then f:close(); return ok, msg; end + return f:close(); + end +end + function store(username, host, datastore, data) if not data then data = {}; @@ -142,20 +186,26 @@ function store(username, host, datastore, data) end -- save the datastore - local f, msg = io_open(getpath(username, host, datastore, nil, true), "w+"); - if not f then - log("error", "Unable to write to "..datastore.." storage ('"..msg.."') for user: "..(username or "nil").."@"..(host or "nil")); - return nil, "Error saving to storage"; - end - f:write("return "); - append(f, data); - f:close(); - if next(data) == nil then -- try to delete empty datastore - log("debug", "Removing empty %s datastore for user %s@%s", datastore, username or "nil", host or "nil"); - os_remove(getpath(username, host, datastore)); - end - -- we write data even when we are deleting because lua doesn't have a - -- platform independent way of checking for non-exisitng files + local d = "return " .. serialize(data) .. ";\n"; + local mkdir_cache_cleared; + repeat + local ok, msg = atomic_store(getpath(username, host, datastore, nil, true), d); + if not ok then + if not mkdir_cache_cleared then -- We may need to recreate a removed directory + _mkdir = {}; + mkdir_cache_cleared = true; + else + log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil"); + return nil, "Error saving to storage"; + end + end + if next(data) == nil then -- try to delete empty datastore + log("debug", "Removing empty %s datastore for user %s@%s", datastore, username or "nil", host or "nil"); + os_remove(getpath(username, host, datastore)); + end + -- we write data even when we are deleting because lua doesn't have a + -- platform independent way of checking for non-exisitng files + until ok; return true; end @@ -163,14 +213,24 @@ function list_append(username, host, datastore, data) if not data then return; end if callback(username, host, datastore) == false then return true; end -- save the datastore - local f, msg = io_open(getpath(username, host, datastore, "list", true), "a+"); + local f, msg = io_open(getpath(username, host, datastore, "list", true), "r+"); + if not f then + f, msg = io_open(getpath(username, host, datastore, "list", true), "w"); + end if not f then - log("error", "Unable to write to "..datastore.." storage ('"..msg.."') for user: "..(username or "nil").."@"..(host or "nil")); + log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil"); return; end - f:write("item("); - append(f, data); - f:write(");\n"); + local data = "item(" .. serialize(data) .. ");\n"; + local pos = f:seek("end"); + local ok, msg = fallocate(f, pos, #data); + f:seek("set", pos); + if ok then + f:write(data); + else + log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil"); + return ok, msg; + end f:close(); return true; end @@ -181,17 +241,15 @@ function list_store(username, host, datastore, data) end if callback(username, host, datastore) == false then return true; end -- save the datastore - local f, msg = io_open(getpath(username, host, datastore, "list", true), "w+"); - if not f then - log("error", "Unable to write to "..datastore.." storage ('"..msg.."') for user: "..(username or "nil").."@"..(host or "nil")); - return; + local d = {}; + for _, item in ipairs(data) do + d[#d+1] = "item(" .. serialize(item) .. ");\n"; end - for _, d in ipairs(data) do - f:write("item("); - append(f, d); - f:write(");\n"); + local ok, msg = atomic_store(getpath(username, host, datastore, "list", true), t_concat(d)); + if not ok then + log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil"); + return; end - f:close(); if next(data) == nil then -- try to delete empty datastore log("debug", "Removing empty %s datastore for user %s@%s", datastore, username or "nil", host or "nil"); os_remove(getpath(username, host, datastore, "list")); @@ -202,26 +260,108 @@ function list_store(username, host, datastore, data) end function list_load(username, host, datastore) - local data, ret = loadfile(getpath(username, host, datastore, "list")); + local items = {}; + local data, ret = envloadfile(getpath(username, host, datastore, "list"), {item = function(i) t_insert(items, i); end}); if not data then local mode = lfs.attributes(getpath(username, host, datastore, "list"), "mode"); if not mode then - log("debug", "Failed to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil")); + log("debug", "Assuming empty %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil"); return nil; else -- file exists, but can't be read -- TODO more detailed error checking and logging? - log("error", "Failed to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil")); + log("error", "Failed to load %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil"); return nil, "Error reading storage"; end end - local items = {}; - setfenv(data, {item = function(i) t_insert(items, i); end}); + local success, ret = pcall(data); if not success then - log("error", "Unable to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil")); + log("error", "Unable to load %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil"); return nil, "Error reading storage"; end return items; end +local type_map = { + keyval = "dat"; + list = "list"; +} + +function users(host, store, typ) + typ = type_map[typ or "keyval"]; + local store_dir = format("%s/%s/%s", data_path, encode(host), store); + + local mode, err = lfs.attributes(store_dir, "mode"); + if not mode then + return function() log("debug", err or (store_dir .. " does not exist")) end + end + local next, state = lfs.dir(store_dir); + return function(state) + for node in next, state do + local file, ext = node:match("^(.*)%.([dalist]+)$"); + if file and ext == typ then + return decode(file); + end + end + end, state; +end + +function stores(username, host, typ) + typ = type_map[typ or "keyval"]; + local store_dir = format("%s/%s/", data_path, encode(host)); + + local mode, err = lfs.attributes(store_dir, "mode"); + if not mode then + return function() log("debug", err or (store_dir .. " does not exist")) end + end + local next, state = lfs.dir(store_dir); + return function(state) + for node in next, state do + if not node:match"^%." then + if username == true then + if lfs.attributes(store_dir..node, "mode") == "directory" then + return decode(node); + end + elseif username then + local store = decode(node) + if lfs.attributes(getpath(username, host, store, typ), "mode") then + return store; + end + elseif lfs.attributes(node, "mode") == "file" then + local file, ext = node:match("^(.*)%.([dalist]+)$"); + if ext == typ then + return decode(file) + end + end + end + end + end, state; +end + +local function do_remove(path) + local ok, err = os_remove(path); + if not ok and lfs.attributes(path, "mode") then + return ok, err; + end + return true +end + +function purge(username, host) + local host_dir = format("%s/%s/", data_path, encode(host)); + local errs = {}; + for file in lfs.dir(host_dir) do + if lfs.attributes(host_dir..file, "mode") == "directory" then + local store = decode(file); + local ok, err = do_remove(getpath(username, host, store)); + if not ok then errs[#errs+1] = err; end + + local ok, err = do_remove(getpath(username, host, store, "list")); + if not ok then errs[#errs+1] = err; end + end + end + return #errs == 0, t_concat(errs, ", "); +end + +_M.path_decode = decode; +_M.path_encode = encode; return _M; diff --git a/util/datetime.lua b/util/datetime.lua index 301a49a5..dd596527 100644 --- a/util/datetime.lua +++ b/util/datetime.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -36,7 +36,7 @@ end function parse(s) if s then local year, month, day, hour, min, sec, tzd; - year, month, day, hour, min, sec, tzd = s:match("^(%d%d%d%d)-?(%d%d)-?(%d%d)T(%d%d):(%d%d):(%d%d)%.?%d*([Z+%-].*)$"); + year, month, day, hour, min, sec, tzd = s:match("^(%d%d%d%d)%-?(%d%d)%-?(%d%d)T(%d%d):(%d%d):(%d%d)%.?%d*([Z+%-]?.*)$"); if year then local time_offset = os_difftime(os_time(os_date("*t")), os_time(os_date("!*t"))); -- to deal with local timezone local tzd_offset = 0; @@ -49,7 +49,7 @@ function parse(s) if sign == "-" then tzd_offset = -tzd_offset; end end sec = (sec + time_offset) - tzd_offset; - return os_time({year=year, month=month, day=day, hour=hour, min=min, sec=sec}); + return os_time({year=year, month=month, day=day, hour=hour, min=min, sec=sec, isdst=false}); end end end diff --git a/util/debug.lua b/util/debug.lua new file mode 100644 index 00000000..91f691e1 --- /dev/null +++ b/util/debug.lua @@ -0,0 +1,199 @@ +-- Variables ending with these names will not +-- have their values printed ('password' includes +-- 'new_password', etc.) +local censored_names = { + password = true; + passwd = true; + pass = true; + pwd = true; +}; +local optimal_line_length = 65; + +local termcolours = require "util.termcolours"; +local getstring = termcolours.getstring; +local styles; +do + _ = termcolours.getstyle; + styles = { + boundary_padding = _("bright"); + filename = _("bright", "blue"); + level_num = _("green"); + funcname = _("yellow"); + location = _("yellow"); + }; +end +module("debugx", package.seeall); + +function get_locals_table(thread, level) + local locals = {}; + for local_num = 1, math.huge do + local name, value; + if thread then + name, value = debug.getlocal(thread, level, local_num); + else + name, value = debug.getlocal(level+1, local_num); + end + if not name then break; end + table.insert(locals, { name = name, value = value }); + end + return locals; +end + +function get_upvalues_table(func) + local upvalues = {}; + if func then + for upvalue_num = 1, math.huge do + local name, value = debug.getupvalue(func, upvalue_num); + if not name then break; end + table.insert(upvalues, { name = name, value = value }); + end + end + return upvalues; +end + +function string_from_var_table(var_table, max_line_len, indent_str) + local var_string = {}; + local col_pos = 0; + max_line_len = max_line_len or math.huge; + indent_str = "\n"..(indent_str or ""); + for _, var in ipairs(var_table) do + local name, value = var.name, var.value; + if name:sub(1,1) ~= "(" then + if type(value) == "string" then + if censored_names[name:match("%a+$")] then + value = "<hidden>"; + else + value = ("%q"):format(value); + end + else + value = tostring(value); + end + if #value > max_line_len then + value = value:sub(1, max_line_len-3).."…"; + end + local str = ("%s = %s"):format(name, tostring(value)); + col_pos = col_pos + #str; + if col_pos > max_line_len then + table.insert(var_string, indent_str); + col_pos = 0; + end + table.insert(var_string, str); + end + end + if #var_string == 0 then + return nil; + else + return "{ "..table.concat(var_string, ", "):gsub(indent_str..", ", indent_str).." }"; + end +end + +function get_traceback_table(thread, start_level) + local levels = {}; + for level = start_level, math.huge do + local info; + if thread then + info = debug.getinfo(thread, level); + else + info = debug.getinfo(level+1); + end + if not info then break; end + + levels[(level-start_level)+1] = { + level = level; + info = info; + locals = get_locals_table(thread, level+(thread and 0 or 1)); + upvalues = get_upvalues_table(info.func); + }; + end + return levels; +end + +function traceback(...) + local ok, ret = pcall(_traceback, ...); + if not ok then + return "Error in error handling: "..ret; + end + return ret; +end + +local function build_source_boundary_marker(last_source_desc) + local padding = string.rep("-", math.floor(((optimal_line_length - 6) - #last_source_desc)/2)); + return getstring(styles.boundary_padding, "v"..padding).." "..getstring(styles.filename, last_source_desc).." "..getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v ")); +end + +function _traceback(thread, message, level) + + -- Lua manual says: debug.traceback ([thread,] [message [, level]]) + -- I fathom this to mean one of: + -- () + -- (thread) + -- (message, level) + -- (thread, message, level) + + if thread == nil then -- Defaults + thread, message, level = coroutine.running(), message, level; + elseif type(thread) == "string" then + thread, message, level = coroutine.running(), thread, message; + elseif type(thread) ~= "thread" then + return nil; -- debug.traceback() does this + end + + level = level or 0; + + message = message and (message.."\n") or ""; + + -- +3 counts for this function, and the pcall() and wrapper above us, the +1... I don't know. + local levels = get_traceback_table(thread, level+(thread == nil and 4 or 0)); + + local last_source_desc; + + local lines = {}; + for nlevel, level in ipairs(levels) do + local info = level.info; + local line = "..."; + local func_type = info.namewhat.." "; + local source_desc = (info.short_src == "[C]" and "C code") or info.short_src or "Unknown"; + if func_type == " " then func_type = ""; end; + if info.short_src == "[C]" then + line = "[ C ] "..func_type.."C function "..getstring(styles.location, (info.name and ("%q"):format(info.name) or "(unknown name)")); + elseif info.what == "main" then + line = "[Lua] "..getstring(styles.location, info.short_src.." line "..info.currentline); + else + local name = info.name or " "; + if name ~= " " then + name = ("%q"):format(name); + end + if func_type == "global " or func_type == "local " then + func_type = func_type.."function "; + end + line = "[Lua] "..getstring(styles.location, info.short_src.." line "..info.currentline).." in "..func_type..getstring(styles.funcname, name).." (defined on line "..info.linedefined..")"; + end + if source_desc ~= last_source_desc then -- Venturing into a new source, add marker for previous + last_source_desc = source_desc; + table.insert(lines, "\t "..build_source_boundary_marker(last_source_desc)); + end + nlevel = nlevel-1; + table.insert(lines, "\t"..(nlevel==0 and ">" or " ")..getstring(styles.level_num, "("..nlevel..") ")..line); + local npadding = (" "):rep(#tostring(nlevel)); + if level.locals then + local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding); + if locals_str then + table.insert(lines, "\t "..npadding.."Locals: "..locals_str); + end + end + local upvalues_str = string_from_var_table(level.upvalues, optimal_line_length, "\t "..npadding); + if upvalues_str then + table.insert(lines, "\t "..npadding.."Upvals: "..upvalues_str); + end + end + +-- table.insert(lines, "\t "..build_source_boundary_marker(last_source_desc)); + + return message.."stack traceback:\n"..table.concat(lines, "\n"); +end + +function use() + debug.traceback = traceback; +end + +return _M; diff --git a/util/dependencies.lua b/util/dependencies.lua index 9371521c..109a3332 100644 --- a/util/dependencies.lua +++ b/util/dependencies.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -35,11 +35,24 @@ function missingdep(name, sources, msg) print(""); end +-- COMPAT w/pre-0.8 Debian: The Debian config file used to use +-- util.ztact, which has been removed from Prosody in 0.8. This +-- is to log an error for people who still use it, so they can +-- update their configs. +package.preload["util.ztact"] = function () + if not package.loaded["core.loggingmanager"] then + error("util.ztact has been removed from Prosody and you need to fix your config " + .."file. More information can be found at http://prosody.im/doc/packagers#ztact", 0); + else + error("module 'util.ztact' has been deprecated in Prosody 0.8."); + end +end; + function check_dependencies() local fatal; - + local lxp = softreq "lxp" - + if not lxp then missingdep("luaexpat", { ["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-expat0"; @@ -48,9 +61,9 @@ function check_dependencies() }); fatal = true; end - + local socket = softreq "socket" - + if not socket then missingdep("luasocket", { ["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-socket2"; @@ -59,7 +72,7 @@ function check_dependencies() }); fatal = true; end - + local lfs, err = softreq "lfs" if not lfs then missingdep("luafilesystem", { @@ -69,9 +82,9 @@ function check_dependencies() }); fatal = true; end - + local ssl = softreq "ssl" - + if not ssl then missingdep("LuaSec", { ["Debian/Ubuntu"] = "http://prosody.im/download/start#debian_and_ubuntu"; @@ -79,7 +92,7 @@ function check_dependencies() ["Source"] = "http://www.inf.puc-rio.br/~brunoos/luasec/"; }, "SSL/TLS support will not be available"); end - + local encodings, err = softreq "util.encodings" if not encodings then if err:match("not found") then @@ -123,6 +136,14 @@ function log_warnings() log("error", "This version of LuaSec contains a known bug that causes disconnects, see http://prosody.im/doc/depends"); end end + if lxp then + if not pcall(lxp.new, { StartDoctypeDecl = false }) then + log("error", "The version of LuaExpat on your system leaves Prosody " + .."vulnerable to denial-of-service attacks. You should upgrade to " + .."LuaExpat 1.1.1 or higher as soon as possible. See " + .."http://prosody.im/doc/depends#luaexpat for more information."); + end + end end return _M; diff --git a/util/envload.lua b/util/envload.lua new file mode 100644 index 00000000..53e28348 --- /dev/null +++ b/util/envload.lua @@ -0,0 +1,34 @@ +-- Prosody IM +-- Copyright (C) 2008-2011 Florian Zeitz +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local load, loadstring, loadfile, setfenv = load, loadstring, loadfile, setfenv; +local envload; +local envloadfile; + +if setfenv then + function envload(code, source, env) + local f, err = loadstring(code, source); + if f and env then setfenv(f, env); end + return f, err; + end + + function envloadfile(file, env) + local f, err = loadfile(file); + if f and env then setfenv(f, env); end + return f, err; + end +else + function envload(code, source, env) + return load(code, source, nil, env); + end + + function envloadfile(file, env) + return loadfile(file, nil, env); + end +end + +return { envload = envload, envloadfile = envloadfile }; diff --git a/util/events.lua b/util/events.lua index 412acccd..40ca3913 100644 --- a/util/events.lua +++ b/util/events.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -60,11 +60,11 @@ function new() remove_handler(event, handler); end end; - local function fire_event(event, ...) - local h = handlers[event]; + local function fire_event(event_name, event_data) + local h = handlers[event_name]; if h then for i=1,#h do - local ret = h[i](...); + local ret = h[i](event_data); if ret ~= nil then return ret; end end end diff --git a/util/filters.lua b/util/filters.lua index d143666b..8a470011 100644 --- a/util/filters.lua +++ b/util/filters.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -16,7 +16,7 @@ function initialize(session) if not session.filters then local filters = {}; session.filters = filters; - + function session.filter(type, data) local filter_list = filters[type]; if filter_list then @@ -28,11 +28,11 @@ function initialize(session) return data; end end - + for i=1,#new_filter_hooks do new_filter_hooks[i](session); end - + return session.filter; end @@ -40,20 +40,20 @@ function add_filter(session, type, callback, priority) if not session.filters then initialize(session); end - + local filter_list = session.filters[type]; if not filter_list then filter_list = {}; session.filters[type] = filter_list; end - + priority = priority or 0; - + local i = 0; repeat i = i + 1; until not filter_list[i] or filter_list[filter_list[i]] >= priority; - + t_insert(filter_list, i, callback); filter_list[callback] = priority; end diff --git a/util/helpers.lua b/util/helpers.lua index 11356176..437a920c 100644 --- a/util/helpers.lua +++ b/util/helpers.lua @@ -1,17 +1,27 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- +local debug = require "util.debug"; + module("helpers", package.seeall); -- Helper functions for debugging local log = require "util.logger".init("util.debug"); +function log_host_events(host) + return log_events(prosody.hosts[host].events, host); +end + +function revert_log_host_events(host) + return revert_log_events(prosody.hosts[host].events); +end + function log_events(events, name, logger) local f = events.fire_event; if not f then @@ -28,7 +38,36 @@ function log_events(events, name, logger) end function revert_log_events(events) - events.fire_event, events[events.fire_event] = events[events.fire_event], nil; -- :) + events.fire_event, events[events.fire_event] = events[events.fire_event], nil; -- :)) +end + +function show_events(events, specific_event) + local event_handlers = events._handlers; + local events_array = {}; + local event_handler_arrays = {}; + for event in pairs(events._event_map) do + local handlers = event_handlers[event]; + if handlers and (event == specific_event or not specific_event) then + table.insert(events_array, event); + local handler_strings = {}; + for i, handler in ipairs(handlers) do + local upvals = debug.string_from_var_table(debug.get_upvalues_table(handler)); + handler_strings[i] = " "..i..": "..tostring(handler)..(upvals and ("\n "..upvals) or ""); + end + event_handler_arrays[event] = handler_strings; + end + end + table.sort(events_array); + local i = 1; + while i <= #events_array do + local handlers = event_handler_arrays[events_array[i]]; + for j=#handlers, 1, -1 do + table.insert(events_array, i+1, handlers[j]); + end + if i > 1 then events_array[i] = "\n"..events_array[i]; end + i = i + #handlers + 1 + end + return table.concat(events_array, "\n"); end function get_upvalue(f, get_name) diff --git a/util/hmac.lua b/util/hmac.lua index 6df6986e..2c4cc6ef 100644 --- a/util/hmac.lua +++ b/util/hmac.lua @@ -1,69 +1,15 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- -local hashes = require "util.hashes" - -local s_char = string.char; -local s_gsub = string.gsub; -local s_rep = string.rep; - -module "hmac" - -local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;}; -local function xor(x, y) - local lowx, lowy = x % 16, y % 16; - local hix, hiy = (x - lowx) / 16, (y - lowy) / 16; - local lowr, hir = xor_map[lowx * 16 + lowy + 1], xor_map[hix * 16 + hiy + 1]; - local r = hir * 16 + lowr; - return r; -end -local opadc, ipadc = s_char(0x5c), s_char(0x36); -local ipad_map = {}; -local opad_map = {}; -for i=0,255 do - ipad_map[s_char(i)] = s_char(xor(0x36, i)); - opad_map[s_char(i)] = s_char(xor(0x5c, i)); -end - ---[[ -key - the key to use in the hash -message - the message to hash -hash - the hash function -blocksize - the blocksize for the hash function in bytes -hex - return raw hash or hexadecimal string ---]] -function hmac(key, message, hash, blocksize, hex) - if #key > blocksize then - key = hash(key) - end +-- COMPAT: Only for external pre-0.9 modules - local padding = blocksize - #key; - local ipad = s_gsub(key, ".", ipad_map)..s_rep(ipadc, padding); - local opad = s_gsub(key, ".", opad_map)..s_rep(opadc, padding); - - return hash(opad..hash(ipad..message), hex) -end - -function md5(key, message, hex) - return hmac(key, message, hashes.md5, 64, hex) -end - -function sha1(key, message, hex) - return hmac(key, message, hashes.sha1, 64, hex) -end - -function sha256(key, message, hex) - return hmac(key, message, hashes.sha256, 64, hex) -end +local hashes = require "util.hashes" -return _M +return { md5 = hashes.hmac_md5, + sha1 = hashes.hmac_sha1, + sha256 = hashes.hmac_sha256 }; diff --git a/util/http.lua b/util/http.lua new file mode 100644 index 00000000..f7259920 --- /dev/null +++ b/util/http.lua @@ -0,0 +1,64 @@ +-- Prosody IM +-- Copyright (C) 2013 Florian Zeitz +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local format, char = string.format, string.char; +local pairs, ipairs, tonumber = pairs, ipairs, tonumber; +local t_insert, t_concat = table.insert, table.concat; + +local function urlencode(s) + return s and (s:gsub("[^a-zA-Z0-9.~_-]", function (c) return format("%%%02x", c:byte()); end)); +end +local function urldecode(s) + return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end)); +end + +local function _formencodepart(s) + return s and (s:gsub("%W", function (c) + if c ~= " " then + return format("%%%02x", c:byte()); + else + return "+"; + end + end)); +end + +local function formencode(form) + local result = {}; + if form[1] then -- Array of ordered { name, value } + for _, field in ipairs(form) do + t_insert(result, _formencodepart(field.name).."=".._formencodepart(field.value)); + end + else -- Unordered map of name -> value + for name, value in pairs(form) do + t_insert(result, _formencodepart(name).."=".._formencodepart(value)); + end + end + return t_concat(result, "&"); +end + +local function formdecode(s) + if not s:match("=") then return urldecode(s); end + local r = {}; + for k, v in s:gmatch("([^=&]*)=([^&]*)") do + k, v = k:gsub("%+", "%%20"), v:gsub("%+", "%%20"); + k, v = urldecode(k), urldecode(v); + t_insert(r, { name = k, value = v }); + r[k] = v; + end + return r; +end + +local function contains_token(field, token) + field = ","..field:gsub("[ \t]", ""):lower()..","; + return field:find(","..token:lower()..",", 1, true) ~= nil; +end + +return { + urlencode = urlencode, urldecode = urldecode; + formencode = formencode, formdecode = formdecode; + contains_token = contains_token; +}; diff --git a/util/httpstream.lua b/util/httpstream.lua deleted file mode 100644 index bdc3fce7..00000000 --- a/util/httpstream.lua +++ /dev/null @@ -1,137 +0,0 @@ - -local coroutine = coroutine; -local tonumber = tonumber; - -local deadroutine = coroutine.create(function() end); -coroutine.resume(deadroutine); - -module("httpstream") - -local function parser(success_cb, parser_type, options_cb) - local data = coroutine.yield(); - local function readline() - local pos = data:find("\r\n", nil, true); - while not pos do - data = data..coroutine.yield(); - pos = data:find("\r\n", nil, true); - end - local r = data:sub(1, pos-1); - data = data:sub(pos+2); - return r; - end - local function readlength(n) - while #data < n do - data = data..coroutine.yield(); - end - local r = data:sub(1, n); - data = data:sub(n + 1); - return r; - end - local function readheaders() - local headers = {}; -- read headers - while true do - local line = readline(); - if line == "" then break; end -- headers done - local key, val = line:match("^([^%s:]+): *(.*)$"); - if not key then coroutine.yield("invalid-header-line"); end -- TODO handle multi-line and invalid headers - key = key:lower(); - headers[key] = headers[key] and headers[key]..","..val or val; - end - return headers; - end - - if not parser_type or parser_type == "server" then - while true do - -- read status line - local status_line = readline(); - local method, path, httpversion = status_line:match("^(%S+)%s+(%S+)%s+HTTP/(%S+)$"); - if not method then coroutine.yield("invalid-status-line"); end - path = path:gsub("^//+", "/"); -- TODO parse url more - local headers = readheaders(); - - -- read body - local len = tonumber(headers["content-length"]); - len = len or 0; -- TODO check for invalid len - local body = readlength(len); - - success_cb({ - method = method; - path = path; - httpversion = httpversion; - headers = headers; - body = body; - }); - end - elseif parser_type == "client" then - while true do - -- read status line - local status_line = readline(); - local httpversion, status_code, reason_phrase = status_line:match("^HTTP/(%S+)%s+(%d%d%d)%s+(.*)$"); - status_code = tonumber(status_code); - if not status_code then coroutine.yield("invalid-status-line"); end - local headers = readheaders(); - - -- read body - local have_body = not - ( (options_cb and options_cb().method == "HEAD") - or (status_code == 204 or status_code == 304 or status_code == 301) - or (status_code >= 100 and status_code < 200) ); - - local body; - if have_body then - local len = tonumber(headers["content-length"]); - if headers["transfer-encoding"] == "chunked" then - body = ""; - while true do - local chunk_size = readline():match("^%x+"); - if not chunk_size then coroutine.yield("invalid-chunk-size"); end - chunk_size = tonumber(chunk_size, 16) - if chunk_size == 0 then break; end - body = body..readlength(chunk_size); - if readline() ~= "" then coroutine.yield("invalid-chunk-ending"); end - end - local trailers = readheaders(); - elseif len then -- TODO check for invalid len - body = readlength(len); - else -- read to end - repeat - local newdata = coroutine.yield(); - data = data..newdata; - until newdata == ""; - body, data = data, ""; - end - end - - success_cb({ - code = status_code; - httpversion = httpversion; - headers = headers; - body = body; - -- COMPAT the properties below are deprecated - responseversion = httpversion; - responseheaders = headers; - }); - end - else coroutine.yield("unknown-parser-type"); end -end - -function new(success_cb, error_cb, parser_type, options_cb) - local co = coroutine.create(parser); - coroutine.resume(co, success_cb, parser_type, options_cb) - return { - feed = function(self, data) - if not data then - if parser_type == "client" then coroutine.resume(co, ""); end - co = deadroutine; - return error_cb(); - end - local success, result = coroutine.resume(co, data); - if result then - co = deadroutine; - return error_cb(result); - end - end; - }; -end - -return _M; diff --git a/util/import.lua b/util/import.lua index 81401e8b..174da0ca 100644 --- a/util/import.lua +++ b/util/import.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/ip.lua b/util/ip.lua new file mode 100644 index 00000000..d0ae07eb --- /dev/null +++ b/util/ip.lua @@ -0,0 +1,244 @@ +-- Prosody IM +-- Copyright (C) 2008-2011 Florian Zeitz +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local ip_methods = {}; +local ip_mt = { __index = function (ip, key) return (ip_methods[key])(ip); end, + __tostring = function (ip) return ip.addr; end, + __eq = function (ipA, ipB) return ipA.addr == ipB.addr; end}; +local hex2bits = { ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111", ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011", ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111" }; + +local function new_ip(ipStr, proto) + if not proto then + local sep = ipStr:match("^%x+(.)"); + if sep == ":" or (not(sep) and ipStr:sub(1,1) == ":") then + proto = "IPv6" + elseif sep == "." then + proto = "IPv4" + end + if not proto then + return nil, "invalid address"; + end + elseif proto ~= "IPv4" and proto ~= "IPv6" then + return nil, "invalid protocol"; + end + if proto == "IPv6" and ipStr:find('.', 1, true) then + local changed; + ipStr, changed = ipStr:gsub(":(%d+)%.(%d+)%.(%d+)%.(%d+)$", function(a,b,c,d) + return (":%04X:%04X"):format(a*256+b,c*256+d); + end); + if changed ~= 1 then return nil, "invalid-address"; end + end + + return setmetatable({ addr = ipStr, proto = proto }, ip_mt); +end + +local function toBits(ip) + local result = ""; + local fields = {}; + if ip.proto == "IPv4" then + ip = ip.toV4mapped; + end + ip = (ip.addr):upper(); + ip:gsub("([^:]*):?", function (c) fields[#fields + 1] = c end); + if not ip:match(":$") then fields[#fields] = nil; end + for i, field in ipairs(fields) do + if field:len() == 0 and i ~= 1 and i ~= #fields then + for i = 1, 16 * (9 - #fields) do + result = result .. "0"; + end + else + for i = 1, 4 - field:len() do + result = result .. "0000"; + end + for i = 1, field:len() do + result = result .. hex2bits[field:sub(i,i)]; + end + end + end + return result; +end + +local function commonPrefixLength(ipA, ipB) + ipA, ipB = toBits(ipA), toBits(ipB); + for i = 1, 128 do + if ipA:sub(i,i) ~= ipB:sub(i,i) then + return i-1; + end + end + return 128; +end + +local function v4scope(ip) + local fields = {}; + ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end); + -- Loopback: + if fields[1] == 127 then + return 0x2; + -- Link-local unicast: + elseif fields[1] == 169 and fields[2] == 254 then + return 0x2; + -- Global unicast: + else + return 0xE; + end +end + +local function v6scope(ip) + -- Loopback: + if ip:match("^[0:]*1$") then + return 0x2; + -- Link-local unicast: + elseif ip:match("^[Ff][Ee][89ABab]") then + return 0x2; + -- Site-local unicast: + elseif ip:match("^[Ff][Ee][CcDdEeFf]") then + return 0x5; + -- Multicast: + elseif ip:match("^[Ff][Ff]") then + return tonumber("0x"..ip:sub(4,4)); + -- Global unicast: + else + return 0xE; + end +end + +local function label(ip) + if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then + return 0; + elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then + return 2; + elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then + return 5; + elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then + return 13; + elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then + return 11; + elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then + return 12; + elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then + return 3; + elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then + return 4; + else + return 1; + end +end + +local function precedence(ip) + if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then + return 50; + elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then + return 30; + elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then + return 5; + elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then + return 3; + elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then + return 1; + elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then + return 1; + elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then + return 1; + elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then + return 35; + else + return 40; + end +end + +local function toV4mapped(ip) + local fields = {}; + local ret = "::ffff:"; + ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end); + ret = ret .. ("%02x"):format(fields[1]); + ret = ret .. ("%02x"):format(fields[2]); + ret = ret .. ":" + ret = ret .. ("%02x"):format(fields[3]); + ret = ret .. ("%02x"):format(fields[4]); + return new_ip(ret, "IPv6"); +end + +function ip_methods:toV4mapped() + if self.proto ~= "IPv4" then return nil, "No IPv4 address" end + local value = toV4mapped(self.addr); + self.toV4mapped = value; + return value; +end + +function ip_methods:label() + local value; + if self.proto == "IPv4" then + value = label(self.toV4mapped); + else + value = label(self); + end + self.label = value; + return value; +end + +function ip_methods:precedence() + local value; + if self.proto == "IPv4" then + value = precedence(self.toV4mapped); + else + value = precedence(self); + end + self.precedence = value; + return value; +end + +function ip_methods:scope() + local value; + if self.proto == "IPv4" then + value = v4scope(self.addr); + else + value = v6scope(self.addr); + end + self.scope = value; + return value; +end + +function ip_methods:private() + local private = self.scope ~= 0xE; + if not private and self.proto == "IPv4" then + local ip = self.addr; + local fields = {}; + ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end); + if fields[1] == 127 or fields[1] == 10 or (fields[1] == 192 and fields[2] == 168) + or (fields[1] == 172 and (fields[2] >= 16 or fields[2] <= 32)) then + private = true; + end + end + self.private = private; + return private; +end + +local function parse_cidr(cidr) + local bits; + local ip_len = cidr:find("/", 1, true); + if ip_len then + bits = tonumber(cidr:sub(ip_len+1, -1)); + cidr = cidr:sub(1, ip_len-1); + end + return new_ip(cidr), bits; +end + +local function match(ipA, ipB, bits) + local common_bits = commonPrefixLength(ipA, ipB); + if not bits then + return ipA == ipB; + end + if bits and ipB.proto == "IPv4" then + common_bits = common_bits - 96; -- v6 mapped addresses always share these bits + end + return common_bits >= bits; +end + +return {new_ip = new_ip, + commonPrefixLength = commonPrefixLength, + parse_cidr = parse_cidr, + match=match}; diff --git a/util/iterators.lua b/util/iterators.lua index dc692d64..aa9c3ec0 100644 --- a/util/iterators.lua +++ b/util/iterators.lua @@ -1,15 +1,21 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- --[[ Iterators ]]-- +local it = {}; + +local t_insert = table.insert; +local select, unpack, next = select, unpack, next; +local function pack(...) return { n = select("#", ...), ... }; end + -- Reverse an iterator -function reverse(f, s, var) +function it.reverse(f, s, var) local results = {}; -- First call the normal iterator @@ -17,9 +23,9 @@ function reverse(f, s, var) local ret = { f(s, var) }; var = ret[1]; if var == nil then break; end - table.insert(results, 1, ret); + t_insert(results, 1, ret); end - + -- Then return our reverse one local i,max = 0, #results; return function (results) @@ -34,12 +40,12 @@ end local function _keys_it(t, key) return (next(t, key)); end -function keys(t) +function it.keys(t) return _keys_it, t; end -- Iterate only over values in a table -function values(t) +function it.values(t) local key, val; return function (t) key, val = next(t, key); @@ -48,38 +54,37 @@ function values(t) end -- Given an iterator, iterate only over unique items -function unique(f, s, var) +function it.unique(f, s, var) local set = {}; - + return function () while true do - local ret = { f(s, var) }; + local ret = pack(f(s, var)); var = ret[1]; if var == nil then break; end if not set[var] then set[var] = true; - return var; + return unpack(ret, 1, ret.n); end end end; end --[[ Return the number of items an iterator returns ]]-- -function count(f, s, var) +function it.count(f, s, var) local x = 0; - + while true do - local ret = { f(s, var) }; - var = ret[1]; + var = f(s, var); if var == nil then break; end x = x + 1; end - + return x; end -- Return the first n items an iterator returns -function head(n, f, s, var) +function it.head(n, f, s, var) local c = 0; return function (s, var) if c >= n then @@ -91,7 +96,7 @@ function head(n, f, s, var) end -- Skip the first n items an iterator returns -function skip(n, f, s, var) +function it.skip(n, f, s, var) for i=1,n do var = f(s, var); end @@ -99,10 +104,10 @@ function skip(n, f, s, var) end -- Return the last n items an iterator returns -function tail(n, f, s, var) +function it.tail(n, f, s, var) local results, count = {}, 0; while true do - local ret = { f(s, var) }; + local ret = pack(f(s, var)); var = ret[1]; if var == nil then break; end results[(count%n)+1] = ret; @@ -115,26 +120,52 @@ function tail(n, f, s, var) return function () pos = pos + 1; if pos > n then return nil; end - return unpack(results[((count-1+pos)%n)+1]); + local ret = results[((count-1+pos)%n)+1]; + return unpack(ret, 1, ret.n); end - --return reverse(head(n, reverse(f, s, var))); + --return reverse(head(n, reverse(f, s, var))); -- ! +end + +function it.filter(filter, f, s, var) + if type(filter) ~= "function" then + local filter_value = filter; + function filter(x) return x ~= filter_value; end + end + return function (s, var) + local ret; + repeat ret = pack(f(s, var)); + var = ret[1]; + until var == nil or filter(unpack(ret, 1, ret.n)); + return unpack(ret, 1, ret.n); + end, s, var; +end + +local function _ripairs_iter(t, key) if key > 1 then return key-1, t[key-1]; end end +function it.ripairs(t) + return _ripairs_iter, t, #t+1; +end + +local function _range_iter(max, curr) if curr < max then return curr + 1; end end +function it.range(x, y) + if not y then x, y = 1, x; end -- Default to 1..x if y not given + return _range_iter, y, x-1; end -- Convert the values returned by an iterator to an array -function it2array(f, s, var) +function it.to_array(f, s, var) local t, var = {}; while true do var = f(s, var); if var == nil then break; end - table.insert(t, var); + t_insert(t, var); end return t; end -- Treat the return of an iterator as key,value pairs, -- and build a table -function it2table(f, s, var) - local t, var = {}; +function it.to_table(f, s, var) + local t, var2 = {}; while true do var, var2 = f(s, var); if var == nil then break; end @@ -142,3 +173,5 @@ function it2table(f, s, var) end return t; end + +return it; diff --git a/util/jid.lua b/util/jid.lua index 069817c6..0d9a864f 100644 --- a/util/jid.lua +++ b/util/jid.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -13,6 +13,16 @@ local nodeprep = require "util.encodings".stringprep.nodeprep; local nameprep = require "util.encodings".stringprep.nameprep; local resourceprep = require "util.encodings".stringprep.resourceprep; +local escapes = { + [" "] = "\\20"; ['"'] = "\\22"; + ["&"] = "\\26"; ["'"] = "\\27"; + ["/"] = "\\2f"; [":"] = "\\3a"; + ["<"] = "\\3c"; [">"] = "\\3e"; + ["@"] = "\\40"; ["\\"] = "\\5c"; +}; +local unescapes = {}; +for k,v in pairs(escapes) do unescapes[v] = k; end + module "jid" local function _split(jid) @@ -91,4 +101,7 @@ function compare(jid, acl) return false end +function escape(s) return s and (s:gsub(".", escapes)); end +function unescape(s) return s and (s:gsub("\\%x%x", unescapes)); end + return _M; diff --git a/util/json.lua b/util/json.lua index 40939bb4..a8a58afc 100644 --- a/util/json.lua +++ b/util/json.lua @@ -1,14 +1,24 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- local type = type; -local t_insert, t_concat, t_remove = table.insert, table.concat, table.remove; +local t_insert, t_concat, t_remove, t_sort = table.insert, table.concat, table.remove, table.sort; local s_char = string.char; local tostring, tonumber = tostring, tonumber; local pairs, ipairs = pairs, ipairs; local next = next; local error = error; -local newproxy, getmetatable = newproxy, getmetatable; +local newproxy, getmetatable, setmetatable = newproxy, getmetatable, setmetatable; local print = print; +local has_array, array = pcall(require, "util.array"); +local array_mt = has_array and getmetatable(array()) or {}; + --module("json") local json = {}; @@ -29,6 +39,19 @@ for i=0,31 do if not escapes[ch] then escapes[ch] = ("\\u%.4X"):format(i); end end +local function codepoint_to_utf8(code) + if code < 0x80 then return s_char(code); end + local bits0_6 = code % 64; + if code < 0x800 then + local bits6_5 = (code - bits0_6) / 64; + return s_char(0x80 + 0x40 + bits6_5, 0x80 + bits0_6); + end + local bits0_12 = code % 4096; + local bits6_6 = (bits0_12 - bits0_6) / 64; + local bits12_4 = (code - bits0_12) / 4096; + return s_char(0x80 + 0x40 + 0x20 + bits12_4, 0x80 + bits6_6, 0x80 + bits0_6); +end + local valid_types = { number = true, string = true, @@ -79,11 +102,25 @@ function tablesave(o, buffer) if next(__hash) ~= nil or next(hash) ~= nil or next(__array) == nil then t_insert(buffer, "{"); local mark = #buffer; - for k,v in pairs(hash) do - stringsave(k, buffer); - t_insert(buffer, ":"); - simplesave(v, buffer); - t_insert(buffer, ","); + if buffer.ordered then + local keys = {}; + for k in pairs(hash) do + t_insert(keys, k); + end + t_sort(keys); + for _,k in ipairs(keys) do + stringsave(k, buffer); + t_insert(buffer, ":"); + simplesave(hash[k], buffer); + t_insert(buffer, ","); + end + else + for k,v in pairs(hash) do + stringsave(k, buffer); + t_insert(buffer, ":"); + simplesave(v, buffer); + t_insert(buffer, ","); + end end if next(__hash) ~= nil then t_insert(buffer, "\"__hash\":["); @@ -116,7 +153,12 @@ function simplesave(o, buffer) elseif t == "string" then stringsave(o, buffer); elseif t == "table" then - tablesave(o, buffer); + local mt = getmetatable(o); + if mt == array_mt then + arraysave(o, buffer); + else + tablesave(o, buffer); + end elseif t == "boolean" then t_insert(buffer, (o and "true" or "false")); else @@ -129,214 +171,191 @@ function json.encode(obj) simplesave(obj, t); return t_concat(t); end +function json.encode_ordered(obj) + local t = { ordered = true }; + simplesave(obj, t); + return t_concat(t); +end +function json.encode_array(obj) + local t = {}; + arraysave(obj, t); + return t_concat(t); +end ----------------------------------- -function json.decode(json) - local pos = 1; - local current = {}; - local stack = {}; - local ch, peek; - local function next() - ch = json:sub(pos, pos); - pos = pos+1; - peek = json:sub(pos, pos); - return ch; - end - - local function skipwhitespace() - while ch and (ch == "\r" or ch == "\n" or ch == "\t" or ch == " ") do - next(); +local function _skip_whitespace(json, index) + return json:find("[^ \t\r\n]", index) or index; -- no need to check \r\n, we converted those to \t +end +local function _fixobject(obj) + local __array = obj.__array; + if __array then + obj.__array = nil; + for i,v in ipairs(__array) do + t_insert(obj, v); end end - local function skiplinecomment() - repeat next(); until not(ch) or ch == "\r" or ch == "\n"; - skipwhitespace(); - end - local function skipstarcomment() - next(); next(); -- skip '/', '*' - while peek and ch ~= "*" and peek ~= "/" do next(); end - if not peek then error("eof in star comment") end - next(); next(); -- skip '*', '/' - skipwhitespace(); - end - local function skipstuff() - while true do - skipwhitespace(); - if ch == "/" and peek == "*" then - skipstarcomment(); - elseif ch == "/" and peek == "*" then - skiplinecomment(); + local __hash = obj.__hash; + if __hash then + obj.__hash = nil; + local k; + for i,v in ipairs(__hash) do + if k ~= nil then + obj[k] = v; k = nil; else - return; + k = v; end end end - - local readvalue; - local function readarray() - local t = {}; - next(); -- skip '[' - skipstuff(); - if ch == "]" then next(); return t; end - t_insert(t, readvalue()); - while true do - skipstuff(); - if ch == "]" then next(); return t; end - if not ch then error("eof while reading array"); - elseif ch == "," then next(); - elseif ch then error("unexpected character in array, comma expected"); end - if not ch then error("eof while reading array"); end - t_insert(t, readvalue()); + return obj; +end +local _readvalue, _readstring; +local function _readobject(json, index) + local o = {}; + while true do + local key, val; + index = _skip_whitespace(json, index + 1); + if json:byte(index) ~= 0x22 then -- "\"" + if json:byte(index) == 0x7d then return o, index + 1; end -- "}" + return nil, "key expected"; end + key, index = _readstring(json, index); + if key == nil then return nil, index; end + index = _skip_whitespace(json, index); + if json:byte(index) ~= 0x3a then return nil, "colon expected"; end -- ":" + val, index = _readvalue(json, index + 1); + if val == nil then return nil, index; end + o[key] = val; + index = _skip_whitespace(json, index); + local b = json:byte(index); + if b == 0x7d then return _fixobject(o), index + 1; end -- "}" + if b ~= 0x2c then return nil, "object eof"; end -- "," end - - local function checkandskip(c) - local x = ch or "eof"; - if x ~= c then error("unexpected "..x..", '"..c.."' expected"); end - next(); - end - local function readliteral(lit, val) - for c in lit:gmatch(".") do - checkandskip(c); +end +local function _readarray(json, index) + local a = {}; + local oindex = index; + while true do + local val; + val, index = _readvalue(json, index + 1); + if val == nil then + if json:byte(oindex + 1) == 0x5d then return setmetatable(a, array_mt), oindex + 2; end -- "]" + return val, index; end - return val; + t_insert(a, val); + index = _skip_whitespace(json, index); + local b = json:byte(index); + if b == 0x5d then return setmetatable(a, array_mt), index + 1; end -- "]" + if b ~= 0x2c then return nil, "array eof"; end -- "," end - local function readstring() - local s = ""; - checkandskip("\""); - while ch do - while ch and ch ~= "\\" and ch ~= "\"" do - s = s..ch; next(); - end - if ch == "\\" then - next(); - if unescapes[ch] then - s = s..unescapes[ch]; - next(); - elseif ch == "u" then - local seq = ""; - for i=1,4 do - next(); - if not ch then error("unexpected eof in string"); end - if not ch:match("[0-9a-fA-F]") then error("invalid unicode escape sequence in string"); end - seq = seq..ch; - end - s = s..s.char(tonumber(seq, 16)); -- FIXME do proper utf-8 - next(); - else error("invalid escape sequence in string"); end - end - if ch == "\"" then - next(); - return s; - end - end - error("eof while reading string"); +end +local _unescape_error; +local function _unescape_surrogate_func(x) + local lead, trail = tonumber(x:sub(3, 6), 16), tonumber(x:sub(9, 12), 16); + local codepoint = lead * 0x400 + trail - 0x35FDC00; + local a = codepoint % 64; + codepoint = (codepoint - a) / 64; + local b = codepoint % 64; + codepoint = (codepoint - b) / 64; + local c = codepoint % 64; + codepoint = (codepoint - c) / 64; + return s_char(0xF0 + codepoint, 0x80 + c, 0x80 + b, 0x80 + a); +end +local function _unescape_func(x) + x = x:match("%x%x%x%x", 3); + if x then + --if x >= 0xD800 and x <= 0xDFFF then _unescape_error = true; end -- bad surrogate pair + return codepoint_to_utf8(tonumber(x, 16)); end - local function readnumber() - local s = ""; - if ch == "-" then - s = s..ch; next(); - if not ch:match("[0-9]") then error("number format error"); end - end - if ch == "0" then - s = s..ch; next(); - if ch:match("[0-9]") then error("number format error"); end - else - while ch and ch:match("[0-9]") do - s = s..ch; next(); - end - end - if ch == "." then - s = s..ch; next(); - if not ch:match("[0-9]") then error("number format error"); end - while ch and ch:match("[0-9]") do - s = s..ch; next(); - end - if ch == "e" or ch == "E" then - s = s..ch; next(); - if ch == "+" or ch == "-" then - s = s..ch; next(); - if not ch:match("[0-9]") then error("number format error"); end - while ch and ch:match("[0-9]") do - s = s..ch; next(); - end - end - end - end - return tonumber(s); + _unescape_error = true; +end +function _readstring(json, index) + index = index + 1; + local endindex = json:find("\"", index, true); + if endindex then + local s = json:sub(index, endindex - 1); + --if s:find("[%z-\31]") then return nil, "control char in string"; end + -- FIXME handle control characters + _unescape_error = nil; + --s = s:gsub("\\u[dD][89abAB]%x%x\\u[dD][cdefCDEF]%x%x", _unescape_surrogate_func); + -- FIXME handle escapes beyond BMP + s = s:gsub("\\u.?.?.?.?", _unescape_func); + if _unescape_error then return nil, "invalid escape"; end + return s, endindex + 1; end - local function readmember(t) - local k = readstring(); - checkandskip(":"); - t[k] = readvalue(); + return nil, "string eof"; +end +local function _readnumber(json, index) + local m = json:match("[0-9%.%-eE%+]+", index); -- FIXME do strict checking + return tonumber(m), index + #m; +end +local function _readnull(json, index) + local a, b, c = json:byte(index + 1, index + 3); + if a == 0x75 and b == 0x6c and c == 0x6c then + return null, index + 4; end - local function fixobject(obj) - local __array = obj.__array; - if __array then - obj.__array = nil; - for i,v in ipairs(__array) do - t_insert(obj, v); - end - end - local __hash = obj.__hash; - if __hash then - obj.__hash = nil; - local k; - for i,v in ipairs(__hash) do - if k ~= nil then - obj[k] = v; k = nil; - else - k = v; - end - end - end - return obj; + return nil, "null parse failed"; +end +local function _readtrue(json, index) + local a, b, c = json:byte(index + 1, index + 3); + if a == 0x72 and b == 0x75 and c == 0x65 then + return true, index + 4; end - local function readobject() - local t = {}; - next(); -- skip '{' - skipstuff(); - if ch == "}" then next(); return t; end - if not ch then error("eof while reading object"); end - readmember(t); - while true do - skipstuff(); - if ch == "}" then next(); return fixobject(t); end - if not ch then error("eof while reading object"); - elseif ch == "," then next(); - elseif ch then error("unexpected character in object, comma expected"); end - if not ch then error("eof while reading object"); end - readmember(t); - end + return nil, "true parse failed"; +end +local function _readfalse(json, index) + local a, b, c, d = json:byte(index + 1, index + 4); + if a == 0x61 and b == 0x6c and c == 0x73 and d == 0x65 then + return false, index + 5; end - - function readvalue() - skipstuff(); - while ch do - if ch == "{" then - return readobject(); - elseif ch == "[" then - return readarray(); - elseif ch == "\"" then - return readstring(); - elseif ch:match("[%-0-9%.]") then - return readnumber(); - elseif ch == "n" then - return readliteral("null", null); - elseif ch == "t" then - return readliteral("true", true); - elseif ch == "f" then - return readliteral("false", false); - else - error("invalid character at value start: "..ch); - end - end - error("eof while reading value"); + return nil, "false parse failed"; +end +function _readvalue(json, index) + index = _skip_whitespace(json, index); + local b = json:byte(index); + -- TODO try table lookup instead of if-else? + if b == 0x7B then -- "{" + return _readobject(json, index); + elseif b == 0x5B then -- "[" + return _readarray(json, index); + elseif b == 0x22 then -- "\"" + return _readstring(json, index); + elseif b ~= nil and b >= 0x30 and b <= 0x39 or b == 0x2d then -- "0"-"9" or "-" + return _readnumber(json, index); + elseif b == 0x6e then -- "n" + return _readnull(json, index); + elseif b == 0x74 then -- "t" + return _readtrue(json, index); + elseif b == 0x66 then -- "f" + return _readfalse(json, index); + else + return nil, "value expected"; end - next(); - return readvalue(); +end +local first_escape = { + ["\\\""] = "\\u0022"; + ["\\\\"] = "\\u005c"; + ["\\/" ] = "\\u002f"; + ["\\b" ] = "\\u0008"; + ["\\f" ] = "\\u000C"; + ["\\n" ] = "\\u000A"; + ["\\r" ] = "\\u000D"; + ["\\t" ] = "\\u0009"; + ["\\u" ] = "\\u"; +}; + +function json.decode(json) + json = json:gsub("\\.", first_escape) -- get rid of all escapes except \uXXXX, making string parsing much simpler + --:gsub("[\r\n]", "\t"); -- \r\n\t are equivalent, we care about none of them, and none of them can be in strings + + -- TODO do encoding verification + + local val, index = _readvalue(json, 1); + if val == nil then return val, index; end + if json:find("[^ \t\r\n]", index) then return nil, "garbage at eof"; end + + return val; end function json.test(object) diff --git a/util/logger.lua b/util/logger.lua index c3bf3992..cd0769f9 100644 --- a/util/logger.lua +++ b/util/logger.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -13,8 +13,7 @@ local ipairs, pairs, setmetatable = ipairs, pairs, setmetatable; module "logger" -local name_sinks, level_sinks = {}, {}; -local name_patterns = {}; +local level_sinks = {}; local make_logger; @@ -24,8 +23,6 @@ function init(name) local log_warn = make_logger(name, "warn"); local log_error = make_logger(name, "error"); - --name = nil; -- While this line is not commented, will automatically fill in file/line number info - local namelen = #name; return function (level, message, ...) if level == "debug" then return log_debug(message, ...); @@ -46,17 +43,7 @@ function make_logger(source_name, level) level_sinks[level] = level_handlers; end - local source_handlers = name_sinks[source_name]; - local logger = function (message, ...) - if source_handlers then - for i = 1,#source_handlers do - if source_handlers[i](source_name, level, message, ...) == false then - return; - end - end - end - for i = 1,#level_handlers do level_handlers[i](source_name, level, message, ...); end @@ -66,14 +53,12 @@ function make_logger(source_name, level) end function reset() - for k in pairs(name_sinks) do name_sinks[k] = nil; end for level, handler_list in pairs(level_sinks) do -- Clear all handlers for this level for i = 1, #handler_list do handler_list[i] = nil; end end - for k in pairs(name_patterns) do name_patterns[k] = nil; end end function add_level_sink(level, sink_function) @@ -84,22 +69,6 @@ function add_level_sink(level, sink_function) end end -function add_name_sink(name, sink_function, exclusive) - if not name_sinks[name] then - name_sinks[name] = { sink_function }; - else - name_sinks[name][#name_sinks[name] + 1] = sink_function; - end -end - -function add_name_pattern_sink(name_pattern, sink_function, exclusive) - if not name_patterns[name_pattern] then - name_patterns[name_pattern] = { sink_function }; - else - name_patterns[name_pattern][#name_patterns[name_pattern] + 1] = sink_function; - end -end - _M.new = make_logger; return _M; diff --git a/util/multitable.lua b/util/multitable.lua index 66b9bd8a..caf25118 100644 --- a/util/multitable.lua +++ b/util/multitable.lua @@ -1,17 +1,14 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- - - local select = select; local t_insert = table.insert; -local pairs = pairs; -local next = next; +local unpack, pairs, next, type = unpack, pairs, next, type; module "multitable" @@ -129,6 +126,41 @@ local function search_add(self, results, ...) return results; end +function iter(self, ...) + local query = { ... }; + local maxdepth = select("#", ...); + local stack = { self.data }; + local keys = { }; + local function it(self) + local depth = #stack; + local key = next(stack[depth], keys[depth]); + if key == nil then -- Go up the stack + stack[depth], keys[depth] = nil, nil; + if depth > 1 then + return it(self); + end + return; -- The end + else + keys[depth] = key; + end + local value = stack[depth][key]; + if query[depth] == nil or key == query[depth] then + if depth == maxdepth then -- Result + local result = {}; -- Collect keys forming path to result + for i = 1, depth do + result[i] = keys[i]; + end + result[depth+1] = value; + return unpack(result, 1, depth+1); + elseif type(value) == "table" then + t_insert(stack, value); -- Descend + end + end + return it(self); + end; + return it, self; +end + function new() return { data = {}; @@ -138,6 +170,7 @@ function new() remove = remove; search = search; search_add = search_add; + iter = iter; }; end diff --git a/util/openssl.lua b/util/openssl.lua new file mode 100644 index 00000000..ef3fba96 --- /dev/null +++ b/util/openssl.lua @@ -0,0 +1,172 @@ +local type, tostring, pairs, ipairs = type, tostring, pairs, ipairs; +local t_insert, t_concat = table.insert, table.concat; +local s_format = string.format; + +local oid_xmppaddr = "1.3.6.1.5.5.7.8.5"; -- [XMPP-CORE] +local oid_dnssrv = "1.3.6.1.5.5.7.8.7"; -- [SRV-ID] + +local idna_to_ascii = require "util.encodings".idna.to_ascii; + +local _M = {}; +local config = {}; +_M.config = config; + +local ssl_config = {}; +local ssl_config_mt = {__index=ssl_config}; + +function config.new() + return setmetatable({ + req = { + distinguished_name = "distinguished_name", + req_extensions = "v3_extensions", + x509_extensions = "v3_extensions", + prompt = "no", + }, + distinguished_name = { + countryName = "GB", + -- stateOrProvinceName = "", + localityName = "The Internet", + organizationName = "Your Organisation", + organizationalUnitName = "XMPP Department", + commonName = "example.com", + emailAddress = "xmpp@example.com", + }, + v3_extensions = { + basicConstraints = "CA:FALSE", + keyUsage = "digitalSignature,keyEncipherment", + extendedKeyUsage = "serverAuth,clientAuth", + subjectAltName = "@subject_alternative_name", + }, + subject_alternative_name = { + DNS = {}, + otherName = {}, + }, + }, ssl_config_mt); +end + +local DN_order = { + "countryName"; + "stateOrProvinceName"; + "localityName"; + "streetAddress"; + "organizationName"; + "organizationalUnitName"; + "commonName"; + "emailAddress"; +} +_M._DN_order = DN_order; +function ssl_config:serialize() + local s = ""; + for k, t in pairs(self) do + s = s .. ("[%s]\n"):format(k); + if k == "subject_alternative_name" then + for san, n in pairs(t) do + for i = 1,#n do + s = s .. s_format("%s.%d = %s\n", san, i -1, n[i]); + end + end + elseif k == "distinguished_name" then + for i=1,#DN_order do + local k = DN_order[i] + local v = t[k]; + if v then + s = s .. ("%s = %s\n"):format(k, v); + end + end + else + for k, v in pairs(t) do + s = s .. ("%s = %s\n"):format(k, v); + end + end + s = s .. "\n"; + end + return s; +end + +local function utf8string(s) + -- This is how we tell openssl not to encode UTF-8 strings as fake Latin1 + return s_format("FORMAT:UTF8,UTF8:%s", s); +end + +local function ia5string(s) + return s_format("IA5STRING:%s", s); +end + +_M.util = { + utf8string = utf8string, + ia5string = ia5string, +}; + +function ssl_config:add_dNSName(host) + t_insert(self.subject_alternative_name.DNS, idna_to_ascii(host)); +end + +function ssl_config:add_sRVName(host, service) + t_insert(self.subject_alternative_name.otherName, + s_format("%s;%s", oid_dnssrv, ia5string("_" .. service .."." .. idna_to_ascii(host)))); +end + +function ssl_config:add_xmppAddr(host) + t_insert(self.subject_alternative_name.otherName, + s_format("%s;%s", oid_xmppaddr, utf8string(host))); +end + +function ssl_config:from_prosody(hosts, config, certhosts) + -- TODO Decide if this should go elsewhere + local found_matching_hosts = false; + for i = 1,#certhosts do + local certhost = certhosts[i]; + for name in pairs(hosts) do + if name == certhost or name:sub(-1-#certhost) == "."..certhost then + found_matching_hosts = true; + self:add_dNSName(name); + --print(name .. "#component_module: " .. (config.get(name, "component_module") or "nil")); + if config.get(name, "component_module") == nil then + self:add_sRVName(name, "xmpp-client"); + end + --print(name .. "#anonymous_login: " .. tostring(config.get(name, "anonymous_login"))); + if not (config.get(name, "anonymous_login") or + config.get(name, "authentication") == "anonymous") then + self:add_sRVName(name, "xmpp-server"); + end + self:add_xmppAddr(name); + end + end + end + if not found_matching_hosts then + return nil, "no-matching-hosts"; + end +end + +do -- Lua to shell calls. + local function shell_escape(s) + return s:gsub("'",[['\'']]); + end + + local function serialize(f,o) + local r = {"openssl", f}; + for k,v in pairs(o) do + if type(k) == "string" then + t_insert(r, ("-%s"):format(k)); + if v ~= true then + t_insert(r, ("'%s'"):format(shell_escape(tostring(v)))); + end + end + end + for _,v in ipairs(o) do + t_insert(r, ("'%s'"):format(shell_escape(tostring(v)))); + end + return t_concat(r, " "); + end + + local os_execute = os.execute; + setmetatable(_M, { + __index=function(_,f) + return function(opts) + return 0 == os_execute(serialize(f, type(opts) == "table" and opts or {})); + end; + end; + }); +end + +return _M; diff --git a/util/pluginloader.lua b/util/pluginloader.lua index 31ab1e88..b894f527 100644 --- a/util/pluginloader.lua +++ b/util/pluginloader.lua @@ -1,58 +1,60 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- +local dir_sep, path_sep = package.config:match("^(%S+)%s(%S+)"); +local plugin_dir = {}; +for path in (CFG_PLUGINDIR or "./plugins/"):gsub("[/\\]", dir_sep):gmatch("[^"..path_sep.."]+") do + path = path..dir_sep; -- add path separator to path end + path = path:gsub(dir_sep..dir_sep.."+", dir_sep); -- coalesce multiple separaters + plugin_dir[#plugin_dir + 1] = path; +end -local plugin_dir = CFG_PLUGINDIR or "./plugins/"; - -local io_open, os_time = io.open, os.time; -local loadstring, pairs = loadstring, pairs; +local io_open = io.open; +local envload = require "util.envload".envload; module "pluginloader" -local function load_file(name) - local file, err = io_open(plugin_dir..name); - if not file then return file, err; end - local content = file:read("*a"); - file:close(); - return content, name; +function load_file(names) + local file, err, path; + for i=1,#plugin_dir do + for j=1,#names do + path = plugin_dir[i]..names[j]; + file, err = io_open(path); + if file then + local content = file:read("*a"); + file:close(); + return content, path; + end + end + end + return file, err; end -function load_resource(plugin, resource, loader) - local path, name = plugin:match("([^/]*)/?(.*)"); - if name == "" then - if not resource then - resource = "mod_"..plugin..".lua"; - end - loader = loader or load_file; - - local content, err = loader(plugin.."/"..resource); - if not content then content, err = loader(resource); end - -- TODO add support for packed plugins - - return content, err; - else - if not resource then - resource = "mod_"..name..".lua"; - end - loader = loader or load_file; +function load_resource(plugin, resource) + resource = resource or "mod_"..plugin..".lua"; - local content, err = loader(plugin.."/"..resource); - if not content then content, err = loader(path.."/"..resource); end - -- TODO add support for packed plugins - - return content, err; - end + local names = { + "mod_"..plugin.."/"..plugin.."/"..resource; -- mod_hello/hello/mod_hello.lua + "mod_"..plugin.."/"..resource; -- mod_hello/mod_hello.lua + plugin.."/"..resource; -- hello/mod_hello.lua + resource; -- mod_hello.lua + }; + + return load_file(names); end -function load_code(plugin, resource) +function load_code(plugin, resource, env) local content, err = load_resource(plugin, resource); if not content then return content, err; end - return loadstring(content, "@"..err); + local path = err; + local f, err = envload(content, "@"..path, env); + if not f then return f, err; end + return f, path; end return _M; diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua index 40d21be8..fe862114 100644 --- a/util/prosodyctl.lua +++ b/util/prosodyctl.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -15,18 +15,119 @@ local usermanager = require "core.usermanager"; local signal = require "util.signal"; local set = require "util.set"; local lfs = require "lfs"; +local pcall = pcall; +local type = type; local nodeprep, nameprep = stringprep.nodeprep, stringprep.nameprep; local io, os = io, os; +local print = print; local tostring, tonumber = tostring, tonumber; local CFG_SOURCEDIR = _G.CFG_SOURCEDIR; +local _G = _G; local prosody = prosody; module "prosodyctl" +-- UI helpers +function show_message(msg, ...) + print(msg:format(...)); +end + +function show_warning(msg, ...) + print(msg:format(...)); +end + +function show_usage(usage, desc) + print("Usage: ".._G.arg[0].." "..usage); + if desc then + print(" "..desc); + end +end + +function getchar(n) + local stty_ret = os.execute("stty raw -echo 2>/dev/null"); + local ok, char; + if stty_ret == 0 then + ok, char = pcall(io.read, n or 1); + os.execute("stty sane"); + else + ok, char = pcall(io.read, "*l"); + if ok then + char = char:sub(1, n or 1); + end + end + if ok then + return char; + end +end + +function getline() + local ok, line = pcall(io.read, "*l"); + if ok then + return line; + end +end + +function getpass() + local stty_ret = os.execute("stty -echo 2>/dev/null"); + if stty_ret ~= 0 then + io.write("\027[08m"); -- ANSI 'hidden' text attribute + end + local ok, pass = pcall(io.read, "*l"); + if stty_ret == 0 then + os.execute("stty sane"); + else + io.write("\027[00m"); + end + io.write("\n"); + if ok then + return pass; + end +end + +function show_yesno(prompt) + io.write(prompt, " "); + local choice = getchar():lower(); + io.write("\n"); + if not choice:match("%a") then + choice = prompt:match("%[.-(%U).-%]$"); + if not choice then return nil; end + end + return (choice == "y"); +end + +function read_password() + local password; + while true do + io.write("Enter new password: "); + password = getpass(); + if not password then + show_message("No password - cancelled"); + return; + end + io.write("Retype new password: "); + if getpass() ~= password then + if not show_yesno [=[Passwords did not match, try again? [Y/n]]=] then + return; + end + else + break; + end + end + return password; +end + +function show_prompt(prompt) + io.write(prompt, " "); + local line = getline(); + line = line and line:gsub("\n$",""); + return (line and #line > 0) and line or nil; +end + +-- Server control function adduser(params) local user, host, password = nodeprep(params.user), nameprep(params.host), params.password; if not user then @@ -35,12 +136,17 @@ function adduser(params) return false, "invalid-hostname"; end - local provider = prosody.hosts[host].users; + local host_session = prosody.hosts[host]; + if not host_session then + return false, "no-such-host"; + end + + storagemanager.initialize_host(host); + local provider = host_session.users; if not(provider) or provider.name == "null" then usermanager.initialize_host(host); end - storagemanager.initialize_host(host); - + local ok, errmsg = usermanager.create_user(user, password, host); if not ok then return false, errmsg; @@ -50,12 +156,13 @@ end function user_exists(params) local user, host, password = nodeprep(params.user), nameprep(params.host), params.password; + + storagemanager.initialize_host(host); local provider = prosody.hosts[host].users; if not(provider) or provider.name == "null" then usermanager.initialize_host(host); end - storagemanager.initialize_host(host); - + return usermanager.user_exists(user, host); end @@ -63,7 +170,7 @@ function passwd(params) if not _M.user_exists(params) then return false, "no-such-user"; end - + return _M.adduser(params); end @@ -71,40 +178,40 @@ function deluser(params) if not _M.user_exists(params) then return false, "no-such-user"; end - params.password = nil; - - return _M.adduser(params); + local user, host = nodeprep(params.user), nameprep(params.host); + + return usermanager.delete_user(user, host); end function getpid() - local pidfile = config.get("*", "core", "pidfile"); + local pidfile = config.get("*", "pidfile"); if not pidfile then return false, "no-pidfile"; end - - local modules_enabled = set.new(config.get("*", "core", "modules_enabled")); + + local modules_enabled = set.new(config.get("*", "modules_enabled")); if not modules_enabled:contains("posix") then return false, "no-posix"; end - + local file, err = io.open(pidfile, "r+"); if not file then return false, "pidfile-read-failed", err; end - + local locked, err = lfs.lock(file, "w"); if locked then file:close(); return false, "pidfile-not-locked"; end - + local pid = tonumber(file:read("*a")); file:close(); - + if not pid then return false, "invalid-pid"; end - + return true, pid; end @@ -145,10 +252,28 @@ function stop() if not ret then return false, "not-running"; end - + local ok, pid = _M.getpid() if not ok then return false, pid; end - + signal.kill(pid, signal.SIGTERM); return true; end + +function reload() + local ok, ret = _M.isrunning(); + if not ok then + return ok, ret; + end + if not ret then + return false, "not-running"; + end + + local ok, pid = _M.getpid() + if not ok then return false, pid; end + + signal.kill(pid, signal.SIGHUP); + return true; +end + +return _M; diff --git a/util/pubsub.lua b/util/pubsub.lua index 3beafab5..0dfd196b 100644 --- a/util/pubsub.lua +++ b/util/pubsub.lua @@ -1,3 +1,5 @@ +local events = require "util.events"; + module("pubsub", package.seeall); local service = {}; @@ -16,6 +18,7 @@ function new(config) affiliations = {}; subscriptions = {}; nodes = {}; + events = events.new(); }, service_mt); end @@ -26,18 +29,15 @@ end function service:may(node, actor, action) if actor == true then return true; end - - + local node_obj = self.nodes[node]; local node_aff = node_obj and node_obj.affiliations[actor]; local service_aff = self.affiliations[actor] or self.config.get_affiliation(actor, node, action) or "none"; - + + -- Check if node allows/forbids it local node_capabilities = node_obj and node_obj.capabilities; - local service_capabilities = self.config.capabilities; - - -- Check if node allows/forbids it if node_capabilities then local caps = node_capabilities[node_aff or service_aff]; if caps then @@ -47,7 +47,9 @@ function service:may(node, actor, action) end end end + -- Check service-wide capabilities instead + local service_capabilities = self.config.capabilities; local caps = service_capabilities[node_aff or service_aff]; if caps then local can = caps[action]; @@ -55,7 +57,7 @@ function service:may(node, actor, action) return can; end end - + return false; end @@ -70,14 +72,14 @@ function service:set_affiliation(node, actor, jid, affiliation) return false, "item-not-found"; end node_obj.affiliations[jid] = affiliation; - local _, jid_sub = self:get_subscription(node, nil, jid); + local _, jid_sub = self:get_subscription(node, true, jid); if not jid_sub and not self:may(node, jid, "be_unsubscribed") then - local ok, err = self:add_subscription(node, nil, jid); + local ok, err = self:add_subscription(node, true, jid); if not ok then return ok, err; end elseif jid_sub and not self:may(node, jid, "be_subscribed") then - local ok, err = self:add_subscription(node, nil, jid); + local ok, err = self:add_subscription(node, true, jid); if not ok then return ok, err; end @@ -88,7 +90,7 @@ end function service:add_subscription(node, actor, jid, options) -- Access checking local cap; - if jid == actor or self:jids_equal(actor, jid) then + if actor == true or jid == actor or self:jids_equal(actor, jid) then cap = "subscribe"; else cap = "subscribe_other"; @@ -105,7 +107,7 @@ function service:add_subscription(node, actor, jid, options) if not self.config.autocreate_on_subscribe then return false, "item-not-found"; else - local ok, err = self:create(node, actor); + local ok, err = self:create(node, true); if not ok then return ok, err; end @@ -124,13 +126,14 @@ function service:add_subscription(node, actor, jid, options) else self.subscriptions[normal_jid] = { [jid] = { [node] = true } }; end + self.events.fire_event("subscription-added", { node = node, jid = jid, normalized_jid = normal_jid, options = options }); return true; end function service:remove_subscription(node, actor, jid) -- Access checking local cap; - if jid == actor or self:jids_equal(actor, jid) then + if actor == true or jid == actor or self:jids_equal(actor, jid) then cap = "unsubscribe"; else cap = "unsubscribe_other"; @@ -164,13 +167,26 @@ function service:remove_subscription(node, actor, jid) self.subscriptions[normal_jid] = nil; end end + self.events.fire_event("subscription-removed", { node = node, jid = jid, normalized_jid = normal_jid }); + return true; +end + +function service:remove_all_subscriptions(actor, jid) + local normal_jid = self.config.normalize_jid(jid); + local subs = self.subscriptions[normal_jid] + subs = subs and subs[jid]; + if subs then + for node in pairs(subs) do + self:remove_subscription(node, true, jid); + end + end return true; end function service:get_subscription(node, actor, jid) -- Access checking local cap; - if jid == actor or self:jids_equal(actor, jid) then + if actor == true or jid == actor or self:jids_equal(actor, jid) then cap = "get_subscription"; else cap = "get_subscription_other"; @@ -195,7 +211,7 @@ function service:create(node, actor) if self.nodes[node] then return false, "conflict"; end - + self.nodes[node] = { name = node; subscribers = {}; @@ -210,6 +226,21 @@ function service:create(node, actor) return ok, err; end +function service:delete(node, actor) + -- Access checking + if not self:may(node, actor, "delete") then + return false, "forbidden"; + end + -- + local node_obj = self.nodes[node]; + if not node_obj then + return false, "item-not-found"; + end + self.nodes[node] = nil; + self.config.broadcaster("delete", node, node_obj.subscribers); + return true; +end + function service:publish(node, actor, id, item) -- Access checking if not self:may(node, actor, "publish") then @@ -221,14 +252,15 @@ function service:publish(node, actor, id, item) if not self.config.autocreate_on_publish then return false, "item-not-found"; end - local ok, err = self:create(node, actor); + local ok, err = self:create(node, true); if not ok then return ok, err; end node_obj = self.nodes[node]; end node_obj.data[id] = item; - self.config.broadcaster(node, node_obj.subscribers, item); + self.events.fire_event("item-published", { node = node, actor = actor, id = id, item = item }); + self.config.broadcaster("items", node, node_obj.subscribers, item); return true; end @@ -244,7 +276,24 @@ function service:retract(node, actor, id, retract) end node_obj.data[id] = nil; if retract then - self.config.broadcaster(node, node_obj.subscribers, retract); + self.config.broadcaster("items", node, node_obj.subscribers, retract); + end + return true +end + +function service:purge(node, actor, notify) + -- Access checking + if not self:may(node, actor, "retract") then + return false, "forbidden"; + end + -- + local node_obj = self.nodes[node]; + if not node_obj then + return false, "item-not-found"; + end + node_obj.data = {}; -- Purge + if notify then + self.config.broadcaster("purge", node, node_obj.subscribers); end return true end @@ -278,7 +327,7 @@ end function service:get_subscriptions(node, actor, jid) -- Access checking local cap; - if jid == actor or self:jids_equal(actor, jid) then + if actor == true or jid == actor or self:jids_equal(actor, jid) then cap = "get_subscriptions"; else cap = "get_subscriptions_other"; @@ -304,7 +353,7 @@ function service:get_subscriptions(node, actor, jid) if node then -- Return only subscriptions to this node if subscribed_nodes[node] then ret[#ret+1] = { - node = subscribed_node; + node = node; jid = jid; subscription = node_obj.subscribers[jid]; }; diff --git a/util/rfc6724.lua b/util/rfc6724.lua new file mode 100644 index 00000000..c8aec631 --- /dev/null +++ b/util/rfc6724.lua @@ -0,0 +1,142 @@ +-- Prosody IM +-- Copyright (C) 2011-2013 Florian Zeitz +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +-- This is used to sort destination addresses by preference +-- during S2S connections. +-- We can't hand this off to getaddrinfo, since it blocks + +local ip_commonPrefixLength = require"util.ip".commonPrefixLength +local new_ip = require"util.ip".new_ip; + +local function commonPrefixLength(ipA, ipB) + local len = ip_commonPrefixLength(ipA, ipB); + return len < 64 and len or 64; +end + +local function t_sort(t, comp) + for i = 1, (#t - 1) do + for j = (i + 1), #t do + local a, b = t[i], t[j]; + if not comp(a,b) then + t[i], t[j] = b, a; + end + end + end +end + +local function source(dest, candidates) + local function comp(ipA, ipB) + -- Rule 1: Prefer same address + if dest == ipA then + return true; + elseif dest == ipB then + return false; + end + + -- Rule 2: Prefer appropriate scope + if ipA.scope < ipB.scope then + if ipA.scope < dest.scope then + return false; + else + return true; + end + elseif ipA.scope > ipB.scope then + if ipB.scope < dest.scope then + return true; + else + return false; + end + end + + -- Rule 3: Avoid deprecated addresses + -- XXX: No way to determine this + -- Rule 4: Prefer home addresses + -- XXX: Mobility Address related, no way to determine this + -- Rule 5: Prefer outgoing interface + -- XXX: Interface to address relation. No way to determine this + -- Rule 6: Prefer matching label + if ipA.label == dest.label and ipB.label ~= dest.label then + return true; + elseif ipB.label == dest.label and ipA.label ~= dest.label then + return false; + end + + -- Rule 7: Prefer temporary addresses (over public ones) + -- XXX: No way to determine this + -- Rule 8: Use longest matching prefix + if commonPrefixLength(ipA, dest) > commonPrefixLength(ipB, dest) then + return true; + else + return false; + end + end + + t_sort(candidates, comp); + return candidates[1]; +end + +local function destination(candidates, sources) + local sourceAddrs = {}; + local function comp(ipA, ipB) + local ipAsource = sourceAddrs[ipA]; + local ipBsource = sourceAddrs[ipB]; + -- Rule 1: Avoid unusable destinations + -- XXX: No such information + -- Rule 2: Prefer matching scope + if ipA.scope == ipAsource.scope and ipB.scope ~= ipBsource.scope then + return true; + elseif ipA.scope ~= ipAsource.scope and ipB.scope == ipBsource.scope then + return false; + end + + -- Rule 3: Avoid deprecated addresses + -- XXX: No way to determine this + -- Rule 4: Prefer home addresses + -- XXX: Mobility Address related, no way to determine this + -- Rule 5: Prefer matching label + if ipAsource.label == ipA.label and ipBsource.label ~= ipB.label then + return true; + elseif ipBsource.label == ipB.label and ipAsource.label ~= ipA.label then + return false; + end + + -- Rule 6: Prefer higher precedence + if ipA.precedence > ipB.precedence then + return true; + elseif ipA.precedence < ipB.precedence then + return false; + end + + -- Rule 7: Prefer native transport + -- XXX: No way to determine this + -- Rule 8: Prefer smaller scope + if ipA.scope < ipB.scope then + return true; + elseif ipA.scope > ipB.scope then + return false; + end + + -- Rule 9: Use longest matching prefix + if commonPrefixLength(ipA, ipAsource) > commonPrefixLength(ipB, ipBsource) then + return true; + elseif commonPrefixLength(ipA, ipAsource) < commonPrefixLength(ipB, ipBsource) then + return false; + end + + -- Rule 10: Otherwise, leave order unchanged + return true; + end + for _, ip in ipairs(candidates) do + sourceAddrs[ip] = source(ip, sources); + end + + t_sort(candidates, comp); + return candidates; +end + +return {source = source, + destination = destination}; diff --git a/util/sasl.lua b/util/sasl.lua index 393a0919..0d90880d 100644 --- a/util/sasl.lua +++ b/util/sasl.lua @@ -68,13 +68,17 @@ end -- create a new SASL object which can be used to authenticate clients function new(realm, profile) - local mechanisms = {}; - for backend, f in pairs(profile) do - if backend_mechanism[backend] then - for _, mechanism in ipairs(backend_mechanism[backend]) do - mechanisms[mechanism] = true; + local mechanisms = profile.mechanisms; + if not mechanisms then + mechanisms = {}; + for backend, f in pairs(profile) do + if backend_mechanism[backend] then + for _, mechanism in ipairs(backend_mechanism[backend]) do + mechanisms[mechanism] = true; + end end end + profile.mechanisms = mechanisms; end return setmetatable({ profile = profile, realm = realm, mechs = mechanisms }, method); end @@ -131,5 +135,6 @@ require "util.sasl.plain" .init(registerMechanism); require "util.sasl.digest-md5".init(registerMechanism); require "util.sasl.anonymous" .init(registerMechanism); require "util.sasl.scram" .init(registerMechanism); +require "util.sasl.external" .init(registerMechanism); return _M; diff --git a/util/sasl/anonymous.lua b/util/sasl/anonymous.lua index b9af17fe..ca5fe404 100644 --- a/util/sasl/anonymous.lua +++ b/util/sasl/anonymous.lua @@ -16,7 +16,7 @@ local s_match = string.match; local log = require "util.logger".init("sasl"); local generate_uuid = require "util.uuid".generate; -module "anonymous" +module "sasl.anonymous" --========================= --SASL ANONYMOUS according to RFC 4505 @@ -43,4 +43,4 @@ function init(registerMechanism) registerMechanism("ANONYMOUS", {"anonymous"}, anonymous); end -return _M;
\ No newline at end of file +return _M; diff --git a/util/sasl/digest-md5.lua b/util/sasl/digest-md5.lua index 6f2c765e..591d8537 100644 --- a/util/sasl/digest-md5.lua +++ b/util/sasl/digest-md5.lua @@ -23,8 +23,9 @@ local to_byte, to_char = string.byte, string.char; local md5 = require "util.hashes".md5; local log = require "util.logger".init("sasl"); local generate_uuid = require "util.uuid".generate; +local nodeprep = require "util.encodings".stringprep.nodeprep; -module "digest-md5" +module "sasl.digest-md5" --========================= --SASL DIGEST-MD5 according to RFC 2831 @@ -139,10 +140,15 @@ local function digest(self, message) end -- check for username, it's REQUIRED by RFC 2831 - if not response["username"] then + local username = response["username"]; + local _nodeprep = self.profile.nodeprep; + if username and _nodeprep ~= false then + username = (_nodeprep or nodeprep)(username); -- FIXME charset + end + if not username or username == "" then return "failure", "malformed-request"; end - self["username"] = response["username"]; + self.username = username; -- check for nonce, ... if not response["nonce"] then @@ -178,7 +184,6 @@ local function digest(self, message) end --TODO maybe realm support - self.username = response["username"]; local Y, state; if self.profile.plain then local password, state = self.profile.plain(self, response["username"], self.realm) @@ -240,4 +245,4 @@ function init(registerMechanism) registerMechanism("DIGEST-MD5", {"plain"}, digest); end -return _M;
\ No newline at end of file +return _M; diff --git a/util/sasl/external.lua b/util/sasl/external.lua new file mode 100644 index 00000000..4c5c4343 --- /dev/null +++ b/util/sasl/external.lua @@ -0,0 +1,25 @@ +local saslprep = require "util.encodings".stringprep.saslprep; + +module "sasl.external" + +local function external(self, message) + message = saslprep(message); + local state + self.username, state = self.profile.external(message); + + if state == false then + return "failure", "account-disabled"; + elseif state == nil then + return "failure", "not-authorized"; + elseif state == "expired" then + return "false", "credentials-expired"; + end + + return "success"; +end + +function init(registerMechanism) + registerMechanism("EXTERNAL", {"external"}, external); +end + +return _M; diff --git a/util/sasl/plain.lua b/util/sasl/plain.lua index d6ebe304..c9ec2911 100644 --- a/util/sasl/plain.lua +++ b/util/sasl/plain.lua @@ -13,9 +13,10 @@ local s_match = string.match; local saslprep = require "util.encodings".stringprep.saslprep; +local nodeprep = require "util.encodings".stringprep.nodeprep; local log = require "util.logger".init("sasl"); -module "plain" +module "sasl.plain" -- ================================ -- SASL PLAIN according to RFC 4616 @@ -54,6 +55,14 @@ local function plain(self, message) return "failure", "malformed-request", "Invalid username or password."; end + local _nodeprep = self.profile.nodeprep; + if _nodeprep ~= false then + authentication = (_nodeprep or nodeprep)(authentication); + if not authentication or authentication == "" then + return "failure", "malformed-request", "Invalid username or password." + end + end + local correct, state = false, false; if self.profile.plain then local correct_password; @@ -64,15 +73,13 @@ local function plain(self, message) end self.username = authentication - if not state then + if state == false then return "failure", "account-disabled"; - end - - if correct then - return "success"; - else + elseif state == nil or not correct then return "failure", "not-authorized", "Unable to authorize you with the authentication credentials you've sent."; end + + return "success"; end function init(registerMechanism) diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua index 071de505..cf938dba 100644 --- a/util/sasl/scram.lua +++ b/util/sasl/scram.lua @@ -16,16 +16,18 @@ local type = type local string = string local tostring = tostring; local base64 = require "util.encodings".base64; -local hmac_sha1 = require "util.hmac".sha1; +local hmac_sha1 = require "util.hashes".hmac_sha1; local sha1 = require "util.hashes".sha1; +local Hi = require "util.hashes".scram_Hi_sha1; local generate_uuid = require "util.uuid".generate; local saslprep = require "util.encodings".stringprep.saslprep; +local nodeprep = require "util.encodings".stringprep.nodeprep; local log = require "util.logger".init("sasl"); local t_concat = table.concat; local char = string.char; local byte = string.byte; -module "scram" +module "sasl.scram" --========================= --SASL SCRAM-SHA-1 according to RFC 5802 @@ -69,33 +71,26 @@ local function binaryXOR( a, b ) return t_concat(result); end --- hash algorithm independent Hi(PBKDF2) implementation -function Hi(hmac, str, salt, i) - local Ust = hmac(str, salt.."\0\0\0\1"); - local res = Ust; - for n=1,i-1 do - local Und = hmac(str, Ust) - res = binaryXOR(res, Und) - Ust = Und - end - return res -end - -local function validate_username(username) +local function validate_username(username, _nodeprep) -- check for forbidden char sequences for eq in username:gmatch("=(.?.?)") do - if eq ~= "2D" and eq ~= "3D" then + if eq ~= "2C" and eq ~= "3D" then return false end end - - -- replace =2D with , and =3D with = - username = username:gsub("=2D", ","); + + -- replace =2C with , and =3D with = + username = username:gsub("=2C", ","); username = username:gsub("=3D", "="); - + -- apply SASLprep username = saslprep(username); - return username; + + if username and _nodeprep ~= false then + username = (_nodeprep or nodeprep)(username); + end + + return username and #username>0 and username; end local function hashprep(hashname) @@ -109,7 +104,7 @@ function getAuthenticationDatabaseSHA1(password, salt, iteration_count) if iteration_count < 4096 then log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.") end - local salted_password = Hi(hmac_sha1, password, salt, iteration_count); + local salted_password = Hi(password, salt, iteration_count); local stored_key = sha1(hmac_sha1(salted_password, "Client Key")) local server_key = hmac_sha1(salted_password, "Server Key"); return true, stored_key, server_key @@ -120,12 +115,12 @@ local function scram_gen(hash_name, H_f, HMAC_f) if not self.state then self["state"] = {} end local support_channel_binding = false; if self.profile.cb then support_channel_binding = true; end - + if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end if not self.state.name then -- we are processing client_first_message local client_first_message = message; - log("debug", client_first_message); + -- TODO: fail if authzid is provided, since we don't support them yet self.state["client_first_message"] = client_first_message; self.state["gs2_cbind_flag"], self.state["gs2_cbind_name"], self.state["authzid"], self.state["name"], self.state["clientnonce"] @@ -156,21 +151,21 @@ local function scram_gen(hash_name, H_f, HMAC_f) if not self.state.name or not self.state.clientnonce then return "failure", "malformed-request", "Channel binding isn't support at this time."; end - - self.state.name = validate_username(self.state.name); + + self.state.name = validate_username(self.state.name, self.profile.nodeprep); if not self.state.name then log("debug", "Username violates either SASLprep or contains forbidden character sequences.") return "failure", "malformed-request", "Invalid username."; end - + self.state["servernonce"] = generate_uuid(); - + -- retreive credentials if self.profile.plain then local password, state = self.profile.plain(self, self.state.name, self.realm) if state == nil then return "failure", "not-authorized" elseif state == false then return "failure", "account-disabled" end - + password = saslprep(password); if not password then log("debug", "Password violates SASLprep."); @@ -190,20 +185,20 @@ local function scram_gen(hash_name, H_f, HMAC_f) local stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, self.state.name, self.realm); if state == nil then return "failure", "not-authorized" elseif state == false then return "failure", "account-disabled" end - + self.state.stored_key = stored_key; self.state.server_key = server_key; self.state.iteration_count = iteration_count; self.state.salt = salt end - + local server_first_message = "r="..self.state.clientnonce..self.state.servernonce..",s="..base64.encode(self.state.salt)..",i="..self.state.iteration_count; self.state["server_first_message"] = server_first_message; return "challenge", server_first_message else -- we are processing client_final_message local client_final_message = message; - log("debug", "client_final_message: %s", client_final_message); + self.state["channelbinding"], self.state["nonce"], self.state["proof"] = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)"); if not self.state.proof or not self.state.nonce or not self.state.channelbinding then @@ -223,10 +218,10 @@ local function scram_gen(hash_name, H_f, HMAC_f) if self.state.nonce ~= self.state.clientnonce..self.state.servernonce then return "failure", "malformed-request", "Wrong nonce in client-final-message."; end - + local ServerKey = self.state.server_key; local StoredKey = self.state.stored_key; - + local AuthMessage = "n=" .. s_match(self.state.client_first_message,"n=(.+)") .. "," .. self.state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+") local ClientSignature = HMAC_f(StoredKey, AuthMessage) local ClientKey = binaryXOR(ClientSignature, base64.decode(self.state.proof)) diff --git a/util/sasl_cyrus.lua b/util/sasl_cyrus.lua index 002118fd..a0e8bd69 100644 --- a/util/sasl_cyrus.lua +++ b/util/sasl_cyrus.lua @@ -78,11 +78,15 @@ local function init(service_name) end -- create a new SASL object which can be used to authenticate clients -function new(realm, service_name, app_name) +-- host_fqdn may be nil in which case gethostname() gives the value. +-- For GSSAPI, this determines the hostname in the service ticket (after +-- reverse DNS canonicalization, only if [libdefaults] rdns = true which +-- is the default). +function new(realm, service_name, app_name, host_fqdn) init(app_name or service_name); - local st, ret = pcall(cyrussasl.server_new, service_name, nil, realm, nil, nil) + local st, ret = pcall(cyrussasl.server_new, service_name, host_fqdn, realm, nil, nil) if not st then log("error", "Creating SASL server connection failed: %s", ret); return nil; diff --git a/util/serialization.lua b/util/serialization.lua index e193b64f..06e45054 100644 --- a/util/serialization.lua +++ b/util/serialization.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -16,11 +16,12 @@ local pairs = pairs; local next = next; local loadstring = loadstring; -local setfenv = setfenv; local pcall = pcall; local debug_traceback = debug.traceback; local log = require "util.logger".init("serialization"); +local envload = require"util.envload".envload; + module "serialization" local indent = function(i) @@ -84,9 +85,8 @@ end function deserialize(str) if type(str) ~= "string" then return nil; end str = "return "..str; - local f, err = loadstring(str, "@data"); + local f, err = envload(str, "@data", {}); if not f then return nil, err; end - setfenv(f, {}); local success, ret = pcall(f); if not success then return nil, ret; end return ret; diff --git a/util/set.lua b/util/set.lua index e4cc2dff..89cd7cf3 100644 --- a/util/set.lua +++ b/util/set.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -26,8 +26,9 @@ function set_mt.__div(set, func) local new_set, new_items = _M.new(); local items, new_items = set._items, new_set._items; for item in pairs(items) do - if func(item) then - new_items[item] = true; + local new_item = func(item); + if new_item ~= nil then + new_items[new_item] = true; end end return new_set; @@ -39,13 +40,13 @@ function set_mt.__eq(set1, set2) return false; end end - + for item in pairs(set2) do if not set1[item] then return false; end end - + return true; end function set_mt.__tostring(set) @@ -64,56 +65,58 @@ end function new(list) local items = setmetatable({}, items_mt); local set = { _items = items }; - + function set:add(item) items[item] = true; end - + function set:contains(item) return items[item]; end - + function set:items() - return items; + return next, items; end - + function set:remove(item) items[item] = nil; end - + function set:add_list(list) - for _, item in ipairs(list) do - items[item] = true; + if list then + for _, item in ipairs(list) do + items[item] = true; + end end end - + function set:include(otherset) - for item in pairs(otherset) do + for item in otherset do items[item] = true; end end function set:exclude(otherset) - for item in pairs(otherset) do + for item in otherset do items[item] = nil; end end - + function set:empty() return not next(items); end - + if list then set:add_list(list); end - + return setmetatable(set, set_mt); end function union(set1, set2) local set = new(); local items = set._items; - + for item in pairs(set1._items) do items[item] = true; end @@ -121,14 +124,14 @@ function union(set1, set2) for item in pairs(set2._items) do items[item] = true; end - + return set; end function difference(set1, set2) local set = new(); local items = set._items; - + for item in pairs(set1._items) do items[item] = (not set2._items[item]) or nil; end @@ -139,13 +142,13 @@ end function intersection(set1, set2) local set = new(); local items = set._items; - + set1, set2 = set1._items, set2._items; - + for item in pairs(set1) do items[item] = (not not set2[item]) or nil; end - + return set; end diff --git a/util/sql.lua b/util/sql.lua new file mode 100644 index 00000000..b8c16e27 --- /dev/null +++ b/util/sql.lua @@ -0,0 +1,342 @@ + +local setmetatable, getmetatable = setmetatable, getmetatable; +local ipairs, unpack, select = ipairs, unpack, select; +local tonumber, tostring = tonumber, tostring; +local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback; +local t_concat = table.concat; +local s_char = string.char; +local log = require "util.logger".init("sql"); + +local DBI = require "DBI"; +-- This loads all available drivers while globals are unlocked +-- LuaDBI should be fixed to not set globals. +DBI.Drivers(); +local build_url = require "socket.url".build; + +module("sql") + +local column_mt = {}; +local table_mt = {}; +local query_mt = {}; +--local op_mt = {}; +local index_mt = {}; + +function is_column(x) return getmetatable(x)==column_mt; end +function is_index(x) return getmetatable(x)==index_mt; end +function is_table(x) return getmetatable(x)==table_mt; end +function is_query(x) return getmetatable(x)==query_mt; end +--function is_op(x) return getmetatable(x)==op_mt; end +--function expr(...) return setmetatable({...}, op_mt); end +function Integer(n) return "Integer()" end +function String(n) return "String()" end + +--[[local ops = { + __add = function(a, b) return "("..a.."+"..b..")" end; + __sub = function(a, b) return "("..a.."-"..b..")" end; + __mul = function(a, b) return "("..a.."*"..b..")" end; + __div = function(a, b) return "("..a.."/"..b..")" end; + __mod = function(a, b) return "("..a.."%"..b..")" end; + __pow = function(a, b) return "POW("..a..","..b..")" end; + __unm = function(a) return "NOT("..a..")" end; + __len = function(a) return "COUNT("..a..")" end; + __eq = function(a, b) return "("..a.."=="..b..")" end; + __lt = function(a, b) return "("..a.."<"..b..")" end; + __le = function(a, b) return "("..a.."<="..b..")" end; +}; + +local functions = { + +}; + +local cmap = { + [Integer] = Integer(); + [String] = String(); +};]] + +function Column(definition) + return setmetatable(definition, column_mt); +end +function Table(definition) + local c = {} + for i,col in ipairs(definition) do + if is_column(col) then + c[i], c[col.name] = col, col; + elseif is_index(col) then + col.table = definition.name; + end + end + return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt); +end +function Index(definition) + return setmetatable(definition, index_mt); +end + +function table_mt:__tostring() + local s = { 'name="'..self.__table__.name..'"' } + for i,col in ipairs(self.__table__) do + s[#s+1] = tostring(col); + end + return 'Table{ '..t_concat(s, ", ")..' }' +end +table_mt.__index = {}; +function table_mt.__index:create(engine) + return engine:_create_table(self); +end +function table_mt:__call(...) + -- TODO +end +function column_mt:__tostring() + return 'Column{ name="'..self.name..'", type="'..self.type..'" }' +end +function index_mt:__tostring() + local s = 'Index{ name="'..self.name..'"'; + for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end + return s..' }'; +-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }' +end +-- + +local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end +local function parse_url(url) + local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)"); + assert(scheme, "Invalid URL format"); + local username, password, host, port; + local authpart, hostpart = secondpart:match("([^@]+)@([^@+])"); + if not authpart then hostpart = secondpart; end + if authpart then + username, password = authpart:match("([^:]*):(.*)"); + username = username or authpart; + password = password and urldecode(password); + end + if hostpart then + host, port = hostpart:match("([^:]*):(.*)"); + host = host or hostpart; + port = port and assert(tonumber(port), "Invalid URL format"); + end + return { + scheme = scheme:lower(); + username = username; password = password; + host = host; port = port; + database = #database > 0 and database or nil; + }; +end + +--[[local session = {}; + +function session.query(...) + local rets = {...}; + local query = setmetatable({ __rets = rets, __filters }, query_mt); + return query; +end +-- + +local function db2uri(params) + return build_url{ + scheme = params.driver, + user = params.username, + password = params.password, + host = params.host, + port = params.port, + path = params.database, + }; +end]] + +local engine = {}; +function engine:connect() + if self.conn then return true; end + + local params = self.params; + assert(params.driver, "no driver") + local dbh, err = DBI.Connect( + params.driver, params.database, + params.username, params.password, + params.host, params.port + ); + if not dbh then return nil, err; end + dbh:autocommit(false); -- don't commit automatically + self.conn = dbh; + self.prepared = {}; + return true; +end +function engine:execute(sql, ...) + local success, err = self:connect(); + if not success then return success, err; end + local prepared = self.prepared; + + local stmt = prepared[sql]; + if not stmt then + local err; + stmt, err = self.conn:prepare(sql); + if not stmt then return stmt, err; end + prepared[sql] = stmt; + end + + local success, err = stmt:execute(...); + if not success then return success, err; end + return stmt; +end + +local result_mt = { __index = { + affected = function(self) return self.__stmt:affected(); end; + rowcount = function(self) return self.__stmt:rowcount(); end; +} }; + +function engine:execute_query(sql, ...) + if self.params.driver == "PostgreSQL" then + sql = sql:gsub("`", "\""); + end + local stmt = assert(self.conn:prepare(sql)); + assert(stmt:execute(...)); + return stmt:rows(); +end +function engine:execute_update(sql, ...) + if self.params.driver == "PostgreSQL" then + sql = sql:gsub("`", "\""); + end + local prepared = self.prepared; + local stmt = prepared[sql]; + if not stmt then + stmt = assert(self.conn:prepare(sql)); + prepared[sql] = stmt; + end + assert(stmt:execute(...)); + return setmetatable({ __stmt = stmt }, result_mt); +end +engine.insert = engine.execute_update; +engine.select = engine.execute_query; +engine.delete = engine.execute_update; +engine.update = engine.execute_update; +function engine:_transaction(func, ...) + if not self.conn then + local a,b = self:connect(); + if not a then return a,b; end + end + --assert(not self.__transaction, "Recursive transactions not allowed"); + local args, n_args = {...}, select("#", ...); + local function f() return func(unpack(args, 1, n_args)); end + self.__transaction = true; + local success, a, b, c = xpcall(f, debug_traceback); + self.__transaction = nil; + if success then + log("debug", "SQL transaction success [%s]", tostring(func)); + local ok, err = self.conn:commit(); + if not ok then return ok, err; end -- commit failed + return success, a, b, c; + else + log("debug", "SQL transaction failure [%s]: %s", tostring(func), a); + if self.conn then self.conn:rollback(); end + return success, a; + end +end +function engine:transaction(...) + local a,b = self:_transaction(...); + if not a then + local conn = self.conn; + if not conn or not conn:ping() then + self.conn = nil; + a,b = self:_transaction(...); + end + end + return a,b; +end +function engine:_create_index(index) + local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` ("; + for i=1,#index do + sql = sql.."`"..index[i].."`"; + if i ~= #index then sql = sql..", "; end + end + sql = sql..");" + if self.params.driver == "PostgreSQL" then + sql = sql:gsub("`", "\""); + elseif self.params.driver == "MySQL" then + sql = sql:gsub("`([,)])", "`(20)%1"); + end + --print(sql); + return self:execute(sql); +end +function engine:_create_table(table) + local sql = "CREATE TABLE `"..table.name.."` ("; + for i,col in ipairs(table.c) do + sql = sql.."`"..col.name.."` "..col.type; + if col.nullable == false then sql = sql.." NOT NULL"; end + if i ~= #table.c then sql = sql..", "; end + end + sql = sql.. ");" + if self.params.driver == "PostgreSQL" then + sql = sql:gsub("`", "\""); + elseif self.params.driver == "MySQL" then + sql = sql:gsub(";$", " CHARACTER SET 'utf8' COLLATE 'utf8_bin';"); + end + local success,err = self:execute(sql); + if not success then return success,err; end + for i,v in ipairs(table.__table__) do + if is_index(v) then + self:_create_index(v); + end + end + return success; +end +local engine_mt = { __index = engine }; + +local function db2uri(params) + return build_url{ + scheme = params.driver, + user = params.username, + password = params.password, + host = params.host, + port = params.port, + path = params.database, + }; +end +local engine_cache = {}; -- TODO make weak valued +function create_engine(self, params) + local url = db2uri(params); + if not engine_cache[url] then + local engine = setmetatable({ url = url, params = params }, engine_mt); + engine_cache[url] = engine; + end + return engine_cache[url]; +end + + +--[[Users = Table { + name="users"; + Column { name="user_id", type=String(), primary_key=true }; +}; +print(Users) +print(Users.c.user_id)]] + +--local engine = create_engine('postgresql://scott:tiger@localhost:5432/mydatabase'); +--[[local engine = create_engine{ driver = "SQLite3", database = "./alchemy.sqlite" }; + +local i = 0; +for row in assert(engine:execute("select * from sqlite_master")):rows(true) do + i = i+1; + print(i); + for k,v in pairs(row) do + print("",k,v); + end +end +print("---") + +Prosody = Table { + name="prosody"; + Column { name="host", type="TEXT", nullable=false }; + Column { name="user", type="TEXT", nullable=false }; + Column { name="store", type="TEXT", nullable=false }; + Column { name="key", type="TEXT", nullable=false }; + Column { name="type", type="TEXT", nullable=false }; + Column { name="value", type="TEXT", nullable=false }; + Index { name="prosody_index", "host", "user", "store", "key" }; +}; +--print(Prosody); +assert(engine:transaction(function() + assert(Prosody:create(engine)); +end)); + +for row in assert(engine:execute("select user from prosody")):rows(true) do + print("username:", row['username']) +end +--result.close();]] + +return _M; diff --git a/util/stanza.lua b/util/stanza.lua index afaf9ce9..82601e63 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -1,29 +1,24 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- local t_insert = table.insert; -local t_concat = table.concat; local t_remove = table.remove; local t_concat = table.concat; local s_format = string.format; local s_match = string.match; local tostring = tostring; local setmetatable = setmetatable; -local getmetatable = getmetatable; local pairs = pairs; local ipairs = ipairs; local type = type; -local next = next; -local print = print; -local unpack = unpack; local s_gsub = string.gsub; -local s_char = string.char; +local s_sub = string.sub; local s_find = string.find; local os = os; @@ -44,11 +39,13 @@ module "stanza" stanza_mt = { __type = "stanza" }; stanza_mt.__index = stanza_mt; +local stanza_mt = stanza_mt; function stanza(name, attr) local stanza = { name = name, attr = attr or {}, tags = {} }; return setmetatable(stanza, stanza_mt); end +local stanza = stanza; function stanza_mt:query(xmlns) return self:tag("query", { xmlns = xmlns }); @@ -102,12 +99,20 @@ function stanza_mt:get_child(name, xmlns) if (not name or child.name == name) and ((not xmlns and self.attr.xmlns == child.attr.xmlns) or child.attr.xmlns == xmlns) then - + return child; end end end +function stanza_mt:get_child_text(name, xmlns) + local tag = self:get_child(name, xmlns); + if tag then + return tag:get_text(); + end + return nil; +end + function stanza_mt:child_with_name(name) for _, child in ipairs(self.tags) do if child.name == name then return child; end @@ -128,37 +133,28 @@ function stanza_mt:children() end, self, i; end -function stanza_mt:matching_tags(name, xmlns) - xmlns = xmlns or self.attr.xmlns; +function stanza_mt:childtags(name, xmlns) local tags = self.tags; local start_i, max_i = 1, #tags; return function () - for i=start_i,max_i do - v = tags[i]; - if (not name or v.name == name) - and (not xmlns or xmlns == v.attr.xmlns) then - start_i = i+1; - return v; - end + for i = start_i, max_i do + local v = tags[i]; + if (not name or v.name == name) + and ((not xmlns and self.attr.xmlns == v.attr.xmlns) + or v.attr.xmlns == xmlns) then + start_i = i+1; + return v; end - end, tags, i; -end - -function stanza_mt:childtags() - local i = 0; - return function (a) - i = i + 1 - local v = self.tags[i] - if v then return v; end - end, self.tags[1], i; + end + end; end function stanza_mt:maptags(callback) local tags, curr_tag = self.tags, 1; local n_children, n_tags = #self, #tags; - + local i = 1; - while curr_tag <= n_tags do + while curr_tag <= n_tags and n_tags > 0 do if self[i] == tags[curr_tag] then local ret = callback(self[i]); if ret == nil then @@ -166,17 +162,44 @@ function stanza_mt:maptags(callback) t_remove(tags, curr_tag); n_children = n_children - 1; n_tags = n_tags - 1; + i = i - 1; + curr_tag = curr_tag - 1; else self[i] = ret; - tags[i] = ret; + tags[curr_tag] = ret; end - i = i + 1; curr_tag = curr_tag + 1; end + i = i + 1; end return self; end +function stanza_mt:find(path) + local pos = 1; + local len = #path + 1; + + repeat + local xmlns, name, text; + local char = s_sub(path, pos, pos); + if char == "@" then + return self.attr[s_sub(path, pos + 1)]; + elseif char == "{" then + xmlns, pos = s_match(path, "^([^}]+)}()", pos + 1); + end + name, text, pos = s_match(path, "^([^@/#]*)([/#]?)()", pos); + name = name ~= "" and name or nil; + if pos == len then + if text == "#" then + return self:get_child_text(name, xmlns); + end + return self:get_child(name, xmlns); + end + self = self:get_child(name, xmlns); + until not self +end + + local xml_escape do local escape_table = { ["'"] = "'", ["\""] = """, ["<"] = "<", [">"] = ">", ["&"] = "&" }; @@ -235,14 +258,14 @@ end function stanza_mt.get_error(stanza) local type, condition, text; - + local error_tag = stanza:get_child("error"); if not error_tag then return nil, nil, nil; end type = error_tag.attr.type; - - for child in error_tag:childtags() do + + for _, child in ipairs(error_tag.tags) do if child.attr.xmlns == xmlns_stanzas then if not text and child.name == "text" then text = child:get_text(); @@ -257,11 +280,6 @@ function stanza_mt.get_error(stanza) return type, condition or "undefined-condition", text; end -function stanza_mt.__add(s1, s2) - return s1:add_direct_child(s2); -end - - do local id = 0; function new_id() @@ -315,28 +333,25 @@ function deserialize(stanza) stanza.tags = tags; end end - + return stanza; end -function clone(stanza) - local lookup_table = {}; - local function _copy(object) - if type(object) ~= "table" then - return object; - elseif lookup_table[object] then - return lookup_table[object]; +local function _clone(stanza) + local attr, tags = {}, {}; + for k,v in pairs(stanza.attr) do attr[k] = v; end + local new = { name = stanza.name, attr = attr, tags = tags }; + for i=1,#stanza do + local child = stanza[i]; + if child.name then + child = _clone(child); + t_insert(tags, child); end - local new_table = {}; - lookup_table[object] = new_table; - for index, value in pairs(object) do - new_table[_copy(index)] = _copy(value); - end - return setmetatable(new_table, getmetatable(object)); + t_insert(new, child); end - - return _copy(stanza) + return setmetatable(new, stanza_mt); end +clone = _clone; function message(attr, body) if not body then @@ -375,7 +390,7 @@ if do_pretty_printing then local style_attrv = getstyle("red"); local style_tagname = getstyle("red"); local style_punc = getstyle("magenta"); - + local attr_format = " "..getstring(style_attrk, "%s")..getstring(style_punc, "=")..getstring(style_attrv, "'%s'"); local top_tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">"); --local tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">").."%s"..getstring(style_punc, "</")..getstring(style_tagname, "%s")..getstring(style_punc, ">"); @@ -396,7 +411,7 @@ if do_pretty_printing then end return s_format(tag_format, t.name, attr_string, children_text, t.name); end - + function stanza_mt.pretty_top_tag(t) local attr_string = ""; if t.attr then diff --git a/util/template.lua b/util/template.lua index ebd8be14..66d4fca7 100644 --- a/util/template.lua +++ b/util/template.lua @@ -1,64 +1,28 @@ -local st = require "util.stanza"; -local lxp = require "lxp"; +local stanza_mt = require "util.stanza".stanza_mt; local setmetatable = setmetatable; local pairs = pairs; local ipairs = ipairs; local error = error; local loadstring = loadstring; local debug = debug; +local t_remove = table.remove; +local parse_xml = require "util.xml".parse; module("template") -local parse_xml = (function() - local ns_prefixes = { - ["http://www.w3.org/XML/1998/namespace"] = "xml"; - }; - local ns_separator = "\1"; - local ns_pattern = "^([^"..ns_separator.."]*)"..ns_separator.."?(.*)$"; - return function(xml) - local handler = {}; - local stanza = st.stanza("root"); - function handler:StartElement(tagname, attr) - local curr_ns,name = tagname:match(ns_pattern); - if name == "" then - curr_ns, name = "", curr_ns; - end - if curr_ns ~= "" then - attr.xmlns = curr_ns; - end - for i=1,#attr do - local k = attr[i]; - attr[i] = nil; - local ns, nm = k:match(ns_pattern); - if nm ~= "" then - ns = ns_prefixes[ns]; - if ns then - attr[ns..":"..nm] = attr[k]; - attr[k] = nil; - end - end - end - stanza:tag(name, attr); - end - function handler:CharacterData(data) - data = data:gsub("^%s*", ""):gsub("%s*$", ""); - stanza:text(data); - end - function handler:EndElement(tagname) - stanza:up(); - end - local parser = lxp.new(handler, "\1"); - local ok, err, line, col = parser:parse(xml); - if ok then ok, err, line, col = parser:parse(); end - --parser:close(); - if ok then - return stanza.tags[1]; +local function trim_xml(stanza) + for i=#stanza,1,-1 do + local child = stanza[i]; + if child.name then + trim_xml(child); else - return ok, err.." (line "..line..", col "..col..")"; + child = child:gsub("^%s*", ""):gsub("%s*$", ""); + stanza[i] = child; + if child == "" then t_remove(stanza, i); end end - end; -end)(); + end +end local function create_string_string(str) str = ("%q"):format(str); @@ -100,7 +64,6 @@ local function create_clone_string(stanza, lookup, xmlns) end return lookup[stanza]; end -local stanza_mt = st.stanza_mt; local function create_cloner(stanza, chunkname) local lookup = {}; local name = create_clone_string(stanza, lookup, ""); @@ -118,6 +81,7 @@ local template_mt = { __tostring = function(t) return t.name end }; local function create_template(templates, text) local stanza, err = parse_xml(text); if not stanza then error(err); end + trim_xml(stanza); local info = debug.getinfo(3, "Sl"); info = info and ("template(%s:%d)"):format(info.short_src:match("[^\\/]*$"), info.currentline) or "template(unknown)"; diff --git a/util/termcolours.lua b/util/termcolours.lua index df204688..ef978364 100644 --- a/util/termcolours.lua +++ b/util/termcolours.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -9,6 +9,7 @@ local t_concat, t_insert = table.concat, table.insert; local char, format = string.char, string.format; +local tonumber = tonumber; local ipairs = ipairs; local io_write = io.write; @@ -34,6 +35,15 @@ local winstylemap = { ["1;31"] = 4+8 -- bold red } +local cssmap = { + [1] = "font-weight: bold", [2] = "opacity: 0.5", [4] = "text-decoration: underline", [8] = "visibility: hidden", + [30] = "color:black", [31] = "color:red", [32]="color:green", [33]="color:#FFD700", + [34] = "color:blue", [35] = "color: magenta", [36] = "color:cyan", [37] = "color: white", + [40] = "background-color:black", [41] = "background-color:red", [42]="background-color:green", + [43]="background-color:yellow", [44] = "background-color:blue", [45] = "background-color: magenta", + [46] = "background-color:cyan", [47] = "background-color: white"; +}; + local fmt_string = char(0x1B).."[%sm%s"..char(0x1B).."[0m"; function getstring(style, text) if style then @@ -76,4 +86,17 @@ if windows then end end +local function ansi2css(ansi_codes) + if ansi_codes == "0" then return "</span>"; end + local css = {}; + for code in ansi_codes:gmatch("[^;]+") do + t_insert(css, cssmap[tonumber(code)]); + end + return "</span><span style='"..t_concat(css, ";").."'>"; +end + +function tohtml(input) + return input:gsub("\027%[(.-)m", ansi2css); +end + return _M; diff --git a/util/throttle.lua b/util/throttle.lua new file mode 100644 index 00000000..55e1d07b --- /dev/null +++ b/util/throttle.lua @@ -0,0 +1,46 @@ + +local gettime = require "socket".gettime; +local setmetatable = setmetatable; +local floor = math.floor; + +module "throttle" + +local throttle = {}; +local throttle_mt = { __index = throttle }; + +function throttle:update() + local newt = gettime(); + local elapsed = newt - self.t; + self.t = newt; + local balance = floor(self.rate * elapsed) + self.balance; + if balance > self.max then + self.balance = self.max; + else + self.balance = balance; + end + return self.balance; +end + +function throttle:peek(cost) + cost = cost or 1; + return self.balance >= cost or self:update() >= cost; +end + +function throttle:poll(cost, split) + if self:peek(cost) then + self.balance = self.balance - cost; + return true; + else + local balance = self.balance; + if split then + self.balance = 0; + end + return false, balance, (cost-balance); + end +end + +function create(max, period) + return setmetatable({ rate = max / period, max = max, t = 0, balance = max }, throttle_mt); +end + +return _M; diff --git a/util/timer.lua b/util/timer.lua index 3061da72..0e10e144 100644 --- a/util/timer.lua +++ b/util/timer.lua @@ -1,22 +1,17 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- - -local ns_addtimer = require "net.server".addtimer; -local event = require "net.server".event; -local event_base = require "net.server".event_base; - +local server = require "net.server"; local math_min = math.min local math_huge = math.huge local get_time = require "socket".gettime; local t_insert = table.insert; -local t_remove = table.remove; -local ipairs, pairs = ipairs, pairs; +local pairs = pairs; local type = type; local data = {}; @@ -25,18 +20,21 @@ local new_data = {}; module "timer" local _add_task; -if not event then - function _add_task(delay, func) +if not server.event then + function _add_task(delay, callback) local current_time = get_time(); delay = delay + current_time; if delay >= current_time then - t_insert(new_data, {delay, func}); + t_insert(new_data, {delay, callback}); else - func(); + local r = callback(current_time); + if r and type(r) == "number" then + return _add_task(r, callback); + end end end - ns_addtimer(function() + server._addtimer(function() local current_time = get_time(); if #new_data > 0 then for _, d in pairs(new_data) do @@ -44,15 +42,15 @@ if not event then end new_data = {}; end - + local next_time = math_huge; for i, d in pairs(data) do - local t, func = d[1], d[2]; + local t, callback = d[1], d[2]; if t <= current_time then data[i] = nil; - local r = func(current_time); + local r = callback(current_time); if type(r) == "number" then - _add_task(r, func); + _add_task(r, callback); next_time = math_min(next_time, r); end else @@ -62,11 +60,14 @@ if not event then return next_time; end); else + local event = server.event; + local event_base = server.event_base; local EVENT_LEAVE = (event.core and event.core.LEAVE) or -1; - function _add_task(delay, func) + + function _add_task(delay, callback) local event_handle; event_handle = event_base:addevent(nil, 0, function () - local ret = func(); + local ret = callback(get_time()); if ret then return 0, ret; elseif event_handle then diff --git a/util/uuid.lua b/util/uuid.lua index 796c8ee4..fc487c72 100644 --- a/util/uuid.lua +++ b/util/uuid.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/watchdog.lua b/util/watchdog.lua new file mode 100644 index 00000000..bcb2e274 --- /dev/null +++ b/util/watchdog.lua @@ -0,0 +1,34 @@ +local timer = require "util.timer"; +local setmetatable = setmetatable; +local os_time = os.time; + +module "watchdog" + +local watchdog_methods = {}; +local watchdog_mt = { __index = watchdog_methods }; + +function new(timeout, callback) + local watchdog = setmetatable({ timeout = timeout, last_reset = os_time(), callback = callback }, watchdog_mt); + timer.add_task(timeout+1, function (current_time) + local last_reset = watchdog.last_reset; + if not last_reset then + return; + end + local time_left = (last_reset + timeout) - current_time; + if time_left < 0 then + return watchdog:callback(); + end + return time_left + 1; + end); + return watchdog; +end + +function watchdog_methods:reset() + self.last_reset = os_time(); +end + +function watchdog_methods:cancel() + self.last_reset = nil; +end + +return _M; diff --git a/util/x509.lua b/util/x509.lua index 11f231a0..19d4ec6d 100644 --- a/util/x509.lua +++ b/util/x509.lua @@ -11,8 +11,8 @@ -- IDN libraries complicate that. --- [TLS-CERTS] - http://tools.ietf.org/html/draft-saintandre-tls-server-id-check-10 --- [XMPP-CORE] - http://tools.ietf.org/html/draft-ietf-xmpp-3920bis-18 +-- [TLS-CERTS] - http://tools.ietf.org/html/rfc6125 +-- [XMPP-CORE] - http://tools.ietf.org/html/rfc6120 -- [SRV-ID] - http://tools.ietf.org/html/rfc4985 -- [IDNA] - http://tools.ietf.org/html/rfc5890 -- [LDAP] - http://tools.ietf.org/html/rfc4519 @@ -21,6 +21,10 @@ local nameprep = require "util.encodings".stringprep.nameprep; local idna_to_ascii = require "util.encodings".idna.to_ascii; local log = require "util.logger".init("x509"); +local pairs, ipairs = pairs, ipairs; +local s_format = string.format; +local t_insert = table.insert; +local t_concat = table.concat; module "x509" @@ -32,7 +36,7 @@ local oid_dnssrv = "1.3.6.1.5.5.7.8.7"; -- [SRV-ID] -- Compare a hostname (possibly international) with asserted names -- extracted from a certificate. -- This function follows the rules laid out in --- sections 4.4.1 and 4.4.2 of [TLS-CERTS] +-- sections 6.4.1 and 6.4.2 of [TLS-CERTS] -- -- A wildcard ("*") all by itself is allowed only as the left-most label local function compare_dnsname(host, asserted_names) @@ -150,7 +154,7 @@ function verify_identity(host, service, cert) if ext[oid_subjectaltname] then local sans = ext[oid_subjectaltname]; - -- Per [TLS-CERTS] 4.3, 4.4.4, "a client MUST NOT seek a match for a + -- Per [TLS-CERTS] 6.3, 6.4.4, "a client MUST NOT seek a match for a -- reference identifier if the presented identifiers include a DNS-ID -- SRV-ID, URI-ID, or any application-specific identifier types" local had_supported_altnames = false @@ -183,7 +187,7 @@ function verify_identity(host, service, cert) -- a dNSName subjectAltName (wildcards may apply for, and receive, -- cat treats) -- - -- Per [TLS-CERTS] 1.5, a CN-ID is the Common Name from a cert subject + -- Per [TLS-CERTS] 1.8, a CN-ID is the Common Name from a cert subject -- which has one and only one Common Name local subject = cert:subject() local cn = nil @@ -200,7 +204,7 @@ function verify_identity(host, service, cert) end if cn then - -- Per [TLS-CERTS] 4.4.4, follow the comparison rules for dNSName SANs. + -- Per [TLS-CERTS] 6.4.4, follow the comparison rules for dNSName SANs. return compare_dnsname(host, { cn }) end diff --git a/util/xml.lua b/util/xml.lua new file mode 100644 index 00000000..6dbed65d --- /dev/null +++ b/util/xml.lua @@ -0,0 +1,57 @@ + +local st = require "util.stanza"; +local lxp = require "lxp"; + +module("xml") + +local parse_xml = (function() + local ns_prefixes = { + ["http://www.w3.org/XML/1998/namespace"] = "xml"; + }; + local ns_separator = "\1"; + local ns_pattern = "^([^"..ns_separator.."]*)"..ns_separator.."?(.*)$"; + return function(xml) + local handler = {}; + local stanza = st.stanza("root"); + function handler:StartElement(tagname, attr) + local curr_ns,name = tagname:match(ns_pattern); + if name == "" then + curr_ns, name = "", curr_ns; + end + if curr_ns ~= "" then + attr.xmlns = curr_ns; + end + for i=1,#attr do + local k = attr[i]; + attr[i] = nil; + local ns, nm = k:match(ns_pattern); + if nm ~= "" then + ns = ns_prefixes[ns]; + if ns then + attr[ns..":"..nm] = attr[k]; + attr[k] = nil; + end + end + end + stanza:tag(name, attr); + end + function handler:CharacterData(data) + stanza:text(data); + end + function handler:EndElement(tagname) + stanza:up(); + end + local parser = lxp.new(handler, "\1"); + local ok, err, line, col = parser:parse(xml); + if ok then ok, err, line, col = parser:parse(); end + --parser:close(); + if ok then + return stanza.tags[1]; + else + return ok, err.." (line "..line..", col "..col..")"; + end + end; +end)(); + +parse = parse_xml; +return _M; diff --git a/util/xmlrpc.lua b/util/xmlrpc.lua deleted file mode 100644 index 29815b0d..00000000 --- a/util/xmlrpc.lua +++ /dev/null @@ -1,182 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - -local pairs = pairs; -local type = type; -local error = error; -local t_concat = table.concat; -local t_insert = table.insert; -local tostring = tostring; -local tonumber = tonumber; -local select = select; -local st = require "util.stanza"; - -module "xmlrpc" - -local _lua_to_xmlrpc; -local map = { - table=function(stanza, object) - stanza:tag("struct"); - for name, value in pairs(object) do - stanza:tag("member"); - stanza:tag("name"):text(tostring(name)):up(); - stanza:tag("value"); - _lua_to_xmlrpc(stanza, value); - stanza:up(); - stanza:up(); - end - stanza:up(); - end; - boolean=function(stanza, object) - stanza:tag("boolean"):text(object and "1" or "0"):up(); - end; - string=function(stanza, object) - stanza:tag("string"):text(object):up(); - end; - number=function(stanza, object) - stanza:tag("int"):text(tostring(object)):up(); - end; - ["nil"]=function(stanza, object) -- nil extension - stanza:tag("nil"):up(); - end; -}; -_lua_to_xmlrpc = function(stanza, object) - local h = map[type(object)]; - if h then - h(stanza, object); - else - error("Type not supported by XML-RPC: " .. type(object)); - end -end -function create_response(object) - local stanza = st.stanza("methodResponse"):tag("params"):tag("param"):tag("value"); - _lua_to_xmlrpc(stanza, object); - stanza:up():up():up(); - return stanza; -end -function create_error_response(faultCode, faultString) - local stanza = st.stanza("methodResponse"):tag("fault"):tag("value"); - _lua_to_xmlrpc(stanza, {faultCode=faultCode, faultString=faultString}); - stanza:up():up(); - return stanza; -end - -function create_request(method_name, ...) - local stanza = st.stanza("methodCall") - :tag("methodName"):text(method_name):up() - :tag("params"); - for i=1,select('#', ...) do - stanza:tag("param"):tag("value"); - _lua_to_xmlrpc(stanza, select(i, ...)); - stanza:up():up(); - end - stanza:up():up():up(); - return stanza; -end - -local _xmlrpc_to_lua; -local int_parse = function(stanza) - if #stanza.tags ~= 0 or #stanza == 0 then error("<"..stanza.name.."> must have a single text child"); end - local n = tonumber(t_concat(stanza)); - if n then return n; end - error("Failed to parse content of <"..stanza.name..">"); -end -local rmap = { - methodCall=function(stanza) - if #stanza.tags ~= 2 then error("<methodCall> must have exactly two subtags"); end -- FIXME <params> is optional - if stanza.tags[1].name ~= "methodName" then error("First <methodCall> child tag must be <methodName>") end - if stanza.tags[2].name ~= "params" then error("Second <methodCall> child tag must be <params>") end - return _xmlrpc_to_lua(stanza.tags[1]), _xmlrpc_to_lua(stanza.tags[2]); - end; - methodName=function(stanza) - if #stanza.tags ~= 0 then error("<methodName> must not have any subtags"); end - if #stanza == 0 then error("<methodName> must have text content"); end - return t_concat(stanza); - end; - params=function(stanza) - local t = {}; - for _, child in pairs(stanza.tags) do - if child.name ~= "param" then error("<params> can only have <param> children"); end; - t_insert(t, _xmlrpc_to_lua(child)); - end - return t; - end; - param=function(stanza) - if not(#stanza.tags == 1 and stanza.tags[1].name == "value") then error("<param> must have exactly one <value> child"); end - return _xmlrpc_to_lua(stanza.tags[1]); - end; - value=function(stanza) - if #stanza.tags == 0 then return t_concat(stanza); end - if #stanza.tags ~= 1 then error("<value> must have a single child"); end - return _xmlrpc_to_lua(stanza.tags[1]); - end; - int=int_parse; - i4=int_parse; - double=int_parse; - boolean=function(stanza) - if #stanza.tags ~= 0 or #stanza == 0 then error("<boolean> must have a single text child"); end - local b = t_concat(stanza); - if b ~= "1" and b ~= "0" then error("Failed to parse content of <boolean>"); end - return b == "1" and true or false; - end; - string=function(stanza) - if #stanza.tags ~= 0 then error("<string> must have a single text child"); end - return t_concat(stanza); - end; - array=function(stanza) - if #stanza.tags ~= 1 then error("<array> must have a single <data> child"); end - return _xmlrpc_to_lua(stanza.tags[1]); - end; - data=function(stanza) - local t = {}; - for _,child in pairs(stanza.tags) do - if child.name ~= "value" then error("<data> can only have <value> children"); end - t_insert(t, _xmlrpc_to_lua(child)); - end - return t; - end; - struct=function(stanza) - local t = {}; - for _,child in pairs(stanza.tags) do - if child.name ~= "member" then error("<struct> can only have <member> children"); end - local name, value = _xmlrpc_to_lua(child); - t[name] = value; - end - return t; - end; - member=function(stanza) - if #stanza.tags ~= 2 then error("<member> must have exactly two subtags"); end -- FIXME <params> is optional - if stanza.tags[1].name ~= "name" then error("First <member> child tag must be <name>") end - if stanza.tags[2].name ~= "value" then error("Second <member> child tag must be <value>") end - return _xmlrpc_to_lua(stanza.tags[1]), _xmlrpc_to_lua(stanza.tags[2]); - end; - name=function(stanza) - if #stanza.tags ~= 0 then error("<name> must have a single text child"); end - local n = t_concat(stanza) - if tostring(tonumber(n)) == n then n = tonumber(n); end - return n; - end; - ["nil"]=function(stanza) -- nil extension - return nil; - end; -} -_xmlrpc_to_lua = function(stanza) - local h = rmap[stanza.name]; - if h then - return h(stanza); - else - error("Unknown element: "..stanza.name); - end -end -function translate_request(stanza) - if stanza.name ~= "methodCall" then error("XML-RPC requests must have <methodCall> as root element"); end - return _xmlrpc_to_lua(stanza); -end - -return _M; diff --git a/util/xmppstream.lua b/util/xmppstream.lua index cbdadd9b..550170c9 100644 --- a/util/xmppstream.lua +++ b/util/xmppstream.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -9,21 +9,27 @@ local lxp = require "lxp"; local st = require "util.stanza"; +local stanza_mt = st.stanza_mt; +local error = error; local tostring = tostring; local t_insert = table.insert; local t_concat = table.concat; +local t_remove = table.remove; +local setmetatable = setmetatable; -local default_log = require "util.logger".init("xmppstream"); - -local error = error; +-- COMPAT: w/LuaExpat 1.1.0 +local lxp_supports_doctype = pcall(lxp.new, { StartDoctypeDecl = false }); module "xmppstream" local new_parser = lxp.new; -local ns_prefixes = { - ["http://www.w3.org/XML/1998/namespace"] = "xml"; +local xml_namespace = { + ["http://www.w3.org/XML/1998/namespace\1lang"] = "xml:lang"; + ["http://www.w3.org/XML/1998/namespace\1space"] = "xml:space"; + ["http://www.w3.org/XML/1998/namespace\1base"] = "xml:base"; + ["http://www.w3.org/XML/1998/namespace\1id"] = "xml:id"; }; local xmlns_streams = "http://etherx.jabber.org/streams"; @@ -36,29 +42,28 @@ _M.ns_pattern = ns_pattern; function new_sax_handlers(session, stream_callbacks) local xml_handlers = {}; - - local log = session.log or default_log; - + local cb_streamopened = stream_callbacks.streamopened; local cb_streamclosed = stream_callbacks.streamclosed; - local cb_error = stream_callbacks.error or function(session, e) error("XML stream error: "..tostring(e)); end; + local cb_error = stream_callbacks.error or function(session, e, stanza) error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); end; local cb_handlestanza = stream_callbacks.handlestanza; - + local stream_ns = stream_callbacks.stream_ns or xmlns_streams; local stream_tag = stream_callbacks.stream_tag or "stream"; if stream_ns ~= "" then stream_tag = stream_ns..ns_separator..stream_tag; end local stream_error_tag = stream_ns..ns_separator..(stream_callbacks.error_tag or "error"); - + local stream_default_ns = stream_callbacks.default_ns; - + + local stack = {}; local chardata, stanza = {}; local non_streamns_depth = 0; function xml_handlers:StartElement(tagname, attr) if stanza and #chardata > 0 then -- We have some character data in the buffer - stanza:text(t_concat(chardata)); + t_insert(stanza, t_concat(chardata)); chardata = {}; end local curr_ns,name = tagname:match(ns_pattern); @@ -70,21 +75,17 @@ function new_sax_handlers(session, stream_callbacks) attr.xmlns = curr_ns; non_streamns_depth = non_streamns_depth + 1; end - - -- FIXME !!!!! + for i=1,#attr do local k = attr[i]; attr[i] = nil; - local ns, nm = k:match(ns_pattern); - if nm ~= "" then - ns = ns_prefixes[ns]; - if ns then - attr[ns..":"..nm] = attr[k]; - attr[k] = nil; - end + local xmlk = xml_namespace[k]; + if xmlk then + attr[xmlk] = attr[k]; + attr[k] = nil; end end - + if not stanza then --if we are not currently inside a stanza if session.notopen then if tagname == stream_tag then @@ -101,10 +102,14 @@ function new_sax_handlers(session, stream_callbacks) if curr_ns == "jabber:client" and name ~= "iq" and name ~= "presence" and name ~= "message" then cb_error(session, "invalid-top-level-element"); end - - stanza = st.stanza(name, attr); + + stanza = setmetatable({ name = name, attr = attr, tags = {} }, stanza_mt); else -- we are inside a stanza, so add a tag - stanza:tag(name, attr); + t_insert(stack, stanza); + local oldstanza = stanza; + stanza = setmetatable({ name = name, attr = attr, tags = {} }, stanza_mt); + t_insert(oldstanza, stanza); + t_insert(oldstanza.tags, stanza); end end function xml_handlers:CharacterData(data) @@ -119,12 +124,11 @@ function new_sax_handlers(session, stream_callbacks) if stanza then if #chardata > 0 then -- We have some character data in the buffer - stanza:text(t_concat(chardata)); + t_insert(stanza, t_concat(chardata)); chardata = {}; end -- Complete stanza - local last_add = stanza.last_add; - if not last_add or #last_add == 0 then + if #stack == 0 then if tagname ~= stream_error_tag then cb_handlestanza(session, stanza); else @@ -132,33 +136,37 @@ function new_sax_handlers(session, stream_callbacks) end stanza = nil; else - stanza:up(); + stanza = t_remove(stack); end else - if tagname == stream_tag then - if cb_streamclosed then - cb_streamclosed(session); - end - else - local curr_ns,name = tagname:match(ns_pattern); - if name == "" then - curr_ns, name = "", curr_ns; - end - cb_error(session, "parse-error", "unexpected-element-close", name); + if cb_streamclosed then + cb_streamclosed(session); end - stanza, chardata = nil, {}; end end - + + local function restricted_handler(parser) + cb_error(session, "parse-error", "restricted-xml", "Restricted XML, see RFC 6120 section 11.1."); + if not parser.stop or not parser:stop() then + error("Failed to abort parsing"); + end + end + + if lxp_supports_doctype then + xml_handlers.StartDoctypeDecl = restricted_handler; + end + xml_handlers.Comment = restricted_handler; + xml_handlers.ProcessingInstruction = restricted_handler; + local function reset() stanza, chardata = nil, {}; + stack = {}; end - + local function set_session(stream, new_session) session = new_session; - log = new_session.log or default_log; end - + return xml_handlers, { reset = reset, set_session = set_session }; end |