From 78739d9638b857bf5c8ce249036818a92eceae46 Mon Sep 17 00:00:00 2001 From: matthew Date: Sun, 24 Aug 2008 01:51:02 +0000 Subject: Switched to new connection framework, courtesy of the luadch project Now supports SSL on 5223 Beginning support for presence (aka. the proper routing of stanzas) --- core/stanza_dispatch.lua | 71 ++++-- core/xmlhandlers.lua | 92 +++++++ main.lua | 300 +++++++---------------- server.lua | 611 +++++++++++++++++++++++++++++++++++++++++++++++ util/stanza.lua | 34 ++- 5 files changed, 870 insertions(+), 238 deletions(-) create mode 100644 core/xmlhandlers.lua create mode 100644 server.lua diff --git a/core/stanza_dispatch.lua b/core/stanza_dispatch.lua index b7428ecd..e392b5ba 100644 --- a/core/stanza_dispatch.lua +++ b/core/stanza_dispatch.lua @@ -12,14 +12,17 @@ function init_stanza_dispatcher(session) local session_log = session.log; local log = function (type, msg) session_log(type, "stanza_dispatcher", msg); end local send = session.send; - - + local send_to; + do + local _send_to = session.send_to; + send_to = function (...) _send_to(session, ...); end + end iq_handlers["jabber:iq:auth"] = function (stanza) - local username = stanza[1]:child_with_name("username"); - local password = stanza[1]:child_with_name("password"); - local resource = stanza[1]:child_with_name("resource"); + local username = stanza.tags[1]:child_with_name("username"); + local password = stanza.tags[1]:child_with_name("password"); + local resource = stanza.tags[1]:child_with_name("resource"); if not (username and password and resource) then local reply = st.reply(stanza); send(reply:query("jabber:iq:auth") @@ -78,24 +81,52 @@ function init_stanza_dispatcher(session) return function (stanza) log("info", "--> "..tostring(stanza)); - if stanza.name == "iq" then - if not stanza[1] then log("warn", " without child is invalid"); return; end - if not stanza.attr.id then log("warn", " without id attribute is invalid"); end - local xmlns = stanza[1].attr.xmlns; - if not xmlns then log("warn", "Child of has no xmlns - invalid"); return; end - if (((not stanza.attr.to) or stanza.attr.to == session.host or stanza.attr.to:match("@[^/]+$")) and (stanza.attr.type == "get" or stanza.attr.type == "set")) then -- Stanza sent to us - if iq_handlers[xmlns] then - if iq_handlers[xmlns](stanza) then return; end; + if (not stanza.attr.to) or (hosts[stanza.attr.to] and hosts[stanza.attr.to].type == "local") then + if stanza.name == "iq" then + if not stanza.tags[1] then log("warn", " without child is invalid"); return; end + if not stanza.attr.id then log("warn", " without id attribute is invalid"); end + local xmlns = (stanza.tags[1].attr and stanza.tags[1].attr.xmlns) or nil; + if not xmlns then log("warn", "Child of has no xmlns - invalid"); return; end + if (((not stanza.attr.to) or stanza.attr.to == session.host or stanza.attr.to:match("@[^/]+$")) and (stanza.attr.type == "get" or stanza.attr.type == "set")) then -- Stanza sent to us + if iq_handlers[xmlns] then + if iq_handlers[xmlns](stanza) then return; end; + end + log("warn", "Unhandled namespace: "..xmlns); + send(format("", stanza.attr.id)); + return; + end + elseif stanza.name == "presence" then + if session.roster then + -- Broadcast presence and probes + local broadcast = st.presence({ from = session.username.."@"..session.host.."/"..session.resource }); + local probe = st.presence { from = broadcast.attr.from, type = "probe" }; + + for child in stanza:children() do + broadcast:tag(child.name, child.attr); + end + for contact in pairs(session.roster) do + broadcast.attr.to = contact; + send_to(contact, broadcast); + --local host = jid.host(contact); + --if hosts[host] and hosts[host].type == "local" then + --local node, host = jid.split(contact); + --if host[host].sessions[node] + --local pres = st.presence { from = con + --else + -- probe.attr.to = contact; + -- send_to(contact, probe); + --end + end + + -- Probe for our contacts' presence end - log("warn", "Unhandled namespace: "..xmlns); - send(format("", stanza.attr.id)); end - - end - -- Need to route stanza - if stanza.attr.to and ((not hosts[stanza.attr.to]) or hosts[stanza.attr.to].type ~= "local") then + else + --end + --if stanza.attr.to and ((not hosts[stanza.attr.to]) or hosts[stanza.attr.to].type ~= "local") then + -- Need to route stanza stanza.attr.from = session.username.."@"..session.host; - session.send_to(stanza.attr.to, stanza); + session:send_to(stanza.attr.to, stanza); end end diff --git a/core/xmlhandlers.lua b/core/xmlhandlers.lua new file mode 100644 index 00000000..4d536ce3 --- /dev/null +++ b/core/xmlhandlers.lua @@ -0,0 +1,92 @@ + +require "util.stanza" + +local st = stanza; +local tostring = tostring; +local format = string.format; +local m_random = math.random; +local t_insert = table.insert; +local t_remove = table.remove; +local t_concat = table.concat; +local t_concatall = function (t, sep) local tt = {}; for _, s in ipairs(t) do t_insert(tt, tostring(s)); end return t_concat(tt, sep); end + +local error = error; + +module "xmlhandlers" + +function init_xmlhandlers(session) + local ns_stack = { "" }; + local curr_ns = ""; + local curr_tag; + local chardata = {}; + local xml_handlers = {}; + local log = session.log; + local print = function (...) log("info", "xmlhandlers", t_concatall({...}, "\t")); end + + local send = session.send; + + local stanza + function xml_handlers:StartElement(name, attr) + if stanza and #chardata > 0 then + stanza:text(t_concat(chardata)); + print("Char data:", t_concat(chardata)); + chardata = {}; + end + curr_ns,name = name:match("^(.+):(%w+)$"); + print("Tag received:", name, tostring(curr_ns)); + if not stanza then + if session.notopen then + if name == "stream" then + session.host = attr.to or error("Client failed to specify destination hostname"); + session.version = attr.version or 0; + session.streamid = m_random(1000000, 99999999); + print(session, session.host, "Client opened stream"); + send(""); + send(format("", session.streamid, session.host)); + --send(""); + --send("PLAIN"); + --send [[ ]] + --send(""); + log("info", "core", "Stream opened successfully"); + session.notopen = nil; + return; + end + error("Client failed to open stream successfully"); + end + if name ~= "iq" and name ~= "presence" and name ~= "message" then + error("Client sent invalid top-level stanza"); + end + stanza = st.stanza(name, { to = attr.to, type = attr.type, id = attr.id, xmlns = curr_ns }); + curr_tag = stanza; + else + attr.xmlns = curr_ns; + stanza:tag(name, attr); + end + end + function xml_handlers:CharacterData(data) + if stanza then + t_insert(chardata, data); + end + end + function xml_handlers:EndElement(name) + curr_ns,name = name:match("^(.+):(%w+)$"); + --print("<"..name.."/>", tostring(stanza), tostring(#stanza.last_add < 1), tostring(stanza.last_add[#stanza.last_add].name)); + if (not stanza) or #stanza.last_add < 0 or (#stanza.last_add > 0 and name ~= stanza.last_add[#stanza.last_add].name) then error("XML parse error in client stream"); end + if stanza and #chardata > 0 then + stanza:text(t_concat(chardata)); + print("Char data:", t_concat(chardata)); + chardata = {}; + end + -- Complete stanza + print(name, tostring(#stanza.last_add)); + if #stanza.last_add == 0 then + session.stanza_dispatch(stanza); + stanza = nil; + else + stanza:up(); + end + end + return xml_handlers; +end + +return init_xmlhandlers; diff --git a/main.lua b/main.lua index cb6e03fd..41d812bc 100644 --- a/main.lua +++ b/main.lua @@ -1,6 +1,6 @@ require "luarocks.require" -require "copas" +server = require "server" require "socket" require "ssl" require "lxp" @@ -10,8 +10,10 @@ function log(type, area, message) end require "core.stanza_dispatch" +init_xmlhandlers = require "core.xmlhandlers" require "core.rostermanager" require "core.offlinemessage" +require "core.usermanager" require "util.stanza" require "util.jid" @@ -24,7 +26,7 @@ local format = string.format; local st = stanza; ------------------------------ -users = {}; +sessions = {}; hosts = { ["localhost"] = { type = "local"; @@ -40,236 +42,118 @@ hosts = { local hosts, users = hosts, users; -local ssl_ctx, msg = ssl.newcontext { mode = "server", protocol = "sslv23", key = "/home/matthew/ssl_cert/server.key", +--local ssl_ctx, msg = ssl.newcontext { mode = "server", protocol = "sslv23", key = "/home/matthew/ssl_cert/server.key", +-- certificate = "/home/matthew/ssl_cert/server.crt", capath = "/etc/ssl/certs", verify = "none", } +-- +--if not ssl_ctx then error("Failed to initialise SSL/TLS support: "..tostring(msg)); end + + +local ssl_ctx = { mode = "server", protocol = "sslv23", key = "/home/matthew/ssl_cert/server.key", certificate = "/home/matthew/ssl_cert/server.crt", capath = "/etc/ssl/certs", verify = "none", } - -if not ssl_ctx then error("Failed to initialise SSL/TLS support: "..tostring(msg)); end function connect_host(host) hosts[host] = { type = "remote", sendbuffer = {} }; end -function handler(conn) - local copas_receive, copas_send = copas.receive, copas.send; - local reqdata, sktmsg; - local session = { sendbuffer = { external = {} }, conn = conn, notopen = true, priority = 0 } - - - -- Logging functions -- - - local mainlog, log = log; - do - local conn_name = tostring(conn):match("%w+$"); - log = function (type, area, message) mainlog(type, conn_name, message); end - end - local print = function (...) log("info", "core", t_concatall({...}, "\t")); end - session.log = log; - - -- -- -- - - -- Send buffers -- - - local sendbuffer = session.sendbuffer; - local send = function (data) return t_insert(sendbuffer, tostring(data)); end; - local send_to = function (to, stanza) - local node, host, resource = jid.split(to); - print("Routing stanza to "..to..":", node, host, resource); - if not hosts[host] then - print(" ...but host offline, establishing connection"); - connect_host(host); - t_insert(hosts[host].sendbuffer, stanza); -- This will be sent when s2s connection succeeds - elseif hosts[host].connected then - print(" ...putting in our external send buffer"); - t_insert(sendbuffer.external, { node = node, host = host, resource = resource, data = stanza}); - print(" ...there are now "..tostring(#sendbuffer.external).." stanzas in the external send buffer"); +local function send_to(session, to, stanza) + local node, host, resource = jid.split(to); + if not hosts[host] then + -- s2s + elseif hosts[host].type == "local" then + print(" ...is to a local user") + local destuser = hosts[host].sessions[node]; + if destuser and destuser.sessions then + if not destuser.sessions[resource] then + local best_session; + for resource, session in pairs(destuser.sessions) do + if not best_session then best_session = session; + elseif session.priority >= best_session.priority and session.priority >= 0 then + best_session = session; end end - session.send, session.send_to = send, send_to; - - -- -- -- - print("Client connected"); - conn = ssl.wrap(copas.wrap(conn), ssl_ctx); - - do - local succ, msg - conn:settimeout(15) - while not succ do - succ, msg = conn:dohandshake() - if not succ then - print("SSL: "..tostring(msg)); - if msg == 'wantread' then - socket.select({conn}, nil) - elseif msg == 'wantwrite' then - socket.select(nil, {conn}) + if not best_session then + offlinemessage.new(node, host, stanza); else - -- other error + print("resource '"..resource.."' was not online, have chosen to send to '"..best_session.username.."@"..best_session.host.."/"..best_session.resource.."'"); + resource = best_session.resource; end end + if destuser.sessions[resource] == session then + log("warn", "core", "Attempt to send stanza to self, dropping..."); + else + print("...sending...", tostring(stanza)); + --destuser.sessions[resource].conn.write(tostring(data)); + print(" to conn ", destuser.sessions[resource].conn); + destuser.sessions[resource].conn.write(tostring(stanza)); + print("...sent") + end + elseif stanza.name == "message" then + print(" ...will be stored offline"); + offlinemessage.new(node, host, stanza); + elseif stanza.name == "iq" then + print(" ...is an iq"); + session.send(st.reply(stanza) + :tag("error", { type = "cancel" }) + :tag("service-unavailable", { xmlns = "urn:ietf:params:xml:ns:xmpp-stanzas" })); end + print(" ...done routing"); end - print("SSL handshake complete"); - -- XML parser initialisation -- +end - local parser; - local stanza; +function handler(conn, data, err) + local session = sessions[conn]; - local stanza_dispatch = init_stanza_dispatcher(session); + if not session then + sessions[conn] = { conn = conn, notopen = true, priority = 0 }; + session = sessions[conn]; - local xml_handlers = {}; - - do - local ns_stack = { "" }; - local curr_ns = ""; - local curr_tag; - function xml_handlers:StartElement(name, attr) - curr_ns,name = name:match("^(.+):(%w+)$"); - print("Tag received:", name, tostring(curr_ns)); - if not stanza then - if session.notopen then - if name == "stream" then - session.host = attr.to or error("Client failed to specify destination hostname"); - session.version = attr.version or 0; - session.streamid = m_random(1000000, 99999999); - print(session, session.host, "Client opened stream"); - send(""); - send(format("", session.streamid, session.host)); - --send(""); - --send("PLAIN"); - --send [[ ]] - --send(""); - log("info", "core", "Stream opened successfully"); - session.notopen = nil; - return; - end - error("Client failed to open stream successfully"); - end - if name ~= "iq" and name ~= "presence" and name ~= "message" then - error("Client sent invalid top-level stanza"); - end - stanza = st.stanza(name, { to = attr.to, type = attr.type, id = attr.id, xmlns = curr_ns }); - curr_tag = stanza; - else - attr.xmlns = curr_ns; - stanza:tag(name, attr); - end - end - function xml_handlers:CharacterData(data) - if data:match("%S") then - stanza:text(data); - end - end - function xml_handlers:EndElement(name) - curr_ns,name = name:match("^(.+):(%w+)$"); - --print("<"..name.."/>", tostring(stanza), tostring(#stanza.last_add < 1), tostring(stanza.last_add[#stanza.last_add].name)); - if (not stanza) or #stanza.last_add < 0 or (#stanza.last_add > 0 and name ~= stanza.last_add[#stanza.last_add].name) then error("XML parse error in client stream"); end - -- Complete stanza - print(name, tostring(#stanza.last_add)); - if #stanza.last_add == 0 then - stanza_dispatch(stanza); - stanza = nil; - else - stanza:up(); - end - end ---[[ function xml_handlers:StartNamespaceDecl(namespace) - table.insert(ns_stack, namespace); - curr_ns = namespace; - log("debug", "parser", "Entering namespace "..tostring(curr_ns)); - end - function xml_handlers:EndNamespaceDecl(namespace) - table.remove(ns_stack); - log("debug", "parser", "Leaving namespace "..tostring(curr_ns)); - curr_ns = ns_stack[#ns_stack]; - log("debug", "parser", "Entering namespace "..tostring(curr_ns)); + -- Logging functions -- + + local mainlog, log = log; + do + local conn_name = tostring(conn):match("%w+$"); + log = function (type, area, message) mainlog(type, conn_name, message); end end -]] - end - parser = lxp.new(xml_handlers, ":"); + local print = function (...) log("info", "core", t_concatall({...}, "\t")); end + session.log = log; - -- -- -- + -- -- -- - -- Main loop -- - print "Receiving..." - reqdata = copas_receive(conn, 1); - print "Received" - while reqdata do - parser:parse(reqdata); - if #sendbuffer.external > 0 then - -- Stanzas queued to go to other places, from us - -- ie. other local users, or remote hosts that weren't connected before - print(#sendbuffer.external.." stanzas queued for other recipients, sending now..."); - for n, packet in pairs(sendbuffer.external) do - if not hosts[packet.host] then - connect_host(packet.host); - t_insert(hosts[packet.host].sendbuffer, packet.data); - elseif hosts[packet.host].type == "local" then - print(" ...is to a local user") - local destuser = hosts[packet.host].sessions[packet.node]; - if destuser and destuser.sessions then - if not destuser.sessions[packet.resource] then - local best_resource; - for resource, session in pairs(destuser.sessions) do - if not best_session then best_session = session; - elseif session.priority >= best_session.priority and session.priority >= 0 then - best_session = session; - end - end - if not best_session then - offlinemessage.new(packet.node, packet.host, packet.data); - else - print("resource '"..packet.resource.."' was not online, have chosen to send to '"..best_session.username.."@"..best_session.host.."/"..best_session.resource.."'"); - packet.resource = best_session.resource; - end - end - if destuser.sessions[packet.resource] == session then - log("warn", "core", "Attempt to send stanza to self, dropping..."); - else - print("...sending..."); - copas_send(destuser.sessions[packet.resource].conn, tostring(packet.data)); - print("...sent") - end - elseif packet.data.name == "message" then - print(" ...will be stored offline"); - offlinemessage.new(packet.node, packet.host, packet.data); - elseif packet.data.name == "iq" then - print(" ...is an iq"); - send(st.reply(packet.data) - :tag("error", { type = "cancel" }) - :tag("service-unavailable", { xmlns = "urn:ietf:params:xml:ns:xmpp-stanzas" })); - end - print(" ...removing from send buffer"); - sendbuffer.external[n] = nil; - end - end - end + -- Send buffers -- + + local send = function (data) print("Sending...", tostring(data)); conn.write(tostring(data)); end; + session.send, session.send_to = send, send_to; + + print("Client connected"); - if #sendbuffer > 0 then - for n, data in ipairs(sendbuffer) do - print "Sending..." - copas_send(conn, data); - print "Sent" - sendbuffer[n] = nil; + session.stanza_dispatch = init_stanza_dispatcher(session); + session.xml_handlers = init_xmlhandlers(session); + session.parser = lxp.new(session.xml_handlers, ":"); + + function session.disconnect(err) + print("Disconnected: "..err); end end - print "Receiving..." - repeat - reqdata, sktmsg = copas_receive(conn, 1); - if sktmsg == 'wantread' then - print("Received... wantread"); - --socket.select({conn}, nil) - --print("Socket ready now..."); - elseif sktmsg then - print("Received socket message:", sktmsg); - end - until reqdata or sktmsg == "closed"; - print("Received", tostring(reqdata)); + if data then + session.parser:parse(data); end - log("info", "core", "Client disconnected, connection closed"); + + --log("info", "core", "Client disconnected, connection closed"); +end + +function disconnect(conn, err) + sessions[conn].disconnect(err); end -server = socket.bind("*", 5223) -assert(server, "Failed to bind to socket") -copas.addserver(server, handler) +print("ssl_ctx:", type(ssl_ctx)); + +setmetatable(_G, { __index = function (t, k) print("WARNING: ATTEMPT TO READ A NIL GLOBAL!!!", k); error("Attempt to read a non-existent global. Naughty boy.", 2); end, __newindex = function (t, k, v) print("ATTEMPT TO SET A GLOBAL!!!!", tostring(k).." = "..tostring(v)); error("Attempt to set a global. Naughty boy.", 2); end }) --]][][[]][]; + + +local protected_handler = function (...) local success, ret = pcall(handler, ...); if not success then print("ERROR on "..tostring((select(1, ...)))..": "..ret); end end; + +print( server.add( { listener = protected_handler, disconnect = disconnect }, 5222, "*", 1, nil ) ) -- server.add will send a status message +print( server.add( { listener = protected_handler, disconnect = disconnect }, 5223, "*", 1, ssl_ctx ) ) -- server.add will send a status message -copas.loop(); +server.loop(); diff --git a/server.lua b/server.lua new file mode 100644 index 00000000..0f731315 --- /dev/null +++ b/server.lua @@ -0,0 +1,611 @@ +--[[ + + server.lua by blastbeat + + - this script contains the server loop of the program + - other scripts can reg a server here + +]]-- + +----------------------------------// DECLARATION //-- + +--// constants //-- + +local STAT_UNIT = 1 / ( 1024 * 1024 ) -- mb + +--// lua functions //-- + +local function use( what ) return _G[ what ] end + +local type = use "type" +local pairs = use "pairs" +local ipairs = use "ipairs" +local tostring = use "tostring" +local collectgarbage = use "collectgarbage" + +--// lua libs //-- + +local table = use "table" +local coroutine = use "coroutine" + +--// lua lib methods //-- + +local table_concat = table.concat +local table_remove = table.remove +local string_sub = use'string'.sub +local coroutine_wrap = coroutine.wrap +local coroutine_yield = coroutine.yield +local print = print; +local out_put = function () end --print; +local out_error = print; + +--// extern libs //-- + +local luasec = require "ssl" +local luasocket = require "socket" + +--// extern lib methods //-- + +local ssl_wrap = ( luasec and luasec.wrap ) +local socket_bind = luasocket.bind +local socket_select = luasocket.select +local ssl_newcontext = ( luasec and luasec.newcontext ) + +--// functions //-- + +local loop +local stats +local addtimer +local closeall +local addserver +local firetimer +local closesocket +local removesocket +local wrapserver +local wraptcpclient +local wrapsslclient + +--// tables //-- + +local listener +local readlist +local writelist +local socketlist +local timelistener + +--// simple data types //-- + +local _ +local readlen = 0 -- length of readlist +local writelen = 0 -- lenght of writelist + +local sendstat= 0 +local receivestat = 0 + +----------------------------------// DEFINITION //-- + +listener = { } -- key = port, value = table +readlist = { } -- array with sockets to read from +writelist = { } -- arrary with sockets to write to +socketlist = { } -- key = socket, value = wrapped socket +timelistener = { } + +stats = function( ) + return receivestat, sendstat +end + +wrapserver = function( listener, socket, ip, serverport, mode, sslctx ) -- this function wraps a server + + local dispatch, disconnect = listener.listener, listener.disconnect -- dangerous + + local wrapclient, err + + if sslctx then + if not ssl_newcontext then + return nil, "luasec not found" +-- elseif not cfg_get "use_ssl" then +-- return nil, "ssl is deactivated" + end + if type( sslctx ) ~= "table" then + out_error "server.lua: wrong server sslctx" + return nil, "wrong server sslctx" + end + sslctx, err = ssl_newcontext( sslctx ) + if not sslctx then + err = err or "wrong sslctx parameters" + out_error( "server.lua: ", err ) + return nil, err + end + wrapclient = wrapsslclient + else + wrapclient = wraptcpclient + end + + local accept = socket.accept + local close = socket.close + + --// public methods of the object //-- + + local handler = { } + + handler.shutdown = function( ) end + + --[[handler.listener = function( data, err ) + return ondata( handler, data, err ) + end]] + handler.ssl = function( ) + return sslctx and true or false + end + handler.close = function( closed ) + _ = not closed and close( socket ) + writelen = removesocket( writelist, socket, writelen ) + readlen = removesocket( readlist, socket, readlen ) + socketlist[ socket ] = nil + handler = nil + end + handler.ip = function( ) + return ip + end + handler.serverport = function( ) + return serverport + end + handler.socket = function( ) + return socket + end + handler.receivedata = function( ) + local client, err = accept( socket ) -- try to accept + if client then + local ip, clientport = client:getpeername( ) + client:settimeout( 0 ) + local handler, client, err = wrapclient( listener, client, ip, serverport, clientport, mode, sslctx ) -- wrap new client socket + if err then -- error while wrapping ssl socket + return false + end + out_put( "server.lua: accepted new client connection from ", ip, ":", clientport ) + return dispatch( handler ) + elseif err then -- maybe timeout or something else + out_put( "server.lua: error with new client connection: ", err ) + return false + end + end + return handler +end + +wrapsslclient = function( listener, socket, ip, serverport, clientport, mode, sslctx ) -- this function wraps a ssl cleint + + local dispatch, disconnect = listener.listener, listener.disconnect + + --// transform socket to ssl object //-- + + local err + socket, err = ssl_wrap( socket, sslctx ) -- wrap socket + if err then + out_put( "server.lua: ssl error: ", err ) + return nil, nil, err -- fatal error + end + socket:settimeout( 0 ) + + --// private closures of the object //-- + + local writequeue = { } -- buffer for messages to send + + local eol -- end of buffer + + local sstat, rstat = 0, 0 + + --// local import of socket methods //-- + + local send = socket.send + local receive = socket.receive + local close = socket.close + --local shutdown = socket.shutdown + + --// public methods of the object //-- + + local handler = { } + + handler.getstats = function( ) + return rstat, sstat + end + + handler.listener = function( data, err ) + return listener( handler, data, err ) + end + handler.ssl = function( ) + return true + end + handler.send = function( _, data, i, j ) + return send( socket, data, i, j ) + end + handler.receive = function( pattern, prefix ) + return receive( socket, pattern, prefix ) + end + handler.shutdown = function( pattern ) + --return shutdown( socket, pattern ) + end + handler.close = function( closed ) + close( socket ) + writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen + readlen = removesocket( readlist, socket, readlen ) + socketlist[ socket ] = nil + out_put "server.lua: closed handler and removed socket from list" + end + handler.ip = function( ) + return ip + end + handler.serverport = function( ) + return serverport + end + handler.clientport = function( ) + return clientport + end + + handler.write = function( data ) + if not eol then + writelen = writelen + 1 + writelist[ writelen ] = socket + eol = 0 + end + eol = eol + 1 + writequeue[ eol ] = data + end + handler.writequeue = function( ) + return writequeue + end + handler.socket = function( ) + return socket + end + handler.mode = function( ) + return mode + end + handler._receivedata = function( ) + local data, err, part = receive( socket, mode ) -- receive data in "mode" + if not err or ( err == "timeout" or err == "wantread" ) then -- received something + local data = data or part or "" + local count = #data * STAT_UNIT + rstat = rstat + count + receivestat = receivestat + count + out_put( "server.lua: read data '", data, "', error: ", err ) + return dispatch( handler, data, err ) + else -- connections was closed or fatal error + out_put( "server.lua: client ", ip, ":", clientport, " error: ", err ) + handler.close( ) + disconnect( handler, err ) + writequeue = nil + handler = nil + return false + end + end + handler._dispatchdata = function( ) -- this function writes data to handlers + local buffer = table_concat( writequeue, "", 1, eol ) + local succ, err, byte = send( socket, buffer ) + local count = ( succ or 0 ) * STAT_UNIT + sstat = sstat + count + sendstat = sendstat + count + out_put( "server.lua: sended '", buffer, "', bytes: ", succ, ", error: ", err, ", part: ", byte, ", to: ", ip, ":", clientport ) + if succ then -- sending succesful + --writequeue = { } + eol = nil + writelen = removesocket( writelist, socket, writelen ) -- delete socket from writelist + return true + elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write + buffer = string_sub( buffer, byte + 1, -1 ) -- new buffer + writequeue[ 1 ] = buffer -- insert new buffer in queue + eol = 1 + return true + else -- connection was closed during sending or fatal error + out_put( "server.lua: client ", ip, ":", clientport, " error: ", err ) + handler.close( ) + disconnect( handler, err ) + writequeue = nil + handler = nil + return false + end + end + + -- // COMPAT // -- + + handler.getIp = handler.ip + handler.getPort = handler.clientport + + --// handshake //-- + + local wrote + + handler.handshake = coroutine_wrap( function( client ) + local err + for i = 1, 10 do -- 10 handshake attemps + _, err = client:dohandshake( ) + if not err then + out_put( "server.lua: ssl handshake done" ) + writelen = ( wrote and removesocket( writelist, socket, writelen ) ) or writelen + handler.receivedata = handler._receivedata -- when handshake is done, replace the handshake function with regular functions + handler.dispatchdata = handler._dispatchdata + return dispatch( handler ) + else + out_put( "server.lua: error during ssl handshake: ", err ) + if err == "wantwrite" then + if wrote == nil then + writelen = writelen + 1 + writelist[ writelen ] = client + wrote = true + end + end + coroutine_yield( handler, nil, err ) -- handshake not finished + end + end + _ = err ~= "closed" and close( socket ) + handler.close( ) + disconnect( handler, err ) + writequeue = nil + handler = nil + return false -- handshake failed + end + ) + handler.receivedata = handler.handshake + handler.dispatchdata = handler.handshake + + handler.handshake( socket ) -- do handshake + + socketlist[ socket ] = handler + readlen = readlen + 1 + readlist[ readlen ] = socket + + return handler, socket +end + +wraptcpclient = function( listener, socket, ip, serverport, clientport, mode ) -- this function wraps a socket + + local dispatch, disconnect = listener.listener, listener.disconnect + + --// private closures of the object //-- + + local writequeue = { } -- list for messages to send + + local eol + + local rstat, sstat = 0, 0 + + --// local import of socket methods //-- + + local send = socket.send + local receive = socket.receive + local close = socket.close + local shutdown = socket.shutdown + + --// public methods of the object //-- + + local handler = { } + + handler.getstats = function( ) + return rstat, sstat + end + + handler.listener = function( data, err ) + return listener( handler, data, err ) + end + handler.ssl = function( ) + return false + end + handler.send = function( _, data, i, j ) + return send( socket, data, i, j ) + end + handler.receive = function( pattern, prefix ) + return receive( socket, pattern, prefix ) + end + handler.shutdown = function( pattern ) + return shutdown( socket, pattern ) + end + handler.close = function( closed ) + _ = not closed and shutdown( socket ) + _ = not closed and close( socket ) + writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen + readlen = removesocket( readlist, socket, readlen ) + socketlist[ socket ] = nil + out_put "server.lua: closed handler and removed socket from list" + end + handler.ip = function( ) + return ip + end + handler.serverport = function( ) + return serverport + end + handler.clientport = function( ) + return clientport + end + handler.write = function( data ) + if not eol then + writelen = writelen + 1 + writelist[ writelen ] = socket + eol = 0 + end + eol = eol + 1 + writequeue[ eol ] = data + end + handler.writequeue = function( ) + return writequeue + end + handler.socket = function( ) + return socket + end + handler.mode = function( ) + return mode + end + handler.receivedata = function( ) + local data, err, part = receive( socket, mode ) -- receive data in "mode" + if not err or ( err == "timeout" or err == "wantread" ) then -- received something + local data = data or part or "" + local count = #data * STAT_UNIT + rstat = rstat + count + receivestat = receivestat + count + out_put( "server.lua: read data '", data, "', error: ", err ) + return dispatch( handler, data, err ) + else -- connections was closed or fatal error + out_put( "server.lua: client ", ip, ":", clientport, " error: ", err ) + handler.close( ) + disconnect( handler, err ) + writequeue = nil + handler = nil + return false + end + end + handler.dispatchdata = function( ) -- this function writes data to handlers + local buffer = table_concat( writequeue, "", 1, eol ) + local succ, err, byte = send( socket, buffer ) + local count = ( succ or 0 ) * STAT_UNIT + sstat = sstat + count + sendstat = sendstat + count + out_put( "server.lua: sended '", buffer, "', bytes: ", succ, ", error: ", err, ", part: ", byte, ", to: ", ip, ":", clientport ) + if succ then -- sending succesful + --writequeue = { } + eol = nil + writelen = removesocket( writelist, socket, writelen ) -- delete socket from writelist + return true + elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write + buffer = string_sub( buffer, byte + 1, -1 ) -- new buffer + writequeue[ 1 ] = buffer -- insert new buffer in queue + eol = 1 + return true + else -- connection was closed during sending or fatal error + out_put( "server.lua: client ", ip, ":", clientport, " error: ", err ) + handler.close( ) + disconnect( handler, err ) + writequeue = nil + handler = nil + return false + end + end + + -- // COMPAT // -- + + handler.getIp = handler.ip + handler.getPort = handler.clientport + + socketlist[ socket ] = handler + readlen = readlen + 1 + readlist[ readlen ] = socket + + return handler, socket +end + +addtimer = function( listener ) + timelistener[ #timelistener + 1 ] = listener +end + +firetimer = function( listener ) + for i, listener in ipairs( timelistener ) do + listener( ) + end +end + +addserver = function( listeners, port, addr, mode, sslctx ) -- this function provides a way for other scripts to reg a server + local err + if type( listeners ) ~= "table" then + err = "invalid listener table" + else + for name, func in pairs( listeners ) do + if type( func ) ~= "function" then + err = "invalid listener function" + break + end + end + end + if not type( port ) == "number" or not ( port >= 0 and port <= 65535 ) then + err = "invalid port" + elseif listener[ port ] then + err= "listeners on port '" .. port .. "' already exist" + elseif sslctx and not luasec then + err = "luasec not found" + end + if err then + out_error( "server.lua: ", err ) + return nil, err + end + addr = addr or "*" + local server, err = socket_bind( addr, port ) + if err then + out_error( "server.lua: ", err ) + return nil, err + end + local handler, err = wrapserver( listeners, server, addr, port, mode, sslctx ) -- wrap new server socket + if not handler then + server:close( ) + return nil, err + end + server:settimeout( 0 ) + readlen = readlen + 1 + readlist[ readlen ] = server + listener[ port ] = listeners + socketlist[ server ] = handler + out_put( "server.lua: new server listener on ", addr, ":", port ) + return true +end + +removesocket = function( tbl, socket, len ) -- this function removes sockets from a list + for i, target in ipairs( tbl ) do + if target == socket then + len = len - 1 + table_remove( tbl, i ) + return len + end + end + return len +end + +closeall = function( ) + for _, handler in pairs( socketlist ) do + handler.shutdown( ) + handler.close( ) + socketlist[ _ ] = nil + end + writelist, readlist, socketlist = { }, { }, { } +end + +closesocket = function( socket ) + writelen = removesocket( writelist, socket, writelen ) + readlen = removesocket( readlist, socket, readlen ) + socketlist[ socket ] = nil + socket:close( ) +end + +loop = function( ) -- this is the main loop of the program + --signal_set( "hub", "run" ) + repeat + local read, write, err = socket_select( readlist, writelist, 1 ) -- 1 sec timeout, nice for timers + for i, socket in ipairs( write ) do -- send data waiting in writequeues + local handler = socketlist[ socket ] + if handler then + handler.dispatchdata( ) + else + closesocket( socket ) + out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen + end + end + for i, socket in ipairs( read ) do -- receive data + local handler = socketlist[ socket ] + if handler then + handler.receivedata( ) + else + closesocket( socket ) + out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen + end + end + firetimer( ) + --collectgarbage "collect" + until false --signal_get "hub" ~= "run" + return --signal_get "hub" +end + +----------------------------------// BEGIN //-- + +----------------------------------// PUBLIC INTERFACE //-- + +return { + + add = addserver, + loop = loop, + stats = stats, + closeall = closeall, + addtimer = addtimer, + +} diff --git a/util/stanza.lua b/util/stanza.lua index 88d0609f..f41ba699 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -1,10 +1,11 @@ -local t_insert = table.insert; -local t_remove = table.remove; -local format = string.format; -local tostring = tostring; -local setmetatable= setmetatable; -local pairs = pairs; -local ipairs = ipairs; +local t_insert = table.insert; +local t_remove = table.remove; +local format = string.format; +local tostring = tostring; +local setmetatable = setmetatable; +local pairs = pairs; +local ipairs = ipairs; +local type = type; module "stanza" @@ -12,7 +13,7 @@ stanza_mt = {}; stanza_mt.__index = stanza_mt; function stanza(name, attr) - local stanza = { name = name, attr = attr or {}, last_add = {}}; + local stanza = { name = name, attr = attr or {}, tags = {}, last_add = {}}; return setmetatable(stanza, stanza_mt); end @@ -46,6 +47,9 @@ function stanza_mt:up() end function stanza_mt:add_child(child) + if type(child) == "table" then + t_insert(self.tags, child); + end t_insert(self, child); end @@ -55,6 +59,16 @@ function stanza_mt:child_with_name(name) end end +function stanza_mt:children() + local i = 0; + return function (a) + i = i + 1 + local v = a[i] + if v then return v; end + end, self, i; + +end + function stanza_mt.__tostring(t) local children_text = ""; for n, child in ipairs(t) do @@ -63,14 +77,14 @@ function stanza_mt.__tostring(t) local attr_string = ""; if t.attr then - for k, v in pairs(t.attr) do attr_string = attr_string .. format(" %s='%s'", k, tostring(v)); end + for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. format(" %s='%s'", k, tostring(v)); end end end return format("<%s%s>%s", t.name, attr_string, children_text, t.name); end function stanza_mt.__add(s1, s2) - return s:add_child(s2); + return s1:add_child(s2); end -- cgit v1.2.3