diff options
-rw-r--r-- | Makefile | 2 | ||||
-rw-r--r-- | core/moduleapi.lua | 32 | ||||
-rw-r--r-- | core/modulemanager.lua | 2 | ||||
-rw-r--r-- | core/portmanager.lua | 4 | ||||
-rw-r--r-- | net/http.lua | 17 | ||||
-rw-r--r-- | net/server_event.lua | 77 | ||||
-rw-r--r-- | net/server_select.lua | 56 | ||||
-rw-r--r-- | plugins/mod_auth_anonymous.lua | 2 | ||||
-rw-r--r-- | plugins/mod_auth_internal_hashed.lua | 12 | ||||
-rw-r--r-- | plugins/mod_compression.lua | 6 | ||||
-rw-r--r-- | plugins/mod_http.lua | 5 | ||||
-rw-r--r-- | plugins/mod_http_files.lua | 3 | ||||
-rw-r--r-- | plugins/mod_pep_plus.lua | 368 | ||||
-rw-r--r-- | plugins/mod_proxy65.lua | 19 | ||||
-rw-r--r-- | plugins/mod_pubsub/mod_pubsub.lua | 2 | ||||
-rw-r--r-- | plugins/mod_pubsub/pubsub.lib.lua | 4 | ||||
-rw-r--r-- | plugins/mod_saslauth.lua | 2 | ||||
-rw-r--r-- | plugins/muc/muc.lib.lua | 2 | ||||
-rw-r--r-- | prosody.cfg.lua.dist | 2 | ||||
-rwxr-xr-x | prosodyctl | 7 | ||||
-rwxr-xr-x | tools/ejabberd2prosody.lua | 6 | ||||
-rw-r--r-- | util/dependencies.lua | 10 | ||||
-rw-r--r-- | util/indexedbheap.lua | 153 | ||||
-rw-r--r-- | util/pluginloader.lua | 8 | ||||
-rw-r--r-- | util/sasl.lua | 14 | ||||
-rw-r--r-- | util/timer.lua | 62 | ||||
-rw-r--r-- | util/xmppstream.lua | 90 |
27 files changed, 833 insertions, 134 deletions
@@ -55,7 +55,7 @@ util/%.so: $(MAKE) install -C util-src %.install: % - sed "1s/\blua\b/$(RUNWITH)/; \ + sed "1s| lua$$| $(RUNWITH)|; \ s|^CFG_SOURCEDIR=.*;$$|CFG_SOURCEDIR='$(INSTALLEDSOURCE)';|; \ s|^CFG_CONFIGDIR=.*;$$|CFG_CONFIGDIR='$(INSTALLEDCONFIG)';|; \ s|^CFG_DATADIR=.*;$$|CFG_DATADIR='$(INSTALLEDDATA)';|; \ diff --git a/core/moduleapi.lua b/core/moduleapi.lua index 65e00d41..5a24f69c 100644 --- a/core/moduleapi.lua +++ b/core/moduleapi.lua @@ -16,8 +16,10 @@ local timer = require "util.timer"; local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat; local error, setmetatable, type = error, setmetatable, type; -local ipairs, pairs, select, unpack = ipairs, pairs, select, unpack; +local ipairs, pairs, select = ipairs, pairs, select; local tonumber, tostring = tonumber, tostring; +local pack = table.pack or function(...) return {n=select("#",...), ...}; end -- table.pack is only in 5.2 +local unpack = table.unpack or unpack; -- renamed in 5.2 local prosody = prosody; local hosts = prosody.hosts; @@ -347,11 +349,29 @@ function api:send(stanza) return core_post_stanza(hosts[self.host], stanza); end -function api:add_timer(delay, callback) - return timer.add_task(delay, function (t) - if self.loaded == false then return; end - return callback(t); - end); +local timer_methods = { } +local timer_mt = { + __index = timer_methods; +} +function timer_methods:stop( ) + timer.stop(self.id); +end +timer_methods.disarm = timer_methods.stop +function timer_methods:reschedule(delay) + timer.reschedule(self.id, delay) +end + +local function timer_callback(now, id, t) + if t.module_env.loaded == false then return; end + return t.callback(now, unpack(t, 1, t.n)); +end + +function api:add_timer(delay, callback, ...) + local t = pack(...) + t.module_env = self; + t.callback = callback; + t.id = timer.add_task(delay, timer_callback, t); + return setmetatable(t, timer_mt); end local path_sep = package.config:sub(1,1); diff --git a/core/modulemanager.lua b/core/modulemanager.lua index 2e488fd5..eb1ce733 100644 --- a/core/modulemanager.lua +++ b/core/modulemanager.lua @@ -30,7 +30,7 @@ pcall = function(f, ...) end local autoload_modules = {prosody.platform, "presence", "message", "iq", "offline", "c2s", "s2s"}; -local component_inheritable_modules = {"tls", "dialback", "iq", "s2s"}; +local component_inheritable_modules = {"tls", "saslauth", "dialback", "iq", "s2s"}; -- We need this to let modules access the real global namespace local _G = _G; diff --git a/core/portmanager.lua b/core/portmanager.lua index 95900c08..4cbf3eb3 100644 --- a/core/portmanager.lua +++ b/core/portmanager.lua @@ -29,6 +29,8 @@ if socket.tcp6 and config.get("*", "use_ipv6") ~= false then table.insert(default_local_interfaces, "::1"); end +local default_mode = config.get("*", "network_default_read_size") or 4096; + --- Private state -- service_name -> { service_info, ... } @@ -111,7 +113,7 @@ function activate(service_name) } bind_ports = set.new(type(bind_ports) ~= "table" and { bind_ports } or bind_ports ); - local mode, ssl = listener.default_mode or "*a"; + local mode, ssl = listener.default_mode or default_mode; local hooked_ports = {}; for interface in bind_interfaces do diff --git a/net/http.lua b/net/http.lua index ab9ec7b6..b87c9396 100644 --- a/net/http.lua +++ b/net/http.lua @@ -6,7 +6,6 @@ -- COPYING file in the source package for more information. -- -local socket = require "socket" local b64 = require "util.encodings".base64.encode; local url = require "socket.url" local httpstream_new = require "net.http.parser".new; @@ -160,21 +159,17 @@ function request(u, ex, callback) end local port_number = port and tonumber(port) or (using_https and 443 or 80); - -- Connect the socket, and wrap it with net.server - local conn = socket.tcp(); - conn:settimeout(10); - local ok, err = conn:connect(host, port_number); - if not ok and err ~= "timeout" then - callback(nil, 0, req); - return nil, err; - end - local sslctx = false; if using_https then sslctx = ex and ex.sslctx or { mode = "client", protocol = "sslv23", options = { "no_sslv2" } }; end - req.handler, req.conn = assert(server.wrapclient(conn, host, port_number, listener, "*a", sslctx)); + local handler, conn = server.addclient(host, port_number, listener, "*a", sslctx) + if not handler then + callback(nil, 0, req); + return nil, conn; + end + req.handler, req.conn = handler, conn req.write = function (...) return req.handler:write(...); end req.callback = function (content, code, request, response) log("debug", "Calling callback, status %s", code or "---"); return select(2, xpcall(function () return callback(content, code, request, response) end, handleerr)); end diff --git a/net/server_event.lua b/net/server_event.lua index 59217a0c..a3087847 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -44,8 +44,9 @@ local setmetatable = use "setmetatable" local t_insert = table.insert local t_concat = table.concat -local ssl = use "ssl" +local has_luasec, ssl = pcall ( require , "ssl" ) local socket = use "socket" or require "socket" +local getaddrinfo = socket.dns.getaddrinfo local log = require ("util.logger").init("socket") @@ -128,7 +129,7 @@ do return self:_destroy(); end - function interface_mt:_start_connection(plainssl) -- should be called from addclient + function interface_mt:_start_connection(plainssl) -- called from wrapclient local callback = function( event ) if EV_TIMEOUT == event then -- timeout during connection self.fatalerror = "connection timeout" @@ -136,7 +137,7 @@ do self:_close() debug( "new connection failed. id:", self.id, "error:", self.fatalerror ) else - if plainssl and ssl then -- start ssl session + if plainssl and has_luasec then -- start ssl session self:starttls(self._sslctx, true) else -- normal connection self:_start_session(true) @@ -367,6 +368,7 @@ do function interface_mt:ssl() return self._usingssl end + interface_mt.clientport = interface_mt.port -- COMPAT server_select function interface_mt:type() return self._type or "client" @@ -506,7 +508,7 @@ do _sslctx = sslctx; -- parameters _usingssl = false; -- client is using ssl; } - if not ssl then interface.starttls = false; end + if not has_luasec then interface.starttls = false; end interface.id = tostring(interface):match("%x+$"); interface.writecallback = function( event ) -- called on write events --vdebug( "new client write event, id/ip/port:", interface, ip, port ) @@ -689,7 +691,7 @@ do interface._connections = interface._connections + 1 -- increase connection count local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx ) --vdebug( "client id:", clientinterface, "startssl:", startssl ) - if ssl and sslctx then + if has_luasec and sslctx then clientinterface:starttls(sslctx, true) else clientinterface:_start_session( true ) @@ -710,25 +712,17 @@ do end local addserver = ( function( ) - return function( addr, port, listener, pattern, sslcfg, startssl ) -- TODO: check arguments - --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslcfg or "nil", startssl or "nil") + return function( addr, port, listener, pattern, sslctx, startssl ) -- TODO: check arguments + --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslctx or "nil", startssl or "nil") + if sslctx and not has_luasec then + debug "fatal error: luasec not found" + return nil, "luasec not found" + end local server, err = socket.bind( addr, port, cfg.ACCEPT_QUEUE ) -- create server socket if not server then debug( "creating server socket on "..addr.." port "..port.." failed:", err ) return nil, err end - local sslctx - if sslcfg then - if not ssl then - debug "fatal error: luasec not found" - return nil, "luasec not found" - end - sslctx, err = sslcfg - if err then - debug( "error while creating new ssl context for server socket:", err ) - return nil, err - end - end local interface = handleserver( server, addr, port, pattern, listener, sslctx, startssl ) -- new server handler debug( "new server created with id:", tostring(interface)) return interface @@ -744,37 +738,34 @@ do --function handleclient( client, ip, port, server, pattern, listener, _, sslctx ) -- creates an client interface end - function addclient( addr, serverport, listener, pattern, localaddr, localport, sslcfg, startssl ) - local client, err = socket.tcp() -- creating new socket + function addclient( addr, serverport, listener, pattern, sslctx, typ ) + if sslctx and not has_luasec then + debug "need luasec, but not available" + return nil, "luasec not found" + end + if getaddrinfo and not typ then + local addrinfo, err = getaddrinfo(addr) + if not addrinfo then return nil, err end + if addrinfo[1] and addrinfo[1].family == "inet6" then + typ = "tcp6" + end + end + local create = socket[typ or "tcp"] + if type( create ) ~= "function" then + return nil, "invalid socket type" + end + local client, err = create() -- creating new socket if not client then debug( "cannot create socket:", err ) return nil, err end client:settimeout( 0 ) -- set nonblocking - if localaddr then - local res, err = client:bind( localaddr, localport, -1 ) - if not res then - debug( "cannot bind client:", err ) - return nil, err - end - end - local sslctx - if sslcfg then -- handle ssl/new context - if not ssl then - debug "need luasec, but not available" - return nil, "luasec not found" - end - sslctx, err = sslcfg - if err then - debug( "cannot create new ssl context:", err ) - return nil, err - end - end local res, err = client:connect( addr, serverport ) -- connect if res or ( err == "timeout" ) then - local ip, port = client:getsockname( ) - local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx, startssl ) - interface:_start_connection( startssl ) + if client.getsockname then + addr = client:getsockname( ) + end + local interface = wrapclient( client, addr, serverport, listener, pattern, sslctx ) debug( "new connection id:", interface.id ) return interface, err else diff --git a/net/server_select.lua b/net/server_select.lua index c5e0772f..4a36617c 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -48,13 +48,14 @@ local coroutine_yield = coroutine.yield --// extern libs //-- -local luasec = use "ssl" +local has_luasec, luasec = pcall ( require , "ssl" ) local luasocket = use "socket" or require "socket" local luasocket_gettime = luasocket.gettime +local getaddrinfo = luasocket.dns.getaddrinfo --// extern lib methods //-- -local ssl_wrap = ( luasec and luasec.wrap ) +local ssl_wrap = ( has_luasec and luasec.wrap ) local socket_bind = luasocket.bind local socket_sleep = luasocket.sleep local socket_select = luasocket.select @@ -401,6 +402,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.clientport = function( ) return clientport end + handler.port = handler.clientport -- COMPAT server_event local write = function( self, data ) bufferlen = bufferlen + #data if bufferlen > maxsendlen then @@ -585,7 +587,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end ) end - if luasec then + if has_luasec then handler.starttls = function( self, _sslctx) if _sslctx then handler:set_sslctx(_sslctx); @@ -638,7 +640,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _socketlist[ socket ] = handler _readlistlen = addsocket(_readlist, socket, _readlistlen) - if sslctx and luasec then + if sslctx and has_luasec then out_put "server.lua: auto-starting ssl negotiation..." handler.autostart_ssl = true; local ok, err = handler:starttls(sslctx); @@ -713,22 +715,23 @@ end ----------------------------------// PUBLIC //-- addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server + addr = addr or "*" local err if type( listeners ) ~= "table" then err = "invalid listener table" - end - if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then + elseif type ( addr ) ~= "string" then + err = "invalid address" + elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then err = "invalid port" elseif _server[ addr..":"..port ] then err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist" - elseif sslctx and not luasec then + elseif sslctx and not has_luasec then err = "luasec not found" end if err then out_error( "server.lua, [", addr, "]:", port, ": ", err ) return nil, err end - addr = addr or "*" local server, err = socket_bind( addr, port, _tcpbacklog ) if err then out_error( "server.lua, [", addr, "]:", port, ": ", err ) @@ -929,17 +932,44 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx return handler, socket end -local addclient = function( address, port, listeners, pattern, sslctx ) - local client, err = luasocket.tcp( ) +local addclient = function( address, port, listeners, pattern, sslctx, typ ) + local err + if type( listeners ) ~= "table" then + err = "invalid listener table" + elseif type ( address ) ~= "string" then + err = "invalid address" + elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then + err = "invalid port" + elseif sslctx and not has_luasec then + err = "luasec not found" + end + if getaddrinfo and not typ then + local addrinfo, err = getaddrinfo(address) + if not addrinfo then return nil, err end + if addrinfo[1] and addrinfo[1].family == "inet6" then + typ = "tcp6" + end + end + local create = luasocket[typ or "tcp"] + if type( create ) ~= "function" then + err = "invalid socket type" + end + + if err then + out_error( "server.lua, addclient: ", err ) + return nil, err + end + + local client, err = create( ) if err then return nil, err end client:settimeout( 0 ) - _, err = client:connect( address, port ) - if err then -- try again + local ok, err = client:connect( address, port ) + if ok or err == "timeout" then return wrapclient( client, address, port, listeners, pattern, sslctx ) else - return wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx ) + return nil, err end end diff --git a/plugins/mod_auth_anonymous.lua b/plugins/mod_auth_anonymous.lua index c877d532..8de46f8c 100644 --- a/plugins/mod_auth_anonymous.lua +++ b/plugins/mod_auth_anonymous.lua @@ -43,7 +43,7 @@ function provider.get_sasl_handler() end function provider.users() - return next, hosts[host].sessions, nil; + return next, hosts[module.host].sessions, nil; end -- datamanager callback to disable writes diff --git a/plugins/mod_auth_internal_hashed.lua b/plugins/mod_auth_internal_hashed.lua index fb87bb9f..954392c9 100644 --- a/plugins/mod_auth_internal_hashed.lua +++ b/plugins/mod_auth_internal_hashed.lua @@ -7,6 +7,8 @@ -- COPYING file in the source package for more information. -- +local max = math.max; + local getAuthenticationDatabaseSHA1 = require "util.sasl.scram".getAuthenticationDatabaseSHA1; local usermanager = require "core.usermanager"; local generate_uuid = require "util.uuid".generate; @@ -39,7 +41,7 @@ end -- Default; can be set per-user -local iteration_count = 4096; +local default_iteration_count = 4096; -- define auth provider local provider = {}; @@ -80,8 +82,8 @@ function provider.set_password(username, password) log("debug", "set_password for username '%s'", username); local account = accounts:get(username); if account then - account.salt = account.salt or generate_uuid(); - account.iteration_count = account.iteration_count or iteration_count; + account.salt = generate_uuid(); + account.iteration_count = max(account.iteration_count or 0, default_iteration_count); local valid, stored_key, server_key = getAuthenticationDatabaseSHA1(password, account.salt, account.iteration_count); local stored_key_hex = to_hex(stored_key); local server_key_hex = to_hex(server_key); @@ -113,10 +115,10 @@ function provider.create_user(username, password) return accounts:set(username, {}); end local salt = generate_uuid(); - local valid, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count); + local valid, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, default_iteration_count); local stored_key_hex = to_hex(stored_key); local server_key_hex = to_hex(server_key); - return accounts:set(username, {stored_key = stored_key_hex, server_key = server_key_hex, salt = salt, iteration_count = iteration_count}); + return accounts:set(username, {stored_key = stored_key_hex, server_key = server_key_hex, salt = salt, iteration_count = default_iteration_count}); end function provider.delete_user(username) diff --git a/plugins/mod_compression.lua b/plugins/mod_compression.lua index f44e8a6d..9da5254e 100644 --- a/plugins/mod_compression.lua +++ b/plugins/mod_compression.lua @@ -125,8 +125,8 @@ end module:hook("stanza/http://jabber.org/protocol/compress:compressed", function(event) local session = event.origin; - - if session.type == "s2sout_unauthed" or session.type == "s2sout" then + + if session.type == "s2sout" then session.log("debug", "Activating compression...") -- create deflate and inflate streams local deflate_stream = get_deflate_stream(session); @@ -150,7 +150,7 @@ end); module:hook("stanza/http://jabber.org/protocol/compress:compress", function(event) local session, stanza = event.origin, event.stanza; - if session.type == "c2s" or session.type == "s2sin" or session.type == "c2s_unauthed" or session.type == "s2sin_unauthed" then + if session.type == "c2s" or session.type == "s2sin" then -- fail if we are already compressed if session.compressed then local error_st = st.stanza("failure", {xmlns=xmlns_compression_protocol}):tag("setup-failed"); diff --git a/plugins/mod_http.lua b/plugins/mod_http.lua index 95933da5..d987ef74 100644 --- a/plugins/mod_http.lua +++ b/plugins/mod_http.lua @@ -42,7 +42,7 @@ local function get_base_path(host_module, app_name, default_app_path) return (normalize_path(host_module:get_option("http_paths", {})[app_name] -- Host or module:get_option("http_paths", {})[app_name] -- Global or default_app_path)) -- Default - :gsub("%$(%w+)", { host = module.host }); + :gsub("%$(%w+)", { host = host_module.host }); end local ports_by_scheme = { http = 80, https = 443, }; @@ -51,6 +51,9 @@ local ports_by_scheme = { http = 80, https = 443, }; function moduleapi.http_url(module, app_name, default_path) app_name = app_name or (module.name:gsub("^http_", "")); local external_url = url_parse(module:get_option_string("http_external_url")) or {}; + if external_url.scheme and external_url.port == nil then + external_url.port = ports_by_scheme[external_url.scheme]; + end local services = portmanager.get_active_services(); local http_services = services:get("https") or services:get("http") or {}; for interface, ports in pairs(http_services) do diff --git a/plugins/mod_http_files.lua b/plugins/mod_http_files.lua index dd04853b..2e9f4182 100644 --- a/plugins/mod_http_files.lua +++ b/plugins/mod_http_files.lua @@ -14,6 +14,7 @@ local os_date = os.date; local open = io.open; local stat = lfs.attributes; local build_path = require"socket.url".build_path; +local path_sep = package.config:sub(1,1); local base_path = module:get_option_string("http_files_dir", module:get_option_string("http_path")); local dir_indices = module:get_option("http_index_files", { "index.html", "index.htm" }); @@ -61,7 +62,7 @@ function serve(opts) local request, response = event.request, event.response; local orig_path = request.path; local full_path = base_path .. (path and "/"..path or ""); - local attr = stat(full_path); + local attr = stat((full_path:gsub('%'..path_sep..'+$',''))); if not attr then return 404; end diff --git a/plugins/mod_pep_plus.lua b/plugins/mod_pep_plus.lua new file mode 100644 index 00000000..4a74e437 --- /dev/null +++ b/plugins/mod_pep_plus.lua @@ -0,0 +1,368 @@ +local pubsub = require "util.pubsub"; +local jid_bare = require "util.jid".bare; +local jid_split = require "util.jid".split; +local set_new = require "util.set".new; +local st = require "util.stanza"; +local calculate_hash = require "util.caps".calculate_hash; +local is_contact_subscribed = require "core.rostermanager".is_contact_subscribed; + +local xmlns_pubsub = "http://jabber.org/protocol/pubsub"; +local xmlns_pubsub_event = "http://jabber.org/protocol/pubsub#event"; +local xmlns_pubsub_owner = "http://jabber.org/protocol/pubsub#owner"; + +local lib_pubsub = module:require "pubsub"; +local handlers = lib_pubsub.handlers; +local pubsub_error_reply = lib_pubsub.pubsub_error_reply; + +local services = {}; +local recipients = {}; +local hash_map = {}; + +function module.save() + return { services = services }; +end + +function module.restore(data) + services = data.services; +end + +local function subscription_presence(user_bare, recipient) + local recipient_bare = jid_bare(recipient); + if (recipient_bare == user_bare) then return true; end + local username, host = jid_split(user_bare); + return is_contact_subscribed(username, host, recipient_bare); +end + +local function get_broadcaster(name) + local function simple_broadcast(kind, node, jids, item) + if item then + item = st.clone(item); + item.attr.xmlns = nil; -- Clear the pubsub namespace + end + local message = st.message({ from = name, type = "headline" }) + :tag("event", { xmlns = xmlns_pubsub_event }) + :tag(kind, { node = node }) + :add_child(item); + for jid in pairs(jids) do + module:log("debug", "Sending notification to %s from %s: %s", jid, name, tostring(item)); + message.attr.to = jid; + module:send(message); + end + end + return simple_broadcast; +end + +local function get_pep_service(name) + if services[name] then + return services[name]; + end + services[name] = pubsub.new({ + capabilities = { + none = { + create = false; + publish = false; + retract = false; + get_nodes = false; + + subscribe = false; + unsubscribe = false; + get_subscription = false; + get_subscriptions = false; + get_items = false; + + subscribe_other = false; + unsubscribe_other = false; + get_subscription_other = false; + get_subscriptions_other = false; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = false; + }; + subscriber = { + create = false; + publish = false; + retract = false; + get_nodes = true; + + subscribe = true; + unsubscribe = true; + get_subscription = true; + get_subscriptions = true; + get_items = true; + + subscribe_other = false; + unsubscribe_other = false; + get_subscription_other = false; + get_subscriptions_other = false; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = false; + }; + publisher = { + create = false; + publish = true; + retract = true; + get_nodes = true; + + subscribe = true; + unsubscribe = true; + get_subscription = true; + get_subscriptions = true; + get_items = true; + + subscribe_other = false; + unsubscribe_other = false; + get_subscription_other = false; + get_subscriptions_other = false; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = false; + }; + owner = { + create = true; + publish = true; + retract = true; + delete = true; + get_nodes = true; + + subscribe = true; + unsubscribe = true; + get_subscription = true; + get_subscriptions = true; + get_items = true; + + + subscribe_other = true; + unsubscribe_other = true; + get_subscription_other = true; + get_subscriptions_other = true; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = true; + }; + }; + + autocreate_on_publish = true; + autocreate_on_subscribe = true; + + broadcaster = get_broadcaster(name); + get_affiliation = function (jid) + if jid_bare(jid) == name then + return "owner"; + elseif subscription_presence(name, jid) then + return "subscriber"; + end + end; + + normalize_jid = jid_bare; + }); + return services[name]; +end + +function handle_pubsub_iq(event) + local origin, stanza = event.origin, event.stanza; + local pubsub = stanza.tags[1]; + local action = pubsub.tags[1]; + if not action then + return origin.send(st.error_reply(stanza, "cancel", "bad-request")); + end + local service_name = stanza.attr.to or origin.username.."@"..origin.host + local service = get_pep_service(service_name); + local handler = handlers[stanza.attr.type.."_"..action.name]; + if handler then + handler(origin, stanza, action, service); + return true; + end +end + +module:hook("iq/bare/"..xmlns_pubsub..":pubsub", handle_pubsub_iq); +module:hook("iq/bare/"..xmlns_pubsub_owner..":pubsub", handle_pubsub_iq); + +module:add_identity("pubsub", "pep", module:get_option_string("name", "Prosody")); +module:add_feature("http://jabber.org/protocol/pubsub#publish"); + +local function get_caps_hash_from_presence(stanza, current) + local t = stanza.attr.type; + if not t then + local child = stanza:get_child("c", "http://jabber.org/protocol/caps"); + if child then + local attr = child.attr; + if attr.hash then -- new caps + if attr.hash == 'sha-1' and attr.node and attr.ver then + return attr.ver, attr.node.."#"..attr.ver; + end + else -- legacy caps + if attr.node and attr.ver then + return attr.node.."#"..attr.ver.."#"..(attr.ext or ""), attr.node.."#"..attr.ver; + end + end + end + return; -- no or bad caps + elseif t == "unavailable" or t == "error" then + return; + end + return current; -- no caps, could mean caps optimization, so return current +end + +local function resend_last_item(jid, node, service) + local ok, items = service:get_items(node, jid); + if not ok then return; end + for i, id in ipairs(items) do + service.config.broadcaster("items", node, { [jid] = true }, items[id]); + end +end + +local function update_subscriptions(recipient, service_name, nodes) + local service = get_pep_service(service_name); + + recipients[service_name] = recipients[service_name] or {}; + nodes = nodes or set_new(); + local old = recipients[service_name][recipient]; + + if old and type(old) == table then + for node in pairs((old - nodes):items()) do + service:remove_subscription(node, recipient, recipient); + end + end + + for node in nodes:items() do + service:add_subscription(node, recipient, recipient); + resend_last_item(recipient, node, service); + end + recipients[service_name][recipient] = nodes; +end + +module:hook("presence/bare", function(event) + -- inbound presence to bare JID recieved + local origin, stanza = event.origin, event.stanza; + local user = stanza.attr.to or (origin.username..'@'..origin.host); + local t = stanza.attr.type; + local self = not stanza.attr.to; + local service = get_pep_service(user); + + if not t then -- available presence + if self or subscription_presence(user, stanza.attr.from) then + local recipient = stanza.attr.from; + local current = recipients[user] and recipients[user][recipient]; + local hash, query_node = get_caps_hash_from_presence(stanza, current); + if current == hash or (current and current == hash_map[hash]) then return; end + if not hash then + update_subscriptions(recipient, user); + else + recipients[user] = recipients[user] or {}; + if hash_map[hash] then + update_subscriptions(recipient, user, hash_map[hash]); + else + recipients[user][recipient] = hash; + local from_bare = origin.type == "c2s" and origin.username.."@"..origin.host; + if self or origin.type ~= "c2s" or (recipients[from_bare] and recipients[from_bare][origin.full_jid]) ~= hash then + -- COMPAT from ~= stanza.attr.to because OneTeam can't deal with missing from attribute + origin.send( + st.stanza("iq", {from=user, to=stanza.attr.from, id="disco", type="get"}) + :tag("query", {xmlns = "http://jabber.org/protocol/disco#info", node = query_node}) + ); + end + end + end + end + elseif t == "unavailable" then + update_subscriptions(stanza.attr.from, user); + elseif not self and t == "unsubscribe" then + local from = jid_bare(stanza.attr.from); + local subscriptions = recipients[user]; + if subscriptions then + for subscriber in pairs(subscriptions) do + if jid_bare(subscriber) == from then + update_subscriptions(subscriber, user); + end + end + end + end +end, 10); + +module:hook("iq-result/bare/disco", function(event) + local origin, stanza = event.origin, event.stanza; + local disco = stanza:get_child("query", "http://jabber.org/protocol/disco#info"); + if not disco then + return; + end + + -- Process disco response + local self = not stanza.attr.to; + local user = stanza.attr.to or (origin.username..'@'..origin.host); + local contact = stanza.attr.from; + local current = recipients[user] and recipients[user][contact]; + if type(current) ~= "string" then return; end -- check if waiting for recipient's response + local ver = current; + if not string.find(current, "#") then + ver = calculate_hash(disco.tags); -- calculate hash + end + local notify = set_new(); + for _, feature in pairs(disco.tags) do + if feature.name == "feature" and feature.attr.var then + local nfeature = feature.attr.var:match("^(.*)%+notify$"); + if nfeature then notify:add(nfeature); end + end + end + hash_map[ver] = notify; -- update hash map + if self then + for jid, item in pairs(origin.roster) do -- for all interested contacts + if item.subscription == "both" or item.subscription == "from" then + if not recipients[jid] then recipients[jid] = {}; end + update_subscriptions(contact, jid, notify); + end + end + end + update_subscriptions(contact, user, notify); +end); + +module:hook("account-disco-info-node", function(event) + local reply, stanza, origin = event.reply, event.stanza, event.origin; + local service_name = stanza.attr.to or origin.username.."@"..origin.host + local service = get_pep_service(service_name); + local node = event.node; + local ok = service:get_items(node, jid_bare(stanza.attr.from) or true); + if not ok then return; end + event.exists = true; + reply:tag('identity', {category='pubsub', type='leaf'}):up(); +end); + +module:hook("account-disco-info", function(event) + local reply = event.reply; + reply:tag('identity', {category='pubsub', type='pep'}):up(); + reply:tag('feature', {var='http://jabber.org/protocol/pubsub#publish'}):up(); +end); + +module:hook("account-disco-items-node", function(event) + local reply, stanza, origin = event.reply, event.stanza, event.origin; + local node = event.node; + local service_name = stanza.attr.to or origin.username.."@"..origin.host + local service = get_pep_service(service_name); + local ok, ret = service:get_items(node, jid_bare(stanza.attr.from) or true); + if not ok then return; end + event.exists = true; + for _, id in ipairs(ret) do + reply:tag("item", { jid = service_name, name = id }):up(); + end +end); + +module:hook("account-disco-items", function(event) + local reply, stanza, origin = event.reply, event.stanza, event.origin; + + local service_name = reply.attr.from or origin.username.."@"..origin.host + local service = get_pep_service(service_name); + local ok, ret = service:get_nodes(jid_bare(stanza.attr.from)); + if not ok then return; end + + for node, node_obj in pairs(ret) do + reply:tag("item", { jid = service_name, node = node, name = node_obj.config.name }):up(); + end +end); diff --git a/plugins/mod_proxy65.lua b/plugins/mod_proxy65.lua index 2ed9faac..73527cbc 100644 --- a/plugins/mod_proxy65.lua +++ b/plugins/mod_proxy65.lua @@ -101,27 +101,10 @@ function module.add_host(module) module:log("warn", "proxy65_port is deprecated, please put proxy65_ports = { %d } into the global section instead", legacy_config); end + module:depends("disco"); module:add_identity("proxy", "bytestreams", name); module:add_feature("http://jabber.org/protocol/bytestreams"); - module:hook("iq-get/host/http://jabber.org/protocol/disco#info:query", function(event) - local origin, stanza = event.origin, event.stanza; - if not stanza.tags[1].attr.node then - origin.send(st.reply(stanza):query("http://jabber.org/protocol/disco#info") - :tag("identity", {category='proxy', type='bytestreams', name=name}):up() - :tag("feature", {var="http://jabber.org/protocol/bytestreams"}) ); - return true; - end - end, -1); - - module:hook("iq-get/host/http://jabber.org/protocol/disco#items:query", function(event) - local origin, stanza = event.origin, event.stanza; - if not stanza.tags[1].attr.node then - origin.send(st.reply(stanza):query("http://jabber.org/protocol/disco#items")); - return true; - end - end, -1); - module:hook("iq-get/host/http://jabber.org/protocol/bytestreams:query", function(event) local origin, stanza = event.origin, event.stanza; diff --git a/plugins/mod_pubsub/mod_pubsub.lua b/plugins/mod_pubsub/mod_pubsub.lua index c6dbe831..33e729af 100644 --- a/plugins/mod_pubsub/mod_pubsub.lua +++ b/plugins/mod_pubsub/mod_pubsub.lua @@ -100,7 +100,7 @@ module:hook("host-disco-items-node", function (event) return; end - for id, item in pairs(ret) do + for _, id in ipairs(ret) do reply:tag("item", { jid = module.host, name = id }):up(); end event.exists = true; diff --git a/plugins/mod_pubsub/pubsub.lib.lua b/plugins/mod_pubsub/pubsub.lib.lua index 2b015e34..4e9acd68 100644 --- a/plugins/mod_pubsub/pubsub.lib.lua +++ b/plugins/mod_pubsub/pubsub.lib.lua @@ -42,8 +42,8 @@ function handlers.get_items(origin, stanza, items, service) end local data = st.stanza("items", { node = node }); - for _, entry in pairs(results) do - data:add_child(entry); + for _, id in ipairs(results) do + data:add_child(results[id]); end local reply; if data then diff --git a/plugins/mod_saslauth.lua b/plugins/mod_saslauth.lua index 94c060b3..df60aefa 100644 --- a/plugins/mod_saslauth.lua +++ b/plugins/mod_saslauth.lua @@ -197,7 +197,7 @@ module:hook("stanza/urn:ietf:params:xml:ns:xmpp-sasl:auth", function(event) return s2s_external_auth(session, stanza) end - if session.type ~= "c2s_unauthed" then return; end + if session.type ~= "c2s_unauthed" or module:get_host_type() ~= "local" then return; end if session.sasl_handler and session.sasl_handler.selected then session.sasl_handler = nil; -- allow starting a new SASL negotiation before completing an old one diff --git a/plugins/muc/muc.lib.lua b/plugins/muc/muc.lib.lua index 27c50cd4..5debb4a3 100644 --- a/plugins/muc/muc.lib.lua +++ b/plugins/muc/muc.lib.lua @@ -1265,7 +1265,7 @@ function room_mt:can_set_role(actor_jid, occupant_jid, role) if actor_jid == true then return true; end local actor = self._occupants[self:get_occupant_jid(actor_jid)]; - if actor.role == "moderator" then + if actor and actor.role == "moderator" then if occupant.affiliation ~= "owner" and occupant.affiliation ~= "admin" then if actor.affiliation == "owner" or actor.affiliation == "admin" then return true; diff --git a/prosody.cfg.lua.dist b/prosody.cfg.lua.dist index 1d11a658..ade219a8 100644 --- a/prosody.cfg.lua.dist +++ b/prosody.cfg.lua.dist @@ -63,7 +63,6 @@ modules_enabled = { --"http_files"; -- Serve static files from a directory over HTTP -- Other specific functionality - --"posix"; -- POSIX functionality, sends server to background, enables syslog, etc. --"groups"; -- Shared roster support --"announce"; -- Send announcement to all online users --"welcome"; -- Welcome users who register accounts @@ -78,6 +77,7 @@ modules_disabled = { -- "offline"; -- Store offline messages -- "c2s"; -- Handle client connections -- "s2s"; -- Handle server-to-server connections + -- "posix"; -- POSIX functionality, sends server to background, enables syslog, etc. } -- Disable account creation by default, for security @@ -687,7 +687,12 @@ function cert_commands.config(arg) conf.distinguished_name[k] = nv ~= "." and nv or nil; end end - local conf_file = io.open(conf_filename, "w"); + local conf_file, err = io.open(conf_filename, "w"); + if not conf_file then + show_warning("Could not open OpenSSL config file for writing"); + show_warning(err); + os.exit(1); + end conf_file:write(conf:serialize()); conf_file:close(); print(""); diff --git a/tools/ejabberd2prosody.lua b/tools/ejabberd2prosody.lua index e9dbd2dc..66bf4f93 100755 --- a/tools/ejabberd2prosody.lua +++ b/tools/ejabberd2prosody.lua @@ -44,8 +44,10 @@ function build_stanza(tuple, stanza) for _, a in ipairs(tuple[4]) do build_stanza(a, stanza); end if up then stanza:up(); else return stanza end elseif tuple[1] == "xmlcdata" then - assert(type(tuple[2]) == "string", "XML CDATA has unexpected type: "..type(tuple[2])); - stanza:text(tuple[2]); + if type(tuple[2]) ~= "table" then + assert(type(tuple[2]) == "string", "XML CDATA has unexpected type: "..type(tuple[2])); + stanza:text(tuple[2]); + end -- else it's [], i.e., the null value, used for the empty string else error("unknown element type: "..serialize(tuple)); end diff --git a/util/dependencies.lua b/util/dependencies.lua index 109a3332..9d80d241 100644 --- a/util/dependencies.lua +++ b/util/dependencies.lua @@ -140,7 +140,15 @@ function log_warnings() if not pcall(lxp.new, { StartDoctypeDecl = false }) then log("error", "The version of LuaExpat on your system leaves Prosody " .."vulnerable to denial-of-service attacks. You should upgrade to " - .."LuaExpat 1.1.1 or higher as soon as possible. See " + .."LuaExpat 1.3.0 or higher as soon as possible. See " + .."http://prosody.im/doc/depends#luaexpat for more information."); + end + if not lxp.new({}).getcurrentbytecount then + log("error", "The version of LuaExpat on your system does not support " + .."stanza size limits, which may leave servers on untrusted " + .."networks (e.g. the internet) vulnerable to denial-of-service " + .."attacks. You should upgrade to LuaExpat 1.3.0 or higher as " + .."soon as possible. See " .."http://prosody.im/doc/depends#luaexpat for more information."); end end diff --git a/util/indexedbheap.lua b/util/indexedbheap.lua new file mode 100644 index 00000000..3cb03037 --- /dev/null +++ b/util/indexedbheap.lua @@ -0,0 +1,153 @@ + +local setmetatable = setmetatable; +local math_floor = math.floor; +local t_remove = table.remove; + +local function _heap_insert(self, item, sync, item2, index) + local pos = #self + 1; + while true do + local half_pos = math_floor(pos / 2); + if half_pos == 0 or item > self[half_pos] then break; end + self[pos] = self[half_pos]; + sync[pos] = sync[half_pos]; + index[sync[pos]] = pos; + pos = half_pos; + end + self[pos] = item; + sync[pos] = item2; + index[item2] = pos; +end + +local function _percolate_up(self, k, sync, index) + local tmp = self[k]; + local tmp_sync = sync[k]; + while k ~= 1 do + local parent = math_floor(k/2); + if tmp < self[parent] then break; end + self[k] = self[parent]; + sync[k] = sync[parent]; + index[sync[k]] = k; + k = parent; + end + self[k] = tmp; + sync[k] = tmp_sync; + index[tmp_sync] = k; + return k; +end + +local function _percolate_down(self, k, sync, index) + local tmp = self[k]; + local tmp_sync = sync[k]; + local size = #self; + local child = 2*k; + while 2*k <= size do + if child ~= size and self[child] > self[child + 1] then + child = child + 1; + end + if tmp > self[child] then + self[k] = self[child]; + sync[k] = sync[child]; + index[sync[k]] = k; + else + break; + end + + k = child; + child = 2*k; + end + self[k] = tmp; + sync[k] = tmp_sync; + index[tmp_sync] = k; + return k; +end + +local function _heap_pop(self, sync, index) + local size = #self; + if size == 0 then return nil; end + + local result = self[1]; + local result_sync = sync[1]; + index[result_sync] = nil; + if size == 1 then + self[1] = nil; + sync[1] = nil; + return result, result_sync; + end + self[1] = t_remove(self); + sync[1] = t_remove(sync); + index[sync[1]] = 1; + + _percolate_down(self, 1, sync, index); + + return result, result_sync; +end + +local indexed_heap = {}; + +function indexed_heap:insert(item, priority, id) + if id == nil then + id = self.current_id; + self.current_id = id + 1; + end + self.items[id] = item; + _heap_insert(self.priorities, priority, self.ids, id, self.index); + return id; +end +function indexed_heap:pop() + local priority, id = _heap_pop(self.priorities, self.ids, self.index); + if id then + local item = self.items[id]; + self.items[id] = nil; + return priority, item, id; + end +end +function indexed_heap:peek() + return self.priorities[1]; +end +function indexed_heap:reprioritize(id, priority) + local k = self.index[id]; + if k == nil then return; end + self.priorities[k] = priority; + + k = _percolate_up(self.priorities, k, self.ids, self.index); + k = _percolate_down(self.priorities, k, self.ids, self.index); +end +function indexed_heap:remove_index(k) + local size = #self.priorities; + + local result = self.priorities[k]; + local result_sync = self.ids[k]; + local item = self.items[result_sync]; + if result == nil then return; end + self.index[result_sync] = nil; + self.items[result_sync] = nil; + + self.priorities[k] = self.priorities[size]; + self.ids[k] = self.ids[size]; + self.index[self.ids[k]] = k; + t_remove(self.priorities); + t_remove(self.ids); + + k = _percolate_up(self.priorities, k, self.ids, self.index); + k = _percolate_down(self.priorities, k, self.ids, self.index); + + return result, item, result_sync; +end +function indexed_heap:remove(id) + return self:remove_index(self.index[id]); +end + +local mt = { __index = indexed_heap }; + +local _M = { + create = function() + return setmetatable({ + ids = {}; -- heap of ids, sync'd with priorities + items = {}; -- map id->items + priorities = {}; -- heap of priorities + index = {}; -- map of id->index of id in ids + current_id = 1.5 + }, mt); + end +}; +return _M; diff --git a/util/pluginloader.lua b/util/pluginloader.lua index b894f527..b9b3e207 100644 --- a/util/pluginloader.lua +++ b/util/pluginloader.lua @@ -39,10 +39,10 @@ function load_resource(plugin, resource) resource = resource or "mod_"..plugin..".lua"; local names = { - "mod_"..plugin.."/"..plugin.."/"..resource; -- mod_hello/hello/mod_hello.lua - "mod_"..plugin.."/"..resource; -- mod_hello/mod_hello.lua - plugin.."/"..resource; -- hello/mod_hello.lua - resource; -- mod_hello.lua + "mod_"..plugin..dir_sep..plugin..dir_sep..resource; -- mod_hello/hello/mod_hello.lua + "mod_"..plugin..dir_sep..resource; -- mod_hello/mod_hello.lua + plugin..dir_sep..resource; -- hello/mod_hello.lua + resource; -- mod_hello.lua }; return load_file(names); diff --git a/util/sasl.lua b/util/sasl.lua index c8490842..b91e29a6 100644 --- a/util/sasl.lua +++ b/util/sasl.lua @@ -100,14 +100,16 @@ end function method:mechanisms() local current_mechs = {}; for mech, _ in pairs(self.mechs) do - if mechanism_channelbindings[mech] and self.profile.cb then - local ok = false; - for cb_name, _ in pairs(self.profile.cb) do - if mechanism_channelbindings[mech][cb_name] then - ok = true; + if mechanism_channelbindings[mech] then + if self.profile.cb then + local ok = false; + for cb_name, _ in pairs(self.profile.cb) do + if mechanism_channelbindings[mech][cb_name] then + ok = true; + end end + if ok == true then current_mechs[mech] = true; end end - if ok == true then current_mechs[mech] = true; end else current_mechs[mech] = true; end diff --git a/util/timer.lua b/util/timer.lua index 0e10e144..23bd6a37 100644 --- a/util/timer.lua +++ b/util/timer.lua @@ -6,6 +6,8 @@ -- COPYING file in the source package for more information. -- +local indexedbheap = require "util.indexedbheap"; +local log = require "util.logger".init("timer"); local server = require "net.server"; local math_min = math.min local math_huge = math.huge @@ -13,6 +15,9 @@ local get_time = require "socket".gettime; local t_insert = table.insert; local pairs = pairs; local type = type; +local debug_traceback = debug.traceback; +local tostring = tostring; +local xpcall = xpcall; local data = {}; local new_data = {}; @@ -78,6 +83,61 @@ else end end -add_task = _add_task; +--add_task = _add_task; + +local h = indexedbheap.create(); +local params = {}; +local next_time = nil; +local _id, _callback, _now, _param; +local function _call() return _callback(_now, _id, _param); end +local function _traceback_handler(err) log("error", "Traceback[timer]: %s", debug_traceback(tostring(err), 2)); end +local function _on_timer(now) + local peek; + while true do + peek = h:peek(); + if peek == nil or peek > now then break; end + local _; + _, _callback, _id = h:pop(); + _now = now; + _param = params[_id]; + params[_id] = nil; + --item(now, id, _param); -- FIXME pcall + local success, err = xpcall(_call, _traceback_handler); + if success and type(err) == "number" then + h:insert(_callback, err + now, _id); -- re-add + params[_id] = _param; + end + end + next_time = peek; + if peek ~= nil then + return peek - now; + end +end +function add_task(delay, callback, param) + local current_time = get_time(); + local event_time = current_time + delay; + + local id = h:insert(callback, event_time); + params[id] = param; + if next_time == nil or event_time < next_time then + next_time = event_time; + _add_task(next_time - current_time, _on_timer); + end + return id; +end +function stop(id) + params[id] = nil; + return h:remove(id); +end +function reschedule(id, delay) + local current_time = get_time(); + local event_time = current_time + delay; + h:reprioritize(id, delay); + if next_time == nil or event_time < next_time then + next_time = event_time; + _add_task(next_time - current_time, _on_timer); + end + return id; +end return _M; diff --git a/util/xmppstream.lua b/util/xmppstream.lua index 550170c9..586ad5f9 100644 --- a/util/xmppstream.lua +++ b/util/xmppstream.lua @@ -6,7 +6,6 @@ -- COPYING file in the source package for more information. -- - local lxp = require "lxp"; local st = require "util.stanza"; local stanza_mt = st.stanza_mt; @@ -20,6 +19,10 @@ local setmetatable = setmetatable; -- COMPAT: w/LuaExpat 1.1.0 local lxp_supports_doctype = pcall(lxp.new, { StartDoctypeDecl = false }); +local lxp_supports_xmldecl = pcall(lxp.new, { XmlDecl = false }); +local lxp_supports_bytecount = not not lxp.new({}).getcurrentbytecount; + +local default_stanza_size_limit = 1024*1024*10; -- 10MB module "xmppstream" @@ -40,13 +43,16 @@ local ns_pattern = "^([^"..ns_separator.."]*)"..ns_separator.."?(.*)$"; _M.ns_separator = ns_separator; _M.ns_pattern = ns_pattern; -function new_sax_handlers(session, stream_callbacks) +local function dummy_cb() end + +function new_sax_handlers(session, stream_callbacks, cb_handleprogress) local xml_handlers = {}; local cb_streamopened = stream_callbacks.streamopened; local cb_streamclosed = stream_callbacks.streamclosed; local cb_error = stream_callbacks.error or function(session, e, stanza) error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); end; local cb_handlestanza = stream_callbacks.handlestanza; + cb_handleprogress = cb_handleprogress or dummy_cb; local stream_ns = stream_callbacks.stream_ns or xmlns_streams; local stream_tag = stream_callbacks.stream_tag or "stream"; @@ -59,6 +65,7 @@ function new_sax_handlers(session, stream_callbacks) local stack = {}; local chardata, stanza = {}; + local stanza_size = 0; local non_streamns_depth = 0; function xml_handlers:StartElement(tagname, attr) if stanza and #chardata > 0 then @@ -87,10 +94,17 @@ function new_sax_handlers(session, stream_callbacks) end if not stanza then --if we are not currently inside a stanza + if lxp_supports_bytecount then + stanza_size = self:getcurrentbytecount(); + end if session.notopen then if tagname == stream_tag then non_streamns_depth = 0; if cb_streamopened then + if lxp_supports_bytecount then + cb_handleprogress(stanza_size); + stanza_size = 0; + end cb_streamopened(session, attr); end else @@ -105,6 +119,9 @@ function new_sax_handlers(session, stream_callbacks) stanza = setmetatable({ name = name, attr = attr, tags = {} }, stanza_mt); else -- we are inside a stanza, so add a tag + if lxp_supports_bytecount then + stanza_size = stanza_size + self:getcurrentbytecount(); + end t_insert(stack, stanza); local oldstanza = stanza; stanza = setmetatable({ name = name, attr = attr, tags = {} }, stanza_mt); @@ -112,12 +129,45 @@ function new_sax_handlers(session, stream_callbacks) t_insert(oldstanza.tags, stanza); end end + if lxp_supports_xmldecl then + function xml_handlers:XmlDecl(version, encoding, standalone) + if lxp_supports_bytecount then + cb_handleprogress(self:getcurrentbytecount()); + end + end + end + function xml_handlers:StartCdataSection() + if lxp_supports_bytecount then + if stanza then + stanza_size = stanza_size + self:getcurrentbytecount(); + else + cb_handleprogress(self:getcurrentbytecount()); + end + end + end + function xml_handlers:EndCdataSection() + if lxp_supports_bytecount then + if stanza then + stanza_size = stanza_size + self:getcurrentbytecount(); + else + cb_handleprogress(self:getcurrentbytecount()); + end + end + end function xml_handlers:CharacterData(data) if stanza then + if lxp_supports_bytecount then + stanza_size = stanza_size + self:getcurrentbytecount(); + end t_insert(chardata, data); + elseif lxp_supports_bytecount then + cb_handleprogress(self:getcurrentbytecount()); end end function xml_handlers:EndElement(tagname) + if lxp_supports_bytecount then + stanza_size = stanza_size + self:getcurrentbytecount() + end if non_streamns_depth > 0 then non_streamns_depth = non_streamns_depth - 1; end @@ -129,6 +179,10 @@ function new_sax_handlers(session, stream_callbacks) end -- Complete stanza if #stack == 0 then + if lxp_supports_bytecount then + cb_handleprogress(stanza_size); + end + stanza_size = 0; if tagname ~= stream_error_tag then cb_handlestanza(session, stanza); else @@ -159,7 +213,7 @@ function new_sax_handlers(session, stream_callbacks) xml_handlers.ProcessingInstruction = restricted_handler; local function reset() - stanza, chardata = nil, {}; + stanza, chardata, stanza_size = nil, {}, 0; stack = {}; end @@ -170,19 +224,39 @@ function new_sax_handlers(session, stream_callbacks) return xml_handlers, { reset = reset, set_session = set_session }; end -function new(session, stream_callbacks) - local handlers, meta = new_sax_handlers(session, stream_callbacks); - local parser = new_parser(handlers, ns_separator); +function new(session, stream_callbacks, stanza_size_limit) + -- Used to track parser progress (e.g. to enforce size limits) + local n_outstanding_bytes = 0; + local handle_progress; + if lxp_supports_bytecount then + function handle_progress(n_parsed_bytes) + n_outstanding_bytes = n_outstanding_bytes - n_parsed_bytes; + end + stanza_size_limit = stanza_size_limit or default_stanza_size_limit; + elseif stanza_size_limit then + error("Stanza size limits are not supported on this version of LuaExpat") + end + + local handlers, meta = new_sax_handlers(session, stream_callbacks, handle_progress); + local parser = new_parser(handlers, ns_separator, false); local parse = parser.parse; return { reset = function () - parser = new_parser(handlers, ns_separator); + parser = new_parser(handlers, ns_separator, false); parse = parser.parse; + n_outstanding_bytes = 0; meta.reset(); end, feed = function (self, data) - return parse(parser, data); + if lxp_supports_bytecount then + n_outstanding_bytes = n_outstanding_bytes + #data; + end + local ok, err = parse(parser, data); + if lxp_supports_bytecount and n_outstanding_bytes > stanza_size_limit then + return nil, "stanza-too-large"; + end + return ok, err; end, set_session = meta.set_session; }; |