diff options
45 files changed, 1848 insertions, 355 deletions
@@ -55,7 +55,7 @@ util/%.so: $(MAKE) install -C util-src %.install: % - sed "1s/\blua\b/$(RUNWITH)/; \ + sed "1s| lua$$| $(RUNWITH)|; \ s|^CFG_SOURCEDIR=.*;$$|CFG_SOURCEDIR='$(INSTALLEDSOURCE)';|; \ s|^CFG_CONFIGDIR=.*;$$|CFG_CONFIGDIR='$(INSTALLEDCONFIG)';|; \ s|^CFG_DATADIR=.*;$$|CFG_DATADIR='$(INSTALLEDDATA)';|; \ diff --git a/core/certmanager.lua b/core/certmanager.lua index 9dfb8f3a..d6a59b9f 100644 --- a/core/certmanager.lua +++ b/core/certmanager.lua @@ -15,9 +15,11 @@ local tostring = tostring; local pairs = pairs; local type = type; local io_open = io.open; +local t_concat = table.concat; +local t_insert = table.insert; local prosody = prosody; -local resolve_path = configmanager.resolve_relative_path; +local resolve_path = require"util.paths".resolve_relative_path; local config_path = prosody.paths.config; local luasec_has_noticket, luasec_has_verifyext, luasec_has_no_compression; @@ -33,11 +35,19 @@ module "certmanager" -- Global SSL options if not overridden per-host local global_ssl_config = configmanager.get("*", "ssl"); +-- Built-in defaults local core_defaults = { capath = "/etc/ssl/certs"; - protocol = "sslv23"; + protocol = "tlsv1+"; verify = (ssl and ssl.x509 and { "peer", "client_once", }) or "none"; - options = { "no_sslv2", "no_sslv3", "cipher_server_preference", luasec_has_noticket and "no_ticket" or nil }; + options = { + cipher_server_preference = true; + no_ticket = luasec_has_noticket; + no_compression = luasec_has_no_compression and configmanager.get("*", "ssl_compression") ~= true; + -- Has no_compression? Then it has these too... + single_dh_use = luasec_has_no_compression; + single_ecdh_use = luasec_has_no_compression; + }; verifyext = { "lsec_continue", "lsec_ignore_purpose" }; curve = "secp384r1"; ciphers = "HIGH+kEDH:HIGH+kEECDH:HIGH:!PSK:!SRP:!3DES:!aNULL"; @@ -45,6 +55,9 @@ local core_defaults = { local path_options = { -- These we pass through resolve_path() key = true, certificate = true, cafile = true, capath = true, dhparam = true } +local set_options = { + options = true, verify = true, verifyext = true +} if ssl and not luasec_has_verifyext and ssl.x509 then -- COMPAT mw/luasec-hg @@ -53,14 +66,21 @@ if ssl and not luasec_has_verifyext and ssl.x509 then end end -if luasec_has_no_compression then -- Has no_compression? Then it has these too... - core_defaults.options[#core_defaults.options+1] = "single_dh_use"; - core_defaults.options[#core_defaults.options+1] = "single_ecdh_use"; - if configmanager.get("*", "ssl_compression") ~= true then - core_defaults.options[#core_defaults.options+1] = "no_compression"; +local function merge_set(t, o) + if type(t) ~= "table" then t = { t } end + for k,v in pairs(t) do + if v == true or v == false then + o[k] = v; + else + o[v] = true; + end end + return o; end +local protocols = { "sslv2", "sslv3", "tlsv1", "tlsv1_1", "tlsv1_2" }; +for i = 1, #protocols do protocols[protocols[i] .. "+"] = i - 1; end + function create_context(host, mode, user_ssl_config) user_ssl_config = user_ssl_config or {} user_ssl_config.mode = mode; @@ -69,25 +89,61 @@ function create_context(host, mode, user_ssl_config) if global_ssl_config then for option,default_value in pairs(global_ssl_config) do - if not user_ssl_config[option] then + if user_ssl_config[option] == nil then user_ssl_config[option] = default_value; end end end + for option,default_value in pairs(core_defaults) do - if not user_ssl_config[option] then + if user_ssl_config[option] == nil then user_ssl_config[option] = default_value; end end - user_ssl_config.password = user_ssl_config.password or function() log("error", "Encrypted certificate for %s requires 'ssl' 'password' to be set in config", host); end; + + for option in pairs(set_options) do + local merged = {}; + merge_set(core_defaults[option], merged); + if global_ssl_config then + merge_set(global_ssl_config[option], merged); + end + merge_set(user_ssl_config[option], merged); + local final_array = {}; + for opt, enable in pairs(merged) do + if enable then + final_array[#final_array+1] = opt; + end + end + user_ssl_config[option] = final_array; + end + + local min_protocol = protocols[user_ssl_config.protocol]; + if min_protocol then + user_ssl_config.protocol = "sslv23"; + for i = 1, min_protocol do + t_insert(user_ssl_config.options, "no_"..protocols[i]); + end + end + + -- We can't read the password interactively when daemonized + user_ssl_config.password = user_ssl_config.password or + function() log("error", "Encrypted certificate for %s requires 'ssl' 'password' to be set in config", host); end; + for option in pairs(path_options) do if type(user_ssl_config[option]) == "string" then user_ssl_config[option] = resolve_path(config_path, user_ssl_config[option]); end end - if not user_ssl_config.key then return nil, "No key present in SSL/TLS configuration for "..host; end - if not user_ssl_config.certificate then return nil, "No certificate present in SSL/TLS configuration for "..host; end + -- Allow the cipher list to be a table + if type(user_ssl_config.ciphers) == "table" then + user_ssl_config.ciphers = t_concat(user_ssl_config.ciphers, ":") + end + + if mode == "server" then + if not user_ssl_config.key then return nil, "No key present in SSL/TLS configuration for "..host; end + if not user_ssl_config.certificate then return nil, "No certificate present in SSL/TLS configuration for "..host; end + end -- LuaSec expects dhparam to be a callback that takes two arguments. -- We ignore those because it is mostly used for having a separate @@ -141,6 +197,9 @@ end function reload_ssl_config() global_ssl_config = configmanager.get("*", "ssl"); + if luasec_has_no_compression then + core_defaults.options.no_compression = configmanager.get("*", "ssl_compression") ~= true; + end end prosody.events.add_handler("config-reloaded", reload_ssl_config); diff --git a/core/configmanager.lua b/core/configmanager.lua index d92120d0..1f7342b2 100644 --- a/core/configmanager.lua +++ b/core/configmanager.lua @@ -14,11 +14,15 @@ local format, math_max = string.format, math.max; local fire_event = prosody and prosody.events.fire_event or function () end; local envload = require"util.envload".envload; -local lfs = require "lfs"; +local deps = require"util.dependencies"; +local resolve_relative_path = require"util.paths".resolve_relative_path; +local glob_to_pattern = require"util.paths".glob_to_pattern; local path_sep = package.config:sub(1,1); module "configmanager" +_M.resolve_relative_path = resolve_relative_path; -- COMPAT + local parsers = {}; local config_mt = { __index = function (t, k) return rawget(t, "*"); end}; @@ -66,41 +70,6 @@ function _M.set(host, key, value, _oldvalue) return set(config, host, key, value); end --- Helper function to resolve relative paths (needed by config) -do - function resolve_relative_path(parent_path, path) - if path then - -- Some normalization - parent_path = parent_path:gsub("%"..path_sep.."+$", ""); - path = path:gsub("^%.%"..path_sep.."+", ""); - - local is_relative; - if path_sep == "/" and path:sub(1,1) ~= "/" then - is_relative = true; - elseif path_sep == "\\" and (path:sub(1,1) ~= "/" and (path:sub(2,3) ~= ":\\" and path:sub(2,3) ~= ":/")) then - is_relative = true; - end - if is_relative then - return parent_path..path_sep..path; - end - end - return path; - end -end - --- Helper function to convert a glob to a Lua pattern -local function glob_to_pattern(glob) - return "^"..glob:gsub("[%p*?]", function (c) - if c == "*" then - return ".*"; - elseif c == "?" then - return "."; - else - return "%"..c; - end - end).."$"; -end - function load(filename, format) format = format or filename:match("%w+$"); @@ -214,6 +183,10 @@ do function env.Include(file) if file:match("[*?]") then + local lfs = deps.softreq "lfs"; + if not lfs then + error(format("Error expanding wildcard pattern in Include %q - LuaFileSystem not available", file)); + end local path_pos, glob = file:match("()([^"..path_sep.."]+)$"); local path = file:sub(1, math_max(path_pos-2,0)); local config_path = config_file:gsub("[^"..path_sep.."]+$", ""); diff --git a/core/moduleapi.lua b/core/moduleapi.lua index 65e00d41..30d28418 100644 --- a/core/moduleapi.lua +++ b/core/moduleapi.lua @@ -13,11 +13,14 @@ local set = require "util.set"; local logger = require "util.logger"; local pluginloader = require "util.pluginloader"; local timer = require "util.timer"; +local resolve_relative_path = require"util.paths".resolve_relative_path; local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat; local error, setmetatable, type = error, setmetatable, type; -local ipairs, pairs, select, unpack = ipairs, pairs, select, unpack; +local ipairs, pairs, select = ipairs, pairs, select; local tonumber, tostring = tonumber, tostring; +local pack = table.pack or function(...) return {n=select("#",...), ...}; end -- table.pack is only in 5.2 +local unpack = table.unpack or unpack; -- renamed in 5.2 local prosody = prosody; local hosts = prosody.hosts; @@ -347,11 +350,29 @@ function api:send(stanza) return core_post_stanza(hosts[self.host], stanza); end -function api:add_timer(delay, callback) - return timer.add_task(delay, function (t) - if self.loaded == false then return; end - return callback(t); - end); +local timer_methods = { } +local timer_mt = { + __index = timer_methods; +} +function timer_methods:stop( ) + timer.stop(self.id); +end +timer_methods.disarm = timer_methods.stop +function timer_methods:reschedule(delay) + timer.reschedule(self.id, delay) +end + +local function timer_callback(now, id, t) + if t.module_env.loaded == false then return; end + return t.callback(now, unpack(t, 1, t.n)); +end + +function api:add_timer(delay, callback, ...) + local t = pack(...) + t.module_env = self; + t.callback = callback; + t.id = timer.add_task(delay, timer_callback, t); + return setmetatable(t, timer_mt); end local path_sep = package.config:sub(1,1); @@ -360,7 +381,7 @@ function api:get_directory() end function api:load_resource(path, mode) - path = config.resolve_relative_path(self:get_directory(), path); + path = resolve_relative_path(self:get_directory(), path); return io.open(path, mode); end diff --git a/core/modulemanager.lua b/core/modulemanager.lua index 2e488fd5..eb1ce733 100644 --- a/core/modulemanager.lua +++ b/core/modulemanager.lua @@ -30,7 +30,7 @@ pcall = function(f, ...) end local autoload_modules = {prosody.platform, "presence", "message", "iq", "offline", "c2s", "s2s"}; -local component_inheritable_modules = {"tls", "dialback", "iq", "s2s"}; +local component_inheritable_modules = {"tls", "saslauth", "dialback", "iq", "s2s"}; -- We need this to let modules access the real global namespace local _G = _G; diff --git a/core/portmanager.lua b/core/portmanager.lua index 95900c08..4cbf3eb3 100644 --- a/core/portmanager.lua +++ b/core/portmanager.lua @@ -29,6 +29,8 @@ if socket.tcp6 and config.get("*", "use_ipv6") ~= false then table.insert(default_local_interfaces, "::1"); end +local default_mode = config.get("*", "network_default_read_size") or 4096; + --- Private state -- service_name -> { service_info, ... } @@ -111,7 +113,7 @@ function activate(service_name) } bind_ports = set.new(type(bind_ports) ~= "table" and { bind_ports } or bind_ports ); - local mode, ssl = listener.default_mode or "*a"; + local mode, ssl = listener.default_mode or default_mode; local hooked_ports = {}; for interface in bind_interfaces do diff --git a/net/http.lua b/net/http.lua index ab9ec7b6..b87c9396 100644 --- a/net/http.lua +++ b/net/http.lua @@ -6,7 +6,6 @@ -- COPYING file in the source package for more information. -- -local socket = require "socket" local b64 = require "util.encodings".base64.encode; local url = require "socket.url" local httpstream_new = require "net.http.parser".new; @@ -160,21 +159,17 @@ function request(u, ex, callback) end local port_number = port and tonumber(port) or (using_https and 443 or 80); - -- Connect the socket, and wrap it with net.server - local conn = socket.tcp(); - conn:settimeout(10); - local ok, err = conn:connect(host, port_number); - if not ok and err ~= "timeout" then - callback(nil, 0, req); - return nil, err; - end - local sslctx = false; if using_https then sslctx = ex and ex.sslctx or { mode = "client", protocol = "sslv23", options = { "no_sslv2" } }; end - req.handler, req.conn = assert(server.wrapclient(conn, host, port_number, listener, "*a", sslctx)); + local handler, conn = server.addclient(host, port_number, listener, "*a", sslctx) + if not handler then + callback(nil, 0, req); + return nil, conn; + end + req.handler, req.conn = handler, conn req.write = function (...) return req.handler:write(...); end req.callback = function (content, code, request, response) log("debug", "Calling callback, status %s", code or "---"); return select(2, xpcall(function () return callback(content, code, request, response) end, handleerr)); end diff --git a/net/http/server.lua b/net/http/server.lua index 5961169f..510b77fb 100644 --- a/net/http/server.lua +++ b/net/http/server.lua @@ -185,6 +185,7 @@ function handle_request(conn, request, finish_cb) persistent = persistent; conn = conn; send = _M.send_response; + done = _M.finish_response; finish_cb = finish_cb; }; conn._http_open_response = response; @@ -246,24 +247,30 @@ function handle_request(conn, request, finish_cb) response.status_code = 404; response:send(events.fire_event("http-error", { code = 404 })); end -function _M.send_response(response, body) - if response.finished then return; end - response.finished = true; - response.conn._http_open_response = nil; - +local function prepare_header(response) local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]); local headers = response.headers; - body = body or response.body or ""; - headers.content_length = #body; - local output = { status_line }; for k,v in pairs(headers) do t_insert(output, headerfix[k]..v); end t_insert(output, "\r\n\r\n"); + return output; +end +_M.prepare_header = prepare_header; +function _M.send_response(response, body) + if response.finished then return; end + body = body or response.body or ""; + response.headers.content_length = #body; + local output = prepare_header(response); t_insert(output, body); - response.conn:write(t_concat(output)); + response:done(); +end +function _M.finish_response(response) + if response.finished then return; end + response.finished = true; + response.conn._http_open_response = nil; if response.on_destroy then response:on_destroy(); response.on_destroy = nil; diff --git a/net/server_event.lua b/net/server_event.lua index 59217a0c..a3087847 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -44,8 +44,9 @@ local setmetatable = use "setmetatable" local t_insert = table.insert local t_concat = table.concat -local ssl = use "ssl" +local has_luasec, ssl = pcall ( require , "ssl" ) local socket = use "socket" or require "socket" +local getaddrinfo = socket.dns.getaddrinfo local log = require ("util.logger").init("socket") @@ -128,7 +129,7 @@ do return self:_destroy(); end - function interface_mt:_start_connection(plainssl) -- should be called from addclient + function interface_mt:_start_connection(plainssl) -- called from wrapclient local callback = function( event ) if EV_TIMEOUT == event then -- timeout during connection self.fatalerror = "connection timeout" @@ -136,7 +137,7 @@ do self:_close() debug( "new connection failed. id:", self.id, "error:", self.fatalerror ) else - if plainssl and ssl then -- start ssl session + if plainssl and has_luasec then -- start ssl session self:starttls(self._sslctx, true) else -- normal connection self:_start_session(true) @@ -367,6 +368,7 @@ do function interface_mt:ssl() return self._usingssl end + interface_mt.clientport = interface_mt.port -- COMPAT server_select function interface_mt:type() return self._type or "client" @@ -506,7 +508,7 @@ do _sslctx = sslctx; -- parameters _usingssl = false; -- client is using ssl; } - if not ssl then interface.starttls = false; end + if not has_luasec then interface.starttls = false; end interface.id = tostring(interface):match("%x+$"); interface.writecallback = function( event ) -- called on write events --vdebug( "new client write event, id/ip/port:", interface, ip, port ) @@ -689,7 +691,7 @@ do interface._connections = interface._connections + 1 -- increase connection count local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx ) --vdebug( "client id:", clientinterface, "startssl:", startssl ) - if ssl and sslctx then + if has_luasec and sslctx then clientinterface:starttls(sslctx, true) else clientinterface:_start_session( true ) @@ -710,25 +712,17 @@ do end local addserver = ( function( ) - return function( addr, port, listener, pattern, sslcfg, startssl ) -- TODO: check arguments - --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslcfg or "nil", startssl or "nil") + return function( addr, port, listener, pattern, sslctx, startssl ) -- TODO: check arguments + --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslctx or "nil", startssl or "nil") + if sslctx and not has_luasec then + debug "fatal error: luasec not found" + return nil, "luasec not found" + end local server, err = socket.bind( addr, port, cfg.ACCEPT_QUEUE ) -- create server socket if not server then debug( "creating server socket on "..addr.." port "..port.." failed:", err ) return nil, err end - local sslctx - if sslcfg then - if not ssl then - debug "fatal error: luasec not found" - return nil, "luasec not found" - end - sslctx, err = sslcfg - if err then - debug( "error while creating new ssl context for server socket:", err ) - return nil, err - end - end local interface = handleserver( server, addr, port, pattern, listener, sslctx, startssl ) -- new server handler debug( "new server created with id:", tostring(interface)) return interface @@ -744,37 +738,34 @@ do --function handleclient( client, ip, port, server, pattern, listener, _, sslctx ) -- creates an client interface end - function addclient( addr, serverport, listener, pattern, localaddr, localport, sslcfg, startssl ) - local client, err = socket.tcp() -- creating new socket + function addclient( addr, serverport, listener, pattern, sslctx, typ ) + if sslctx and not has_luasec then + debug "need luasec, but not available" + return nil, "luasec not found" + end + if getaddrinfo and not typ then + local addrinfo, err = getaddrinfo(addr) + if not addrinfo then return nil, err end + if addrinfo[1] and addrinfo[1].family == "inet6" then + typ = "tcp6" + end + end + local create = socket[typ or "tcp"] + if type( create ) ~= "function" then + return nil, "invalid socket type" + end + local client, err = create() -- creating new socket if not client then debug( "cannot create socket:", err ) return nil, err end client:settimeout( 0 ) -- set nonblocking - if localaddr then - local res, err = client:bind( localaddr, localport, -1 ) - if not res then - debug( "cannot bind client:", err ) - return nil, err - end - end - local sslctx - if sslcfg then -- handle ssl/new context - if not ssl then - debug "need luasec, but not available" - return nil, "luasec not found" - end - sslctx, err = sslcfg - if err then - debug( "cannot create new ssl context:", err ) - return nil, err - end - end local res, err = client:connect( addr, serverport ) -- connect if res or ( err == "timeout" ) then - local ip, port = client:getsockname( ) - local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx, startssl ) - interface:_start_connection( startssl ) + if client.getsockname then + addr = client:getsockname( ) + end + local interface = wrapclient( client, addr, serverport, listener, pattern, sslctx ) debug( "new connection id:", interface.id ) return interface, err else diff --git a/net/server_select.lua b/net/server_select.lua index c5e0772f..4a36617c 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -48,13 +48,14 @@ local coroutine_yield = coroutine.yield --// extern libs //-- -local luasec = use "ssl" +local has_luasec, luasec = pcall ( require , "ssl" ) local luasocket = use "socket" or require "socket" local luasocket_gettime = luasocket.gettime +local getaddrinfo = luasocket.dns.getaddrinfo --// extern lib methods //-- -local ssl_wrap = ( luasec and luasec.wrap ) +local ssl_wrap = ( has_luasec and luasec.wrap ) local socket_bind = luasocket.bind local socket_sleep = luasocket.sleep local socket_select = luasocket.select @@ -401,6 +402,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.clientport = function( ) return clientport end + handler.port = handler.clientport -- COMPAT server_event local write = function( self, data ) bufferlen = bufferlen + #data if bufferlen > maxsendlen then @@ -585,7 +587,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end ) end - if luasec then + if has_luasec then handler.starttls = function( self, _sslctx) if _sslctx then handler:set_sslctx(_sslctx); @@ -638,7 +640,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _socketlist[ socket ] = handler _readlistlen = addsocket(_readlist, socket, _readlistlen) - if sslctx and luasec then + if sslctx and has_luasec then out_put "server.lua: auto-starting ssl negotiation..." handler.autostart_ssl = true; local ok, err = handler:starttls(sslctx); @@ -713,22 +715,23 @@ end ----------------------------------// PUBLIC //-- addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server + addr = addr or "*" local err if type( listeners ) ~= "table" then err = "invalid listener table" - end - if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then + elseif type ( addr ) ~= "string" then + err = "invalid address" + elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then err = "invalid port" elseif _server[ addr..":"..port ] then err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist" - elseif sslctx and not luasec then + elseif sslctx and not has_luasec then err = "luasec not found" end if err then out_error( "server.lua, [", addr, "]:", port, ": ", err ) return nil, err end - addr = addr or "*" local server, err = socket_bind( addr, port, _tcpbacklog ) if err then out_error( "server.lua, [", addr, "]:", port, ": ", err ) @@ -929,17 +932,44 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx return handler, socket end -local addclient = function( address, port, listeners, pattern, sslctx ) - local client, err = luasocket.tcp( ) +local addclient = function( address, port, listeners, pattern, sslctx, typ ) + local err + if type( listeners ) ~= "table" then + err = "invalid listener table" + elseif type ( address ) ~= "string" then + err = "invalid address" + elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then + err = "invalid port" + elseif sslctx and not has_luasec then + err = "luasec not found" + end + if getaddrinfo and not typ then + local addrinfo, err = getaddrinfo(address) + if not addrinfo then return nil, err end + if addrinfo[1] and addrinfo[1].family == "inet6" then + typ = "tcp6" + end + end + local create = luasocket[typ or "tcp"] + if type( create ) ~= "function" then + err = "invalid socket type" + end + + if err then + out_error( "server.lua, addclient: ", err ) + return nil, err + end + + local client, err = create( ) if err then return nil, err end client:settimeout( 0 ) - _, err = client:connect( address, port ) - if err then -- try again + local ok, err = client:connect( address, port ) + if ok or err == "timeout" then return wrapclient( client, address, port, listeners, pattern, sslctx ) else - return wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx ) + return nil, err end end diff --git a/plugins/mod_admin_telnet.lua b/plugins/mod_admin_telnet.lua index 2aa9bd9b..66560d44 100644 --- a/plugins/mod_admin_telnet.lua +++ b/plugins/mod_admin_telnet.lua @@ -154,6 +154,14 @@ function console_listener.onincoming(conn, data) session.partial_data = data:match("[^\n]+$"); end +function console_listener.onreadtimeout(conn) + local session = sessions[conn]; + if session then + session.send("\0"); + return true; + end +end + function console_listener.ondisconnect(conn, err) local session = sessions[conn]; if session then @@ -212,9 +220,11 @@ function commands.help(session, data) print [[c2s:show(jid) - Show all client sessions with the specified JID (or all if no JID given)]] print [[c2s:show_insecure() - Show all unencrypted client connections]] print [[c2s:show_secure() - Show all encrypted client connections]] + print [[c2s:show_tls() - Show TLS cipher info for encrypted sessions]] print [[c2s:close(jid) - Close all sessions for the specified JID]] elseif section == "s2s" then print [[s2s:show(domain) - Show all s2s connections for the given domain (or all if no domain given)]] + print [[s2s:show_tls(domain) - Show TLS cipher info for encrypted sessions]] print [[s2s:close(from, to) - Close a connection from one domain to another]] print [[s2s:closeall(host) - Close all the incoming/outgoing s2s sessions to specified host]] elseif section == "module" then @@ -471,22 +481,28 @@ function def_env.config:reload() return ok, (ok and "Config reloaded (you may need to reload modules to take effect)") or tostring(err); end -def_env.hosts = {}; -function def_env.hosts:list() - for host, host_session in pairs(hosts) do - self.session.print(host); +local function common_info(session, line) + if session.id then + line[#line+1] = "["..session.id.."]" + else + line[#line+1] = "["..session.type..(tostring(session):match("%x*$")).."]" end - return true, "Done"; -end - -function def_env.hosts:add(name) end local function session_flags(session, line) line = line or {}; + common_info(session, line); + if session.type == "c2s" then + local status, priority = "unavailable", tostring(session.priority or "-"); + if session.presence then + status = session.presence:get_child_text("show") or "available"; + end + line[#line+1] = status.."("..priority..")"; + end if session.cert_identity_status == "valid" then - line[#line+1] = "(secure)"; - elseif session.secure then + line[#line+1] = "(authenticated)"; + end + if session.secure then line[#line+1] = "(encrypted)"; end if session.compressed then @@ -501,6 +517,23 @@ local function session_flags(session, line) return table.concat(line, " "); end +local function tls_info(session, line) + line = line or {}; + common_info(session, line); + if session.secure then + local sock = session.conn and session.conn.socket and session.conn:socket(); + if sock and sock.info then + local info = sock:info(); + line[#line+1] = ("(%s with %s)"):format(info.protocol, info.cipher); + else + line[#line+1] = "(cipher info unavailable)"; + end + else + line[#line+1] = "(insecure)"; + end + return table.concat(line, " "); +end + def_env.c2s = {}; local function show_c2s(callback) @@ -524,8 +557,9 @@ function def_env.c2s:count(match_jid) return true, "Total: "..count.." clients"; end -function def_env.c2s:show(match_jid) +function def_env.c2s:show(match_jid, annotate) local print, count = self.session.print, 0; + annotate = annotate or session_flags; local curr_host; show_c2s(function (jid, session) if curr_host ~= session.host then @@ -534,11 +568,7 @@ function def_env.c2s:show(match_jid) end if (not match_jid) or jid:match(match_jid) then count = count + 1; - local status, priority = "unavailable", tostring(session.priority or "-"); - if session.presence then - status = session.presence:get_child_text("show") or "available"; - end - print(session_flags(session, { " "..jid.." - "..status.."("..priority..")" })); + print(annotate(session, { " ", jid })); end end); return true, "Total: "..count.." clients"; @@ -566,6 +596,10 @@ function def_env.c2s:show_secure(match_jid) return true, "Total: "..count.." secure client connections"; end +function def_env.c2s:show_tls(match_jid) + return self:show(match_jid, tls_info); +end + function def_env.c2s:close(match_jid) local count = 0; show_c2s(function (jid, session) @@ -579,8 +613,9 @@ end def_env.s2s = {}; -function def_env.s2s:show(match_jid) +function def_env.s2s:show(match_jid, annotate) local print = self.session.print; + annotate = annotate or session_flags; local count_in, count_out = 0,0; local s2s_list = { }; @@ -598,8 +633,7 @@ function def_env.s2s:show(match_jid) remotehost, localhost = session.from_host or "?", session.to_host or "?"; end local sess_lines = { l = localhost, r = remotehost, - session_flags(session, { "", direction, remotehost or "?", - "["..session.type..tostring(session):match("[a-f0-9]*$").."]" })}; + annotate(session, { "", direction, remotehost or "?" })}; if (not match_jid) or remotehost:match(match_jid) or localhost:match(match_jid) then table.insert(s2s_list, sess_lines); @@ -654,6 +688,10 @@ function def_env.s2s:show(match_jid) return true, "Total: "..count_out.." outgoing, "..count_in.." incoming connections"; end +function def_env.s2s:show_tls(match_jid) + return self:show(match_jid, tls_info); +end + local function print_subject(print, subject) for _, entry in ipairs(subject) do print( @@ -823,9 +861,19 @@ end function def_env.host:list() local print = self.session.print; local i = 0; + local type; for host in values(array.collect(keys(prosody.hosts)):sort()) do i = i + 1; - print(host); + type = hosts[host].type; + if type == "local" then + print(host); + else + type = module:context(host):get_option_string("component_module", type); + if type ~= "component" then + type = type .. " component"; + end + print(("%s (%s)"):format(host, type)); + end end return true, i.." hosts"; end @@ -896,6 +944,9 @@ end function def_env.muc:create(room_jid) local room, host = check_muc(room_jid); + if not room_name then + return room_name, host; + end if not room then return nil, host end if hosts[host].modules.muc.rooms[room_jid] then return nil, "Room exists already" end return hosts[host].modules.muc.create_room(room_jid); @@ -903,6 +954,9 @@ end function def_env.muc:room(room_jid) local room_name, host = check_muc(room_jid); + if not room_name then + return room_name, host; + end local room_obj = hosts[host].modules.muc.rooms[room_jid]; if not room_obj then return nil, "No such room: "..room_jid; @@ -910,6 +964,19 @@ function def_env.muc:room(room_jid) return setmetatable({ room = room_obj }, console_room_mt); end +function def_env.muc:list(host) + local host_session = hosts[host]; + if not host_session or not host_session.modules.muc then + return nil, "Please supply the address of a local MUC component"; + end + local c = 0; + for name in keys(host_session.modules.muc.rooms) do + print(name); + c = c + 1; + end + return true, c.." rooms"; +end + local um = require"core.usermanager"; def_env.user = {}; diff --git a/plugins/mod_auth_anonymous.lua b/plugins/mod_auth_anonymous.lua index c877d532..8de46f8c 100644 --- a/plugins/mod_auth_anonymous.lua +++ b/plugins/mod_auth_anonymous.lua @@ -43,7 +43,7 @@ function provider.get_sasl_handler() end function provider.users() - return next, hosts[host].sessions, nil; + return next, hosts[module.host].sessions, nil; end -- datamanager callback to disable writes diff --git a/plugins/mod_auth_internal_hashed.lua b/plugins/mod_auth_internal_hashed.lua index fb87bb9f..954392c9 100644 --- a/plugins/mod_auth_internal_hashed.lua +++ b/plugins/mod_auth_internal_hashed.lua @@ -7,6 +7,8 @@ -- COPYING file in the source package for more information. -- +local max = math.max; + local getAuthenticationDatabaseSHA1 = require "util.sasl.scram".getAuthenticationDatabaseSHA1; local usermanager = require "core.usermanager"; local generate_uuid = require "util.uuid".generate; @@ -39,7 +41,7 @@ end -- Default; can be set per-user -local iteration_count = 4096; +local default_iteration_count = 4096; -- define auth provider local provider = {}; @@ -80,8 +82,8 @@ function provider.set_password(username, password) log("debug", "set_password for username '%s'", username); local account = accounts:get(username); if account then - account.salt = account.salt or generate_uuid(); - account.iteration_count = account.iteration_count or iteration_count; + account.salt = generate_uuid(); + account.iteration_count = max(account.iteration_count or 0, default_iteration_count); local valid, stored_key, server_key = getAuthenticationDatabaseSHA1(password, account.salt, account.iteration_count); local stored_key_hex = to_hex(stored_key); local server_key_hex = to_hex(server_key); @@ -113,10 +115,10 @@ function provider.create_user(username, password) return accounts:set(username, {}); end local salt = generate_uuid(); - local valid, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count); + local valid, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, default_iteration_count); local stored_key_hex = to_hex(stored_key); local server_key_hex = to_hex(server_key); - return accounts:set(username, {stored_key = stored_key_hex, server_key = server_key_hex, salt = salt, iteration_count = iteration_count}); + return accounts:set(username, {stored_key = stored_key_hex, server_key = server_key_hex, salt = salt, iteration_count = default_iteration_count}); end function provider.delete_user(username) diff --git a/plugins/mod_c2s.lua b/plugins/mod_c2s.lua index 7a8af406..f0cdd7fb 100644 --- a/plugins/mod_c2s.lua +++ b/plugins/mod_c2s.lua @@ -174,19 +174,6 @@ local function session_close(session, reason) end end -local function session_open_stream(session) - local attr = { - ["xmlns:stream"] = 'http://etherx.jabber.org/streams', - xmlns = stream_callbacks.default_ns, - version = "1.0", - ["xml:lang"] = 'en', - id = session.streamid or "", - from = session.host - }; - session.send("<?xml version='1.0'?>"); - session.send(st.stanza("stream:stream", attr):top_tag()); -end - module:hook_global("user-deleted", function(event) local username, host = event.username, event.host; local user = hosts[host].sessions[username]; @@ -234,7 +221,6 @@ function listener.onconnect(conn) conn:setoption("keepalive", opt_keepalives); end - session.open_stream = session_open_stream; session.close = session_close; local stream = new_xmpp_stream(session, stream_callbacks); diff --git a/plugins/mod_component.lua b/plugins/mod_component.lua index 1497b12f..297609d8 100644 --- a/plugins/mod_component.lua +++ b/plugins/mod_component.lua @@ -177,9 +177,7 @@ function stream_callbacks.streamopened(session, attr) session.streamid = uuid_gen(); session.notopen = nil; -- Return stream header - session.send("<?xml version='1.0'?>"); - session.send(st.stanza("stream:stream", { xmlns=xmlns_component, - ["xmlns:stream"]='http://etherx.jabber.org/streams', id=session.streamid, from=session.host }):top_tag()); + session:open_stream(); end function stream_callbacks.streamclosed(session) diff --git a/plugins/mod_compression.lua b/plugins/mod_compression.lua index f44e8a6d..969172fd 100644 --- a/plugins/mod_compression.lua +++ b/plugins/mod_compression.lua @@ -26,7 +26,7 @@ end module:hook("stream-features", function(event) local origin, features = event.origin, event.features; - if not origin.compressed and (origin.type == "c2s" or origin.type == "s2sin" or origin.type == "s2sout") then + if not origin.compressed and origin.type == "c2s" then -- FIXME only advertise compression support when TLS layer has no compression enabled features:add_child(compression_stream_feature); end @@ -35,7 +35,7 @@ end); module:hook("s2s-stream-features", function(event) local origin, features = event.origin, event.features; -- FIXME only advertise compression support when TLS layer has no compression enabled - if not origin.compressed and (origin.type == "c2s" or origin.type == "s2sin" or origin.type == "s2sout") then + if not origin.compressed and origin.type == "s2sin" then features:add_child(compression_stream_feature); end end); @@ -43,13 +43,13 @@ end); -- Hook to activate compression if remote server supports it. module:hook_stanza(xmlns_stream, "features", function (session, stanza) - if not session.compressed and (session.type == "c2s" or session.type == "s2sin" or session.type == "s2sout") then + if not session.compressed and session.type == "s2sout" then -- does remote server support compression? - local comp_st = stanza:child_with_name("compression"); + local comp_st = stanza:get_child("compression", xmlns_compression_feature); if comp_st then -- do we support the mechanism - for a in comp_st:children() do - local algorithm = a[1] + for a in comp_st:childtags("method") do + local algorithm = a:get_text(); if algorithm == "zlib" then session.sends2s(st.stanza("compress", {xmlns=xmlns_compression_protocol}):tag("method"):text("zlib")) session.log("debug", "Enabled compression using zlib.") @@ -125,8 +125,8 @@ end module:hook("stanza/http://jabber.org/protocol/compress:compressed", function(event) local session = event.origin; - - if session.type == "s2sout_unauthed" or session.type == "s2sout" then + + if session.type == "s2sout" then session.log("debug", "Activating compression...") -- create deflate and inflate streams local deflate_stream = get_deflate_stream(session); @@ -150,7 +150,7 @@ end); module:hook("stanza/http://jabber.org/protocol/compress:compress", function(event) local session, stanza = event.origin, event.stanza; - if session.type == "c2s" or session.type == "s2sin" or session.type == "c2s_unauthed" or session.type == "s2sin_unauthed" then + if session.type == "c2s" or session.type == "s2sin" then -- fail if we are already compressed if session.compressed then local error_st = st.stanza("failure", {xmlns=xmlns_compression_protocol}):tag("setup-failed"); @@ -160,8 +160,7 @@ module:hook("stanza/http://jabber.org/protocol/compress:compress", function(even end -- checking if the compression method is supported - local method = stanza:child_with_name("method"); - method = method and (method[1] or ""); + local method = stanza:get_child_text("method"); if method == "zlib" then session.log("debug", "zlib compression enabled."); diff --git a/plugins/mod_http.lua b/plugins/mod_http.lua index 95933da5..49529ea2 100644 --- a/plugins/mod_http.lua +++ b/plugins/mod_http.lua @@ -42,7 +42,7 @@ local function get_base_path(host_module, app_name, default_app_path) return (normalize_path(host_module:get_option("http_paths", {})[app_name] -- Host or module:get_option("http_paths", {})[app_name] -- Global or default_app_path)) -- Default - :gsub("%$(%w+)", { host = module.host }); + :gsub("%$(%w+)", { host = host_module.host }); end local ports_by_scheme = { http = 80, https = 443, }; @@ -51,6 +51,9 @@ local ports_by_scheme = { http = 80, https = 443, }; function moduleapi.http_url(module, app_name, default_path) app_name = app_name or (module.name:gsub("^http_", "")); local external_url = url_parse(module:get_option_string("http_external_url")) or {}; + if external_url.scheme and external_url.port == nil then + external_url.port = ports_by_scheme[external_url.scheme]; + end local services = portmanager.get_active_services(); local http_services = services:get("https") or services:get("http") or {}; for interface, ports in pairs(http_services) do @@ -139,7 +142,13 @@ module:provides("net", { listener = server.listener; default_port = 5281; encryption = "ssl"; - ssl_config = { verify = "none" }; + ssl_config = { + verify = { + peer = false, + client_once = false, + "none", + } + }; multiplex = { pattern = "^[A-Z]"; }; diff --git a/plugins/mod_http_files.lua b/plugins/mod_http_files.lua index dd04853b..2e9f4182 100644 --- a/plugins/mod_http_files.lua +++ b/plugins/mod_http_files.lua @@ -14,6 +14,7 @@ local os_date = os.date; local open = io.open; local stat = lfs.attributes; local build_path = require"socket.url".build_path; +local path_sep = package.config:sub(1,1); local base_path = module:get_option_string("http_files_dir", module:get_option_string("http_path")); local dir_indices = module:get_option("http_index_files", { "index.html", "index.htm" }); @@ -61,7 +62,7 @@ function serve(opts) local request, response = event.request, event.response; local orig_path = request.path; local full_path = base_path .. (path and "/"..path or ""); - local attr = stat(full_path); + local attr = stat((full_path:gsub('%'..path_sep..'+$',''))); if not attr then return 404; end diff --git a/plugins/mod_pep_plus.lua b/plugins/mod_pep_plus.lua new file mode 100644 index 00000000..ee57e647 --- /dev/null +++ b/plugins/mod_pep_plus.lua @@ -0,0 +1,368 @@ +local pubsub = require "util.pubsub"; +local jid_bare = require "util.jid".bare; +local jid_split = require "util.jid".split; +local set_new = require "util.set".new; +local st = require "util.stanza"; +local calculate_hash = require "util.caps".calculate_hash; +local is_contact_subscribed = require "core.rostermanager".is_contact_subscribed; + +local xmlns_pubsub = "http://jabber.org/protocol/pubsub"; +local xmlns_pubsub_event = "http://jabber.org/protocol/pubsub#event"; +local xmlns_pubsub_owner = "http://jabber.org/protocol/pubsub#owner"; + +local lib_pubsub = module:require "pubsub"; +local handlers = lib_pubsub.handlers; +local pubsub_error_reply = lib_pubsub.pubsub_error_reply; + +local services = {}; +local recipients = {}; +local hash_map = {}; + +function module.save() + return { services = services }; +end + +function module.restore(data) + services = data.services; +end + +local function subscription_presence(user_bare, recipient) + local recipient_bare = jid_bare(recipient); + if (recipient_bare == user_bare) then return true; end + local username, host = jid_split(user_bare); + return is_contact_subscribed(username, host, recipient_bare); +end + +local function get_broadcaster(name) + local function simple_broadcast(kind, node, jids, item) + if item then + item = st.clone(item); + item.attr.xmlns = nil; -- Clear the pubsub namespace + end + local message = st.message({ from = name, type = "headline" }) + :tag("event", { xmlns = xmlns_pubsub_event }) + :tag(kind, { node = node }) + :add_child(item); + for jid in pairs(jids) do + module:log("debug", "Sending notification to %s from %s: %s", jid, name, tostring(item)); + message.attr.to = jid; + module:send(message); + end + end + return simple_broadcast; +end + +function get_pep_service(name) + if services[name] then + return services[name]; + end + services[name] = pubsub.new({ + capabilities = { + none = { + create = false; + publish = false; + retract = false; + get_nodes = false; + + subscribe = false; + unsubscribe = false; + get_subscription = false; + get_subscriptions = false; + get_items = false; + + subscribe_other = false; + unsubscribe_other = false; + get_subscription_other = false; + get_subscriptions_other = false; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = false; + }; + subscriber = { + create = false; + publish = false; + retract = false; + get_nodes = true; + + subscribe = true; + unsubscribe = true; + get_subscription = true; + get_subscriptions = true; + get_items = true; + + subscribe_other = false; + unsubscribe_other = false; + get_subscription_other = false; + get_subscriptions_other = false; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = false; + }; + publisher = { + create = false; + publish = true; + retract = true; + get_nodes = true; + + subscribe = true; + unsubscribe = true; + get_subscription = true; + get_subscriptions = true; + get_items = true; + + subscribe_other = false; + unsubscribe_other = false; + get_subscription_other = false; + get_subscriptions_other = false; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = false; + }; + owner = { + create = true; + publish = true; + retract = true; + delete = true; + get_nodes = true; + + subscribe = true; + unsubscribe = true; + get_subscription = true; + get_subscriptions = true; + get_items = true; + + + subscribe_other = true; + unsubscribe_other = true; + get_subscription_other = true; + get_subscriptions_other = true; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = true; + }; + }; + + autocreate_on_publish = true; + autocreate_on_subscribe = true; + + broadcaster = get_broadcaster(name); + get_affiliation = function (jid) + if jid_bare(jid) == name then + return "owner"; + elseif subscription_presence(name, jid) then + return "subscriber"; + end + end; + + normalize_jid = jid_bare; + }); + return services[name]; +end + +function handle_pubsub_iq(event) + local origin, stanza = event.origin, event.stanza; + local pubsub = stanza.tags[1]; + local action = pubsub.tags[1]; + if not action then + return origin.send(st.error_reply(stanza, "cancel", "bad-request")); + end + local service_name = stanza.attr.to or origin.username.."@"..origin.host + local service = get_pep_service(service_name); + local handler = handlers[stanza.attr.type.."_"..action.name]; + if handler then + handler(origin, stanza, action, service); + return true; + end +end + +module:hook("iq/bare/"..xmlns_pubsub..":pubsub", handle_pubsub_iq); +module:hook("iq/bare/"..xmlns_pubsub_owner..":pubsub", handle_pubsub_iq); + +module:add_identity("pubsub", "pep", module:get_option_string("name", "Prosody")); +module:add_feature("http://jabber.org/protocol/pubsub#publish"); + +local function get_caps_hash_from_presence(stanza, current) + local t = stanza.attr.type; + if not t then + local child = stanza:get_child("c", "http://jabber.org/protocol/caps"); + if child then + local attr = child.attr; + if attr.hash then -- new caps + if attr.hash == 'sha-1' and attr.node and attr.ver then + return attr.ver, attr.node.."#"..attr.ver; + end + else -- legacy caps + if attr.node and attr.ver then + return attr.node.."#"..attr.ver.."#"..(attr.ext or ""), attr.node.."#"..attr.ver; + end + end + end + return; -- no or bad caps + elseif t == "unavailable" or t == "error" then + return; + end + return current; -- no caps, could mean caps optimization, so return current +end + +local function resend_last_item(jid, node, service) + local ok, items = service:get_items(node, jid); + if not ok then return; end + for i, id in ipairs(items) do + service.config.broadcaster("items", node, { [jid] = true }, items[id]); + end +end + +local function update_subscriptions(recipient, service_name, nodes) + local service = get_pep_service(service_name); + + recipients[service_name] = recipients[service_name] or {}; + nodes = nodes or set_new(); + local old = recipients[service_name][recipient]; + + if old and type(old) == table then + for node in pairs((old - nodes):items()) do + service:remove_subscription(node, recipient, recipient); + end + end + + for node in nodes:items() do + service:add_subscription(node, recipient, recipient); + resend_last_item(recipient, node, service); + end + recipients[service_name][recipient] = nodes; +end + +module:hook("presence/bare", function(event) + -- inbound presence to bare JID recieved + local origin, stanza = event.origin, event.stanza; + local user = stanza.attr.to or (origin.username..'@'..origin.host); + local t = stanza.attr.type; + local self = not stanza.attr.to; + local service = get_pep_service(user); + + if not t then -- available presence + if self or subscription_presence(user, stanza.attr.from) then + local recipient = stanza.attr.from; + local current = recipients[user] and recipients[user][recipient]; + local hash, query_node = get_caps_hash_from_presence(stanza, current); + if current == hash or (current and current == hash_map[hash]) then return; end + if not hash then + update_subscriptions(recipient, user); + else + recipients[user] = recipients[user] or {}; + if hash_map[hash] then + update_subscriptions(recipient, user, hash_map[hash]); + else + recipients[user][recipient] = hash; + local from_bare = origin.type == "c2s" and origin.username.."@"..origin.host; + if self or origin.type ~= "c2s" or (recipients[from_bare] and recipients[from_bare][origin.full_jid]) ~= hash then + -- COMPAT from ~= stanza.attr.to because OneTeam can't deal with missing from attribute + origin.send( + st.stanza("iq", {from=user, to=stanza.attr.from, id="disco", type="get"}) + :tag("query", {xmlns = "http://jabber.org/protocol/disco#info", node = query_node}) + ); + end + end + end + end + elseif t == "unavailable" then + update_subscriptions(stanza.attr.from, user); + elseif not self and t == "unsubscribe" then + local from = jid_bare(stanza.attr.from); + local subscriptions = recipients[user]; + if subscriptions then + for subscriber in pairs(subscriptions) do + if jid_bare(subscriber) == from then + update_subscriptions(subscriber, user); + end + end + end + end +end, 10); + +module:hook("iq-result/bare/disco", function(event) + local origin, stanza = event.origin, event.stanza; + local disco = stanza:get_child("query", "http://jabber.org/protocol/disco#info"); + if not disco then + return; + end + + -- Process disco response + local self = not stanza.attr.to; + local user = stanza.attr.to or (origin.username..'@'..origin.host); + local contact = stanza.attr.from; + local current = recipients[user] and recipients[user][contact]; + if type(current) ~= "string" then return; end -- check if waiting for recipient's response + local ver = current; + if not string.find(current, "#") then + ver = calculate_hash(disco.tags); -- calculate hash + end + local notify = set_new(); + for _, feature in pairs(disco.tags) do + if feature.name == "feature" and feature.attr.var then + local nfeature = feature.attr.var:match("^(.*)%+notify$"); + if nfeature then notify:add(nfeature); end + end + end + hash_map[ver] = notify; -- update hash map + if self then + for jid, item in pairs(origin.roster) do -- for all interested contacts + if item.subscription == "both" or item.subscription == "from" then + if not recipients[jid] then recipients[jid] = {}; end + update_subscriptions(contact, jid, notify); + end + end + end + update_subscriptions(contact, user, notify); +end); + +module:hook("account-disco-info-node", function(event) + local reply, stanza, origin = event.reply, event.stanza, event.origin; + local service_name = stanza.attr.to or origin.username.."@"..origin.host + local service = get_pep_service(service_name); + local node = event.node; + local ok = service:get_items(node, jid_bare(stanza.attr.from) or true); + if not ok then return; end + event.exists = true; + reply:tag('identity', {category='pubsub', type='leaf'}):up(); +end); + +module:hook("account-disco-info", function(event) + local reply = event.reply; + reply:tag('identity', {category='pubsub', type='pep'}):up(); + reply:tag('feature', {var='http://jabber.org/protocol/pubsub#publish'}):up(); +end); + +module:hook("account-disco-items-node", function(event) + local reply, stanza, origin = event.reply, event.stanza, event.origin; + local node = event.node; + local service_name = stanza.attr.to or origin.username.."@"..origin.host + local service = get_pep_service(service_name); + local ok, ret = service:get_items(node, jid_bare(stanza.attr.from) or true); + if not ok then return; end + event.exists = true; + for _, id in ipairs(ret) do + reply:tag("item", { jid = service_name, name = id }):up(); + end +end); + +module:hook("account-disco-items", function(event) + local reply, stanza, origin = event.reply, event.stanza, event.origin; + + local service_name = reply.attr.from or origin.username.."@"..origin.host + local service = get_pep_service(service_name); + local ok, ret = service:get_nodes(jid_bare(stanza.attr.from)); + if not ok then return; end + + for node, node_obj in pairs(ret) do + reply:tag("item", { jid = service_name, node = node, name = node_obj.config.name }):up(); + end +end); diff --git a/plugins/mod_posix.lua b/plugins/mod_posix.lua index 69542c96..89d6d2b6 100644 --- a/plugins/mod_posix.lua +++ b/plugins/mod_posix.lua @@ -129,14 +129,6 @@ end require "core.loggingmanager".register_sink_type("syslog", syslog_sink_maker); local daemonize = module:get_option("daemonize", prosody.installed); -if daemonize == nil then - local no_daemonize = module:get_option("no_daemonize"); --COMPAT w/ 0.5 - daemonize = not no_daemonize; - if no_daemonize ~= nil then - module:log("warn", "The 'no_daemonize' option is now replaced by 'daemonize'"); - module:log("warn", "Update your config from 'no_daemonize = %s' to 'daemonize = %s'", tostring(no_daemonize), tostring(daemonize)); - end -end local function remove_log_sinks() local lm = require "core.loggingmanager"; diff --git a/plugins/mod_proxy65.lua b/plugins/mod_proxy65.lua index 2ed9faac..73527cbc 100644 --- a/plugins/mod_proxy65.lua +++ b/plugins/mod_proxy65.lua @@ -101,27 +101,10 @@ function module.add_host(module) module:log("warn", "proxy65_port is deprecated, please put proxy65_ports = { %d } into the global section instead", legacy_config); end + module:depends("disco"); module:add_identity("proxy", "bytestreams", name); module:add_feature("http://jabber.org/protocol/bytestreams"); - module:hook("iq-get/host/http://jabber.org/protocol/disco#info:query", function(event) - local origin, stanza = event.origin, event.stanza; - if not stanza.tags[1].attr.node then - origin.send(st.reply(stanza):query("http://jabber.org/protocol/disco#info") - :tag("identity", {category='proxy', type='bytestreams', name=name}):up() - :tag("feature", {var="http://jabber.org/protocol/bytestreams"}) ); - return true; - end - end, -1); - - module:hook("iq-get/host/http://jabber.org/protocol/disco#items:query", function(event) - local origin, stanza = event.origin, event.stanza; - if not stanza.tags[1].attr.node then - origin.send(st.reply(stanza):query("http://jabber.org/protocol/disco#items")); - return true; - end - end, -1); - module:hook("iq-get/host/http://jabber.org/protocol/bytestreams:query", function(event) local origin, stanza = event.origin, event.stanza; diff --git a/plugins/mod_pubsub/mod_pubsub.lua b/plugins/mod_pubsub/mod_pubsub.lua index c6dbe831..33e729af 100644 --- a/plugins/mod_pubsub/mod_pubsub.lua +++ b/plugins/mod_pubsub/mod_pubsub.lua @@ -100,7 +100,7 @@ module:hook("host-disco-items-node", function (event) return; end - for id, item in pairs(ret) do + for _, id in ipairs(ret) do reply:tag("item", { jid = module.host, name = id }):up(); end event.exists = true; diff --git a/plugins/mod_pubsub/pubsub.lib.lua b/plugins/mod_pubsub/pubsub.lib.lua index 2b015e34..4e9acd68 100644 --- a/plugins/mod_pubsub/pubsub.lib.lua +++ b/plugins/mod_pubsub/pubsub.lib.lua @@ -42,8 +42,8 @@ function handlers.get_items(origin, stanza, items, service) end local data = st.stanza("items", { node = node }); - for _, entry in pairs(results) do - data:add_child(entry); + for _, id in ipairs(results) do + data:add_child(results[id]); end local reply; if data then diff --git a/plugins/mod_s2s/mod_s2s.lua b/plugins/mod_s2s/mod_s2s.lua index 5531ca3e..3de59d35 100644 --- a/plugins/mod_s2s/mod_s2s.lua +++ b/plugins/mod_s2s/mod_s2s.lua @@ -150,6 +150,13 @@ function module.add_host(module) module:hook("route/remote", route_to_new_session, -10); module:hook("s2s-authenticated", make_authenticated, -1); module:hook("s2s-read-timeout", keepalive, -1); + module:hook_stanza("http://etherx.jabber.org/streams", "features", function (session, stanza) + if session.type == "s2sout" then + -- Stream is authenticated and we are seem to be done with feature negotiation, + -- so the stream is ready for stanzas. RFC 6120 Section 4.3 + mark_connected(session); + end + end, -1); end -- Stream is authorised, and ready for normal stanzas @@ -219,7 +226,10 @@ function make_authenticated(event) end session.log("debug", "connection %s->%s is now authenticated for %s", session.from_host, session.to_host, host); - mark_connected(session); + if (session.type == "s2sout" and session.external_auth ~= "succeeded") or session.type == "s2sin" then + -- Stream either used dialback for authentication or is an incoming stream. + mark_connected(session); + end return true; end @@ -510,27 +520,16 @@ local function session_close(session, reason, remote_reason) end end -function session_open_stream(session, from, to) - local attr = { - ["xmlns:stream"] = 'http://etherx.jabber.org/streams', - xmlns = 'jabber:server', - version = session.version and (session.version > 0 and "1.0" or nil), - ["xml:lang"] = 'en', - id = session.streamid, - from = from, to = to, - } +function session_stream_attrs(session, from, to, attr) if not from or (hosts[from] and hosts[from].modules.dialback) then attr["xmlns:db"] = 'jabber:server:dialback'; end - - session.sends2s("<?xml version='1.0'?>"); - session.sends2s(st.stanza("stream:stream", attr):top_tag()); - return true; end -- Session initialization logic shared by incoming and outgoing local function initialize_session(session) local stream = new_xmpp_stream(session, stream_callbacks); + local log = session.log or log; session.stream = stream; session.notopen = true; @@ -540,16 +539,32 @@ local function initialize_session(session) session.stream:reset(); end - session.open_stream = session_open_stream; + session.stream_attrs = session_stream_attrs; + + local filter = initialize_filters(session); + local conn = session.conn; + local w = conn.write; + + function session.sends2s(t) + log("debug", "sending: %s", t.top_tag and t:top_tag() or t:match("^[^>]*>?")); + if t.name then + t = filter("stanzas/out", t); + end + if t then + t = filter("bytes/out", tostring(t)); + if t then + return w(conn, t); + end + end + end - local filter = session.filter; function session.data(data) data = filter("bytes/in", data); if data then local ok, err = stream:feed(data); if ok then return; end - (session.log or log)("warn", "Received invalid XML: %s", data); - (session.log or log)("warn", "Problem was: %s", err); + log("warn", "Received invalid XML: %s", data); + log("warn", "Problem was: %s", err); session:close("not-well-formed"); end end @@ -561,6 +576,8 @@ local function initialize_session(session) return handlestanza(session, stanza); end + module:fire_event("s2s-created", { session = session }); + add_task(connect_timeout, function () if session.type == "s2sin" or session.type == "s2sout" then return; -- Ok, we're connected @@ -581,22 +598,6 @@ function listener.onconnect(conn) session = s2s_new_incoming(conn); sessions[conn] = session; session.log("debug", "Incoming s2s connection"); - - local filter = initialize_filters(session); - local w = conn.write; - session.sends2s = function (t) - log("debug", "sending: %s", t.top_tag and t:top_tag() or t:match("^([^>]*>?)")); - if t.name then - t = filter("stanzas/out", t); - end - if t then - t = filter("bytes/out", tostring(t)); - if t then - return w(conn, t); - end - end - end - initialize_session(session); else -- Outgoing session connected session:open_stream(session.from_host, session.to_host); @@ -644,7 +645,6 @@ function listener.onreadtimeout(conn) end function listener.register_outgoing(conn, session) - session.direction = "outgoing"; sessions[conn] = session; initialize_session(session); end diff --git a/plugins/mod_s2s/s2sout.lib.lua b/plugins/mod_s2s/s2sout.lib.lua index 42b4281c..942a618d 100644 --- a/plugins/mod_s2s/s2sout.lib.lua +++ b/plugins/mod_s2s/s2sout.lib.lua @@ -297,21 +297,6 @@ function s2sout.make_connect(host_session, connect_host, connect_port) conn = wrapclient(conn, connect_host.addr, connect_port, s2s_listener, "*a"); host_session.conn = conn; - local filter = initialize_filters(host_session); - local w, log = conn.write, host_session.log; - host_session.sends2s = function (t) - log("debug", "sending: %s", (t.top_tag and t:top_tag()) or t:match("^[^>]*>?")); - if t.name then - t = filter("stanzas/out", t); - end - if t then - t = filter("bytes/out", tostring(t)); - if t then - return w(conn, tostring(t)); - end - end - end - -- Register this outgoing connection so that xmppserver_listener knows about it -- otherwise it will assume it is a new incoming connection s2s_listener.register_outgoing(conn, host_session); diff --git a/plugins/mod_saslauth.lua b/plugins/mod_saslauth.lua index 94c060b3..df60aefa 100644 --- a/plugins/mod_saslauth.lua +++ b/plugins/mod_saslauth.lua @@ -197,7 +197,7 @@ module:hook("stanza/urn:ietf:params:xml:ns:xmpp-sasl:auth", function(event) return s2s_external_auth(session, stanza) end - if session.type ~= "c2s_unauthed" then return; end + if session.type ~= "c2s_unauthed" or module:get_host_type() ~= "local" then return; end if session.sasl_handler and session.sasl_handler.selected then session.sasl_handler = nil; -- allow starting a new SASL negotiation before completing an old one diff --git a/plugins/mod_storage_sql.lua b/plugins/mod_storage_sql.lua index 1f453d42..7b810ab8 100644 --- a/plugins/mod_storage_sql.lua +++ b/plugins/mod_storage_sql.lua @@ -49,7 +49,7 @@ local function db2uri(params) end -local resolve_relative_path = require "core.configmanager".resolve_relative_path; +local resolve_relative_path = require "util.paths".resolve_relative_path; local function test_connection() if not connection then return nil; end diff --git a/plugins/mod_storage_sql2.lua b/plugins/mod_storage_sql2.lua index 7a2ec4a7..249c72a7 100644 --- a/plugins/mod_storage_sql2.lua +++ b/plugins/mod_storage_sql2.lua @@ -2,7 +2,7 @@ local json = require "util.json"; local xml_parse = require "util.xml".parse; local uuid = require "util.uuid"; -local resolve_relative_path = require "core.configmanager".resolve_relative_path; +local resolve_relative_path = require "util.paths".resolve_relative_path; local stanza_mt = require"util.stanza".stanza_mt; local getmetatable = getmetatable; @@ -289,7 +289,7 @@ function archive_store:find(username, query) -- Total matching if query.total then - local stats = engine:select(sql_query:gsub("^(SELECT).-(FROM)", "%1 COUNT(*) %2"):format(t_concat(where, " AND "), "DESC", ""), unpack(args)); + local stats = engine:select("SELECT COUNT(*) FROM `prosodyarchive` WHERE " .. t_concat(where, " AND "), unpack(args)); if stats then local _total = stats() total = _total and _total[1]; @@ -49,9 +49,6 @@ _G.prosody = prosody; -- Check dependencies local dependencies = require "util.dependencies"; -if not dependencies.check_dependencies() then - os.exit(1); -end -- Load the config-parsing module config = require "core.configmanager" @@ -116,6 +113,12 @@ function read_config() end end +function check_dependencies() + if not dependencies.check_dependencies() then + os.exit(1); + end +end + function load_libraries() -- Load socket framework server = require "net.server" @@ -382,6 +385,7 @@ init_logging(); sanity_check(); sandbox_require(); set_function_metatable(); +check_dependencies(); load_libraries(); init_global_state(); read_version(); diff --git a/prosody.cfg.lua.dist b/prosody.cfg.lua.dist index 1d11a658..ade219a8 100644 --- a/prosody.cfg.lua.dist +++ b/prosody.cfg.lua.dist @@ -63,7 +63,6 @@ modules_enabled = { --"http_files"; -- Serve static files from a directory over HTTP -- Other specific functionality - --"posix"; -- POSIX functionality, sends server to background, enables syslog, etc. --"groups"; -- Shared roster support --"announce"; -- Send announcement to all online users --"welcome"; -- Welcome users who register accounts @@ -78,6 +77,7 @@ modules_disabled = { -- "offline"; -- Store offline messages -- "c2s"; -- Handle client connections -- "s2s"; -- Handle server-to-server connections + -- "posix"; -- POSIX functionality, sends server to background, enables syslog, etc. } -- Disable account creation by default, for security @@ -414,7 +414,11 @@ function commands.start(arg) local ok, ret = prosodyctl.start(); if ok then - if config.get("*", "daemonize") ~= false then + local daemonize = config.get("*", "daemonize"); + if daemonize == nil then + daemonize = prosody.installed; + end + if daemonize then local i=1; while true do local ok, running = prosodyctl.isrunning(); @@ -687,7 +691,12 @@ function cert_commands.config(arg) conf.distinguished_name[k] = nv ~= "." and nv or nil; end end - local conf_file = io.open(conf_filename, "w"); + local conf_file, err = io.open(conf_filename, "w"); + if not conf_file then + show_warning("Could not open OpenSSL config file for writing"); + show_warning(err); + os.exit(1); + end conf_file:write(conf:serialize()); conf_file:close(); print(""); @@ -788,8 +797,28 @@ function commands.check(arg) local array, set = require "util.array", require "util.set"; local it = require "util.iterators"; local ok = true; + local function disabled_hosts(host, conf) return host ~= "*" and conf.enabled ~= false; end + local function enabled_hosts() return it.filter(disabled_hosts, pairs(config.getconfig())); end + if not what or what == "disabled" then + local disabled_hosts = set.new(); + for host, host_options in it.filter("*", pairs(config.getconfig())) do + if host_options.enabled == false then + disabled_hosts:add(host); + end + end + if not disabled_hosts:empty() then + local msg = "Checks will be skipped for these disabled hosts: %s"; + if what then msg = "These hosts are disabled: %s"; end + show_warning(msg, tostring(disabled_hosts)); + if what then return 0; end + print"" + end + end if not what or what == "config" then print("Checking config..."); + local deprecated = set.new({ + "bosh_ports", "disallow_s2s", "no_daemonize", "anonymous_login", + }); local known_global_options = set.new({ "pidfile", "log", "plugin_paths", "prosody_user", "prosody_group", "daemonize", "umask", "prosodyctl_timeout", "use_ipv6", "use_libevent", "network_settings" @@ -802,9 +831,27 @@ function commands.check(arg) print(" No global options defined. Perhaps you have put a host definition at the top") print(" of the config file? They should be at the bottom, see http://prosody.im/doc/configure#overview"); end + if it.count(enabled_hosts()) == 0 then + ok = false; + print(""); + if it.count(it.filter("*", pairs(config))) == 0 then + print(" No hosts are defined, please add at least one VirtualHost section") + elseif config["*"]["enabled"] == false then + print(" No hosts are enabled. Remove enabled = false from the global section or put enabled = true under at least one VirtualHost section") + else + print(" All hosts are disabled. Remove enabled = false from at least one VirtualHost section") + end + end -- Check for global options under hosts local global_options = set.new(it.to_array(it.keys(config["*"]))); - for host, options in it.filter("*", pairs(config)) do + local deprecated_global_options = set.intersection(global_options, deprecated); + if not deprecated_global_options:empty() then + print(""); + print(" You have some deprecated options in the global section:"); + print(" "..tostring(deprecated_global_options)) + ok = false; + end + for host, options in enabled_hosts() do local host_options = set.new(it.to_array(it.keys(options))); local misplaced_options = set.intersection(host_options, known_global_options); for name in pairs(options) do @@ -889,7 +936,7 @@ function commands.check(arg) local v6_supported = not not socket.tcp6; - for host, host_options in it.filter("*", pairs(config.getconfig())) do + for host, host_options in enabled_hosts() do local all_targets_ok, some_targets_ok = true, false; local is_component = not not host_options.component_module; @@ -1038,54 +1085,52 @@ function commands.check(arg) print("This version of LuaSec (" .. ssl._VERSION .. ") does not support certificate checking"); cert_ok = false else - for host in pairs(hosts) do - if host ~= "*" then -- Should check global certs too. - print("Checking certificate for "..host); - -- First, let's find out what certificate this host uses. - local ssl_config = config.rawget(host, "ssl"); - if not ssl_config then - local base_host = host:match("%.(.*)"); - ssl_config = config.get(base_host, "ssl"); - end - if not ssl_config then - print(" No 'ssl' option defined for "..host) - cert_ok = false - elseif not ssl_config.certificate then - print(" No 'certificate' set in ssl option for "..host) + for host in enabled_hosts() do + print("Checking certificate for "..host); + -- First, let's find out what certificate this host uses. + local ssl_config = config.rawget(host, "ssl"); + if not ssl_config then + local base_host = host:match("%.(.*)"); + ssl_config = config.get(base_host, "ssl"); + end + if not ssl_config then + print(" No 'ssl' option defined for "..host) + cert_ok = false + elseif not ssl_config.certificate then + print(" No 'certificate' set in ssl option for "..host) + cert_ok = false + elseif not ssl_config.key then + print(" No 'key' set in ssl option for "..host) + cert_ok = false + else + local key, err = io.open(ssl_config.key); -- Permissions check only + if not key then + print(" Could not open "..ssl_config.key..": "..err); cert_ok = false - elseif not ssl_config.key then - print(" No 'key' set in ssl option for "..host) + else + key:close(); + end + local cert_fh, err = io.open(ssl_config.certificate); -- Load the file. + if not cert_fh then + print(" Could not open "..ssl_config.certificate..": "..err); cert_ok = false else - local key, err = io.open(ssl_config.key); -- Permissions check only - if not key then - print(" Could not open "..ssl_config.key..": "..err); + print(" Certificate: "..ssl_config.certificate) + local cert = load_cert(cert_fh:read"*a"); cert_fh = cert_fh:close(); + if not cert:validat(os.time()) then + print(" Certificate has expired.") cert_ok = false - else - key:close(); end - local cert_fh, err = io.open(ssl_config.certificate); -- Load the file. - if not cert_fh then - print(" Could not open "..ssl_config.certificate..": "..err); - cert_ok = false - else - print(" Certificate: "..ssl_config.certificate) - local cert = load_cert(cert_fh:read"*a"); cert_fh = cert_fh:close(); - if not cert:validat(os.time()) then - print(" Certificate has expired.") - cert_ok = false - end - if config.get(host, "component_module") == nil + if config.get(host, "component_module") == nil and not x509_verify_identity(host, "_xmpp-client", cert) then - print(" Not vaild for client connections to "..host..".") - cert_ok = false - end - if (not (config.get(name, "anonymous_login") - or config.get(name, "authentication") == "anonymous")) + print(" Not vaild for client connections to "..host..".") + cert_ok = false + end + if (not (config.get(host, "anonymous_login") + or config.get(host, "authentication") == "anonymous")) and not x509_verify_identity(host, "_xmpp-client", cert) then - print(" Not vaild for server-to-server connections to "..host..".") - cert_ok = false - end + print(" Not vaild for server-to-server connections to "..host..".") + cert_ok = false end end end diff --git a/tools/ejabberd2prosody.lua b/tools/ejabberd2prosody.lua index e9dbd2dc..66bf4f93 100755 --- a/tools/ejabberd2prosody.lua +++ b/tools/ejabberd2prosody.lua @@ -44,8 +44,10 @@ function build_stanza(tuple, stanza) for _, a in ipairs(tuple[4]) do build_stanza(a, stanza); end if up then stanza:up(); else return stanza end elseif tuple[1] == "xmlcdata" then - assert(type(tuple[2]) == "string", "XML CDATA has unexpected type: "..type(tuple[2])); - stanza:text(tuple[2]); + if type(tuple[2]) ~= "table" then + assert(type(tuple[2]) == "string", "XML CDATA has unexpected type: "..type(tuple[2])); + stanza:text(tuple[2]); + end -- else it's [], i.e., the null value, used for the empty string else error("unknown element type: "..serialize(tuple)); end diff --git a/tools/jabberd14sql2prosody.lua b/tools/jabberd14sql2prosody.lua index 03376b30..e43dc296 100644 --- a/tools/jabberd14sql2prosody.lua +++ b/tools/jabberd14sql2prosody.lua @@ -428,7 +428,7 @@ end end -- import modules -package.path = package.path.."..\?.lua;"; +package.path = package.path..";../?.lua;"; local my_name = arg[0]; if my_name:match("[/\\]") then diff --git a/util-src/pposix.c b/util-src/pposix.c index 73e0d6e3..9b3e97eb 100644 --- a/util-src/pposix.c +++ b/util-src/pposix.c @@ -674,6 +674,7 @@ int lc_meminfo(lua_State* L) #if _XOPEN_SOURCE >= 600 || _POSIX_C_SOURCE >= 200112L || defined(_GNU_SOURCE) int lc_fallocate(lua_State* L) { + int ret; off_t offset, len; FILE *f = *(FILE**) luaL_checkudata(L, 1, LUA_FILEHANDLE); if (f == NULL) @@ -683,11 +684,15 @@ int lc_fallocate(lua_State* L) len = luaL_checkinteger(L, 3); #if defined(__linux__) && defined(_GNU_SOURCE) - if(fallocate(fileno(f), FALLOC_FL_KEEP_SIZE, offset, len) == 0) + errno = 0; + ret = fallocate(fileno(f), FALLOC_FL_KEEP_SIZE, offset, len); + if(ret == 0) { lua_pushboolean(L, 1); return 1; } + /* Some old versions of Linux apparently use the return value instead of errno */ + if(errno == 0) errno = ret; if(errno != ENOSYS && errno != EOPNOTSUPP) { @@ -701,7 +706,8 @@ int lc_fallocate(lua_State* L) #warning Note that posix_fallocate() will still be used on filesystems that dont support fallocate() #endif - if(posix_fallocate(fileno(f), offset, len) == 0) + ret = posix_fallocate(fileno(f), offset, len); + if(ret == 0) { lua_pushboolean(L, 1); return 1; @@ -709,7 +715,7 @@ int lc_fallocate(lua_State* L) else { lua_pushnil(L); - lua_pushstring(L, strerror(errno)); + lua_pushstring(L, strerror(ret)); /* posix_fallocate() can leave a bunch of NULs at the end, so we cut that * this assumes that offset == length of the file */ ftruncate(fileno(f), offset); diff --git a/util/dataforms.lua b/util/dataforms.lua index b38d0e27..c352858c 100644 --- a/util/dataforms.lua +++ b/util/dataforms.lua @@ -94,6 +94,15 @@ function form_t.form(layout, data, formtype) end end + local media = field.media; + if media then + form:tag("media", { xmlns = "urn:xmpp:media-element", height = media.height, width = media.width }); + for _, val in ipairs(media) do + form:tag("uri", { type = val.type }):text(val.uri):up() + end + form:up(); + end + if field.required then form:tag("required"):up(); end diff --git a/util/dependencies.lua b/util/dependencies.lua index 109a3332..ea19d9a8 100644 --- a/util/dependencies.lua +++ b/util/dependencies.lua @@ -49,6 +49,14 @@ package.preload["util.ztact"] = function () end; function check_dependencies() + if _VERSION ~= "Lua 5.1" then + print "***********************************" + print("Unsupported Lua version: ".._VERSION); + print("Only Lua 5.1 is supported."); + print "***********************************" + return false; + end + local fatal; local lxp = softreq "lxp" @@ -140,7 +148,15 @@ function log_warnings() if not pcall(lxp.new, { StartDoctypeDecl = false }) then log("error", "The version of LuaExpat on your system leaves Prosody " .."vulnerable to denial-of-service attacks. You should upgrade to " - .."LuaExpat 1.1.1 or higher as soon as possible. See " + .."LuaExpat 1.3.0 or higher as soon as possible. See " + .."http://prosody.im/doc/depends#luaexpat for more information."); + end + if not lxp.new({}).getcurrentbytecount then + log("error", "The version of LuaExpat on your system does not support " + .."stanza size limits, which may leave servers on untrusted " + .."networks (e.g. the internet) vulnerable to denial-of-service " + .."attacks. You should upgrade to LuaExpat 1.3.0 or higher as " + .."soon as possible. See " .."http://prosody.im/doc/depends#luaexpat for more information."); end end diff --git a/util/indexedbheap.lua b/util/indexedbheap.lua new file mode 100644 index 00000000..c60861e8 --- /dev/null +++ b/util/indexedbheap.lua @@ -0,0 +1,157 @@ + +local setmetatable = setmetatable; +local math_floor = math.floor; +local t_remove = table.remove; + +local function _heap_insert(self, item, sync, item2, index) + local pos = #self + 1; + while true do + local half_pos = math_floor(pos / 2); + if half_pos == 0 or item > self[half_pos] then break; end + self[pos] = self[half_pos]; + sync[pos] = sync[half_pos]; + index[sync[pos]] = pos; + pos = half_pos; + end + self[pos] = item; + sync[pos] = item2; + index[item2] = pos; +end + +local function _percolate_up(self, k, sync, index) + local tmp = self[k]; + local tmp_sync = sync[k]; + while k ~= 1 do + local parent = math_floor(k/2); + if tmp < self[parent] then break; end + self[k] = self[parent]; + sync[k] = sync[parent]; + index[sync[k]] = k; + k = parent; + end + self[k] = tmp; + sync[k] = tmp_sync; + index[tmp_sync] = k; + return k; +end + +local function _percolate_down(self, k, sync, index) + local tmp = self[k]; + local tmp_sync = sync[k]; + local size = #self; + local child = 2*k; + while 2*k <= size do + if child ~= size and self[child] > self[child + 1] then + child = child + 1; + end + if tmp > self[child] then + self[k] = self[child]; + sync[k] = sync[child]; + index[sync[k]] = k; + else + break; + end + + k = child; + child = 2*k; + end + self[k] = tmp; + sync[k] = tmp_sync; + index[tmp_sync] = k; + return k; +end + +local function _heap_pop(self, sync, index) + local size = #self; + if size == 0 then return nil; end + + local result = self[1]; + local result_sync = sync[1]; + index[result_sync] = nil; + if size == 1 then + self[1] = nil; + sync[1] = nil; + return result, result_sync; + end + self[1] = t_remove(self); + sync[1] = t_remove(sync); + index[sync[1]] = 1; + + _percolate_down(self, 1, sync, index); + + return result, result_sync; +end + +local indexed_heap = {}; + +function indexed_heap:insert(item, priority, id) + if id == nil then + id = self.current_id; + self.current_id = id + 1; + end + self.items[id] = item; + _heap_insert(self.priorities, priority, self.ids, id, self.index); + return id; +end +function indexed_heap:pop() + local priority, id = _heap_pop(self.priorities, self.ids, self.index); + if id then + local item = self.items[id]; + self.items[id] = nil; + return priority, item, id; + end +end +function indexed_heap:peek() + return self.priorities[1]; +end +function indexed_heap:reprioritize(id, priority) + local k = self.index[id]; + if k == nil then return; end + self.priorities[k] = priority; + + k = _percolate_up(self.priorities, k, self.ids, self.index); + k = _percolate_down(self.priorities, k, self.ids, self.index); +end +function indexed_heap:remove_index(k) + local result = self.priorities[k]; + if result == nil then return; end + + local result_sync = self.ids[k]; + local item = self.items[result_sync]; + local size = #self.priorities; + + self.priorities[k] = self.priorities[size]; + self.ids[k] = self.ids[size]; + self.index[self.ids[k]] = k; + + t_remove(self.priorities); + t_remove(self.ids); + + self.index[result_sync] = nil; + self.items[result_sync] = nil; + + if size > k then + k = _percolate_up(self.priorities, k, self.ids, self.index); + k = _percolate_down(self.priorities, k, self.ids, self.index); + end + + return result, item, result_sync; +end +function indexed_heap:remove(id) + return self:remove_index(self.index[id]); +end + +local mt = { __index = indexed_heap }; + +local _M = { + create = function() + return setmetatable({ + ids = {}; -- heap of ids, sync'd with priorities + items = {}; -- map id->items + priorities = {}; -- heap of priorities + index = {}; -- map of id->index of id in ids + current_id = 1.5 + }, mt); + end +}; +return _M; diff --git a/util/paths.lua b/util/paths.lua new file mode 100644 index 00000000..3e5744df --- /dev/null +++ b/util/paths.lua @@ -0,0 +1,38 @@ +local path_sep = package.config:sub(1,1); + +local path_util = {} + +-- Helper function to resolve relative paths (needed by config) +function path_util.resolve_relative_path(parent_path, path) + if path then + -- Some normalization + parent_path = parent_path:gsub("%"..path_sep.."+$", ""); + path = path:gsub("^%.%"..path_sep.."+", ""); + + local is_relative; + if path_sep == "/" and path:sub(1,1) ~= "/" then + is_relative = true; + elseif path_sep == "\\" and (path:sub(1,1) ~= "/" and (path:sub(2,3) ~= ":\\" and path:sub(2,3) ~= ":/")) then + is_relative = true; + end + if is_relative then + return parent_path..path_sep..path; + end + end + return path; +end + +-- Helper function to convert a glob to a Lua pattern +function path_util.glob_to_pattern(glob) + return "^"..glob:gsub("[%p*?]", function (c) + if c == "*" then + return ".*"; + elseif c == "?" then + return "."; + else + return "%"..c; + end + end).."$"; +end + +return path_util; diff --git a/util/pluginloader.lua b/util/pluginloader.lua index b894f527..b9b3e207 100644 --- a/util/pluginloader.lua +++ b/util/pluginloader.lua @@ -39,10 +39,10 @@ function load_resource(plugin, resource) resource = resource or "mod_"..plugin..".lua"; local names = { - "mod_"..plugin.."/"..plugin.."/"..resource; -- mod_hello/hello/mod_hello.lua - "mod_"..plugin.."/"..resource; -- mod_hello/mod_hello.lua - plugin.."/"..resource; -- hello/mod_hello.lua - resource; -- mod_hello.lua + "mod_"..plugin..dir_sep..plugin..dir_sep..resource; -- mod_hello/hello/mod_hello.lua + "mod_"..plugin..dir_sep..resource; -- mod_hello/mod_hello.lua + plugin..dir_sep..resource; -- hello/mod_hello.lua + resource; -- mod_hello.lua }; return load_file(names); diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua index fe862114..d59c163c 100644 --- a/util/prosodyctl.lua +++ b/util/prosodyctl.lua @@ -189,8 +189,8 @@ function getpid() return false, "no-pidfile"; end - local modules_enabled = set.new(config.get("*", "modules_enabled")); - if not modules_enabled:contains("posix") then + local modules_enabled = set.new(config.get("*", "modules_disabled")); + if prosody.platform ~= "posix" or modules_enabled:contains("posix") then return false, "no-posix"; end diff --git a/util/sasl.lua b/util/sasl.lua index c8490842..b91e29a6 100644 --- a/util/sasl.lua +++ b/util/sasl.lua @@ -100,14 +100,16 @@ end function method:mechanisms() local current_mechs = {}; for mech, _ in pairs(self.mechs) do - if mechanism_channelbindings[mech] and self.profile.cb then - local ok = false; - for cb_name, _ in pairs(self.profile.cb) do - if mechanism_channelbindings[mech][cb_name] then - ok = true; + if mechanism_channelbindings[mech] then + if self.profile.cb then + local ok = false; + for cb_name, _ in pairs(self.profile.cb) do + if mechanism_channelbindings[mech][cb_name] then + ok = true; + end end + if ok == true then current_mechs[mech] = true; end end - if ok == true then current_mechs[mech] = true; end else current_mechs[mech] = true; end diff --git a/util/timer.lua b/util/timer.lua index 0e10e144..23bd6a37 100644 --- a/util/timer.lua +++ b/util/timer.lua @@ -6,6 +6,8 @@ -- COPYING file in the source package for more information. -- +local indexedbheap = require "util.indexedbheap"; +local log = require "util.logger".init("timer"); local server = require "net.server"; local math_min = math.min local math_huge = math.huge @@ -13,6 +15,9 @@ local get_time = require "socket".gettime; local t_insert = table.insert; local pairs = pairs; local type = type; +local debug_traceback = debug.traceback; +local tostring = tostring; +local xpcall = xpcall; local data = {}; local new_data = {}; @@ -78,6 +83,61 @@ else end end -add_task = _add_task; +--add_task = _add_task; + +local h = indexedbheap.create(); +local params = {}; +local next_time = nil; +local _id, _callback, _now, _param; +local function _call() return _callback(_now, _id, _param); end +local function _traceback_handler(err) log("error", "Traceback[timer]: %s", debug_traceback(tostring(err), 2)); end +local function _on_timer(now) + local peek; + while true do + peek = h:peek(); + if peek == nil or peek > now then break; end + local _; + _, _callback, _id = h:pop(); + _now = now; + _param = params[_id]; + params[_id] = nil; + --item(now, id, _param); -- FIXME pcall + local success, err = xpcall(_call, _traceback_handler); + if success and type(err) == "number" then + h:insert(_callback, err + now, _id); -- re-add + params[_id] = _param; + end + end + next_time = peek; + if peek ~= nil then + return peek - now; + end +end +function add_task(delay, callback, param) + local current_time = get_time(); + local event_time = current_time + delay; + + local id = h:insert(callback, event_time); + params[id] = param; + if next_time == nil or event_time < next_time then + next_time = event_time; + _add_task(next_time - current_time, _on_timer); + end + return id; +end +function stop(id) + params[id] = nil; + return h:remove(id); +end +function reschedule(id, delay) + local current_time = get_time(); + local event_time = current_time + delay; + h:reprioritize(id, delay); + if next_time == nil or event_time < next_time then + next_time = event_time; + _add_task(next_time - current_time, _on_timer); + end + return id; +end return _M; diff --git a/util/vcard.lua b/util/vcard.lua new file mode 100644 index 00000000..29a40844 --- /dev/null +++ b/util/vcard.lua @@ -0,0 +1,576 @@ +-- Copyright (C) 2011-2014 Kim Alvefur +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +-- TODO +-- Fix folding. + +local st = require "util.stanza"; +local t_insert, t_concat = table.insert, table.concat; +local type = type; +local next, pairs, ipairs = next, pairs, ipairs; + +local from_text, to_text, from_xep54, to_xep54; + +local line_sep = "\n"; + +local vCard_dtd; -- See end of file +local vCard4_dtd; + +local function fold_line() + error "Not implemented" --TODO +end +local function unfold_line() + error "Not implemented" + -- gsub("\r?\n[ \t]([^\r\n])", "%1"); +end + +local function vCard_esc(s) + return s:gsub("[,:;\\]", "\\%1"):gsub("\n","\\n"); +end + +local function vCard_unesc(s) + return s:gsub("\\?[\\nt:;,]", { + ["\\\\"] = "\\", + ["\\n"] = "\n", + ["\\r"] = "\r", + ["\\t"] = "\t", + ["\\:"] = ":", -- FIXME Shouldn't need to espace : in values, just params + ["\\;"] = ";", + ["\\,"] = ",", + [":"] = "\29", + [";"] = "\30", + [","] = "\31", + }); +end + +local function item_to_xep54(item) + local t = st.stanza(item.name, { xmlns = "vcard-temp" }); + + local prop_def = vCard_dtd[item.name]; + if prop_def == "text" then + t:text(item[1]); + elseif type(prop_def) == "table" then + if prop_def.types and item.TYPE then + if type(item.TYPE) == "table" then + for _,v in pairs(prop_def.types) do + for _,typ in pairs(item.TYPE) do + if typ:upper() == v then + t:tag(v):up(); + break; + end + end + end + else + t:tag(item.TYPE:upper()):up(); + end + end + + if prop_def.props then + for _,v in pairs(prop_def.props) do + if item[v] then + t:tag(v):up(); + end + end + end + + if prop_def.value then + t:tag(prop_def.value):text(item[1]):up(); + elseif prop_def.values then + local prop_def_values = prop_def.values; + local repeat_last = prop_def_values.behaviour == "repeat-last" and prop_def_values[#prop_def_values]; + for i=1,#item do + t:tag(prop_def.values[i] or repeat_last):text(item[i]):up(); + end + end + end + + return t; +end + +local function vcard_to_xep54(vCard) + local t = st.stanza("vCard", { xmlns = "vcard-temp" }); + for i=1,#vCard do + t:add_child(item_to_xep54(vCard[i])); + end + return t; +end + +function to_xep54(vCards) + if not vCards[1] or vCards[1].name then + return vcard_to_xep54(vCards) + else + local t = st.stanza("xCard", { xmlns = "vcard-temp" }); + for i=1,#vCards do + t:add_child(vcard_to_xep54(vCards[i])); + end + return t; + end +end + +function from_text(data) + data = data -- unfold and remove empty lines + :gsub("\r\n","\n") + :gsub("\n ", "") + :gsub("\n\n+","\n"); + local vCards = {}; + local c; -- current item + for line in data:gmatch("[^\n]+") do + local line = vCard_unesc(line); + local name, params, value = line:match("^([-%a]+)(\30?[^\29]*)\29(.*)$"); + value = value:gsub("\29",":"); + if #params > 0 then + local _params = {}; + for k,isval,v in params:gmatch("\30([^=]+)(=?)([^\30]*)") do + k = k:upper(); + local _vt = {}; + for _p in v:gmatch("[^\31]+") do + _vt[#_vt+1]=_p + _vt[_p]=true; + end + if isval == "=" then + _params[k]=_vt; + else + _params[k]=true; + end + end + params = _params; + end + if name == "BEGIN" and value == "VCARD" then + c = {}; + vCards[#vCards+1] = c; + elseif name == "END" and value == "VCARD" then + c = nil; + elseif c and vCard_dtd[name] then + local dtd = vCard_dtd[name]; + local p = { name = name }; + c[#c+1]=p; + --c[name]=p; + local up = c; + c = p; + if dtd.types then + for _, t in ipairs(dtd.types) do + local t = t:lower(); + if ( params.TYPE and params.TYPE[t] == true) + or params[t] == true then + c.TYPE=t; + end + end + end + if dtd.props then + for _, p in ipairs(dtd.props) do + if params[p] then + if params[p] == true then + c[p]=true; + else + for _, prop in ipairs(params[p]) do + c[p]=prop; + end + end + end + end + end + if dtd == "text" or dtd.value then + t_insert(c, value); + elseif dtd.values then + local value = "\30"..value; + for p in value:gmatch("\30([^\30]*)") do + t_insert(c, p); + end + end + c = up; + end + end + return vCards; +end + +local function item_to_text(item) + local value = {}; + for i=1,#item do + value[i] = vCard_esc(item[i]); + end + value = t_concat(value, ";"); + + local params = ""; + for k,v in pairs(item) do + if type(k) == "string" and k ~= "name" then + params = params .. (";%s=%s"):format(k, type(v) == "table" and t_concat(v,",") or v); + end + end + + return ("%s%s:%s"):format(item.name, params, value) +end + +local function vcard_to_text(vcard) + local t={}; + t_insert(t, "BEGIN:VCARD") + for i=1,#vcard do + t_insert(t, item_to_text(vcard[i])); + end + t_insert(t, "END:VCARD") + return t_concat(t, line_sep); +end + +function to_text(vCards) + if vCards[1] and vCards[1].name then + return vcard_to_text(vCards) + else + local t = {}; + for i=1,#vCards do + t[i]=vcard_to_text(vCards[i]); + end + return t_concat(t, line_sep); + end +end + +local function from_xep54_item(item) + local prop_name = item.name; + local prop_def = vCard_dtd[prop_name]; + + local prop = { name = prop_name }; + + if prop_def == "text" then + prop[1] = item:get_text(); + elseif type(prop_def) == "table" then + if prop_def.value then --single item + prop[1] = item:get_child_text(prop_def.value) or ""; + elseif prop_def.values then --array + local value_names = prop_def.values; + if value_names.behaviour == "repeat-last" then + for i=1,#item.tags do + t_insert(prop, item.tags[i]:get_text() or ""); + end + else + for i=1,#value_names do + t_insert(prop, item:get_child_text(value_names[i]) or ""); + end + end + elseif prop_def.names then + local names = prop_def.names; + for i=1,#names do + if item:get_child(names[i]) then + prop[1] = names[i]; + break; + end + end + end + + if prop_def.props_verbatim then + for k,v in pairs(prop_def.props_verbatim) do + prop[k] = v; + end + end + + if prop_def.types then + local types = prop_def.types; + prop.TYPE = {}; + for i=1,#types do + if item:get_child(types[i]) then + t_insert(prop.TYPE, types[i]:lower()); + end + end + if #prop.TYPE == 0 then + prop.TYPE = nil; + end + end + + -- A key-value pair, within a key-value pair? + if prop_def.props then + local params = prop_def.props; + for i=1,#params do + local name = params[i] + local data = item:get_child_text(name); + if data then + prop[name] = prop[name] or {}; + t_insert(prop[name], data); + end + end + end + else + return nil + end + + return prop; +end + +local function from_xep54_vCard(vCard) + local tags = vCard.tags; + local t = {}; + for i=1,#tags do + t_insert(t, from_xep54_item(tags[i])); + end + return t +end + +function from_xep54(vCard) + if vCard.attr.xmlns ~= "vcard-temp" then + return nil, "wrong-xmlns"; + end + if vCard.name == "xCard" then -- A collection of vCards + local t = {}; + local vCards = vCard.tags; + for i=1,#vCards do + t[i] = from_xep54_vCard(vCards[i]); + end + return t + elseif vCard.name == "vCard" then -- A single vCard + return from_xep54_vCard(vCard) + end +end + +local vcard4 = { } + +function vcard4:text(node, params, value) + self:tag(node:lower()) + -- FIXME params + if type(value) == "string" then + self:tag("text"):text(value):up() + elseif vcard4[node] then + vcard4[node](value); + end + self:up(); +end + +function vcard4.N(value) + for i, k in ipairs(vCard_dtd.N.values) do + value:tag(k):text(value[i]):up(); + end +end + +local xmlns_vcard4 = "urn:ietf:params:xml:ns:vcard-4.0" + +local function item_to_vcard4(item) + local typ = item.name:lower(); + local t = st.stanza(typ, { xmlns = xmlns_vcard4 }); + + local prop_def = vCard4_dtd[typ]; + if prop_def == "text" then + t:tag("text"):text(item[1]):up(); + elseif type(prop_def) == "table" then + if prop_def.values then + for i, v in ipairs(prop_def.values) do + t:tag(v:lower()):text(item[i] or ""):up(); + end + else + t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"}) + end + else + t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"}) + end + return t; +end + +local function vcard_to_vcard4xml(vCard) + local t = st.stanza("vcard", { xmlns = xmlns_vcard4 }); + for i=1,#vCard do + t:add_child(item_to_vcard4(vCard[i])); + end + return t; +end + +local function vcards_to_vcard4xml(vCards) + if not vCards[1] or vCards[1].name then + return vcard_to_vcard4xml(vCards) + else + local t = st.stanza("vcards", { xmlns = xmlns_vcard4 }); + for i=1,#vCards do + t:add_child(vcard_to_vcard4xml(vCards[i])); + end + return t; + end +end + +-- This was adapted from http://xmpp.org/extensions/xep-0054.html#dtd +vCard_dtd = { + VERSION = "text", --MUST be 3.0, so parsing is redundant + FN = "text", + N = { + values = { + "FAMILY", + "GIVEN", + "MIDDLE", + "PREFIX", + "SUFFIX", + }, + }, + NICKNAME = "text", + PHOTO = { + props_verbatim = { ENCODING = { "b" } }, + props = { "TYPE" }, + value = "BINVAL", --{ "EXTVAL", }, + }, + BDAY = "text", + ADR = { + types = { + "HOME", + "WORK", + "POSTAL", + "PARCEL", + "DOM", + "INTL", + "PREF", + }, + values = { + "POBOX", + "EXTADD", + "STREET", + "LOCALITY", + "REGION", + "PCODE", + "CTRY", + } + }, + LABEL = { + types = { + "HOME", + "WORK", + "POSTAL", + "PARCEL", + "DOM", + "INTL", + "PREF", + }, + value = "LINE", + }, + TEL = { + types = { + "HOME", + "WORK", + "VOICE", + "FAX", + "PAGER", + "MSG", + "CELL", + "VIDEO", + "BBS", + "MODEM", + "ISDN", + "PCS", + "PREF", + }, + value = "NUMBER", + }, + EMAIL = { + types = { + "HOME", + "WORK", + "INTERNET", + "PREF", + "X400", + }, + value = "USERID", + }, + JABBERID = "text", + MAILER = "text", + TZ = "text", + GEO = { + values = { + "LAT", + "LON", + }, + }, + TITLE = "text", + ROLE = "text", + LOGO = "copy of PHOTO", + AGENT = "text", + ORG = { + values = { + behaviour = "repeat-last", + "ORGNAME", + "ORGUNIT", + } + }, + CATEGORIES = { + values = "KEYWORD", + }, + NOTE = "text", + PRODID = "text", + REV = "text", + SORTSTRING = "text", + SOUND = "copy of PHOTO", + UID = "text", + URL = "text", + CLASS = { + names = { -- The item.name is the value if it's one of these. + "PUBLIC", + "PRIVATE", + "CONFIDENTIAL", + }, + }, + KEY = { + props = { "TYPE" }, + value = "CRED", + }, + DESC = "text", +}; +vCard_dtd.LOGO = vCard_dtd.PHOTO; +vCard_dtd.SOUND = vCard_dtd.PHOTO; + +vCard4_dtd = { + source = "uri", + kind = "text", + xml = "text", + fn = "text", + n = { + values = { + "family", + "given", + "middle", + "prefix", + "suffix", + }, + }, + nickname = "text", + photo = "uri", + bday = "date-and-or-time", + anniversary = "date-and-or-time", + gender = "text", + adr = { + values = { + "pobox", + "ext", + "street", + "locality", + "region", + "code", + "country", + } + }, + tel = "text", + email = "text", + impp = "uri", + lang = "language-tag", + tz = "text", + geo = "uri", + title = "text", + role = "text", + logo = "uri", + org = "text", + member = "uri", + related = "uri", + categories = "text", + note = "text", + prodid = "text", + rev = "timestamp", + sound = "uri", + uid = "uri", + clientpidmap = "number, uuid", + url = "uri", + version = "text", + key = "uri", + fburl = "uri", + caladruri = "uri", + caluri = "uri", +}; + +return { + from_text = from_text; + to_text = to_text; + + from_xep54 = from_xep54; + to_xep54 = to_xep54; + + to_vcard4 = vcards_to_vcard4xml; +}; diff --git a/util/x509.lua b/util/x509.lua index 857f02a4..5e1b49e5 100644 --- a/util/x509.lua +++ b/util/x509.lua @@ -20,11 +20,9 @@ local nameprep = require "util.encodings".stringprep.nameprep; local idna_to_ascii = require "util.encodings".idna.to_ascii; +local base64 = require "util.encodings".base64; local log = require "util.logger".init("x509"); -local pairs, ipairs = pairs, ipairs; local s_format = string.format; -local t_insert = table.insert; -local t_concat = table.concat; module "x509" @@ -214,4 +212,23 @@ function verify_identity(host, service, cert) return false end +local pat = "%-%-%-%-%-BEGIN ([A-Z ]+)%-%-%-%-%-\r?\n".. +"([0-9A-Za-z+/=\r\n]*)\r?\n%-%-%-%-%-END %1%-%-%-%-%-"; + +function pem2der(pem) + local typ, data = pem:match(pat); + if typ and data then + return base64.decode(data), typ; + end +end + +local wrap = ('.'):rep(64); +local envelope = "-----BEGIN %s-----\n%s\n-----END %s-----\n" + +function der2pem(data, typ) + typ = typ and typ:upper() or "CERTIFICATE"; + data = base64.encode(data); + return s_format(envelope, typ, data:gsub(wrap, '%0\n', (#data-1)/64), typ); +end + return _M; diff --git a/util/xmppstream.lua b/util/xmppstream.lua index 550170c9..6982aae3 100644 --- a/util/xmppstream.lua +++ b/util/xmppstream.lua @@ -6,7 +6,6 @@ -- COPYING file in the source package for more information. -- - local lxp = require "lxp"; local st = require "util.stanza"; local stanza_mt = st.stanza_mt; @@ -20,6 +19,10 @@ local setmetatable = setmetatable; -- COMPAT: w/LuaExpat 1.1.0 local lxp_supports_doctype = pcall(lxp.new, { StartDoctypeDecl = false }); +local lxp_supports_xmldecl = pcall(lxp.new, { XmlDecl = false }); +local lxp_supports_bytecount = not not lxp.new({}).getcurrentbytecount; + +local default_stanza_size_limit = 1024*1024*10; -- 10MB module "xmppstream" @@ -40,13 +43,16 @@ local ns_pattern = "^([^"..ns_separator.."]*)"..ns_separator.."?(.*)$"; _M.ns_separator = ns_separator; _M.ns_pattern = ns_pattern; -function new_sax_handlers(session, stream_callbacks) +local function dummy_cb() end + +function new_sax_handlers(session, stream_callbacks, cb_handleprogress) local xml_handlers = {}; local cb_streamopened = stream_callbacks.streamopened; local cb_streamclosed = stream_callbacks.streamclosed; local cb_error = stream_callbacks.error or function(session, e, stanza) error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); end; local cb_handlestanza = stream_callbacks.handlestanza; + cb_handleprogress = cb_handleprogress or dummy_cb; local stream_ns = stream_callbacks.stream_ns or xmlns_streams; local stream_tag = stream_callbacks.stream_tag or "stream"; @@ -59,6 +65,7 @@ function new_sax_handlers(session, stream_callbacks) local stack = {}; local chardata, stanza = {}; + local stanza_size = 0; local non_streamns_depth = 0; function xml_handlers:StartElement(tagname, attr) if stanza and #chardata > 0 then @@ -87,10 +94,17 @@ function new_sax_handlers(session, stream_callbacks) end if not stanza then --if we are not currently inside a stanza + if lxp_supports_bytecount then + stanza_size = self:getcurrentbytecount(); + end if session.notopen then if tagname == stream_tag then non_streamns_depth = 0; if cb_streamopened then + if lxp_supports_bytecount then + cb_handleprogress(stanza_size); + stanza_size = 0; + end cb_streamopened(session, attr); end else @@ -105,6 +119,9 @@ function new_sax_handlers(session, stream_callbacks) stanza = setmetatable({ name = name, attr = attr, tags = {} }, stanza_mt); else -- we are inside a stanza, so add a tag + if lxp_supports_bytecount then + stanza_size = stanza_size + self:getcurrentbytecount(); + end t_insert(stack, stanza); local oldstanza = stanza; stanza = setmetatable({ name = name, attr = attr, tags = {} }, stanza_mt); @@ -112,12 +129,45 @@ function new_sax_handlers(session, stream_callbacks) t_insert(oldstanza.tags, stanza); end end + if lxp_supports_xmldecl then + function xml_handlers:XmlDecl(version, encoding, standalone) + if lxp_supports_bytecount then + cb_handleprogress(self:getcurrentbytecount()); + end + end + end + function xml_handlers:StartCdataSection() + if lxp_supports_bytecount then + if stanza then + stanza_size = stanza_size + self:getcurrentbytecount(); + else + cb_handleprogress(self:getcurrentbytecount()); + end + end + end + function xml_handlers:EndCdataSection() + if lxp_supports_bytecount then + if stanza then + stanza_size = stanza_size + self:getcurrentbytecount(); + else + cb_handleprogress(self:getcurrentbytecount()); + end + end + end function xml_handlers:CharacterData(data) if stanza then + if lxp_supports_bytecount then + stanza_size = stanza_size + self:getcurrentbytecount(); + end t_insert(chardata, data); + elseif lxp_supports_bytecount then + cb_handleprogress(self:getcurrentbytecount()); end end function xml_handlers:EndElement(tagname) + if lxp_supports_bytecount then + stanza_size = stanza_size + self:getcurrentbytecount() + end if non_streamns_depth > 0 then non_streamns_depth = non_streamns_depth - 1; end @@ -129,6 +179,10 @@ function new_sax_handlers(session, stream_callbacks) end -- Complete stanza if #stack == 0 then + if lxp_supports_bytecount then + cb_handleprogress(stanza_size); + end + stanza_size = 0; if tagname ~= stream_error_tag then cb_handlestanza(session, stanza); else @@ -159,7 +213,7 @@ function new_sax_handlers(session, stream_callbacks) xml_handlers.ProcessingInstruction = restricted_handler; local function reset() - stanza, chardata = nil, {}; + stanza, chardata, stanza_size = nil, {}, 0; stack = {}; end @@ -170,19 +224,58 @@ function new_sax_handlers(session, stream_callbacks) return xml_handlers, { reset = reset, set_session = set_session }; end -function new(session, stream_callbacks) - local handlers, meta = new_sax_handlers(session, stream_callbacks); - local parser = new_parser(handlers, ns_separator); +function new(session, stream_callbacks, stanza_size_limit) + -- Used to track parser progress (e.g. to enforce size limits) + local n_outstanding_bytes = 0; + local handle_progress; + if lxp_supports_bytecount then + function handle_progress(n_parsed_bytes) + n_outstanding_bytes = n_outstanding_bytes - n_parsed_bytes; + end + stanza_size_limit = stanza_size_limit or default_stanza_size_limit; + elseif stanza_size_limit then + error("Stanza size limits are not supported on this version of LuaExpat") + end + + local handlers, meta = new_sax_handlers(session, stream_callbacks, handle_progress); + local parser = new_parser(handlers, ns_separator, false); local parse = parser.parse; + function session.open_stream(session, from, to) + local send = session.sends2s or session.send; + + local attr = { + ["xmlns:stream"] = "http://etherx.jabber.org/streams", + ["xml:lang"] = "en", + xmlns = stream_callbacks.default_ns, + version = session.version and (session.version > 0 and "1.0" or nil), + id = session.streamid or "", + from = from or session.host, to = to, + }; + if session.stream_attrs then + session:stream_attrs(from, to, attr) + end + send("<?xml version='1.0'?>"); + send(st.stanza("stream:stream", attr):top_tag()); + return true; + end + return { reset = function () - parser = new_parser(handlers, ns_separator); + parser = new_parser(handlers, ns_separator, false); parse = parser.parse; + n_outstanding_bytes = 0; meta.reset(); end, feed = function (self, data) - return parse(parser, data); + if lxp_supports_bytecount then + n_outstanding_bytes = n_outstanding_bytes + #data; + end + local ok, err = parse(parser, data); + if lxp_supports_bytecount and n_outstanding_bytes > stanza_size_limit then + return nil, "stanza-too-large"; + end + return ok, err; end, set_session = meta.set_session; }; |