diff options
37 files changed, 2287 insertions, 262 deletions
diff --git a/core/configmanager.lua b/core/configmanager.lua new file mode 100644 index 00000000..5f5648b9 --- /dev/null +++ b/core/configmanager.lua @@ -0,0 +1,121 @@ + +local _G = _G; +local setmetatable, loadfile, pcall, rawget, rawset, io = + setmetatable, loadfile, pcall, rawget, rawset, io; +module "configmanager" + +local parsers = {}; + +local config = { ["*"] = { core = {} } }; + +local global_config = config["*"]; + +-- When host not found, use global +setmetatable(config, { __index = function () return global_config; end}); +local host_mt = { __index = global_config }; + +-- When key not found in section, check key in global's section +function section_mt(section_name) + return { __index = function (t, k) + local section = rawget(global_config, section_name); + if not section then return nil; end + return section[k]; + end }; +end + +function getconfig() + return config; +end + +function get(host, section, key) + local sec = config[host][section]; + if sec then + return sec[key]; + end + return nil; +end + +function set(host, section, key, value) + if host and section and key then + local hostconfig = rawget(config, host); + if not hostconfig then + hostconfig = rawset(config, host, setmetatable({}, host_mt))[host]; + end + if not rawget(hostconfig, section) then + hostconfig[section] = setmetatable({}, section_mt(section)); + end + hostconfig[section][key] = value; + return true; + end + return false; +end + +function load(filename, format) + format = format or filename:match("%w+$"); + if parsers[format] and parsers[format].load then + local f = io.open(filename); + if f then + local ok, err = parsers[format].load(f:read("*a")); + f:close(); + return ok, err; + end + end + if not format then + return nil, "no parser specified"; + else + return false, "no parser"; + end +end + +function save(filename, format) +end + +function addparser(format, parser) + if format and parser then + parsers[format] = parser; + end +end + +-- Built-in Lua parser +do + local loadstring, pcall, setmetatable = _G.loadstring, _G.pcall, _G.setmetatable; + local setfenv, rawget, tostring = _G.setfenv, _G.rawget, _G.tostring; + parsers.lua = {}; + function parsers.lua.load(data) + local env; + env = setmetatable({ Host = true; host = true; }, { __index = function (t, k) + return rawget(_G, k) or + function (settings_table) + config[__currenthost or "*"][k] = settings_table; + end; + end, + __newindex = function (t, k, v) + set(env.__currenthost or "*", "core", k, v); + end}); + + function env.Host(name) + rawset(env, "__currenthost", name); + set(name or "*", "core", "defined", true); + end + env.host = env.Host; + + local chunk, err = loadstring(data); + + if not chunk then + return nil, err; + end + + setfenv(chunk, env); + + local ok, err = pcall(chunk); + + if not ok then + return nil, err; + end + + return true; + end + +end + +return _M;
\ No newline at end of file diff --git a/core/discomanager.lua b/core/discomanager.lua new file mode 100644 index 00000000..5f7b3c78 --- /dev/null +++ b/core/discomanager.lua @@ -0,0 +1,39 @@ +
+local helper = require "util.discohelper".new();
+local hosts = hosts;
+local jid_split = require "util.jid".split;
+local jid_bare = require "util.jid".bare;
+local usermanager_user_exists = require "core.usermanager".user_exists;
+local rostermanager_is_contact_subscribed = require "core.rostermanager".is_contact_subscribed;
+
+do
+ helper:addDiscoInfoHandler("*host", function(reply, to, from, node)
+ if hosts[to] then
+ reply:tag("identity", {category="server", type="im", name="lxmppd"}):up();
+ return true;
+ end
+ end);
+ helper:addDiscoInfoHandler("*node", function(reply, to, from, node)
+ local node, host = jid_split(to);
+ if hosts[host] and rostermanager_is_contact_subscribed(node, host, jid_bare(from)) then
+ reply:tag("identity", {category="account", type="registered"}):up();
+ return true;
+ end
+ end);
+end
+
+module "discomanager"
+
+function handle(stanza)
+ return helper:handle(stanza);
+end
+
+function addDiscoItemsHandler(jid, func)
+ return helper:addDiscoItemsHandler(jid, func);
+end
+
+function addDiscoInfoHandler(jid, func)
+ return helper:addDiscoInfoHandler(jid, func);
+end
+
+return _M;
diff --git a/core/modulemanager.lua b/core/modulemanager.lua index 783fea55..d313130c 100644 --- a/core/modulemanager.lua +++ b/core/modulemanager.lua @@ -78,7 +78,7 @@ function load(name) local success, ret = pcall(mod); if not success then log("error", "Error initialising module '%s': %s", name or "nil", ret or "nil"); - return nil, err; + return nil, ret; end return true; end @@ -92,8 +92,8 @@ function handle_stanza(origin, stanza) if child then local xmlns = child.attr.xmlns or xmlns; log("debug", "Stanza of type %s from %s has xmlns: %s", name, origin_type, xmlns); - local handler = handlers[origin_type][name][xmlns]; - if handler then + local handler = handlers[origin_type][name] and handlers[origin_type][name][xmlns]; + if handler then log("debug", "Passing stanza to mod_%s", handler_info[handler].name); return handler(origin, stanza) or true; end diff --git a/core/presencemanager.lua b/core/presencemanager.lua new file mode 100644 index 00000000..c6619fea --- /dev/null +++ b/core/presencemanager.lua @@ -0,0 +1,121 @@ +
+local log = require "util.logger".init("presencemanager")
+
+local require = require;
+local pairs = pairs;
+
+local st = require "util.stanza";
+local jid_split = require "util.jid".split;
+local hosts = hosts;
+
+local rostermanager = require "core.rostermanager";
+local sessionmanager = require "core.sessionmanager";
+
+module "presencemanager"
+
+function send_presence_of_available_resources(user, host, jid, recipient_session, core_route_stanza)
+ local h = hosts[host];
+ local count = 0;
+ if h and h.type == "local" then
+ local u = h.sessions[user];
+ if u then
+ for k, session in pairs(u.sessions) do
+ local pres = session.presence;
+ if pres then
+ pres.attr.to = jid;
+ pres.attr.from = session.full_jid;
+ core_route_stanza(session, pres);
+ pres.attr.to = nil;
+ pres.attr.from = nil;
+ count = count + 1;
+ end
+ end
+ end
+ end
+ log("info", "broadcasted presence of "..count.." resources from "..user.."@"..host.." to "..jid);
+ return count;
+end
+
+function handle_outbound_presence_subscriptions_and_probes(origin, stanza, from_bare, to_bare, core_route_stanza)
+ local node, host = jid_split(from_bare);
+ local st_from, st_to = stanza.attr.from, stanza.attr.to;
+ stanza.attr.from, stanza.attr.to = from_bare, to_bare;
+ log("debug", "outbound presence "..stanza.attr.type.." from "..from_bare.." for "..to_bare);
+ if stanza.attr.type == "subscribe" then
+ -- 1. route stanza
+ -- 2. roster push (subscription = none, ask = subscribe)
+ if rostermanager.set_contact_pending_out(node, host, to_bare) then
+ rostermanager.roster_push(node, host, to_bare);
+ end -- else file error
+ core_route_stanza(origin, stanza);
+ elseif stanza.attr.type == "unsubscribe" then
+ -- 1. route stanza
+ -- 2. roster push (subscription = none or from)
+ if rostermanager.unsubscribe(node, host, to_bare) then
+ rostermanager.roster_push(node, host, to_bare); -- FIXME do roster push when roster has in fact not changed?
+ end -- else file error
+ core_route_stanza(origin, stanza);
+ elseif stanza.attr.type == "subscribed" then
+ -- 1. route stanza
+ -- 2. roster_push ()
+ -- 3. send_presence_of_available_resources
+ if rostermanager.subscribed(node, host, to_bare) then
+ rostermanager.roster_push(node, host, to_bare);
+ end
+ core_route_stanza(origin, stanza);
+ send_presence_of_available_resources(node, host, to_bare, origin, core_route_stanza);
+ elseif stanza.attr.type == "unsubscribed" then
+ -- 1. route stanza
+ -- 2. roster push (subscription = none or to)
+ if rostermanager.unsubscribed(node, host, to_bare) then
+ rostermanager.roster_push(node, host, to_bare);
+ end
+ core_route_stanza(origin, stanza);
+ end
+ stanza.attr.from, stanza.attr.to = st_from, st_to;
+end
+
+function handle_inbound_presence_subscriptions_and_probes(origin, stanza, from_bare, to_bare, core_route_stanza)
+ local node, host = jid_split(to_bare);
+ local st_from, st_to = stanza.attr.from, stanza.attr.to;
+ stanza.attr.from, stanza.attr.to = from_bare, to_bare;
+ log("debug", "inbound presence "..stanza.attr.type.." from "..from_bare.." for "..to_bare);
+ if stanza.attr.type == "probe" then
+ if rostermanager.is_contact_subscribed(node, host, from_bare) then
+ if 0 == send_presence_of_available_resources(node, host, from_bare, origin, core_route_stanza) then
+ -- TODO send last recieved unavailable presence (or we MAY do nothing, which is fine too)
+ end
+ else
+ core_route_stanza(origin, st.presence({from=to_bare, to=from_bare, type="unsubscribed"}));
+ end
+ elseif stanza.attr.type == "subscribe" then
+ if rostermanager.is_contact_subscribed(node, host, from_bare) then
+ core_route_stanza(origin, st.presence({from=to_bare, to=from_bare, type="subscribed"})); -- already subscribed
+ -- Sending presence is not clearly stated in the RFC, but it seems appropriate
+ if 0 == send_presence_of_available_resources(node, host, from_bare, origin, core_route_stanza) then
+ -- TODO send last recieved unavailable presence (or we MAY do nothing, which is fine too)
+ end
+ else
+ if not rostermanager.is_contact_pending_in(node, host, from_bare) then
+ if rostermanager.set_contact_pending_in(node, host, from_bare) then
+ sessionmanager.send_to_available_resources(node, host, stanza);
+ end -- TODO else return error, unable to save
+ end
+ end
+ elseif stanza.attr.type == "unsubscribe" then
+ if rostermanager.process_inbound_unsubscribe(node, host, from_bare) then
+ rostermanager.roster_push(node, host, from_bare);
+ end
+ elseif stanza.attr.type == "subscribed" then
+ if rostermanager.process_inbound_subscription_approval(node, host, from_bare) then
+ rostermanager.roster_push(node, host, from_bare);
+ end
+ elseif stanza.attr.type == "unsubscribed" then
+ if rostermanager.process_inbound_subscription_cancellation(node, host, from_bare) then
+ rostermanager.roster_push(node, host, from_bare);
+ end
+ end -- discard any other type
+ stanza.attr.from, stanza.attr.to = st_from, st_to;
+end
+
+return _M;
diff --git a/core/s2smanager.lua b/core/s2smanager.lua index c3d9bdb4..6d8f3a00 100644 --- a/core/s2smanager.lua +++ b/core/s2smanager.lua @@ -3,7 +3,7 @@ local hosts = hosts; local sessions = sessions; local socket = require "socket"; local format = string.format; -local t_insert = table.insert; +local t_insert, t_sort = table.insert, table.sort; local get_traceback = debug.traceback; local tostring, pairs, ipairs, getmetatable, print, newproxy, error, tonumber = tostring, pairs, ipairs, getmetatable, print, newproxy, error, tonumber; @@ -24,16 +24,19 @@ local md5_hash = require "util.hashes".md5; local dialback_secret = "This is very secret!!! Ha!"; -local srvmap = { ["gmail.com"] = "talk.google.com", ["identi.ca"] = "hampton.controlezvous.ca", ["cdr.se"] = "jabber.cdr.se" }; +local dns = require "net.dns"; module "s2smanager" +local function compare_srv_priorities(a,b) return a.priority < b.priority or a.weight < b.weight; end + function send_to_host(from_host, to_host, data) + if data.name then data = tostring(data); end local host = hosts[from_host].s2sout[to_host]; if host then -- We have a connection to this host already - if host.type == "s2sout_unauthed" then - host.log("debug", "trying to send over unauthed s2sout to "..to_host..", authing it now..."); + if host.type == "s2sout_unauthed" and ((not data.xmlns) or data.xmlns == "jabber:client" or data.xmlns == "jabber:server") then + (host.log or log)("debug", "trying to send over unauthed s2sout to "..to_host..", authing it now..."); if not host.notopen and not host.dialback_key then host.log("debug", "dialback had not been initiated"); initiate_dialback(host); @@ -51,7 +54,7 @@ function send_to_host(from_host, to_host, data) -- FIXME if host.from_host ~= from_host then log("error", "WARNING! This might, possibly, be a bug, but it might not..."); - log("error", "We are going to send from %s instead of %s", host.from_host, from_host); + log("error", "We are going to send from %s instead of %s", tostring(host.from_host), tostring(from_host)); end host.sends2s(data); host.log("debug", "stanza sent over "..host.type); @@ -73,8 +76,8 @@ function new_incoming(conn) getmetatable(session.trace).__gc = function () open_sessions = open_sessions - 1; print("s2s session got collected, now "..open_sessions.." s2s sessions are allocated") end; end open_sessions = open_sessions + 1; - local w = conn.write; - session.sends2s = function (t) w(tostring(t)); end + local w, log = conn.write, logger_init("s2sin"..tostring(conn):match("[a-f0-9]+$")); + session.sends2s = function (t) log("debug", "sending: %s", tostring(t)); w(tostring(t)); end return session; end @@ -84,30 +87,49 @@ function new_outgoing(from_host, to_host) local cl = connlisteners_get("xmppserver"); local conn, handler = socket.tcp() + + local connect_host, connect_port = to_host, 5269; - --FIXME: Below parameters (ports/ip) are incorrect (use SRV) - to_host = srvmap[to_host] or to_host; + local answer = dns.lookup("_xmpp-server._tcp."..to_host..".", "SRV"); + + if answer then + log("debug", to_host.." has SRV records, handling..."); + local srv_hosts = {}; + host_session.srv_hosts = srv_hosts; + for _, record in ipairs(answer) do + t_insert(srv_hosts, record.srv); + end + t_sort(srv_hosts, compare_srv_priorities); + + local srv_choice = srv_hosts[1]; + if srv_choice then + connect_host, connect_port = srv_choice.target or to_host, srv_choice.port or connect_port; + log("debug", "Best record found, will connect to %s:%d", connect_host, connect_port); + end + end conn:settimeout(0); - local success, err = conn:connect(to_host, 5269); - if not success then + local success, err = conn:connect(connect_host, connect_port); + if not success and err ~= "timeout" then log("warn", "s2s connect() failed: %s", err); end - conn = wraptlsclient(cl, conn, to_host, 5269, 0, 1, hosts[from_host].ssl_ctx ); + conn = wraptlsclient(cl, conn, connect_host, connect_port, 0, 1, hosts[from_host].ssl_ctx ); host_session.conn = conn; -- Register this outgoing connection so that xmppserver_listener knows about it -- otherwise it will assume it is a new incoming connection cl.register_outgoing(conn, host_session); + local log; do local conn_name = "s2sout"..tostring(conn):match("[a-f0-9]*$"); - host_session.log = logger_init(conn_name); + log = logger_init(conn_name); + host_session.log = log; end local w = conn.write; - host_session.sends2s = function (t) w(tostring(t)); end + host_session.sends2s = function (t) log("debug", "sending: %s", tostring(t)); w(tostring(t)); end conn.write(format([[<stream:stream xmlns='jabber:server' xmlns:db='jabber:server:dialback' xmlns:stream='http://etherx.jabber.org/streams' from='%s' to='%s' version='1.0'>]], from_host, to_host)); @@ -119,23 +141,31 @@ function streamopened(session, attr) session.version = tonumber(attr.version) or 0; if session.version >= 1.0 and not (attr.to and attr.from) then - print("to: "..tostring(attr.to).." from: "..tostring(attr.from)); - --error(session.to_host.." failed to specify 'to' or 'from' hostname as per RFC"); + --print("to: "..tostring(attr.to).." from: "..tostring(attr.from)); log("warn", (session.to_host or "(unknown)").." failed to specify 'to' or 'from' hostname as per RFC"); end if session.direction == "incoming" then -- Send a reply stream header - for k,v in pairs(attr) do print("", tostring(k), ":::", tostring(v)); end + --for k,v in pairs(attr) do print("", tostring(k), ":::", tostring(v)); end session.to_host = attr.to; session.from_host = attr.from; session.streamid = uuid_gen(); - print(session, session.from_host, "incoming s2s stream opened"); + (session.log or log)("debug", "incoming s2s received <stream:stream>"); send("<?xml version='1.0'?>"); send(stanza("stream:stream", { xmlns='jabber:server', ["xmlns:db"]='jabber:server:dialback', ["xmlns:stream"]='http://etherx.jabber.org/streams', id=session.streamid, from=session.to_host }):top_tag()); + if session.to_host and not hosts[session.to_host] then + -- Attempting to connect to a host we don't serve + session:close("host-unknown"); + return; + end + if session.version >= 1.0 then + send(st.stanza("stream:features") + :tag("dialback", { xmlns='urn:xmpp:features:dialback' }):tag("optional"):up():up()); + end elseif session.direction == "outgoing" then -- If we are just using the connection for verifying dialback keys, we won't try and auth it if not attr.id then error("stream response did not give us a streamid!!!"); end @@ -147,17 +177,6 @@ function streamopened(session, attr) mark_connected(session); end end - --[[ - local features = {}; - modulemanager.fire_event("stream-features-s2s", session, features); - - send("<stream:features>"); - - for _, feature in ipairs(features) do - send(tostring(feature)); - end - - send("</stream:features>");]] session.notopen = nil; end @@ -217,11 +236,13 @@ end function destroy_session(session) (session.log or log)("info", "Destroying "..tostring(session.direction).." session "..tostring(session.from_host).."->"..tostring(session.to_host)); + + -- FIXME: Flush sendq here/report errors to originators + if session.direction == "outgoing" then hosts[session.from_host].s2sout[session.to_host] = nil; end - session.conn = nil; - session.disconnect = nil; + for k in pairs(session) do if k ~= "trace" then session[k] = nil; diff --git a/core/servermanager.lua b/core/servermanager.lua index 99eb4c23..8cbf2f12 100644 --- a/core/servermanager.lua +++ b/core/servermanager.lua @@ -2,7 +2,7 @@ local st = require "util.stanza"; local xmlns_stanzas ='urn:ietf:params:xml:ns:xmpp-stanzas'; -require "modulemanager" +local modulemanager = require "core.modulemanager"; -- Handle stanzas that were addressed to the server (whether they came from c2s, s2s, etc.) function handle_stanza(origin, stanza) diff --git a/core/sessionmanager.lua b/core/sessionmanager.lua index 0d65f6d6..e83b7c23 100644 --- a/core/sessionmanager.lua +++ b/core/sessionmanager.lua @@ -14,6 +14,8 @@ local error = error; local uuid_generate = require "util.uuid".generate; local rm_load_roster = require "core.rostermanager".load_roster; +local st = require "util.stanza"; + local newproxy = newproxy; local getmetatable = getmetatable; @@ -28,13 +30,24 @@ function new_session(conn) getmetatable(session.trace).__gc = function () open_sessions = open_sessions - 1; print("Session got collected, now "..open_sessions.." sessions are allocated") end; end open_sessions = open_sessions + 1; + log("info", "open sessions now: ".. open_sessions); local w = conn.write; session.send = function (t) w(tostring(t)); end return session; end -function destroy_session(session) +function destroy_session(session, err) (session.log or log)("info", "Destroying session"); + + -- Send unavailable presence + if session.presence then + local pres = st.presence{ type = "unavailable" }; + if (not err) or err == "closed" then err = "connection closed"; end + pres:tag("status"):text("Disconnected: "..err); + session.stanza_dispatch(pres); + end + + -- Remove session/resource from user's session list if session.host and session.username then if session.resource then hosts[session.host].sessions[session.username].sessions[session.resource] = nil; @@ -46,8 +59,7 @@ function destroy_session(session) end end end - session.conn = nil; - session.disconnect = nil; + for k in pairs(session) do if k ~= "trace" then session[k] = nil; @@ -96,21 +108,25 @@ function streamopened(session, attr) session.host = attr.to or error("Client failed to specify destination hostname"); session.version = tonumber(attr.version) or 0; session.streamid = m_random(1000000, 99999999); - print(session, session.host, "Client opened stream"); - send("<?xml version='1.0'?>"); + (session.log or session)("debug", "Client sent opening <stream:stream> to %s", session.host); + + + send("<?xml version='1.0'?>"); send(format("<stream:stream xmlns='jabber:client' xmlns:stream='http://etherx.jabber.org/streams' id='%s' from='%s' version='1.0'>", session.streamid, session.host)); - local features = {}; + if not hosts[session.host] then + -- We don't serve this host... + session:close{ condition = "host-unknown", text = "This server does not serve "..tostring(session.host)}; + return; + end + + + local features = st.stanza("stream:features"); modulemanager.fire_event("stream-features", session, features); - send("<stream:features>"); + send(features); - for _, feature in ipairs(features) do - send(tostring(feature)); - end - - send("</stream:features>"); - log("info", "Stream opened successfully"); + (session.log or log)("info", "Sent reply <stream:stream> to client"); session.notopen = nil; end diff --git a/core/stanza_router.lua b/core/stanza_router.lua index 2b0e1f4b..2505fca3 100644 --- a/core/stanza_router.lua +++ b/core/stanza_router.lua @@ -21,6 +21,9 @@ local s2s_make_authenticated = require "core.s2smanager".make_authenticated; local modules_handle_stanza = require "core.modulemanager".handle_stanza; local component_handle_stanza = require "core.componentmanager".handle_stanza; +local handle_outbound_presence_subscriptions_and_probes = require "core.presencemanager".handle_outbound_presence_subscriptions_and_probes; +local handle_inbound_presence_subscriptions_and_probes = require "core.presencemanager".handle_inbound_presence_subscriptions_and_probes; + local format = string.format; local tostring = tostring; local t_concat = table.concat; @@ -32,7 +35,7 @@ local jid_split = require "util.jid".split; local print = print; function core_process_stanza(origin, stanza) - log("debug", "Received[%s]: %s", origin.type, stanza:pretty_top_tag()) + (origin.log or log)("debug", "Received[%s]: %s", origin.type, stanza:pretty_print()) --top_tag()) if not stanza.attr.xmlns then stanza.attr.xmlns = "jabber:client"; end -- FIXME Hack. This should be removed when we fix namespace handling. -- TODO verify validity of stanza (as well as JID validity) @@ -87,7 +90,7 @@ function core_process_stanza(origin, stanza) elseif hosts[host] and hosts[host].type == "component" then -- directed at a component component_handle_stanza(origin, stanza); elseif origin.type == "c2s" and stanza.name == "presence" and stanza.attr.type ~= nil and stanza.attr.type ~= "unavailable" then - handle_outbound_presence_subscriptions_and_probes(origin, stanza, from_bare, to_bare); + handle_outbound_presence_subscriptions_and_probes(origin, stanza, from_bare, to_bare, core_route_stanza); elseif origin.type ~= "c2s" and stanza.name == "iq" and not resource then -- directed at bare JID core_handle_stanza(origin, stanza); else @@ -174,130 +177,23 @@ function core_handle_stanza(origin, stanza) stanza.attr.to = nil; -- reset it else log("warn", "Unhandled c2s presence: %s", tostring(stanza)); - if stanza.attr.type ~= "error" then + if (stanza.attr.xmlns == "jabber:client" or stanza.attr.xmlns == "jabber:server") and stanza.attr.type ~= "error" then origin.send(st.error_reply(stanza, "cancel", "service-unavailable")); -- FIXME correct error? end end else log("warn", "Unhandled c2s stanza: %s", tostring(stanza)); - if stanza.attr.type ~= "error" and stanza.attr.type ~= "result" then + if (stanza.attr.xmlns == "jabber:client" or stanza.attr.xmlns == "jabber:server") and stanza.attr.type ~= "error" and stanza.attr.type ~= "result" then origin.send(st.error_reply(stanza, "cancel", "service-unavailable")); -- FIXME correct error? end end -- TODO handle other stanzas else log("warn", "Unhandled origin: %s", origin.type); - if stanza.attr.type ~= "error" and stanza.attr.type ~= "result" then + if (stanza.attr.xmlns == "jabber:client" or stanza.attr.xmlns == "jabber:server") and stanza.attr.type ~= "error" and stanza.attr.type ~= "result" then -- s2s stanzas can get here - (origin.sends2s or origin.send)(st.error_reply(stanza, "cancel", "service-unavailable")); -- FIXME correct error? - end - end -end - -function send_presence_of_available_resources(user, host, jid, recipient_session) - local h = hosts[host]; - local count = 0; - if h and h.type == "local" then - local u = h.sessions[user]; - if u then - for k, session in pairs(u.sessions) do - local pres = session.presence; - if pres then - pres.attr.to = jid; - pres.attr.from = session.full_jid; - recipient_session.send(pres); - pres.attr.to = nil; - pres.attr.from = nil; - count = count + 1; - end - end + origin.send(st.error_reply(stanza, "cancel", "service-unavailable")); -- FIXME correct error? end end - return count; -end - -function handle_outbound_presence_subscriptions_and_probes(origin, stanza, from_bare, to_bare) - local node, host = jid_split(from_bare); - local st_from, st_to = stanza.attr.from, stanza.attr.to; - stanza.attr.from, stanza.attr.to = from_bare, to_bare; - if stanza.attr.type == "subscribe" then - log("debug", "outbound subscribe from "..from_bare.." for "..to_bare); - -- 1. route stanza - -- 2. roster push (subscription = none, ask = subscribe) - if rostermanager.set_contact_pending_out(node, host, to_bare) then - rostermanager.roster_push(node, host, to_bare); - end -- else file error - core_route_stanza(origin, stanza); - elseif stanza.attr.type == "unsubscribe" then - log("debug", "outbound unsubscribe from "..from_bare.." for "..to_bare); - -- 1. route stanza - -- 2. roster push (subscription = none or from) - if rostermanager.unsubscribe(node, host, to_bare) then - rostermanager.roster_push(node, host, to_bare); -- FIXME do roster push when roster has in fact not changed? - end -- else file error - core_route_stanza(origin, stanza); - elseif stanza.attr.type == "subscribed" then - log("debug", "outbound subscribed from "..from_bare.." for "..to_bare); - -- 1. route stanza - -- 2. roster_push () - -- 3. send_presence_of_available_resources - if rostermanager.subscribed(node, host, to_bare) then - rostermanager.roster_push(node, host, to_bare); - core_route_stanza(origin, stanza); - send_presence_of_available_resources(node, host, to_bare, origin); - end - elseif stanza.attr.type == "unsubscribed" then - log("debug", "outbound unsubscribed from "..from_bare.." for "..to_bare); - -- 1. route stanza - -- 2. roster push (subscription = none or to) - if rostermanager.unsubscribed(node, host, to_bare) then - rostermanager.roster_push(node, host, to_bare); - core_route_stanza(origin, stanza); - end - end - stanza.attr.from, stanza.attr.to = st_from, st_to; -end - -function handle_inbound_presence_subscriptions_and_probes(origin, stanza, from_bare, to_bare) - local node, host = jid_split(to_bare); - local st_from, st_to = stanza.attr.from, stanza.attr.to; - stanza.attr.from, stanza.attr.to = from_bare, to_bare; - if stanza.attr.type == "probe" then - log("debug", "inbound probe from "..from_bare.." for "..to_bare); - if rostermanager.is_contact_subscribed(node, host, from_bare) then - if 0 == send_presence_of_available_resources(node, host, from_bare, origin) then - -- TODO send last recieved unavailable presence (or we MAY do nothing, which is fine too) - end - else - core_route_stanza(origin, st.presence({from=to_bare, to=from_bare, type="unsubscribed"})); - end - elseif stanza.attr.type == "subscribe" then - log("debug", "inbound subscribe from "..from_bare.." for "..to_bare); - if rostermanager.is_contact_subscribed(node, host, from_bare) then - core_route_stanza(origin, st.presence({from=to_bare, to=from_bare, type="subscribed"})); -- already subscribed - else - if not rostermanager.is_contact_pending_in(node, host, from_bare) then - if rostermanager.set_contact_pending_in(node, host, from_bare) then - sessionmanager.send_to_available_resources(node, host, stanza); - end -- TODO else return error, unable to save - end - end - elseif stanza.attr.type == "unsubscribe" then - log("debug", "inbound unsubscribe from "..from_bare.." for "..to_bare); - if rostermanager.process_inbound_unsubscribe(node, host, from_bare) then - rostermanager.roster_push(node, host, from_bare); - end - elseif stanza.attr.type == "subscribed" then - log("debug", "inbound subscribed from "..from_bare.." for "..to_bare); - if rostermanager.process_inbound_subscription_approval(node, host, from_bare) then - rostermanager.roster_push(node, host, from_bare); - end - elseif stanza.attr.type == "unsubscribed" then - log("debug", "inbound unsubscribed from "..from_bare.." for "..to_bare); - if rostermanager.process_inbound_subscription_approval(node, host, from_bare) then - rostermanager.roster_push(node, host, from_bare); - end - end -- discard any other type - stanza.attr.from, stanza.attr.to = st_from, st_to; end function core_route_stanza(origin, stanza) @@ -312,6 +208,10 @@ function core_route_stanza(origin, stanza) local from_node, from_host, from_resource = jid_split(from); local from_bare = from_node and (from_node.."@"..from_host) or from_host; -- bare JID + -- Auto-detect origin if not specified + origin = origin or hosts[from_host]; + if not origin then return false; end + if stanza.name == "presence" and (stanza.attr.type ~= nil and stanza.attr.type ~= "unavailable") then resource = nil; end local host_session = hosts[host] @@ -324,7 +224,7 @@ function core_route_stanza(origin, stanza) -- if we get here, resource was not specified or was unavailable if stanza.name == "presence" then if stanza.attr.type ~= nil and stanza.attr.type ~= "unavailable" then - handle_inbound_presence_subscriptions_and_probes(origin, stanza, from_bare, to_bare); + handle_inbound_presence_subscriptions_and_probes(origin, stanza, from_bare, to_bare, core_route_stanza); else -- sender is available or unavailable for _, session in pairs(user.sessions) do -- presence broadcast to all user resources. if session.full_jid then -- FIXME should this be just for available resources? Do we need to check subscription? @@ -367,7 +267,7 @@ function core_route_stanza(origin, stanza) if user_exists(node, host) then if stanza.name == "presence" then if stanza.attr.type ~= nil and stanza.attr.type ~= "unavailable" then - handle_inbound_presence_subscriptions_and_probes(origin, stanza, from_bare, to_bare); + handle_inbound_presence_subscriptions_and_probes(origin, stanza, from_bare, to_bare, core_route_stanza); else -- TODO send unavailable presence or unsubscribed end @@ -404,8 +304,6 @@ function core_route_stanza(origin, stanza) elseif origin.type == "component" or origin.type == "local" then -- Route via s2s for components and modules log("debug", "Routing outgoing stanza for %s to %s", origin.host, host); - for k,v in pairs(origin) do print("origin:", tostring(k), tostring(v)); end - print(tostring(host), tostring(from_host)) send_s2s(origin.host, host, stanza); else log("warn", "received stanza from unhandled connection type: %s", origin.type); diff --git a/core/xmlhandlers.lua b/core/xmlhandlers.lua index 3037a848..a97db8e9 100644 --- a/core/xmlhandlers.lua +++ b/core/xmlhandlers.lua @@ -25,7 +25,7 @@ local ns_prefixes = { ["http://www.w3.org/XML/1998/namespace"] = "xml"; } -function init_xmlhandlers(session, streamopened) +function init_xmlhandlers(session, stream_callbacks) local ns_stack = { "" }; local curr_ns = ""; local curr_tag; @@ -36,6 +36,9 @@ function init_xmlhandlers(session, streamopened) local send = session.send; + local cb_streamopened = stream_callbacks.streamopened; + local cb_streamclosed = stream_callbacks.streamclosed; + local stanza function xml_handlers:StartElement(name, attr) if stanza and #chardata > 0 then @@ -66,7 +69,9 @@ function init_xmlhandlers(session, streamopened) if not stanza then --if we are not currently inside a stanza if session.notopen then if name == "stream" then - streamopened(session, attr); + if cb_streamopened then + cb_streamopened(session, attr); + end return; end error("Client failed to open stream successfully"); @@ -75,7 +80,7 @@ function init_xmlhandlers(session, streamopened) error("Client sent invalid top-level stanza"); end - stanza = st.stanza(name, attr); --{ to = attr.to, type = attr.type, id = attr.id, xmlns = curr_ns }); + stanza = st.stanza(name, attr); curr_tag = stanza; else -- we are inside a stanza, so add a tag attr.xmlns = nil; @@ -92,15 +97,17 @@ function init_xmlhandlers(session, streamopened) end function xml_handlers:EndElement(name) curr_ns,name = name:match("^(.+)|([%w%-]+)$"); - if (not stanza) or #stanza.last_add < 0 or (#stanza.last_add > 0 and name ~= stanza.last_add[#stanza.last_add].name) then + if (not stanza) or (#stanza.last_add > 0 and name ~= stanza.last_add[#stanza.last_add].name) then if name == "stream" then log("debug", "Stream closed"); - sm_destroy_session(session); + if cb_streamclosed then + cb_streamclosed(session); + end return; elseif name == "error" then error("Stream error: "..tostring(name)..": "..tostring(stanza)); else - error("XML parse error in client stream"); + error("XML parse error in client stream with element: "..name); end end if stanza and #chardata > 0 then diff --git a/lxmppd.cfg.dist b/lxmppd.cfg.dist deleted file mode 100644 index 59b85b97..00000000 --- a/lxmppd.cfg.dist +++ /dev/null @@ -1,31 +0,0 @@ - ----- lxmppd configuration file ---- - -config = { - hosts = { -- local hosts - "localhost"; - --"snikket.com"; - }; - -- If the following is commented, no SSL will be set up on 5223 - --[[ssl_ctx = { - mode = "server"; - protocol = "sslv23"; - - key = "/home/matthew/ssl_cert/server.key"; - certificate = "/home/matthew/ssl_cert/server.crt"; - capath = "/etc/ssl/certs"; - verify = "none"; - };]] - modules = { -- enabled modules - "saslauth"; - "legacyauth"; - "roster"; - "register"; - "tls"; - "vcard"; - "private"; - "version"; - "dialback"; - "uptime"; - }; -} diff --git a/lxmppd.cfg.lua.dist b/lxmppd.cfg.lua.dist new file mode 100644 index 00000000..d2c6d3ff --- /dev/null +++ b/lxmppd.cfg.lua.dist @@ -0,0 +1,74 @@ +-- lxmppd Example Configuration File +-- +-- If it wasn't already obvious, -- starts a comment, and all +-- text after it is ignored by lxmppd. +-- +-- The config is split into sections, a global section, and one +-- for each defined host that we serve. You can add as many host +-- sections as you like. +-- +-- Lists are written { "like", "this", "one" } +-- Lists can also be of { 1, 2, 3 } numbers, etc. +-- Either commas, or semi-colons; may be used +-- as seperators. +-- +-- A table is a list of values, except each value has a name. An +-- example would be: +-- +-- logging = { type = "html", directory = "/var/logs", rotate = "daily" } +-- +-- Whitespace (that is tabs, spaces, line breaks) is insignificant, so can +-- be placed anywhere +-- that you deem fitting. Youcouldalsoremoveitentirely,butforobviousrea +--sonsIdon'trecommendit. +-- +-- Tip: You can check that the syntax of this file is correct when you have finished +-- by running: luac -p lxmppd.cfg.lua +-- If there are any errors, it will let you know what and where they are, otherwise it +-- will keep quiet. +-- +-- The only thing left to do is rename this file to remove the .dist ending, and fill in the +-- blanks. Good luck, and happy Jabbering! + +-- Global settings go in this section +Host "*" + + -- This is the list of modules lxmppd will load on startup. + -- It looks for plugins/mod_modulename.lua, so make sure that exists too. + modules_enabled = { + "saslauth"; -- Authentication for clients and servers. Recommended if you want to log in. + "legacyauth"; -- Legacy authentication. Only used by some old clients and bots. + "roster"; -- Allow users to have a roster. Recommended ;) + "register"; -- Allow users to register on this server using a client + "tls"; -- Add support for secure TLS on c2s/s2s connections + "vcard"; -- Allow users to set vCards + "private"; -- Private XML storage (for room bookmarks, etc.) + "version"; -- Replies to server version requests + "dialback"; -- s2s dialback support + }; + + -- These are the SSL/TLS-related settings. If you don't want + -- to use SSL/TLS, you may comment or remove this + ssl = { + key = "certs/server.key"; + certificate = "certs/server.crt"; + } + +-- This allows clients to connect to localhost. No harm in it. +Host "localhost" + +-- Section for example.com +-- (replace example.com with your domain name) +Host "example.com" + -- Assign this host a certificate for TLS, otherwise it would use the one + -- set in the global section (if any). + -- Note that old-style SSL on port 5223 only supports one certificate, and will always + -- use the global one. + ssl = { + key = "certs/example.com.key"; + certificate = "certs/example.com.crt"; + } + +Host "example.org" + enabled = false -- This will disable the host, preserving the config, but denying connections + @@ -4,22 +4,42 @@ local server = require "net.server" require "lxp" require "socket" require "ssl" +local config = require "core.configmanager" -function log(type, area, message) - print(type, area, message); -end +log = require "util.logger".init("general"); -dofile "lxmppd.cfg" +do + -- TODO: Check for other formats when we add support for them + -- Use lfs? Make a new conf/ dir? + local ok, err = config.load("lxmppd.cfg.lua"); + if not ok then + log("error", "Couldn't load config file: %s", err); + log("info", "Falling back to old config file format...") + ok, err = pcall(dofile, "lxmppd.cfg"); + if not ok then + log("error", "Old config format loading failed too: %s", err); + else + for _, host in ipairs(_G.config.hosts) do + config.set(host, "core", "defined", true); + end + + config.set("*", "core", "modules_enabled", _G.config.modules); + config.set("*", "core", "ssl", _G.config.ssl_ctx); + end + end +end -- Maps connections to sessions -- sessions = {}; hosts = {}; -if config.hosts and #config.hosts > 0 then - for _, host in pairs(config.hosts) do +local defined_hosts = config.getconfig(); + +for host, host_config in pairs(defined_hosts) do + if host ~= "*" and (host_config.core.enabled == nil or host_config.core.enabled) then hosts[host] = {type = "local", connected = true, sessions = {}, host = host, s2sout = {} }; end -else error("No hosts defined in the configuration file"); end +end -- Load and initialise core modules -- @@ -32,8 +52,10 @@ require "core.usermanager" require "core.sessionmanager" require "core.stanza_router" +--[[ pcall(require, "remdebug.engine"); if remdebug then remdebug.engine.start() end +]] local start = require "net.connlisteners".start; require "util.stanza" @@ -42,11 +64,12 @@ require "util.jid" ------------------------------------------------------------------------ -- Initialise modules -if config.modules and #config.modules > 0 then - for _, module in pairs(config.modules) do +local modules_enabled = config.get("*", "core", "modules_enabled"); +if modules_enabled then + for _, module in pairs(modules_enabled) do modulemanager.load(module); end -else error("No modules enabled in the configuration file"); end +end -- setup error handling setmetatable(_G, { __index = function (t, k) print("WARNING: ATTEMPT TO READ A NIL GLOBAL!!!", k); error("Attempt to read a non-existent global. Naughty boy.", 2); end, __newindex = function (t, k, v) print("ATTEMPT TO SET A GLOBAL!!!!", tostring(k).." = "..tostring(v)); error("Attempt to set a global. Naughty boy.", 2); end }) --]][][[]][]; @@ -54,8 +77,21 @@ setmetatable(_G, { __index = function (t, k) print("WARNING: ATTEMPT TO READ A N local protected_handler = function (conn, data, err) local success, ret = pcall(handler, conn, data, err); if not success then print("ERROR on "..tostring(conn)..": "..ret); conn:close(); end end; local protected_disconnect = function (conn, err) local success, ret = pcall(disconnect, conn, err); if not success then print("ERROR on "..tostring(conn).." disconnect: "..ret); conn:close(); end end; + +local global_ssl_ctx = config.get("*", "core", "ssl"); +if global_ssl_ctx then + local default_ssl_ctx = { mode = "server", protocol = "sslv23", capath = "/etc/ssl/certs", verify = "none"; }; + setmetatable(global_ssl_ctx, { __index = default_ssl_ctx }); +end + -- start listening on sockets -start("xmppclient", { ssl = config.ssl_ctx }) -start("xmppserver", { ssl = config.ssl_ctx }) +start("xmppclient", { ssl = global_ssl_ctx }) +start("xmppserver", { ssl = global_ssl_ctx }) + +if config.get("*", "core", "console_enabled") then + start("console") +end + +modulemanager.fire_event("server-started"); server.loop(); diff --git a/net/connlisteners.lua b/net/connlisteners.lua index 431d8717..2b95331c 100644 --- a/net/connlisteners.lua +++ b/net/connlisteners.lua @@ -38,8 +38,8 @@ function start(name, udata) error("No such connection module: "..name, 0); end return server_add(h, - udata.port or h.default_port or error("Can't start listener "..name.." because no port was specified, and it has no default port", 0), - udata.interface or "*", udata.mode or h.default_mode or 1, udata.ssl ); + (udata and udata.port) or h.default_port or error("Can't start listener "..name.." because no port was specified, and it has no default port", 0), + (udata and udata.interface) or "*", (udata and udata.mode) or h.default_mode or 1, (udata and udata.ssl) or nil ); end return _M;
\ No newline at end of file diff --git a/net/dns.lua b/net/dns.lua new file mode 100644 index 00000000..7364161e --- /dev/null +++ b/net/dns.lua @@ -0,0 +1,802 @@ + + +-- public domain 20080404 lua@ztact.com + + +-- todo: quick (default) header generation +-- todo: nxdomain, error handling +-- todo: cache results of encodeName + + +-- reference: http://tools.ietf.org/html/rfc1035 +-- reference: http://tools.ietf.org/html/rfc1876 (LOC) + + +require 'socket' +local ztact = require 'util.ztact' + + +local coroutine, io, math, socket, string, table = + coroutine, io, math, socket, string, table + +local ipairs, next, pairs, print, setmetatable, tostring, assert, error = + ipairs, next, pairs, print, setmetatable, tostring, assert, error + +local get, set = ztact.get, ztact.set + + +-------------------------------------------------- module dns +module ('dns') +local dns = _M; + + +-- dns type & class codes ------------------------------ dns type & class codes + + +local append = table.insert + + +local function highbyte (i) -- - - - - - - - - - - - - - - - - - - highbyte + return (i-(i%0x100))/0x100 + end + + +local function augment (t) -- - - - - - - - - - - - - - - - - - - - augment + local a = {} + for i,s in pairs (t) do a[i] = s a[s] = s a[string.lower (s)] = s end + return a + end + + +local function encode (t) -- - - - - - - - - - - - - - - - - - - - - encode + local code = {} + for i,s in pairs (t) do + local word = string.char (highbyte (i), i %0x100) + code[i] = word + code[s] = word + code[string.lower (s)] = word + end + return code + end + + +dns.types = { + 'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS', + 'PTR', 'HINFO', 'MINFO', 'MX', 'TXT', + [ 28] = 'AAAA', [ 29] = 'LOC', [ 33] = 'SRV', + [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' } + + +dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' } + + +dns.type = augment (dns.types) +dns.class = augment (dns.classes) +dns.typecode = encode (dns.types) +dns.classcode = encode (dns.classes) + + + +local function standardize (qname, qtype, qclass) -- - - - - - - standardize + if string.byte (qname, -1) ~= 0x2E then qname = qname..'.' end + qname = string.lower (qname) + return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN'] + end + + +local function prune (rrs, time, soft) -- - - - - - - - - - - - - - - prune + + time = time or socket.gettime () + for i,rr in pairs (rrs) do + + if rr.tod then + -- rr.tod = rr.tod - 50 -- accelerated decripitude + rr.ttl = math.floor (rr.tod - time) + if rr.ttl <= 0 then rrs[i] = nil end + + elseif soft == 'soft' then -- What is this? I forget! + assert (rr.ttl == 0) + rrs[i] = nil + end end end + + +-- metatables & co. ------------------------------------------ metatables & co. + + +local resolver = {} +resolver.__index = resolver + + +local SRV_tostring + + +local rr_metatable = {} -- - - - - - - - - - - - - - - - - - - rr_metatable +function rr_metatable.__tostring (rr) + local s0 = string.format ( + '%2s %-5s %6i %-28s', rr.class, rr.type, rr.ttl, rr.name ) + local s1 = '' + if rr.type == 'A' then s1 = ' '..rr.a + elseif rr.type == 'MX' then + s1 = string.format (' %2i %s', rr.pref, rr.mx) + elseif rr.type == 'CNAME' then s1 = ' '..rr.cname + elseif rr.type == 'LOC' then s1 = ' '..resolver.LOC_tostring (rr) + elseif rr.type == 'NS' then s1 = ' '..rr.ns + elseif rr.type == 'SRV' then s1 = ' '..SRV_tostring (rr) + elseif rr.type == 'TXT' then s1 = ' '..rr.txt + else s1 = ' <UNKNOWN RDATA TYPE>' end + return s0..s1 + end + + +local rrs_metatable = {} -- - - - - - - - - - - - - - - - - - rrs_metatable +function rrs_metatable.__tostring (rrs) + local t = {} + for i,rr in pairs (rrs) do append (t, tostring (rr)..'\n') end + return table.concat (t) + end + + +local cache_metatable = {} -- - - - - - - - - - - - - - - - cache_metatable +function cache_metatable.__tostring (cache) + local time = socket.gettime () + local t = {} + for class,types in pairs (cache) do + for type,names in pairs (types) do + for name,rrs in pairs (names) do + prune (rrs, time) + append (t, tostring (rrs)) end end end + return table.concat (t) + end + + +function resolver:new () -- - - - - - - - - - - - - - - - - - - - - resolver + local r = { active = {}, cache = {}, unsorted = {} } + setmetatable (r, resolver) + setmetatable (r.cache, cache_metatable) + setmetatable (r.unsorted, { __mode = 'kv' }) + return r + end + + +-- packet layer -------------------------------------------------- packet layer + + +function dns.random (...) -- - - - - - - - - - - - - - - - - - - dns.random + math.randomseed (10000*socket.gettime ()) + dns.random = math.random + return dns.random (...) + end + + +local function encodeHeader (o) -- - - - - - - - - - - - - - - encodeHeader + + o = o or {} + + o.id = o.id or -- 16b (random) id + dns.random (0, 0xffff) + + o.rd = o.rd or 1 -- 1b 1 recursion desired + o.tc = o.tc or 0 -- 1b 1 truncated response + o.aa = o.aa or 0 -- 1b 1 authoritative response + o.opcode = o.opcode or 0 -- 4b 0 query + -- 1 inverse query + -- 2 server status request + -- 3-15 reserved + o.qr = o.qr or 0 -- 1b 0 query, 1 response + + o.rcode = o.rcode or 0 -- 4b 0 no error + -- 1 format error + -- 2 server failure + -- 3 name error + -- 4 not implemented + -- 5 refused + -- 6-15 reserved + o.z = o.z or 0 -- 3b 0 resvered + o.ra = o.ra or 0 -- 1b 1 recursion available + + o.qdcount = o.qdcount or 1 -- 16b number of question RRs + o.ancount = o.ancount or 0 -- 16b number of answers RRs + o.nscount = o.nscount or 0 -- 16b number of nameservers RRs + o.arcount = o.arcount or 0 -- 16b number of additional RRs + + -- string.char() rounds, so prevent roundup with -0.4999 + local header = string.char ( + highbyte (o.id), o.id %0x100, + o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr, + o.rcode + 16*o.z + 128*o.ra, + highbyte (o.qdcount), o.qdcount %0x100, + highbyte (o.ancount), o.ancount %0x100, + highbyte (o.nscount), o.nscount %0x100, + highbyte (o.arcount), o.arcount %0x100 ) + + return header, o.id + end + + +local function encodeName (name) -- - - - - - - - - - - - - - - - encodeName + local t = {} + for part in string.gmatch (name, '[^.]+') do + append (t, string.char (string.len (part))) + append (t, part) + end + append (t, string.char (0)) + return table.concat (t) + end + + +local function encodeQuestion (qname, qtype, qclass) -- - - - encodeQuestion + qname = encodeName (qname) + qtype = dns.typecode[qtype or 'a'] + qclass = dns.classcode[qclass or 'in'] + return qname..qtype..qclass; + end + + +function resolver:byte (len) -- - - - - - - - - - - - - - - - - - - - - byte + len = len or 1 + local offset = self.offset + local last = offset + len - 1 + if last > #self.packet then + error (string.format ('out of bounds: %i>%i', last, #self.packet)) end + self.offset = offset + len + return string.byte (self.packet, offset, last) + end + + +function resolver:word () -- - - - - - - - - - - - - - - - - - - - - - word + local b1, b2 = self:byte (2) + return 0x100*b1 + b2 + end + + +function resolver:dword () -- - - - - - - - - - - - - - - - - - - - - dword + local b1, b2, b3, b4 = self:byte (4) + -- print ('dword', b1, b2, b3, b4) + return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4 + end + + +function resolver:sub (len) -- - - - - - - - - - - - - - - - - - - - - - sub + len = len or 1 + local s = string.sub (self.packet, self.offset, self.offset + len - 1) + self.offset = self.offset + len + return s + end + + +function resolver:header (force) -- - - - - - - - - - - - - - - - - - header + + local id = self:word () + -- print (string.format (':header id %x', id)) + if not self.active[id] and not force then return nil end + + local h = { id = id } + + local b1, b2 = self:byte (2) + + h.rd = b1 %2 + h.tc = b1 /2%2 + h.aa = b1 /4%2 + h.opcode = b1 /8%16 + h.qr = b1 /128 + + h.rcode = b2 %16 + h.z = b2 /16%8 + h.ra = b2 /128 + + h.qdcount = self:word () + h.ancount = self:word () + h.nscount = self:word () + h.arcount = self:word () + + for k,v in pairs (h) do h[k] = v-v%1 end + + return h + end + + +function resolver:name () -- - - - - - - - - - - - - - - - - - - - - - name + local remember, pointers = nil, 0 + local len = self:byte () + local n = {} + while len > 0 do + if len >= 0xc0 then -- name is "compressed" + pointers = pointers + 1 + if pointers >= 20 then error ('dns error: 20 pointers') end + local offset = ((len-0xc0)*0x100) + self:byte () + remember = remember or self.offset + self.offset = offset + 1 -- +1 for lua + else -- name is not compressed + append (n, self:sub (len)..'.') + end + len = self:byte () + end + self.offset = remember or self.offset + return table.concat (n) + end + + +function resolver:question () -- - - - - - - - - - - - - - - - - - question + local q = {} + q.name = self:name () + q.type = dns.type[self:word ()] + q.class = dns.type[self:word ()] + return q + end + + +function resolver:A (rr) -- - - - - - - - - - - - - - - - - - - - - - - - A + local b1, b2, b3, b4 = self:byte (4) + rr.a = string.format ('%i.%i.%i.%i', b1, b2, b3, b4) + end + + +function resolver:CNAME (rr) -- - - - - - - - - - - - - - - - - - - - CNAME + rr.cname = self:name () + end + + +function resolver:MX (rr) -- - - - - - - - - - - - - - - - - - - - - - - MX + rr.pref = self:word () + rr.mx = self:name () + end + + +function resolver:LOC_nibble_power () -- - - - - - - - - - LOC_nibble_power + local b = self:byte () + -- print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10)) + return ((b-(b%0x10))/0x10) * (10^(b%0x10)) + end + + +function resolver:LOC (rr) -- - - - - - - - - - - - - - - - - - - - - - LOC + rr.version = self:byte () + if rr.version == 0 then + rr.loc = rr.loc or {} + rr.loc.size = self:LOC_nibble_power () + rr.loc.horiz_pre = self:LOC_nibble_power () + rr.loc.vert_pre = self:LOC_nibble_power () + rr.loc.latitude = self:dword () + rr.loc.longitude = self:dword () + rr.loc.altitude = self:dword () + end end + + +local function LOC_tostring_degrees (f, pos, neg) -- - - - - - - - - - - - - + f = f - 0x80000000 + if f < 0 then pos = neg f = -f end + local deg, min, msec + msec = f%60000 + f = (f-msec)/60000 + min = f%60 + deg = (f-min)/60 + return string.format ('%3d %2d %2.3f %s', deg, min, msec/1000, pos) + end + + +function resolver.LOC_tostring (rr) -- - - - - - - - - - - - - LOC_tostring + + local t = {} + + --[[ + for k,name in pairs { 'size', 'horiz_pre', 'vert_pre', + 'latitude', 'longitude', 'altitude' } do + append (t, string.format ('%4s%-10s: %12.0f\n', '', name, rr.loc[name])) + end + --]] + + append ( t, string.format ( + '%s %s %.2fm %.2fm %.2fm %.2fm', + LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'), + LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'), + (rr.loc.altitude - 10000000) / 100, + rr.loc.size / 100, + rr.loc.horiz_pre / 100, + rr.loc.vert_pre / 100 ) ) + + return table.concat (t) + end + + +function resolver:NS (rr) -- - - - - - - - - - - - - - - - - - - - - - - NS + rr.ns = self:name () + end + + +function resolver:SOA (rr) -- - - - - - - - - - - - - - - - - - - - - - SOA + end + + +function resolver:SRV (rr) -- - - - - - - - - - - - - - - - - - - - - - SRV + rr.srv = {} + rr.srv.priority = self:word () + rr.srv.weight = self:word () + rr.srv.port = self:word () + rr.srv.target = self:name () + end + + +function SRV_tostring (rr) -- - - - - - - - - - - - - - - - - - SRV_tostring + local s = rr.srv + return string.format ( '%5d %5d %5d %s', + s.priority, s.weight, s.port, s.target ) + end + + +function resolver:TXT (rr) -- - - - - - - - - - - - - - - - - - - - - - TXT + rr.txt = self:sub (rr.rdlength) + end + + +function resolver:rr () -- - - - - - - - - - - - - - - - - - - - - - - - rr + local rr = {} + setmetatable (rr, rr_metatable) + rr.name = self:name (self) + rr.type = dns.type[self:word ()] or rr.type + rr.class = dns.class[self:word ()] or rr.class + rr.ttl = 0x10000*self:word () + self:word () + rr.rdlength = self:word () + + if rr.ttl == 0 then -- pass + else rr.tod = self.time + rr.ttl end + + local remember = self.offset + local rr_parser = self[dns.type[rr.type]] + if rr_parser then rr_parser (self, rr) end + self.offset = remember + rr.rdata = self:sub (rr.rdlength) + return rr + end + + +function resolver:rrs (count) -- - - - - - - - - - - - - - - - - - - - - rrs + local rrs = {} + for i = 1,count do append (rrs, self:rr ()) end + return rrs + end + + +function resolver:decode (packet, force) -- - - - - - - - - - - - - - decode + + self.packet, self.offset = packet, 1 + local header = self:header (force) + if not header then return nil end + local response = { header = header } + + response.question = {} + local offset = self.offset + for i = 1,response.header.qdcount do + append (response.question, self:question ()) end + response.question.raw = string.sub (self.packet, offset, self.offset - 1) + + if not force then + if not self.active[response.header.id] or + not self.active[response.header.id][response.question.raw] then + return nil end end + + response.answer = self:rrs (response.header.ancount) + response.authority = self:rrs (response.header.nscount) + response.additional = self:rrs (response.header.arcount) + + return response + end + + +-- socket layer -------------------------------------------------- socket layer + + +resolver.delays = { 1, 3, 11, 45 } + + +function resolver:addnameserver (address) -- - - - - - - - - - addnameserver + self.server = self.server or {} + append (self.server, address) + end + + +function resolver:setnameserver (address) -- - - - - - - - - - setnameserver + self.server = {} + self:addnameserver (address) + end + + +function resolver:adddefaultnameservers () -- - - - - adddefaultnameservers + local resolv_conf = io.open("/etc/resolv.conf"); + if resolv_conf then + for line in resolv_conf:lines() do + local address = string.match (line, 'nameserver%s+(%d+%.%d+%.%d+%.%d+)') + if address then self:addnameserver (address) end + end + else -- FIXME correct for windows, using opendns nameservers for now + self:addnameserver ("208.67.222.222") + self:addnameserver ("208.67.220.220") + end +end + + +function resolver:getsocket (servernum) -- - - - - - - - - - - - - getsocket + + self.socket = self.socket or {} + self.socketset = self.socketset or {} + + local sock = self.socket[servernum] + if sock then return sock end + + sock = socket.udp () + if self.socket_wrapper then sock = self.socket_wrapper (sock) end + sock:settimeout (0) + -- todo: attempt to use a random port, fallback to 0 + sock:setsockname ('*', 0) + sock:setpeername (self.server[servernum], 53) + self.socket[servernum] = sock + self.socketset[sock] = sock + return sock + end + + +function resolver:socket_wrapper_set (func) -- - - - - - - socket_wrapper_set + self.socket_wrapper = func + end + + +function resolver:closeall () -- - - - - - - - - - - - - - - - - - closeall + for i,sock in ipairs (self.socket) do self.socket[i]:close () end + self.socket = {} + end + + +function resolver:remember (rr, type) -- - - - - - - - - - - - - - remember + + -- print ('remember', type, rr.class, rr.type, rr.name) + + if type ~= '*' then + type = rr.type + local all = get (self.cache, rr.class, '*', rr.name) + -- print ('remember all', all) + if all then append (all, rr) end + end + + self.cache = self.cache or setmetatable ({}, cache_metatable) + local rrs = get (self.cache, rr.class, type, rr.name) or + set (self.cache, rr.class, type, rr.name, setmetatable ({}, rrs_metatable)) + append (rrs, rr) + + if type == 'MX' then self.unsorted[rrs] = true end + end + + +local function comp_mx (a, b) -- - - - - - - - - - - - - - - - - - - comp_mx + return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref) + end + + +function resolver:peek (qname, qtype, qclass) -- - - - - - - - - - - - peek + qname, qtype, qclass = standardize (qname, qtype, qclass) + local rrs = get (self.cache, qclass, qtype, qname) + if not rrs then return nil end + if prune (rrs, socket.gettime ()) and qtype == '*' or not next (rrs) then + set (self.cache, qclass, qtype, qname, nil) return nil end + if self.unsorted[rrs] then table.sort (rrs, comp_mx) end + return rrs + end + + +function resolver:purge (soft) -- - - - - - - - - - - - - - - - - - - purge + if soft == 'soft' then + self.time = socket.gettime () + for class,types in pairs (self.cache or {}) do + for type,names in pairs (types) do + for name,rrs in pairs (names) do + prune (rrs, self.time, 'soft') + end end end + else self.cache = {} end + end + + +function resolver:query (qname, qtype, qclass) -- - - - - - - - - - -- query + + qname, qtype, qclass = standardize (qname, qtype, qclass) + + if not self.server then self:adddefaultnameservers () end + + local question = encodeQuestion (qname, qtype, qclass) + local peek = self:peek (qname, qtype, qclass) + if peek then return peek end + + local header, id = encodeHeader () + -- print ('query id', id, qclass, qtype, qname) + local o = { packet = header..question, + server = 1, + delay = 1, + retry = socket.gettime () + self.delays[1] } + self:getsocket (o.server):send (o.packet) + + -- remember the query + self.active[id] = self.active[id] or {} + self.active[id][question] = o + + -- remember which coroutine wants the answer + local co = coroutine.running () + if co then + set (self.wanted, qclass, qtype, qname, co, true) + set (self.yielded, co, qclass, qtype, qname, true) + end end + + +function resolver:receive (rset) -- - - - - - - - - - - - - - - - - receive + + -- print 'receive' print (self.socket) + self.time = socket.gettime () + rset = rset or self.socket + + local response + for i,sock in pairs (rset) do + + if self.socketset[sock] then + local packet = sock:receive () + if packet then + + response = self:decode (packet) + if response then + -- print 'received response' + -- self.print (response) + + for i,section in pairs { 'answer', 'authority', 'additional' } do + for j,rr in pairs (response[section]) do + self:remember (rr, response.question[1].type) end end + + -- retire the query + local queries = self.active[response.header.id] + if queries[response.question.raw] then + queries[response.question.raw] = nil end + if not next (queries) then self.active[response.header.id] = nil end + if not next (self.active) then self:closeall () end + + -- was the query on the wanted list? + local q = response.question + local cos = get (self.wanted, q.class, q.type, q.name) + if cos then + for co in pairs (cos) do + set (self.yielded, co, q.class, q.type, q.name, nil) + if not self.yielded[co] then coroutine.resume (co) end + end + set (self.wanted, q.class, q.type, q.name, nil) + end end end end end + + return response + end + + +function resolver:pulse () -- - - - - - - - - - - - - - - - - - - - - pulse + + -- print ':pulse' + while self:receive () do end + if not next (self.active) then return nil end + + self.time = socket.gettime () + for id,queries in pairs (self.active) do + for question,o in pairs (queries) do + if self.time >= o.retry then + + o.server = o.server + 1 + if o.server > #self.server then + o.server = 1 + o.delay = o.delay + 1 + end + + if o.delay > #self.delays then + print ('timeout') + queries[question] = nil + if not next (queries) then self.active[id] = nil end + if not next (self.active) then return nil end + else + -- print ('retry', o.server, o.delay) + self.socket[o.server]:send (o.packet) + o.retry = self.time + self.delays[o.delay] + end end end end + + if next (self.active) then return true end + return nil + end + + +function resolver:lookup (qname, qtype, qclass) -- - - - - - - - - - lookup + self:query (qname, qtype, qclass) + while self:pulse () do socket.select (self.socket, nil, 4) end + -- print (self.cache) + return self:peek (qname, qtype, qclass) + end + + +-- print ---------------------------------------------------------------- print + + +local hints = { -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints + qr = { [0]='query', 'response' }, + opcode = { [0]='query', 'inverse query', 'server status request' }, + aa = { [0]='non-authoritative', 'authoritative' }, + tc = { [0]='complete', 'truncated' }, + rd = { [0]='recursion not desired', 'recursion desired' }, + ra = { [0]='recursion not available', 'recursion available' }, + z = { [0]='(reserved)' }, + rcode = { [0]='no error', 'format error', 'server failure', 'name error', + 'not implemented' }, + + type = dns.type, + class = dns.class, } + + +local function hint (p, s) -- - - - - - - - - - - - - - - - - - - - - - hint + return (hints[s] and hints[s][p[s]]) or '' end + + +function resolver.print (response) -- - - - - - - - - - - - - resolver.print + + for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z', + 'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do + print ( string.format ('%-30s', 'header.'..s), + response.header[s], hint (response.header, s) ) + end + + for i,question in ipairs (response.question) do + print (string.format ('question[%i].name ', i), question.name) + print (string.format ('question[%i].type ', i), question.type) + print (string.format ('question[%i].class ', i), question.class) + end + + local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 } + local tmp + for s,s in pairs {'answer', 'authority', 'additional'} do + for i,rr in pairs (response[s]) do + for j,t in pairs { 'name', 'type', 'class', 'ttl', 'rdlength' } do + tmp = string.format ('%s[%i].%s', s, i, t) + print (string.format ('%-30s', tmp), rr[t], hint (rr, t)) + end + for j,t in pairs (rr) do + if not common[j] then + tmp = string.format ('%s[%i].%s', s, i, j) + print (string.format ('%-30s %s', tmp, t)) + end end end end end + + +-- module api ------------------------------------------------------ module api + + +local function resolve (func, ...) -- - - - - - - - - - - - - - resolver_get + dns._resolver = dns._resolver or dns.resolver () + return func (dns._resolver, ...) + end + + +function dns.resolver () -- - - - - - - - - - - - - - - - - - - - - resolver + + -- this function seems to be redundant with resolver.new () + + local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {} } + setmetatable (r, resolver) + setmetatable (r.cache, cache_metatable) + setmetatable (r.unsorted, { __mode = 'kv' }) + return r + end + + +function dns.lookup (...) -- - - - - - - - - - - - - - - - - - - - - lookup + return resolve (resolver.lookup, ...) end + + +function dns.purge (...) -- - - - - - - - - - - - - - - - - - - - - - purge + return resolve (resolver.purge, ...) end + +function dns.peek (...) -- - - - - - - - - - - - - - - - - - - - - - - peek + return resolve (resolver.peek, ...) end + + +function dns.query (...) -- - - - - - - - - - - - - - - - - - - - - - query + return resolve (resolver.query, ...) end + + +function dns:socket_wrapper_set (...) -- - - - - - - - - socket_wrapper_set + return resolve (resolver.socket_wrapper_set, ...) end + + +return dns diff --git a/net/server.lua b/net/server.lua index 3bbe80b4..40f37345 100644 --- a/net/server.lua +++ b/net/server.lua @@ -190,7 +190,7 @@ wrapsslclient = function( listener, socket, ip, serverport, clientport, mode, ss local writequeue = { } -- buffer for messages to send
- local eol -- end of buffer
+ local eol, fatal_send_error -- end of buffer
local sstat, rstat = 0, 0
@@ -225,6 +225,7 @@ wrapsslclient = function( listener, socket, ip, serverport, clientport, mode, ss --return shutdown( socket, pattern )
end
handler.close = function( closed )
+ if eol and not fatal_send_error then handler._dispatchdata(); end
close( socket )
writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen
readlen = removesocket( readlist, socket, readlen )
@@ -295,6 +296,7 @@ wrapsslclient = function( listener, socket, ip, serverport, clientport, mode, ss eol = 1
return true
else -- connection was closed during sending or fatal error
+ fatal_send_error = true;
out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )
handler.close( )
disconnect( handler, err )
@@ -364,12 +366,13 @@ wraptlsclient = function( listener, socket, ip, serverport, clientport, mode, ss local err
socket:settimeout( 0 )
-
+ out_put("setting linger on "..tostring(socket))
+ socket:setoption("linger", { on = true, timeout = 10 });
--// private closures of the object //--
local writequeue = { } -- buffer for messages to send
- local eol -- end of buffer
+ local eol, fatal_send_error -- end of buffer
local sstat, rstat = 0, 0
@@ -404,6 +407,7 @@ wraptlsclient = function( listener, socket, ip, serverport, clientport, mode, ss --return shutdown( socket, pattern )
end
handler.close = function( closed )
+ if eol and not fatal_send_error then handler._dispatchdata(); end
close( socket )
writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen
readlen = removesocket( readlist, socket, readlen )
@@ -481,6 +485,7 @@ wraptlsclient = function( listener, socket, ip, serverport, clientport, mode, ss eol = 1
return true
else -- connection was closed during sending or fatal error
+ fatal_send_error = true; -- :(
out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )
handler.close( )
disconnect( handler, err )
@@ -579,7 +584,7 @@ wraptcpclient = function( listener, socket, ip, serverport, clientport, mode ) local writequeue = { } -- list for messages to send
- local eol
+ local eol, fatal_send_error
local rstat, sstat = 0, 0
@@ -614,6 +619,7 @@ wraptcpclient = function( listener, socket, ip, serverport, clientport, mode ) return shutdown( socket, pattern )
end
handler.close = function( closed )
+ if eol and not fatal_send_error then handler.dispatchdata(); end
_ = not closed and shutdown( socket )
_ = not closed and close( socket )
writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen
@@ -686,6 +692,7 @@ wraptcpclient = function( listener, socket, ip, serverport, clientport, mode ) eol = 1
return true
else -- connection was closed during sending or fatal error
+ fatal_send_error = true; -- :'-(
out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )
handler.close( )
disconnect( handler, err )
diff --git a/net/xmppclient_listener.lua b/net/xmppclient_listener.lua index b5028db0..0f5511b4 100644 --- a/net/xmppclient_listener.lua +++ b/net/xmppclient_listener.lua @@ -13,8 +13,11 @@ local m_random = math.random; local format = string.format; local sm_new_session, sm_destroy_session = sessionmanager.new_session, sessionmanager.destroy_session; --import("core.sessionmanager", "new_session", "destroy_session"); local sm_streamopened = sessionmanager.streamopened; +local sm_streamclosed = sessionmanager.streamclosed; local st = stanza; +local stream_callbacks = { streamopened = sm_streamopened, streamclosed = sm_streamclosed }; + local sessions = {}; local xmppclient = { default_port = 5222 }; @@ -22,7 +25,7 @@ local xmppclient = { default_port = 5222 }; local function session_reset_stream(session) -- Reset stream - local parser = lxp.new(init_xmlhandlers(session, sm_streamopened), "|"); + local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "|"); session.parser = parser; session.notopen = true; @@ -33,6 +36,39 @@ local function session_reset_stream(session) return true; end + +local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; +local function session_close(session, reason) + local log = session.log or log; + if session.conn then + if reason then + if type(reason) == "string" then -- assume stream error + log("info", "Disconnecting client, <stream:error> is: %s", reason); + session.send(st.stanza("stream:error"):tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' })); + elseif type(reason) == "table" then + if reason.condition then + local stanza = st.stanza("stream:error"):tag(reason.condition, stream_xmlns_attr):up(); + if reason.text then + stanza:tag("text", stream_xmlns_attr):text(reason.text):up(); + end + if reason.extra then + stanza:add_child(reason.extra); + end + log("info", "Disconnecting client, <stream:error> is: %s", tostring(stanza)); + session.send(stanza); + elseif reason.name then -- a stanza + log("info", "Disconnecting client, <stream:error> is: %s", tostring(reason)); + session.send(reason); + end + end + end + session.send("</stream:stream>"); + session.conn.close(); + xmppclient.disconnect(session.conn, "stream error"); + end +end + + -- End of session methods -- function xmppclient.listener(conn, data) @@ -54,6 +90,7 @@ function xmppclient.listener(conn, data) print("Client connected"); session.reset_stream = session_reset_stream; + session.close = session_close; session_reset_stream(session); -- Initialise, ready for use @@ -64,9 +101,6 @@ function xmppclient.listener(conn, data) -- Debug version -- local function handleerr(err) print("Traceback:", err, debug.traceback()); end session.stanza_dispatch = function (stanza) return select(2, xpcall(function () return core_process_stanza(session, stanza); end, handleerr)); end - --- session.stanza_dispatch = function (stanza) return core_process_stanza(session, stanza); end - end if data then session.data(conn, data); @@ -76,12 +110,6 @@ end function xmppclient.disconnect(conn, err) local session = sessions[conn]; if session then - if session.presence and session.presence.attr.type ~= "unavailable" then - local pres = st.presence{ type = "unavailable" }; - if err == "closed" then err = "connection closed"; end - pres:tag("status"):text("Disconnected: "..err); - session.stanza_dispatch(pres); - end (session.log or log)("info", "Client disconnected: %s", err); sm_destroy_session(session); sessions[conn] = nil; diff --git a/net/xmppserver_listener.lua b/net/xmppserver_listener.lua index ee3faa8f..51116a5e 100644 --- a/net/xmppserver_listener.lua +++ b/net/xmppserver_listener.lua @@ -5,8 +5,11 @@ local init_xmlhandlers = require "core.xmlhandlers" local sm_new_session = require "core.sessionmanager".new_session; local s2s_new_incoming = require "core.s2smanager".new_incoming; local s2s_streamopened = require "core.s2smanager".streamopened; +local s2s_streamclosed = require "core.s2smanager".streamclosed; local s2s_destroy_session = require "core.s2smanager".destroy_session; +local stream_callbacks = { streamopened = s2s_streamopened, streamclosed = s2s_streamclosed }; + local connlisteners_register = require "net.connlisteners".register; local t_insert = table.insert; @@ -24,7 +27,7 @@ local xmppserver = { default_port = 5269 }; local function session_reset_stream(session) -- Reset stream - local parser = lxp.new(init_xmlhandlers(session, s2s_streamopened), "|"); + local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "|"); session.parser = parser; session.notopen = true; @@ -35,6 +38,39 @@ local function session_reset_stream(session) return true; end + +local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; +local function session_close(session, reason) + local log = session.log or log; + if session.conn then + if reason then + if type(reason) == "string" then -- assume stream error + log("info", "Disconnecting %s[%s], <stream:error> is: %s", session.host or "(unknown host)", session.type, reason); + session.sends2s(st.stanza("stream:error"):tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' })); + elseif type(reason) == "table" then + if reason.condition then + local stanza = st.stanza("stream:error"):tag(reason.condition, stream_xmlns_attr):up(); + if reason.text then + stanza:tag("text", stream_xmlns_attr):text(reason.text):up(); + end + if reason.extra then + stanza:add_child(reason.extra); + end + log("info", "Disconnecting %s[%s], <stream:error> is: %s", session.host or "(unknown host)", session.type, tostring(stanza)); + session.sends2s(stanza); + elseif reason.name then -- a stanza + log("info", "Disconnecting %s->%s[%s], <stream:error> is: %s", session.from_host or "(unknown host)", session.to_host or "(unknown host)", session.type, tostring(reason)); + session.sends2s(reason); + end + end + end + session.sends2s("</stream:stream>"); + session.conn.close(); + xmppserver.disconnect(session.conn, "stream error"); + end +end + + -- End of session methods -- function xmppserver.listener(conn, data) @@ -56,6 +92,7 @@ function xmppserver.listener(conn, data) print("Incoming s2s connection"); session.reset_stream = session_reset_stream; + session.close = session_close; session_reset_stream(session); -- Initialise, ready for use @@ -66,9 +103,6 @@ function xmppserver.listener(conn, data) -- Debug version -- local function handleerr(err) print("Traceback:", err, debug.traceback()); end session.stanza_dispatch = function (stanza) return select(2, xpcall(function () return core_process_stanza(session, stanza); end, handleerr)); end - --- session.stanza_dispatch = function (stanza) return core_process_stanza(session, stanza); end - end if data then session.data(conn, data); @@ -78,6 +112,7 @@ end function xmppserver.disconnect(conn) local session = sessions[conn]; if session then + (session.log or log)("info", "s2s disconnected: %s->%s", tostring(session.from_host), tostring(session.to_host)); s2s_destroy_session(session); sessions[conn] = nil; session = nil; diff --git a/plugins/mod_console.lua b/plugins/mod_console.lua new file mode 100644 index 00000000..5787ad25 --- /dev/null +++ b/plugins/mod_console.lua @@ -0,0 +1,140 @@ +
+local connlisteners_register = require "net.connlisteners".register;
+
+local console_listener = { default_port = 5582; default_mode = "*l"; };
+
+local commands = {};
+local default_env = {};
+local default_env_mt = { __index = default_env };
+
+console = {};
+
+function console:new_session(conn)
+ local w = conn.write;
+ return { conn = conn;
+ send = function (t) w(tostring(t)); end;
+ print = function (t) w("| "..tostring(t).."\n"); end;
+ disconnect = function () conn.close(); end;
+ env = setmetatable({}, default_env_mt);
+ };
+end
+
+local sessions = {};
+
+function console_listener.listener(conn, data)
+ local session = sessions[conn];
+
+ if not session then
+ -- Handle new connection
+ session = console:new_session(conn);
+ sessions[conn] = session;
+ session.print("Welcome to the lxmppd admin console!");
+ end
+ if data then
+ -- Handle data
+
+ if data:match("[!.]$") then
+ local command = data:lower();
+ command = data:match("^%w+") or data:match("%p");
+ if commands[command] then
+ commands[command](session, data);
+ return;
+ end
+ end
+
+ session.env._ = data;
+
+ local chunk, err = loadstring("return "..data);
+ if not chunk then
+ chunk, err = loadstring(data);
+ if not chunk then
+ err = err:gsub("^%[string .-%]:%d+: ", "");
+ err = err:gsub("^:%d+: ", "");
+ err = err:gsub("'<eof>'", "the end of the line");
+ session.print("Sorry, I couldn't understand that... "..err);
+ return;
+ end
+ end
+
+ setfenv(chunk, session.env);
+ local ranok, taskok, message = pcall(chunk);
+
+ if not ranok then
+ session.print("Fatal error while running command, it did not complete");
+ session.print("Error: "..taskok);
+ return;
+ end
+
+ if not message then
+ session.print("Result: "..tostring(taskok));
+ return;
+ elseif (not taskok) and message then
+ session.print("Command completed with a problem");
+ session.print("Message: "..tostring(message));
+ return;
+ end
+
+ session.print("OK: "..tostring(message));
+ end
+end
+
+function console_listener.disconnect(conn, err)
+
+end
+
+connlisteners_register('console', console_listener);
+
+-- Console commands --
+-- These are simple commands, not valid standalone in Lua
+
+function commands.bye(session)
+ session.print("See you! :)");
+ session.disconnect();
+end
+
+commands["!"] = function (session, data)
+ if data:match("^!!") then
+ session.print("!> "..session.env._);
+ return console_listener.listener(session.conn, session.env._);
+ end
+ local old, new = data:match("^!(.-[^\\])!(.-)!$");
+ if old and new then
+ local ok, res = pcall(string.gsub, session.env._, old, new);
+ if not ok then
+ session.print(res)
+ return;
+ end
+ session.print("!> "..res);
+ return console_listener.listener(session.conn, res);
+ end
+ session.print("Sorry, not sure what you want");
+end
+
+-- Session environment --
+-- Anything in default_env will be accessible within the session as a global variable
+
+default_env.server = {};
+function default_env.server.reload()
+ dofile "main.lua"
+ return true, "Server reloaded";
+end
+
+default_env.module = {};
+function default_env.module.load(name)
+ local mm = require "modulemanager";
+ local ok, err = mm.load(name);
+ if not ok then
+ return false, err or "Unknown error loading module";
+ end
+ return true, "Module loaded";
+end
+
+default_env.config = {};
+function default_env.config.load(filename, format)
+ local cfgm_load = require "core.configmanager".load;
+ local ok, err = cfgm_load(filename, format);
+ if not ok then
+ return false, err or "Unknown error loading config";
+ end
+ return true, "Config loaded";
+end
diff --git a/plugins/mod_dialback.lua b/plugins/mod_dialback.lua index c17cbcaf..87ac303b 100644 --- a/plugins/mod_dialback.lua +++ b/plugins/mod_dialback.lua @@ -55,8 +55,12 @@ add_handler({ "s2sout_unauthed", "s2sout" }, "verify", xmlns_dialback, log("warn", "dialback for "..(origin.dialback_verifying.from_host or "(unknown)").." failed"); valid = "invalid"; end - origin.dialback_verifying.sends2s(format("<db:result from='%s' to='%s' id='%s' type='%s'>%s</db:result>", - attr.from, attr.to, attr.id, valid, origin.dialback_verifying.dialback_key)); + if not origin.dialback_verifying.sends2s then + log("warn", "Incoming s2s session %s was closed in the meantime, so we can't notify it of the db result", tostring(origin.dialback_verifying):match("%w+$")); + else + origin.dialback_verifying.sends2s(format("<db:result from='%s' to='%s' id='%s' type='%s'>%s</db:result>", + attr.to, attr.from, attr.id, valid, origin.dialback_verifying.dialback_key)); + end end end); diff --git a/plugins/mod_disco.lua b/plugins/mod_disco.lua new file mode 100644 index 00000000..261650ce --- /dev/null +++ b/plugins/mod_disco.lua @@ -0,0 +1,9 @@ +
+local discomanager_handle = require "core.discomanager".handle;
+
+add_iq_handler({"c2s", "s2sin"}, "http://jabber.org/protocol/disco#info", function (session, stanza)
+ session.send(discomanager_handle(stanza));
+end);
+add_iq_handler({"c2s", "s2sin"}, "http://jabber.org/protocol/disco#items", function (session, stanza)
+ session.send(discomanager_handle(stanza));
+end);
diff --git a/plugins/mod_register.lua b/plugins/mod_register.lua index fb001392..c2b85bae 100644 --- a/plugins/mod_register.lua +++ b/plugins/mod_register.lua @@ -2,6 +2,7 @@ local st = require "util.stanza"; local usermanager_user_exists = require "core.usermanager".user_exists; local usermanager_create_user = require "core.usermanager".create_user; +local datamanager_store = require "util.datamanager".store; add_iq_handler("c2s", "jabber:iq:register", function (session, stanza) if stanza.tags[1].name == "query" then @@ -16,7 +17,33 @@ add_iq_handler("c2s", "jabber:iq:register", function (session, stanza) elseif stanza.attr.type == "set" then if query.tags[1] and query.tags[1].name == "remove" then -- TODO delete user auth data, send iq response, kick all user resources with a <not-authorized/>, delete all user data - session.send(st.error_reply(stanza, "cancel", "not-allowed")); + --session.send(st.error_reply(stanza, "cancel", "not-allowed")); + --return; + usermanager_create_user(session.username, nil, session.host); -- Disable account + -- FIXME the disabling currently allows a different user to recreate the account + -- we should add an in-memory account block mode when we have threading + session.send(st.reply(stanza)); + local roster = session.roster; + for _, session in pairs(hosts[session.host].sessions[session.username].sessions) do -- disconnect all resources + session:disconnect({condition = "not-authorized", text = "Account deleted"}); + end + -- TODO datamanager should be able to delete all user data itself + datamanager.store(session.username, session.host, "roster", nil); + datamanager.store(session.username, session.host, "vCard", nil); + datamanager.store(session.username, session.host, "private", nil); + datamanager.store(session.username, session.host, "offline", nil); + local bare = session.username.."@"..session.host; + for jid, item in pairs(roster) do + if jid ~= "pending" then + if item.subscription == "both" or item.subscription == "to" then + -- TODO unsubscribe + end + if item.subscription == "both" or item.subscription == "from" then + -- TODO unsubscribe + end + end + end + datamanager.store(session.username, session.host, "accounts", nil); -- delete accounts datastore at the end else local username = query:child_with_name("username"); local password = query:child_with_name("password"); diff --git a/plugins/mod_roster.lua b/plugins/mod_roster.lua index 23a19828..24d858e7 100644 --- a/plugins/mod_roster.lua +++ b/plugins/mod_roster.lua @@ -4,6 +4,7 @@ local st = require "util.stanza" local jid_split = require "util.jid".split; local t_concat = table.concat; +local handle_outbound_presence_subscriptions_and_probes = require "core.presencemanager".handle_outbound_presence_subscriptions_and_probes; local rm_remove_from_roster = require "core.rostermanager".remove_from_roster; local rm_add_to_roster = require "core.rostermanager".add_to_roster; local rm_roster_push = require "core.rostermanager".roster_push; @@ -38,15 +39,25 @@ add_iq_handler("c2s", "jabber:iq:roster", and query.tags[1].attr.jid ~= "pending" then local item = query.tags[1]; local from_node, from_host = jid_split(stanza.attr.from); + local from_bare = from_node and (from_node.."@"..from_host) or from_host; -- bare JID local node, host, resource = jid_split(item.attr.jid); - if not resource then + local to_bare = node and (node.."@"..host) or host; -- bare JID + if not resource and host then if item.attr.jid ~= from_node.."@"..from_host then if item.attr.subscription == "remove" then - if session.roster[item.attr.jid] then + local r_item = session.roster[item.attr.jid]; + if r_item then local success, err_type, err_cond, err_msg = rm_remove_from_roster(session, item.attr.jid); if success then session.send(st.reply(stanza)); rm_roster_push(from_node, from_host, item.attr.jid); + if r_item.subscription == "both" or r_item.subscription == "from" then + handle_outbound_presence_subscriptions_and_probes(session, + st.presence({type="unsubscribed"}), from_bare, to_bare); + elseif r_item.subscription == "both" or r_item.subscription == "to" then + handle_outbound_presence_subscriptions_and_probes(session, + st.presence({type="unsubscribe"}), from_bare, to_bare); + end else session.send(st.error_reply(stanza, err_type, err_cond, err_msg)); end diff --git a/plugins/mod_saslauth.lua b/plugins/mod_saslauth.lua index 6ceb0be3..7ca4308b 100644 --- a/plugins/mod_saslauth.lua +++ b/plugins/mod_saslauth.lua @@ -83,19 +83,21 @@ add_handler("c2s_unauthed", "auth", xmlns_sasl, sasl_handler); add_handler("c2s_unauthed", "abort", xmlns_sasl, sasl_handler); add_handler("c2s_unauthed", "response", xmlns_sasl, sasl_handler); +local mechanisms_attr = { xmlns='urn:ietf:params:xml:ns:xmpp-sasl' }; +local bind_attr = { xmlns='urn:ietf:params:xml:ns:xmpp-bind' }; +local xmpp_session_attr = { xmlns='urn:ietf:params:xml:ns:xmpp-session' }; add_event_hook("stream-features", function (session, features) if not session.username then - t_insert(features, "<mechanisms xmlns='urn:ietf:params:xml:ns:xmpp-sasl'>"); + features:tag("mechanisms", mechanisms_attr); -- TODO: Provide PLAIN only if TLS is active, this is a SHOULD from the introduction of RFC 4616. This behavior could be overridden via configuration but will issuing a warning or so. - t_insert(features, "<mechanism>PLAIN</mechanism>"); - t_insert(features, "<mechanism>DIGEST-MD5</mechanism>"); - t_insert(features, "</mechanisms>"); + features:tag("mechanism"):text("PLAIN"):up(); + features:tag("mechanism"):text("DIGEST-MD5"):up(); + features:up(); else - t_insert(features, "<bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'><required/></bind>"); - t_insert(features, "<session xmlns='urn:ietf:params:xml:ns:xmpp-session'/>"); + features:tag("bind", bind_attr):tag("required"):up():up(); + features:tag("session", xmpp_session_attr):up(); end - --send [[<register xmlns="http://jabber.org/features/iq-register"/> ]] end); add_iq_handler("c2s", "urn:ietf:params:xml:ns:xmpp-bind", diff --git a/plugins/mod_selftests.lua b/plugins/mod_selftests.lua new file mode 100644 index 00000000..4f128504 --- /dev/null +++ b/plugins/mod_selftests.lua @@ -0,0 +1,53 @@ + +local st = require "util.stanza"; +local register_component = require "core.componentmanager".register_component; +local core_route_stanza = core_route_stanza; +local socket = require "socket"; +local config = require "core.configmanager"; +local ping_hosts = config.get("*", "mod_selftests", "ping_hosts") or { "jabber.org" }; + +local open_pings = {}; + +local t_insert = table.insert; + +local log = require "util.logger".init("mod_selftests"); + +local tests_jid = "self_tests@getjabber.ath.cx"; +local host = "getjabber.ath.cx"; + +if not (tests_jid and host) then + for currhost in pairs(host) do + if currhost ~= "localhost" then + tests_jid, host = "self_tests@"..currhost, currhost; + end + end +end + +if tests_jid and host then + local bot = register_component(tests_jid, function(origin, stanza, ourhost) + local time = open_pings[stanza.attr.id]; + + if time then + log("info", "Ping reply from %s in %fs", tostring(stanza.attr.from), socket.gettime() - time); + else + log("info", "Unexpected reply: %s", stanza:pretty_print()); + end + end); + + + local our_origin = hosts[host]; + add_event_hook("server-started", + function () + local id = st.new_id(); + local ping_attr = { xmlns = 'urn:xmpp:ping' }; + local function send_ping(to) + log("info", "Sending ping to %s", to); + core_route_stanza(our_origin, st.iq{ to = to, from = tests_jid, id = id, type = "get" }:tag("ping", ping_attr)); + open_pings[id] = socket.gettime(); + end + + for _, host in ipairs(ping_hosts) do + send_ping(host); + end + end); +end diff --git a/plugins/mod_tls.lua b/plugins/mod_tls.lua index b5ca5015..cc46d556 100644 --- a/plugins/mod_tls.lua +++ b/plugins/mod_tls.lua @@ -24,9 +24,10 @@ add_handler("c2s_unauthed", "starttls", xmlns_starttls, end end); +local starttls_attr = { xmlns = xmlns_starttls }; add_event_hook("stream-features", function (session, features) if session.conn.starttls then - t_insert(features, "<starttls xmlns='"..xmlns_starttls.."'/>"); + features:tag("starttls", starttls_attr):up(); end end); diff --git a/plugins/mod_vcard.lua b/plugins/mod_vcard.lua index fb7382c2..d2f2c7ba 100644 --- a/plugins/mod_vcard.lua +++ b/plugins/mod_vcard.lua @@ -43,9 +43,10 @@ add_iq_handler({"c2s", "s2sin"}, "vcard-temp", end end); +local feature_vcard_attr = { var='vcard-temp' }; add_event_hook("stream-features", function (session, features) if session.type == "c2s" then - t_insert(features, "<feature var='vcard-temp'/>"); + features:tag("feature", feature_vcard_attr):up(); end end); diff --git a/tests/reports/empty b/tests/reports/empty new file mode 100644 index 00000000..0e3c9a08 --- /dev/null +++ b/tests/reports/empty @@ -0,0 +1 @@ +This file was intentionally left blank. diff --git a/tests/run_tests.sh b/tests/run_tests.sh new file mode 100755 index 00000000..d93cd39b --- /dev/null +++ b/tests/run_tests.sh @@ -0,0 +1,3 @@ +#!/bin/sh +rm reports/*.report +lua test.lua $* diff --git a/tests/test.lua b/tests/test.lua index c028e859..33af3e98 100644 --- a/tests/test.lua +++ b/tests/test.lua @@ -1,4 +1,11 @@ +function run_all_tests() + dotest "util.jid" + dotest "core.stanza_router" + dotest "core.s2smanager" + dotest "core.configmanager" +end + local verbosity = tonumber(arg[1]) or 2; package.path = package.path..";../?.lua"; @@ -36,7 +43,8 @@ function dotest(unitname) local unit = setmetatable({}, { __index = setmetatable({ module = function () end }, { __index = _G }) }); - local chunk, err = loadfile("../"..unitname:gsub("%.", "/")..".lua"); + local fn = "../"..unitname:gsub("%.", "/")..".lua"; + local chunk, err = loadfile(fn); if not chunk then print("WARNING: ", "Failed to load module: "..unitname, err); return; @@ -50,21 +58,29 @@ function dotest(unitname) end for name, f in pairs(unit) do + local test = rawget(tests, name); if type(f) ~= "function" then if verbosity >= 3 then print("INFO: ", "Skipping "..unitname.."."..name.." because it is not a function"); end - elseif type(tests[name]) ~= "function" then + elseif type(test) ~= "function" then if verbosity >= 1 then print("WARNING: ", unitname.."."..name.." has no test!"); end else - local success, ret = pcall(tests[name], f, unit); + local line_hook, line_info = new_line_coverage_monitor(fn); + debug.sethook(line_hook, "l") + local success, ret = pcall(test, f, unit); + debug.sethook(); if not success then print("TEST FAILED! Unit: ["..unitname.."] Function: ["..name.."]"); print(" Location: "..ret:gsub(":%s*\n", "\n")); + line_info(name, false, report_file); elseif verbosity >= 2 then print("TEST SUCCEEDED: ", unitname, name); + print(string.format("TEST COVERED %d/%d lines", line_info(name, true, report_file))); + else + line_info(name, success, report_file); end end end @@ -81,5 +97,45 @@ function runtest(f, msg) end end -dotest "util.jid" -dotest "core.stanza_router" +function new_line_coverage_monitor(file) + local lines_hit, funcs_hit = {}, {}; + local total_lines, covered_lines = 0, 0; + + for line in io.lines(file) do + total_lines = total_lines + 1; + end + + return function (event, line) -- Line hook + if not lines_hit[line] then + local info = debug.getinfo(2, "fSL") + if not info.source:find(file) then return; end + if not funcs_hit[info.func] and info.activelines then + funcs_hit[info.func] = true; + for line in pairs(info.activelines) do + lines_hit[line] = false; -- Marks it as hittable, but not hit yet + end + end + if lines_hit[line] == false then + --print("New line hit: "..line.." in "..debug.getinfo(2, "S").source); + lines_hit[line] = true; + covered_lines = covered_lines + 1; + end + end + end, + function (test_name, success) -- Get info + local fn = file:gsub("^%W*", ""); + local total_active_lines = 0; + local coverage_file = io.open("reports/coverage_"..fn:gsub("%W+", "_")..".report", "a+"); + for line, active in pairs(lines_hit) do + if active ~= nil then total_active_lines = total_active_lines + 1; end + if coverage_file then + if active == false then coverage_file:write(fn, "|", line, "|", name or "", "|miss\n"); + else coverage_file:write(fn, "|", line, "|", name or "", "|", tostring(success), "\n"); end + end + end + if coverage_file then coverage_file:close(); end + return covered_lines, total_active_lines, lines_hit; + end +end + +run_all_tests() diff --git a/tests/test_core_configmanager.lua b/tests/test_core_configmanager.lua new file mode 100644 index 00000000..099ff3c3 --- /dev/null +++ b/tests/test_core_configmanager.lua @@ -0,0 +1,28 @@ + +function get(get, config) + config.set("example.com", "test", "testkey", 123); + assert_equal(get("example.com", "test", "testkey"), 123, "Retrieving a set key"); + + config.set("*", "test", "testkey1", 321); + assert_equal(get("*", "test", "testkey1"), 321, "Retrieving a set global key"); + assert_equal(get("example.com", "test", "testkey1"), 321, "Retrieving a set key of undefined host, of which only a globally set one exists"); + + config.set("example.com", "test", ""); -- Creates example.com host in config + assert_equal(get("example.com", "test", "testkey1"), 321, "Retrieving a set key, of which only a globally set one exists"); + + assert_equal(get(), nil, "No parameters to get()"); + assert_equal(get("undefined host"), nil, "Getting for undefined host"); + assert_equal(get("undefined host", "undefined section"), nil, "Getting for undefined host & section"); + assert_equal(get("undefined host", "undefined section", "undefined key"), nil, "Getting for undefined host & section & key"); + + assert_equal(get("example.com", "undefined section", "testkey"), nil, "Defined host, undefined section"); +end + +function set(set, u) + assert_equal(set("*"), false, "Set with no section/key"); + assert_equal(set("*", "set_test"), false, "Set with no key"); + + assert_equal(set("*", "set_test", "testkey"), true, "Setting a nil global value"); + assert_equal(set("*", "set_test", "testkey", 123), true, "Setting a global value"); +end + diff --git a/tests/test_core_s2smanager.lua b/tests/test_core_s2smanager.lua new file mode 100644 index 00000000..69715b26 --- /dev/null +++ b/tests/test_core_s2smanager.lua @@ -0,0 +1,38 @@ +function compare_srv_priorities(csp) + local r1 = { priority = 10, weight = 0 } + local r2 = { priority = 100, weight = 0 } + local r3 = { priority = 1000, weight = 2 } + local r4 = { priority = 1000, weight = 2 } + local r5 = { priority = 1000, weight = 5 } + + assert_equal(csp(r1, r1), false); + assert_equal(csp(r1, r2), true); + assert_equal(csp(r1, r3), true); + assert_equal(csp(r1, r4), true); + assert_equal(csp(r1, r5), true); + + assert_equal(csp(r2, r1), false); + assert_equal(csp(r2, r2), false); + assert_equal(csp(r2, r3), true); + assert_equal(csp(r2, r4), true); + assert_equal(csp(r2, r5), true); + + assert_equal(csp(r3, r1), false); + assert_equal(csp(r3, r2), false); + assert_equal(csp(r3, r3), false); + assert_equal(csp(r3, r4), false); + assert_equal(csp(r3, r5), true); + + assert_equal(csp(r4, r1), false); + assert_equal(csp(r4, r2), false); + assert_equal(csp(r4, r3), false); + assert_equal(csp(r4, r4), false); + assert_equal(csp(r4, r5), true); + + assert_equal(csp(r5, r1), false); + assert_equal(csp(r5, r2), false); + assert_equal(csp(r5, r3), false); + assert_equal(csp(r5, r4), false); + assert_equal(csp(r5, r5), false); + +end diff --git a/tests/test_util_jid.lua b/tests/test_util_jid.lua index 1dbd72b7..7a616008 100644 --- a/tests/test_util_jid.lua +++ b/tests/test_util_jid.lua @@ -11,4 +11,24 @@ function split(split) test("server", nil, "server", nil ); test("server/resource", nil, "server", "resource" ); test(nil, nil, nil , nil ); + + test("node@/server", nil, nil, nil , nil ); +end + +function bare(bare) + assert_equal(bare("user@host"), "user@host", "bare JID remains bare"); + assert_equal(bare("host"), "host", "Host JID remains host"); + assert_equal(bare("host/resource"), "host", "Host JID with resource becomes host"); + assert_equal(bare("user@host/resource"), "user@host", "user@host JID with resource becomes user@host"); + assert_equal(bare("user@/resource"), nil, "invalid JID is nil"); + assert_equal(bare("@/resource"), nil, "invalid JID is nil"); + assert_equal(bare("@/"), nil, "invalid JID is nil"); + assert_equal(bare("/"), nil, "invalid JID is nil"); + assert_equal(bare(""), nil, "invalid JID is nil"); + assert_equal(bare("@"), nil, "invalid JID is nil"); + assert_equal(bare("user@"), nil, "invalid JID is nil"); + assert_equal(bare("user@@"), nil, "invalid JID is nil"); + assert_equal(bare("user@@host"), nil, "invalid JID is nil"); + assert_equal(bare("user@@host/resource"), nil, "invalid JID is nil"); + assert_equal(bare("user@host/"), nil, "invalid JID is nil"); end diff --git a/util/discohelper.lua b/util/discohelper.lua new file mode 100644 index 00000000..4ac8f227 --- /dev/null +++ b/util/discohelper.lua @@ -0,0 +1,79 @@ +
+local t_insert = table.insert;
+local jid_split = require "util.jid".split;
+local ipairs = ipairs;
+local st = require "util.stanza";
+
+module "discohelper";
+
+local function addDiscoItemsHandler(self, jid, func)
+ if self.item_handlers[jid] then
+ t_insert(self.item_handlers[jid], func);
+ else
+ self.item_handlers[jid] = {func};
+ end
+end
+
+local function addDiscoInfoHandler(self, jid, func)
+ if self.info_handlers[jid] then
+ t_insert(self.info_handlers[jid], func);
+ else
+ self.info_handlers[jid] = {func};
+ end
+end
+
+local function handle(self, stanza)
+ if stanza.name == "iq" and stanza.tags[1].name == "query" then
+ local query = stanza.tags[1];
+ local to = stanza.attr.to;
+ local from = stanza.attr.from
+ local node = query.attr.node or "";
+ local to_node, to_host = jid_split(to);
+
+ local reply = st.reply(stanza):query(query.attr.xmlns);
+ local handlers;
+ if query.attr.xmlns == "http://jabber.org/protocol/disco#info" then -- select handler set
+ handlers = self.info_handlers;
+ elseif query.attr.xmlns == "http://jabber.org/protocol/disco#items" then
+ handlers = self.item_handlers;
+ end
+ local handler = handlers[to]; -- get the handler
+ if not handler then -- if not found then use default handler
+ if to_node then
+ handler = handlers["*defaultnode"];
+ else
+ handler = handlers["*defaulthost"];
+ end
+ end
+ local found; -- to keep track of any handlers found
+ if handler then
+ for _, h in ipairs(handler) do
+ if h(reply, to, from, node) then found = true; end
+ end
+ end
+ if to_node then -- handlers which get called always
+ handler = handlers["*node"];
+ else
+ handler = handlers["*host"];
+ end
+ if handler then -- call always called handler
+ for _, h in ipairs(handler) do
+ if h(reply, to, from, node) then found = true; end
+ end
+ end
+ if found then return reply; end -- return the reply if there was one
+ return st.error_reply(stanza, "cancel", "service-unavailable");
+ end
+end
+
+function new()
+ return {
+ item_handlers = {};
+ info_handlers = {};
+ addDiscoItemsHandler = addDiscoItemsHandler;
+ addDiscoInfoHandler = addDiscoInfoHandler;
+ handle = handle;
+ };
+end
+
+return _M;
diff --git a/util/jid.lua b/util/jid.lua index b1e4131d..065f176f 100644 --- a/util/jid.lua +++ b/util/jid.lua @@ -5,11 +5,20 @@ module "jid" function split(jid) if not jid then return; end - -- TODO verify JID, and return; if invalid - local node = match(jid, "^([^@]+)@"); - local server = (node and match(jid, ".-@([^@/]+)")) or match(jid, "^([^@/]+)"); - local resource = match(jid, "/(.+)$"); - return node, server, resource; + local node, nodepos = match(jid, "^([^@]+)@()"); + local host, hostpos = match(jid, "^([^@/]+)()", nodepos) + if node and not host then return nil, nil, nil; end + local resource = match(jid, "^/(.+)$", hostpos); + if (not host) or ((not resource) and #jid >= hostpos) then return nil, nil, nil; end + return node, host, resource; end -return _M;
\ No newline at end of file +function bare(jid) + local node, host = split(jid); + if node and host then + return node.."@"..host; + end + return host; +end + +return _M; diff --git a/util/logger.lua b/util/logger.lua index f93cafc1..e9440a04 100644 --- a/util/logger.lua +++ b/util/logger.lua @@ -33,4 +33,4 @@ function init(name) end end -return _M;
\ No newline at end of file +return _M; diff --git a/util/stanza.lua b/util/stanza.lua index 5a6ba8c5..df0d43ff 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -30,6 +30,11 @@ end function stanza_mt:query(xmlns) return self:tag("query", { xmlns = xmlns }); end + +function stanza_mt:body(text, attr) + return self:tag("body", attr):text(text); +end + function stanza_mt:tag(name, attrs) local s = stanza(name, attrs); (self.last_add[#self.last_add] or self):add_direct_child(s); @@ -103,7 +108,7 @@ function stanza_mt.__tostring(t) local attr_string = ""; if t.attr then - for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(" %s='%s'", k, tostring(v)); end end + for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(" %s='%s'", k, xml_escape(tostring(v))); end end end return s_format("<%s%s>%s</%s>", t.name, attr_string, children_text, t.name); end @@ -111,7 +116,7 @@ end function stanza_mt.top_tag(t) local attr_string = ""; if t.attr then - for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(" %s='%s'", k, tostring(v)); end end + for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(" %s='%s'", k, xml_escape(tostring(v))); end end end return s_format("<%s%s>", t.name, attr_string); end diff --git a/util/ztact.lua b/util/ztact.lua new file mode 100644 index 00000000..15bcffad --- /dev/null +++ b/util/ztact.lua @@ -0,0 +1,364 @@ + + +-- public domain 20080410 lua@ztact.com + + +pcall (require, 'lfs') -- lfs may not be installed/necessary. +pcall (require, 'pozix') -- pozix may not be installed/necessary. + + +local getfenv, ipairs, next, pairs, pcall, require, select, tostring, type = + getfenv, ipairs, next, pairs, pcall, require, select, tostring, type +local unpack, xpcall = + unpack, xpcall + +local io, lfs, os, string, table, pozix = io, lfs, os, string, table, pozix + +local assert, print = assert, print + +local error = error + + +module ((...) or 'ztact') ------------------------------------- module ztact + + +-- dir -------------------------------------------------------------------- dir + + +function dir (path) -- - - - - - - - - - - - - - - - - - - - - - - - - - dir + local it = lfs.dir (path) + return function () + repeat + local dir = it () + if dir ~= '.' and dir ~= '..' then return dir end + until not dir + end end + + +function is_file (path) -- - - - - - - - - - - - - - - - - - is_file (path) + local mode = lfs.attributes (path, 'mode') + return mode == 'file' and path + end + + +-- network byte ordering -------------------------------- network byte ordering + + +function htons (word) -- - - - - - - - - - - - - - - - - - - - - - - - htons + return (word-word%0x100)/0x100, word%0x100 + end + + +-- pcall2 -------------------------------------------------------------- pcall2 + + +getfenv ().pcall = pcall -- store the original pcall as ztact.pcall + + +local argc, argv, errorhandler, pcall2_f + + +local function _pcall2 () -- - - - - - - - - - - - - - - - - - - - - _pcall2 + local tmpv = argv + argv = nil + return pcall2_f (unpack (tmpv, 1, argc)) + end + + +function seterrorhandler (func) -- - - - - - - - - - - - - - seterrorhandler + errorhandler = func + end + + +function pcall2 (f, ...) -- - - - - - - - - - - - - - - - - - - - - - pcall2 + + pcall2_f = f + argc = select ('#', ...) + argv = { ... } + + if not errorhandler then + local debug = require ('debug') + errorhandler = debug.traceback + end + + return xpcall (_pcall2, errorhandler) + end + + +function append (t, ...) -- - - - - - - - - - - - - - - - - - - - - - append + local insert = table.insert + for i,v in ipairs {...} do + insert (t, v) + end end + + +function print_r (d, indent) -- - - - - - - - - - - - - - - - - - - print_r + local rep = string.rep (' ', indent or 0) + if type (d) == 'table' then + for k,v in pairs (d) do + if type (v) == 'table' then + io.write (rep, k, '\n') + print_r (v, (indent or 0) + 1) + else io.write (rep, k, ' = ', tostring (v), '\n') end + end + else io.write (d, '\n') end + end + + +function tohex (s) -- - - - - - - - - - - - - - - - - - - - - - - - - tohex + return string.format (string.rep ('%02x ', #s), string.byte (s, 1, #s)) + end + + +function tostring_r (d, indent, tab0) -- - - - - - - - - - - - - tostring_r + + tab1 = tab0 or {} + local rep = string.rep (' ', indent or 0) + if type (d) == 'table' then + for k,v in pairs (d) do + if type (v) == 'table' then + append (tab1, rep, k, '\n') + tostring_r (v, (indent or 0) + 1, tab1) + else append (tab1, rep, k, ' = ', tostring (v), '\n') end + end + else append (tab1, d, '\n') end + + if not tab0 then return table.concat (tab1) end + end + + +-- queue manipulation -------------------------------------- queue manipulation + + +-- Possible queue states. 1 (i.e. queue.p[1]) is head of queue. +-- +-- 1..2 +-- 3..4 1..2 +-- 3..4 1..2 5..6 +-- 1..2 5..6 +-- 1..2 + + +local function print_queue (queue, ...) -- - - - - - - - - - - - print_queue + for i=1,10 do io.write ((queue[i] or '.')..' ') end + io.write ('\t') + for i=1,6 do io.write ((queue.p[i] or '.')..' ') end + print (...) + end + + +function dequeue (queue) -- - - - - - - - - - - - - - - - - - - - - dequeue + + local p = queue.p + if not p and queue[1] then queue.p = { 1, #queue } p = queue.p end + + if not p[1] then return nil end + + local element = queue[p[1]] + queue[p[1]] = nil + + if p[1] < p[2] then p[1] = p[1] + 1 + + elseif p[4] then p[1], p[2], p[3], p[4] = p[3], p[4], nil, nil + + elseif p[5] then p[1], p[2], p[5], p[6] = p[5], p[6], nil, nil + + else p[1], p[2] = nil, nil end + + print_queue (queue, ' de '..element) + return element + end + + +function enqueue (queue, element) -- - - - - - - - - - - - - - - - - enqueue + + local p = queue.p + if not p then queue.p = {} p = queue.p end + + if p[5] then -- p3..p4 p1..p2 p5..p6 + p[6] = p[6]+1 + queue[p[6]] = element + + elseif p[3] then -- p3..p4 p1..p2 + + if p[4]+1 < p[1] then + p[4] = p[4] + 1 + queue[p[4]] = element + + else + p[5] = p[2]+1 + p[6], queue[p[5]] = p[5], element + end + + elseif p[1] then -- p1..p2 + if p[1] == 1 then + p[2] = p[2] + 1 + queue[p[2]] = element + + else + p[3], p[4], queue[1] = 1, 1, element + end + + else -- empty queue + p[1], p[2], queue[1] = 1, 1, element + end + + print_queue (queue, ' '..element) + end + + +local function test_queue () + t = {} + enqueue (t, 1) + enqueue (t, 2) + enqueue (t, 3) + enqueue (t, 4) + enqueue (t, 5) + dequeue (t) + dequeue (t) + enqueue (t, 6) + enqueue (t, 7) + enqueue (t, 8) + enqueue (t, 9) + dequeue (t) + dequeue (t) + dequeue (t) + dequeue (t) + enqueue (t, 'a') + dequeue (t) + enqueue (t, 'b') + enqueue (t, 'c') + dequeue (t) + dequeue (t) + dequeue (t) + dequeue (t) + dequeue (t) + enqueue (t, 'd') + dequeue (t) + dequeue (t) + dequeue (t) + end + + +-- test_queue () + + +function queue_len (queue) + end + + +function queue_peek (queue) + end + + +-- tree manipulation ---------------------------------------- tree manipulation + + +function set (parent, ...) --- - - - - - - - - - - - - - - - - - - - - - set + + -- print ('set', ...) + + local len = select ('#', ...) + local key, value = select (len-1, ...) + local cutpoint, cutkey + + for i=1,len-2 do + + local key = select (i, ...) + local child = parent[key] + + if value == nil then + if child == nil then return + elseif next (child, next (child)) then cutpoint = nil cutkey = nil + elseif cutpoint == nil then cutpoint = parent cutkey = key end + + elseif child == nil then child = {} parent[key] = child end + + parent = child + end + + if value == nil and cutpoint then cutpoint[cutkey] = nil + else parent[key] = value return value end + end + + +function get (parent, ...) --- - - - - - - - - - - - - - - - - - - - - - get + local len = select ('#', ...) + for i=1,len do + parent = parent[select (i, ...)] + if parent == nil then break end + end + return parent + end + + +-- misc ------------------------------------------------------------------ misc + + +function find (path, ...) --------------------------------------------- find + + local dirs, operators = { path }, {...} + for operator in ivalues (operators) do + if not operator (path) then break end end + + while next (dirs) do + local parent = table.remove (dirs) + for child in assert (pozix.opendir (parent)) do + if child and child ~= '.' and child ~= '..' then + local path = parent..'/'..child + if pozix.stat (path, 'is_dir') then table.insert (dirs, path) end + for operator in ivalues (operators) do + if not operator (path) then break end end + end end end end + + +function ivalues (t) ----------------------------------------------- ivalues + local i = 0 + return function () if t[i+1] then i = i + 1 return t[i] end end + end + + +function lson_encode (mixed, f, indent, indents) --------------- lson_encode + + + local capture + if not f then + capture = {} + f = function (s) append (capture, s) end + end + + indent = indent or 0 + indents = indents or {} + indents[indent] = indents[indent] or string.rep (' ', 2*indent) + + local type = type (mixed) + + if type == 'number' then f (mixed) + + else if type == 'string' then f (string.format ('%q', mixed)) + + else if type == 'table' then + f ('{') + for k,v in pairs (mixed) do + f ('\n') + f (indents[indent]) + f ('[') f (lson_encode (k)) f ('] = ') + lson_encode (v, f, indent+1, indents) + f (',') + end + f (' }') + end end end + + if capture then return table.concat (capture) end + end + + +function timestamp (time) ---------------------------------------- timestamp + return os.date ('%Y%m%d.%H%M%S', time) + end + + +function values (t) ------------------------------------------------- values + local k, v + return function () k, v = next (t, k) return v end + end |