diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/adns.lua | 28 | ||||
-rw-r--r-- | net/connlisteners.lua | 8 | ||||
-rw-r--r-- | net/dns.lua | 183 | ||||
-rw-r--r-- | net/http.lua | 123 | ||||
-rw-r--r-- | net/httpserver.lua | 127 | ||||
-rw-r--r-- | net/multiplex_listener.lua | 4 | ||||
-rw-r--r-- | net/server.lua | 2 | ||||
-rw-r--r-- | net/server_event.lua | 43 | ||||
-rw-r--r-- | net/server_select.lua | 78 | ||||
-rw-r--r-- | net/xmppclient_listener.lua | 115 | ||||
-rw-r--r-- | net/xmppcomponent_listener.lua | 119 | ||||
-rw-r--r-- | net/xmppserver_listener.lua | 121 |
12 files changed, 491 insertions, 460 deletions
diff --git a/net/adns.lua b/net/adns.lua index 88d4b4b3..cd69a627 100644 --- a/net/adns.lua +++ b/net/adns.lua @@ -26,22 +26,26 @@ function lookup(handler, qname, qtype, qclass) return; end log("debug", "Records for %s not in cache, sending query (%s)...", qname, tostring(coroutine.running())); - dns.query(qname, qtype, qclass); - coroutine.yield({ qclass or "IN", qtype or "A", qname, coroutine.running()}); -- Wait for reply - log("debug", "Reply for %s (%s)", qname, tostring(coroutine.running())); - local ok, err = pcall(handler, dns.peek(qname, qtype, qclass)); + local ok, err = dns.query(qname, qtype, qclass); + if ok then + coroutine.yield({ qclass or "IN", qtype or "A", qname, coroutine.running()}); -- Wait for reply + log("debug", "Reply for %s (%s)", qname, tostring(coroutine.running())); + end + if ok then + ok, err = pcall(handler, dns.peek(qname, qtype, qclass)); + else + log("error", "Error sending DNS query: %s", err); + ok, err = pcall(handler, nil, err); + end if not ok then log("error", "Error in DNS response handler: %s", tostring(err)); end end)(dns.peek(qname, qtype, qclass)); end -function cancel(handle, call_handler) +function cancel(handle, call_handler, reason) log("warn", "Cancelling DNS lookup for %s", tostring(handle[3])); - dns.cancel(handle); - if call_handler then - coroutine.resume(handle[4]); - end + dns.cancel(handle[1], handle[2], handle[3], handle[4], call_handler); end function new_async_socket(sock, resolver) @@ -74,7 +78,11 @@ function new_async_socket(sock, resolver) handler.setpeername = function (_, ...) peername = (...); local ret = sock:setpeername(...); _:set_send(dummy_send); return ret; end handler.connect = function (_, ...) return sock:connect(...) end --handler.send = function (_, data) _:write(data); return _.sendbuffer and _.sendbuffer(); end - handler.send = function (_, data) return sock:send(data); end + handler.send = function (_, data) + local getpeername = sock.getpeername; + log("debug", "Sending DNS query to %s", (getpeername and getpeername(sock)) or "<unconnected>"); + return sock:send(data); + end return handler; end diff --git a/net/connlisteners.lua b/net/connlisteners.lua index 93dce8b3..7da25c62 100644 --- a/net/connlisteners.lua +++ b/net/connlisteners.lua @@ -13,8 +13,10 @@ local server = require "net.server"; local log = require "util.logger".init("connlisteners"); local tostring = tostring; -local dofile, pcall, error = - dofile, pcall, error +local dofile, xpcall, error = + dofile, xpcall, error + +local debug_traceback = debug.traceback; module "connlisteners" @@ -37,7 +39,7 @@ end function get(name) local h = listeners[name]; if not h then - local ok, ret = pcall(dofile, listeners_dir..name:gsub("[^%w%-]", "_").."_listener.lua"); + local ok, ret = xpcall(function() dofile(listeners_dir..name:gsub("[^%w%-]", "_").."_listener.lua") end, debug_traceback); if not ok then log("error", "Error while loading listener '%s': %s", tostring(name), tostring(ret)); return nil, ret; diff --git a/net/dns.lua b/net/dns.lua index c0de97fd..c905f56c 100644 --- a/net/dns.lua +++ b/net/dns.lua @@ -2,8 +2,6 @@ -- This file is included with Prosody IM. It has modifications, -- which are hereby placed in the public domain. --- public domain 20080404 lua@ztact.com - -- todo: quick (default) header generation -- todo: nxdomain, error handling @@ -15,18 +13,61 @@ local socket = require "socket"; -local ztact = require "util.ztact"; +local timer = require "util.timer"; + local _, windows = pcall(require, "util.windows"); local is_windows = (_ and windows) or os.getenv("WINDIR"); local coroutine, io, math, string, table = coroutine, io, math, string, table; -local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack = - ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack; +local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type= + ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type; +local ztact = { -- public domain 20080404 lua@ztact.com + get = function(parent, ...) + local len = select('#', ...); + for i=1,len do + parent = parent[select(i, ...)]; + if parent == nil then break; end + end + return parent; + end; + set = function(parent, ...) + local len = select('#', ...); + local key, value = select(len-1, ...); + local cutpoint, cutkey; + + for i=1,len-2 do + local key = select (i, ...) + local child = parent[key] + + if value == nil then + if child == nil then + return; + elseif next(child, next(child)) then + cutpoint = nil; cutkey = nil; + elseif cutpoint == nil then + cutpoint = parent; cutkey = key; + end + elseif child == nil then + child = {}; + parent[key] = child; + end + parent = child + end + + if value == nil and cutpoint then + cutpoint[cutkey] = nil; + else + parent[key] = value; + return value; + end + end; +}; local get, set = ztact.get, ztact.set; +local default_timeout = 15; -------------------------------------------------- module dns module('dns') @@ -115,32 +156,31 @@ end local resolver = {}; resolver.__index = resolver; +resolver.timeout = default_timeout; -local SRV_tostring; - +local function default_rr_tostring(rr) + local rr_val = rr.type and rr[rr.type:lower()]; + if type(rr_val) ~= "string" then + return "<UNKNOWN RDATA TYPE>"; + end + return rr_val; +end + +local special_tostrings = { + LOC = resolver.LOC_tostring; + MX = function (rr) + return string.format('%2i %s', rr.pref, rr.mx); + end; + SRV = function (rr) + local s = rr.srv; + return string.format('%5d %5d %5d %s', s.priority, s.weight, s.port, s.target); + end; +}; local rr_metatable = {}; -- - - - - - - - - - - - - - - - - - - rr_metatable function rr_metatable.__tostring(rr) - local s0 = string.format('%2s %-5s %6i %-28s', rr.class, rr.type, rr.ttl, rr.name); - local s1 = ''; - if rr.type == 'A' then - s1 = ' '..rr.a; - elseif rr.type == 'MX' then - s1 = string.format(' %2i %s', rr.pref, rr.mx); - elseif rr.type == 'CNAME' then - s1 = ' '..rr.cname; - elseif rr.type == 'LOC' then - s1 = ' '..resolver.LOC_tostring(rr); - elseif rr.type == 'NS' then - s1 = ' '..rr.ns; - elseif rr.type == 'SRV' then - s1 = ' '..SRV_tostring(rr); - elseif rr.type == 'TXT' then - s1 = ' '..rr.txt; - else - s1 = ' <UNKNOWN RDATA TYPE>'; - end - return s0..s1; + local rr_string = (special_tostrings[rr.type] or default_rr_tostring)(rr); + return string.format('%2s %-5s %6i %-28s %s', rr.class, rr.type, rr.ttl, rr.name, rr_string); end @@ -434,13 +474,10 @@ function resolver:SRV(rr) -- - - - - - - - - - - - - - - - - - - - - - SRV rr.srv.target = self:name(); end - -function SRV_tostring(rr) -- - - - - - - - - - - - - - - - - - SRV_tostring - local s = rr.srv; - return string.format( '%5d %5d %5d %s', s.priority, s.weight, s.port, s.target ); +function resolver:PTR(rr) + rr.ptr = self:name(); end - function resolver:TXT(rr) -- - - - - - - - - - - - - - - - - - - - - - TXT rr.txt = self:sub (rr.rdlength); end @@ -524,7 +561,7 @@ end function resolver:adddefaultnameservers() -- - - - - adddefaultnameservers if is_windows then - if windows then + if windows and windows.get_nameservers then for _, server in ipairs(windows.get_nameservers()) do self:addnameserver(server); end @@ -562,7 +599,11 @@ function resolver:getsocket(servernum) -- - - - - - - - - - - - - getsocket local sock = self.socket[servernum]; if sock then return sock; end - sock = socket.udp(); + local err; + sock, err = socket.udp(); + if not sock then + return nil, err; + end if self.socket_wrapper then sock = self.socket_wrapper(sock, self); end sock:settimeout(0); -- todo: attempt to use a random port, fallback to 0 @@ -667,18 +708,44 @@ function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query retry = socket.gettime() + self.delays[1] }; - -- remember the query + -- remember the query self.active[id] = self.active[id] or {}; self.active[id][question] = o; - -- remember which coroutine wants the answer + -- remember which coroutine wants the answer local co = coroutine.running(); if co then set(self.wanted, qclass, qtype, qname, co, true); --set(self.yielded, co, qclass, qtype, qname, true); end - self:getsocket (o.server):send (o.packet) + local conn, err = self:getsocket(o.server) + if not conn then + return nil, err; + end + conn:send (o.packet) + + if timer and self.timeout then + local num_servers = #self.server; + local i = 1; + timer.add_task(self.timeout, function () + if get(self.wanted, qclass, qtype, qname, co) then + if i < num_servers then + i = i + 1; + self:servfail(conn); + o.server = self.best_server; + conn, err = self:getsocket(o.server); + if conn then + conn:send(o.packet); + return self.timeout; + end + end + -- Tried everything, failed + self:cancel(qclass, qtype, qname, co, true); + end + end) + end + return true; end function resolver:servfail(sock) @@ -710,7 +777,7 @@ function resolver:servfail(sock) end end end - + if num == self.best_server then self.best_server = self.best_server + 1; if self.best_server > #self.server then @@ -720,6 +787,10 @@ function resolver:servfail(sock) end end +function resolver:settimeout(seconds) + self.timeout = seconds; +end + function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive --print('receive'); print(self.socket); self.time = socket.gettime(); @@ -769,11 +840,11 @@ function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive end -function resolver:feed(sock, packet) +function resolver:feed(sock, packet, force) --print('receive'); print(self.socket); self.time = socket.gettime(); - local response = self:decode(packet); + local response = self:decode(packet, force); if response and self.active[response.header.id] and self.active[response.header.id][response.question.raw] then --print('received response'); @@ -806,10 +877,13 @@ function resolver:feed(sock, packet) return response; end -function resolver:cancel(data) - local cos = get(self.wanted, unpack(data, 1, 3)); +function resolver:cancel(qclass, qtype, qname, co, call_handler) + local cos = get(self.wanted, qclass, qtype, qname); if cos then - cos[data[4]] = nil; + if call_handler then + coroutine.resume(co); + end + cos[co] = nil; end end @@ -852,12 +926,12 @@ end function resolver:lookup(qname, qtype, qclass) -- - - - - - - - - - lookup self:query (qname, qtype, qclass) while self:pulse() do - local recvt = {} - for i, s in ipairs(self.socket) do - recvt[i] = s - end - socket.select(recvt, nil, 4) - end + local recvt = {} + for i, s in ipairs(self.socket) do + recvt[i] = s + end + socket.select(recvt, nil, 4) + end --print(self.cache); return self:peek(qname, qtype, qclass); end @@ -866,6 +940,9 @@ function resolver:lookupex(handler, qname, qtype, qclass) -- - - - - - - - - return self:peek(qname, qtype, qclass) or self:query(qname, qtype, qclass); end +function resolver:tohostname(ip) + return dns.lookup(ip:gsub("(%d+)%.(%d+)%.(%d+)%.(%d+)", "%4.%3.%2.%1.in-addr.arpa."), "PTR"); +end --print ---------------------------------------------------------------- print @@ -941,6 +1018,10 @@ function dns.lookup(...) -- - - - - - - - - - - - - - - - - - - - - lookup return _resolver:lookup(...); end +function dns.tohostname(...) + return _resolver:tohostname(...); +end + function dns.purge(...) -- - - - - - - - - - - - - - - - - - - - - - purge return _resolver:purge(...); end @@ -961,6 +1042,10 @@ function dns.cancel(...) -- - - - - - - - - - - - - - - - - - - - - - cancel return _resolver:cancel(...); end +function dns.settimeout(...) + return _resolver:settimeout(...); +end + function dns.socket_wrapper_set(...) -- - - - - - - - - socket_wrapper_set return _resolver:socket_wrapper_set(...); end diff --git a/net/http.lua b/net/http.lua index 0634d773..6c8e0a68 100644 --- a/net/http.lua +++ b/net/http.lua @@ -10,6 +10,7 @@ local socket = require "socket" local mime = require "mime" local url = require "socket.url" +local httpstream_new = require "util.httpstream".new; local server = require "net.server" @@ -17,8 +18,9 @@ local connlisteners_get = require "net.connlisteners".get; local listener = connlisteners_get("httpclient") or error("No httpclient listener!"); local t_insert, t_concat = table.insert, table.concat; -local tonumber, tostring, pairs, xpcall, select, debug_traceback, char, format = - tonumber, tostring, pairs, xpcall, select, debug.traceback, string.char, string.format; +local pairs, ipairs = pairs, ipairs; +local tonumber, tostring, xpcall, select, debug_traceback, char, format = + tonumber, tostring, xpcall, select, debug.traceback, string.char, string.format; local log = require "util.logger".init("http"); @@ -27,107 +29,46 @@ module "http" function urlencode(s) return s and (s:gsub("%W", function (c) return format("%%%02x", c:byte()); end)); end function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end)); end -local function expectbody(reqt, code) - if reqt.method == "HEAD" then return nil end - if code == 204 or code == 304 or code == 301 then return nil end - if code >= 100 and code < 200 then return nil end - return 1 +local function _formencodepart(s) + return s and (s:gsub("%W", function (c) + if c ~= " " then + return format("%%%02x", c:byte()); + else + return "+"; + end + end)); +end +function formencode(form) + local result = {}; + for _, field in ipairs(form) do + t_insert(result, _formencodepart(field.name).."=".._formencodepart(field.value)); + end + return t_concat(result, "&"); end local function request_reader(request, data, startpos) - if not data then - if request.body then - log("debug", "Connection closed, but we have data, calling callback..."); - request.callback(t_concat(request.body), request.code, request); - elseif request.state ~= "completed" then - -- Error.. connection was closed prematurely - request.callback("connection-closed", 0, request); - return; - end - destroy_request(request); - request.body = nil; - request.state = "completed"; - return; - end - if request.state == "body" and request.state ~= "completed" then - log("debug", "Reading body...") - if not request.body then request.body = {}; request.havebodylength, request.bodylength = 0, tonumber(request.responseheaders["content-length"]); end - if startpos then - data = data:sub(startpos, -1) - end - t_insert(request.body, data); - if request.bodylength then - request.havebodylength = request.havebodylength + #data; - if request.havebodylength >= request.bodylength then - -- We have the body - log("debug", "Have full body, calling callback"); - if request.callback then - request.callback(t_concat(request.body), request.code, request); - end - request.body = nil; - request.state = "completed"; - else - log("debug", "Have "..request.havebodylength.." bytes out of "..request.bodylength); - end - end - elseif request.state == "headers" then - log("debug", "Reading headers...") - local pos = startpos; - local headers, headers_complete = request.responseheaders; - if not headers then - headers = {}; - request.responseheaders = headers; - end - for line in data:sub(startpos, -1):gmatch("(.-)\r\n") do - startpos = startpos + #line + 2; - local k, v = line:match("(%S+): (.+)"); - if k and v then - headers[k:lower()] = v; - --log("debug", "Header: "..k:lower().." = "..v); - elseif #line == 0 then - headers_complete = true; - break; - else - log("warn", "Unhandled header line: "..line); + if not request.parser then + local function success_cb(r) + if request.callback then + for k,v in pairs(r) do request[k] = v; end + request.callback(r.body, r.code, request); + request.callback = nil; end - end - if not headers_complete then return; end - -- Reached the end of the headers - if not expectbody(request, request.code) then - request.callback(nil, request.code, request); - return; - end - request.state = "body"; - if #data > startpos then - return request_reader(request, data, startpos); - end - elseif request.state == "status" then - log("debug", "Reading status...") - local http, code, text, linelen = data:match("^HTTP/(%S+) (%d+) (.-)\r\n()", startpos); - code = tonumber(code); - if not code then - log("warn", "Invalid HTTP status line, telling callback then closing"); - local ret = request.callback("invalid-status-line", 0, request); destroy_request(request); - return ret; end - - request.code, request.responseversion = code, http; - - if request.onlystatus then + local function error_cb(r) if request.callback then - request.callback(nil, code, request); + request.callback(r or "connection-closed", 0, request); + request.callback = nil; end destroy_request(request); - return; end - - request.state = "headers"; - - if #data > linelen then - return request_reader(request, data, linelen); + local function options_cb() + return request; end + request.parser = httpstream_new(success_cb, error_cb, "client", options_cb); end + request.parser:feed(data); end local function handleerr(err) log("error", "Traceback[http]: %s: %s", tostring(err), debug_traceback()); end diff --git a/net/httpserver.lua b/net/httpserver.lua index 59ddbb12..74f61c56 100644 --- a/net/httpserver.lua +++ b/net/httpserver.lua @@ -7,19 +7,20 @@ -- -local socket = require "socket" local server = require "net.server" local url_parse = require "socket.url".parse; +local httpstream_new = require "util.httpstream".new; local connlisteners_start = require "net.connlisteners".start; local connlisteners_get = require "net.connlisteners".get; local listener; local t_insert, t_concat = table.insert, table.concat; -local s_match, s_gmatch = string.match, string.gmatch; local tonumber, tostring, pairs, ipairs, type = tonumber, tostring, pairs, ipairs, type; +local xpcall = xpcall; +local debug_traceback = debug.traceback; -local urlencode = function (s) return s and (s:gsub("%W", function (c) return string.format("%%%02x", c:byte()); end)); end +local urlencode = function (s) return s and (s:gsub("%W", function (c) return ("%%%02x"):format(c:byte()); end)); end local log = require "util.logger".init("httpserver"); @@ -29,10 +30,6 @@ module "httpserver" local default_handler; -local function expectbody(reqt) - return reqt.method == "POST"; -end - local function send_response(request, response) -- Write status line local resp; @@ -87,6 +84,22 @@ local function call_callback(request, err) callback = (request.server and request.server.handlers[base]) or default_handler; end if callback then + local _callback = callback; + function callback(method, body, request) + local ok, result = xpcall(function() return _callback(method, body, request) end, debug_traceback); + if ok then return result; end + log("error", "Error in HTTP server handler: %s", result); + -- TODO: When we support pipelining, request.destroyed + -- won't be the right flag - we just want to see if there + -- has been a response to this request yet. + if not request.destroyed then + return { + status = "500 Internal Server Error"; + headers = { ["Content-Type"] = "text/plain" }; + body = "There was an error processing your request. See the error log for more details."; + }; + end + end if err then log("debug", "Request error: "..err); if not callback(nil, err, request) then @@ -114,94 +127,21 @@ local function call_callback(request, err) end local function request_reader(request, data, startpos) - if not data then - if request.body then - call_callback(request); - else - -- Error.. connection was closed prematurely - call_callback(request, "connection-closed"); - end - -- Here we force a destroy... the connection is gone, so we can't reply later - destroy_request(request); - return; - end - if request.state == "body" then - log("debug", "Reading body...") - if not request.body then request.body = {}; request.havebodylength, request.bodylength = 0, tonumber(request.headers["content-length"]); end - if startpos then - data = data:sub(startpos, -1) - end - t_insert(request.body, data); - if request.bodylength then - request.havebodylength = request.havebodylength + #data; - if request.havebodylength >= request.bodylength then - -- We have the body - call_callback(request); - end - end - elseif request.state == "headers" then - log("debug", "Reading headers...") - local pos = startpos; - local headers, headers_complete = request.headers; - if not headers then - headers = {}; - request.headers = headers; - end - - for line in data:gmatch("(.-)\r\n") do - startpos = (startpos or 1) + #line + 2; - local k, v = line:match("(%S+): (.+)"); - if k and v then - headers[k:lower()] = v; - --log("debug", "Header: '"..k:lower().."' = '"..v.."'"); - elseif #line == 0 then - headers_complete = true; - break; - else - log("debug", "Unhandled header line: "..line); - end - end - - if not headers_complete then return; end - - if not expectbody(request) then + if not request.parser then + local function success_cb(r) + for k,v in pairs(r) do request[k] = v; end + request.url = url_parse(request.path); + request.url.path = request.url.path and request.url.path:gsub("%%(%x%x)", function(x) return x.char(tonumber(x, 16)) end); + request.body = { request.body }; call_callback(request); - return; - end - - -- Reached the end of the headers - request.state = "body"; - if #data > startpos then - return request_reader(request, data:sub(startpos, -1)); - end - elseif request.state == "request" then - log("debug", "Reading request line...") - local method, path, http, linelen = data:match("^(%S+) (%S+) HTTP/(%S+)\r\n()", startpos); - if not method then - log("warn", "Invalid HTTP status line, telling callback then closing"); - local ret = call_callback(request, "invalid-status-line"); - request:destroy(); - return ret; end - - request.method, request.path, request.httpversion = method, path, http; - - request.url = url_parse(request.path); - - log("debug", method.." request for "..tostring(request.path) .. " on port "..request.handler:serverport()); - - if request.onlystatus then - if not call_callback(request) then - return; - end - end - - request.state = "headers"; - - if #data > linelen then - return request_reader(request, data:sub(linelen, -1)); + local function error_cb(r) + call_callback(request, r or "connection-closed"); + destroy_request(request); end + request.parser = httpstream_new(success_cb, error_cb); end + request.parser:feed(data); end -- The default handler for requests @@ -263,6 +203,7 @@ function new_from_config(ports, handle_request, default_options) log("warn", "Old syntax of httpserver.new_from_config being used to register %s", handle_request); handle_request, default_options = default_options, { base = handle_request }; end + ports = ports or {5280}; for _, options in ipairs(ports) do local port = default_options.port or 5280; local base = default_options.base; @@ -285,8 +226,8 @@ function new_from_config(ports, handle_request, default_options) ssl.options = "no_sslv2"; end - new{ port = port, interface = interface, - base = base, handler = handle_request, + new{ port = port, interface = interface, + base = base, handler = handle_request, ssl = ssl, type = (ssl and "ssl") or "tcp" }; end end diff --git a/net/multiplex_listener.lua b/net/multiplex_listener.lua index bf193ad8..b515ccce 100644 --- a/net/multiplex_listener.lua +++ b/net/multiplex_listener.lua @@ -19,6 +19,8 @@ function server.onincoming(conn, data) if buf:match("^[a-zA-Z]") then local listener = httpserver_listener; conn:setlistener(listener); + local onconnect = listener.onconnect; + if onconnect then onconnect(conn) end listener.onincoming(conn, buf); elseif buf:match(">") then local listener; @@ -31,6 +33,8 @@ function server.onincoming(conn, data) listener = xmppclient_listener; end conn:setlistener(listener); + local onconnect = listener.onconnect; + if onconnect then onconnect(conn) end listener.onincoming(conn, buf); elseif #buf > 1024 then conn:close(); diff --git a/net/server.lua b/net/server.lua index e0d4b85a..1c1a63a4 100644 --- a/net/server.lua +++ b/net/server.lua @@ -6,7 +6,7 @@ -- COPYING file in the source package for more information. -- -local use_luaevent = require "core.configmanager".get("*", "core", "use_libevent"); +local use_luaevent = prosody and require "core.configmanager".get("*", "core", "use_libevent"); if use_luaevent then use_luaevent = pcall(require, "luaevent.core"); diff --git a/net/server_event.lua b/net/server_event.lua index 0331e793..528305d3 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -143,9 +143,9 @@ do debug( "new connection failed. id:", self.id, "error:", self.fatalerror ) else if plainssl and ssl then -- start ssl session - self:starttls() + self:starttls(nil, true) else -- normal connection - self:_start_session( self.listener.onconnect ) + self:_start_session(true) end debug( "new connection established. id:", self.id ) end @@ -155,13 +155,15 @@ do self.eventconnect = addevent( base, self.conn, EV_WRITE, callback, cfg.CONNECT_TIMEOUT ) return true end - function interface_mt:_start_session(onconnect) -- new session, for example after startssl + function interface_mt:_start_session(call_onconnect) -- new session, for example after startssl if self.type == "client" then local callback = function( ) self:_lock( false, false, false ) --vdebug( "start listening on client socket with id:", self.id ) self.eventread = addevent( base, self.conn, EV_READ, self.readcallback, cfg.READ_TIMEOUT ); -- register callback - self:onconnect() + if call_onconnect then + self:onconnect() + end self.eventsession = nil return -1 end @@ -173,7 +175,7 @@ do end return true end - function interface_mt:_start_ssl(arg) -- old socket will be destroyed, therefore we have to close read/write events first + function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed, therefore we have to close read/write events first --vdebug( "starting ssl session with client id:", self.id ) local _ _ = self.eventread and self.eventread:close( ) -- close events; this must be called outside of the event callbacks! @@ -184,7 +186,7 @@ do if err then self.fatalerror = err self.conn = nil -- cannot be used anymore - if "onconnect" == arg then + if call_onconnect then self.ondisconnect = nil -- dont call this when client isnt really connected end self:_close() @@ -211,28 +213,25 @@ do self.send = self.conn.send -- caching table lookups with new client object self.receive = self.conn.receive local onsomething - if "onconnect" == arg then -- trigger listener - onsomething = self.onconnect - else - onsomething = self.onsslconnection + if not call_onconnect then -- trigger listener + self:onstatus("ssl-handshake-complete"); end - self:_start_session( onsomething ) + self:_start_session( call_onconnect ) debug( "ssl handshake done" ) - self:onstatus("ssl-handshake-complete"); self.eventhandshake = nil return -1 end - debug( "error during ssl handshake:", err ) if err == "wantwrite" then event = EV_WRITE elseif err == "wantread" then event = EV_READ else + debug( "ssl handshake error:", err ) self.fatalerror = err end end if self.fatalerror then - if "onconnect" == arg then + if call_onconnect then self.ondisconnect = nil -- dont call this when client isnt really connected end self:_close() @@ -362,6 +361,10 @@ do end end + function interface_mt:socket() + return self.conn + end + function interface_mt:server() return self._server or self; end @@ -414,7 +417,7 @@ do -- No-op, we always use the underlying connection's send end - function interface_mt:starttls(sslctx) + function interface_mt:starttls(sslctx, call_onconnect) debug( "try to start ssl at client id:", self.id ) local err self._sslctx = sslctx; @@ -428,7 +431,7 @@ do self._usingssl = true self.startsslcallback = function( ) -- we have to start the handshake outside of a read/write event self.startsslcallback = nil - self:_start_ssl(); + self:_start_ssl(call_onconnect); self.eventstarthandshake = nil return -1 end @@ -468,7 +471,6 @@ do function interface_mt:ondrain() end function interface_mt:onstatus() - debug("server.lua: Dummy onstatus()") end end @@ -700,9 +702,9 @@ do local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, nil, sslctx ) --vdebug( "client id:", clientinterface, "startssl:", startssl ) if ssl and sslctx then - clientinterface:starttls(sslctx) + clientinterface:starttls(sslctx, true) else - clientinterface:_start_session( clientinterface.onconnect ) + clientinterface:_start_session( true ) end debug( "accepted incoming client connection from:", client_ip or "<unknown IP>", client_port or "<unknown port>", "to", port or "<unknown port>"); @@ -724,7 +726,7 @@ local addserver = ( function( ) --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslcfg or "nil", startssl or "nil") local server, err = socket.bind( addr, port, cfg.ACCEPT_QUEUE ) -- create server socket if not server then - debug( "creating server socket failed because:", err ) + debug( "creating server socket on "..addr.." port "..port.." failed:", err ) return nil, err end local sslctx @@ -846,7 +848,6 @@ function hook_signal(signal_num, handler) end local function link(sender, receiver, buffersize) - sender:set_mode(buffersize); local sender_locked; function receiver:ondrain() diff --git a/net/server_select.lua b/net/server_select.lua index 298e560a..c3777a5f 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -32,6 +32,7 @@ local STAT_UNIT = 1 -- byte local type = use "type" local pairs = use "pairs" local ipairs = use "ipairs" +local tonumber = use "tonumber" local tostring = use "tostring" local collectgarbage = use "collectgarbage" @@ -44,8 +45,9 @@ local coroutine = use "coroutine" --// lua lib methods //-- -local os_time = os.time local os_difftime = os.difftime +local math_min = math.min +local math_huge = math.huge local table_concat = table.concat local table_remove = table.remove local string_len = string.len @@ -57,6 +59,7 @@ local coroutine_yield = coroutine.yield local luasec = use "ssl" local luasocket = use "socket" or require "socket" +local luasocket_gettime = luasocket.gettime --// extern lib methods //-- @@ -74,6 +77,7 @@ local stats local idfalse local addtimer local closeall +local addsocket local addserver local getserver local wrapserver @@ -125,6 +129,8 @@ local _timer local _maxclientsperserver +local _maxsslhandshake + ----------------------------------// DEFINITION //-- _server = { } -- key = port, value = table; list of listening servers @@ -167,7 +173,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco local connections = 0 - local dispatch, disconnect = listeners.onincoming, listeners.ondisconnect + local dispatch, disconnect = listeners.onconnect or listeners.onincoming, listeners.ondisconnect local accept = socket.accept @@ -483,7 +489,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport if drain then drain(handler) end - _ = needtls and handler:starttls(nil, true) + _ = needtls and handler:starttls(nil) _ = toclose and handler:close( ) return true elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write @@ -524,7 +530,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _readlistlen = addsocket(_readlist, client, _readlistlen) return true else - out_put( "server.lua: error during ssl handshake: ", tostring(err) ) if err == "wantwrite" and not wrote then _sendlistlen = addsocket(_sendlist, client, _sendlistlen) wrote = true @@ -532,6 +537,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _readlistlen = addsocket(_readlist, client, _readlistlen) read = true else + out_put( "server.lua: ssl handshake error: ", tostring(err) ) break; end --coroutine_yield( handler, nil, err ) -- handshake not finished @@ -564,13 +570,13 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end else local sslctx; - handler.starttls = function( self, _sslctx, now ) + handler.starttls = function( self, _sslctx) if _sslctx then sslctx = _sslctx; handler:set_sslctx(sslctx); end - if not now then - out_put "server.lua: we need to do tls, but delaying until later" + if bufferqueuelen > 0 then + out_put "server.lua: we need to do tls, but delaying until send buffer empty" needtls = true return end @@ -623,16 +629,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _socketlist[ socket ] = handler _readlistlen = addsocket(_readlist, socket, _readlistlen) - if listeners.onconnect then - _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) - handler.sendbuffer = function () - listeners.onconnect(handler); - handler.sendbuffer = _sendbuffer; - if bufferqueuelen > 0 then - return _sendbuffer(); - end - end - end return handler, socket end @@ -676,7 +672,6 @@ closesocket = function( socket ) end local function link(sender, receiver, buffersize) - sender:set_mode(buffersize); local sender_locked; local _sendbuffer = receiver.sendbuffer; function receiver.sendbuffer() @@ -798,16 +793,18 @@ stats = function( ) return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen end -local dontstop = true; -- thinking about tomorrow, ... +local quitting; setquitting = function (quit) - dontstop = not quit; - return; + quitting = not not quit; end -loop = function( ) -- this is the main loop of the program - while dontstop do - local read, write, err = socket_select( _readlist, _sendlist, _selecttimeout ) +loop = function(once) -- this is the main loop of the program + if quitting then return "quitting"; end + if once then quitting = "once"; end + local next_timer_time = math_huge; + repeat + local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) ) for i, socket in ipairs( write ) do -- send data waiting in writequeues local handler = _socketlist[ socket ] if handler then @@ -831,19 +828,28 @@ loop = function( ) -- this is the main loop of the program handler:close( true ) -- forced disconnect end clean( _closelist ) - _currenttime = os_time( ) - if os_difftime( _currenttime - _timer ) >= 1 then + _currenttime = luasocket_gettime( ) + if _currenttime - _timer >= math_min(next_timer_time, 1) then + next_timer_time = math_huge; for i = 1, _timerlistlen do - _timerlist[ i ]( _currenttime ) -- fire timers + local t = _timerlist[ i ]( _currenttime ) -- fire timers + if t then next_timer_time = math_min(next_timer_time, t); end end _timer = _currenttime + else + next_timer_time = next_timer_time - (_currenttime - _timer); end socket_sleep( _sleeptime ) -- wait some time --collectgarbage( ) - end + until quitting; + if once and quitting == "once" then quitting = nil; return; end return "quitting" end +step = function () + return loop(true); +end + local function get_backend() return "select"; end @@ -854,6 +860,18 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) _socketlist[ socket ] = handler _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) + if listeners.onconnect then + -- When socket is writeable, call onconnect + local _sendbuffer = handler.sendbuffer; + handler.sendbuffer = function () + handler.sendbuffer = _sendbuffer; + listeners.onconnect(handler); + -- If there was data with the incoming packet, handle it now. + if #handler:bufferqueue() > 0 then + return _sendbuffer(); + end + end + end return handler, socket end @@ -879,8 +897,8 @@ use "setmetatable" ( _socketlist, { __mode = "k" } ) use "setmetatable" ( _readtimes, { __mode = "k" } ) use "setmetatable" ( _writetimes, { __mode = "k" } ) -_timer = os_time( ) -_starttime = os_time( ) +_timer = luasocket_gettime( ) +_starttime = luasocket_gettime( ) addtimer( function( ) local difftime = os_difftime( _currenttime - _starttime ) diff --git a/net/xmppclient_listener.lua b/net/xmppclient_listener.lua index 94daa2b2..4cc90cbf 100644 --- a/net/xmppclient_listener.lua +++ b/net/xmppclient_listener.lua @@ -10,22 +10,19 @@ local logger = require "logger"; local log = logger.init("xmppclient_listener"); -local lxp = require "lxp" -local init_xmlhandlers = require "core.xmlhandlers" -local sm_new_session = require "core.sessionmanager".new_session; +local new_xmpp_stream = require "util.xmppstream".new; local connlisteners_register = require "net.connlisteners".register; -local t_insert = table.insert; -local t_concat = table.concat; -local t_concatall = function (t, sep) local tt = {}; for _, s in ipairs(t) do t_insert(tt, tostring(s)); end return t_concat(tt, sep); end -local m_random = math.random; -local format = string.format; local sessionmanager = require "core.sessionmanager"; local sm_new_session, sm_destroy_session = sessionmanager.new_session, sessionmanager.destroy_session; local sm_streamopened = sessionmanager.streamopened; local sm_streamclosed = sessionmanager.streamclosed; local st = require "util.stanza"; +local xpcall = xpcall; +local tostring = tostring; +local type = type; +local traceback = debug.traceback; local config = require "core.configmanager"; local opt_keepalives = config.get("*", "core", "tcp_keepalives"); @@ -41,7 +38,7 @@ function stream_callbacks.error(session, error, data) session:close("invalid-namespace"); elseif error == "parse-error" then (session.log or log)("debug", "Client XML parse error: %s", tostring(data)); - session:close("xml-not-well-formed"); + session:close("not-well-formed"); elseif error == "stream-error" then local condition, text = "undefined-condition"; for child in data:children() do @@ -62,9 +59,12 @@ function stream_callbacks.error(session, error, data) end end -local function handleerr(err) log("error", "Traceback[c2s]: %s: %s", tostring(err), debug.traceback()); end -function stream_callbacks.handlestanza(a, b) - xpcall(function () core_process_stanza(a, b) end, handleerr); +local function handleerr(err) log("error", "Traceback[c2s]: %s: %s", tostring(err), traceback()); end +function stream_callbacks.handlestanza(session, stanza) + stanza = session.filter("stanzas/in", stanza); + if stanza then + return xpcall(function () return core_process_stanza(session, stanza) end, handleerr); + end end local sessions = {}; @@ -72,23 +72,6 @@ local xmppclient = { default_port = 5222, default_mode = "*a" }; -- These are session methods -- -local function session_reset_stream(session) - -- Reset stream - local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "\1"); - session.parser = parser; - - session.notopen = true; - - function session.data(conn, data) - local ok, err = parser:parse(data); - if ok then return; end - log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " "):gsub("[%z\1-\31]", "_")); - session:close("xml-not-well-formed"); - end - - return true; -end - local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; local default_stream_attr = { ["xmlns:stream"] = "http://etherx.jabber.org/streams", xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; local function session_close(session, reason) @@ -128,32 +111,54 @@ end -- End of session methods -- -function xmppclient.onincoming(conn, data) - local session = sessions[conn]; - if not session then - session = sm_new_session(conn); - sessions[conn] = session; - - session.log("info", "Client connected"); - - -- Client is using legacy SSL (otherwise mod_tls sets this flag) - if conn:ssl() then - session.secure = true; - end - - if opt_keepalives ~= nil then - conn:setoption("keepalive", opt_keepalives); +function xmppclient.onconnect(conn) + local session = sm_new_session(conn); + sessions[conn] = session; + + session.log("info", "Client connected"); + + -- Client is using legacy SSL (otherwise mod_tls sets this flag) + if conn:ssl() then + session.secure = true; + end + + if opt_keepalives ~= nil then + conn:setoption("keepalive", opt_keepalives); + end + + session.close = session_close; + + local stream = new_xmpp_stream(session, stream_callbacks); + session.stream = stream; + + session.notopen = true; + + function session.reset_stream() + session.notopen = true; + session.stream:reset(); + end + + local filter = session.filter; + function session.data(data) + data = filter("bytes/in", data); + if data then + local ok, err = stream:feed(data); + if ok then return; end + log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " "):gsub("[%z\1-\31]", "_")); + session:close("not-well-formed"); end - - session.reset_stream = session_reset_stream; - session.close = session_close; - - session_reset_stream(session); -- Initialise, ready for use - - session.dispatch_stanza = stream_callbacks.handlestanza; end - if data then - session.data(conn, data); + + local handlestanza = stream_callbacks.handlestanza; + function session.dispatch_stanza(session, stanza) + return handlestanza(session, stanza); + end +end + +function xmppclient.onincoming(conn, data) + local session = sessions[conn]; + if session then + session.data(data); end end @@ -167,4 +172,8 @@ function xmppclient.ondisconnect(conn, err) end end +function xmppclient.associate_session(conn, session) + sessions[conn] = session; +end + connlisteners_register("xmppclient", xmppclient); diff --git a/net/xmppcomponent_listener.lua b/net/xmppcomponent_listener.lua index b87f7c96..90293559 100644 --- a/net/xmppcomponent_listener.lua +++ b/net/xmppcomponent_listener.lua @@ -10,17 +10,19 @@ local hosts = _G.hosts; local t_concat = table.concat; +local tostring = tostring; +local type = type; +local pairs = pairs; local lxp = require "lxp"; local logger = require "util.logger"; local config = require "core.configmanager"; local connlisteners = require "net.connlisteners"; -local cm_register_component = require "core.componentmanager".register_component; -local cm_deregister_component = require "core.componentmanager".deregister_component; local uuid_gen = require "util.uuid".generate; +local jid_split = require "util.jid".split; local sha1 = require "util.hashes".sha1; local st = require "util.stanza"; -local init_xmlhandlers = require "core.xmlhandlers"; +local new_xmpp_stream = require "util.xmppstream".new; local sessions = {}; @@ -30,7 +32,7 @@ local component_listener = { default_port = 5347; default_mode = "*a"; default_i local xmlns_component = 'jabber:component:accept'; ---- Callbacks/data for xmlhandlers to handle streams for us --- +--- Callbacks/data for xmppstream to handle streams for us --- local stream_callbacks = { default_ns = xmlns_component }; @@ -43,7 +45,7 @@ function stream_callbacks.error(session, error, data, data2) session:close("invalid-namespace"); elseif error == "parse-error" then session.log("warn", "External component %s XML parse error: %s", tostring(session.host), tostring(data)); - session:close("xml-not-well-formed"); + session:close("not-well-formed"); elseif error == "stream-error" then local condition, text = "undefined-condition"; for child in data:children() do @@ -66,19 +68,16 @@ end function stream_callbacks.streamopened(session, attr) if config.get(attr.to, "core", "component_module") ~= "component" then - -- Trying to act as a component domain which + -- Trying to act as a component domain which -- hasn't been configured session:close{ condition = "host-unknown", text = tostring(attr.to).." does not match any configured external components" }; return; end - -- Store the original host (this is used for config, etc.) - session.user = attr.to; - -- Set the host for future reference - session.host = config.get(attr.to, "core", "component_address") or attr.to; - -- Note that we don't create the internal component + -- Note that we don't create the internal component -- until after the external component auths successfully + session.host = attr.to; session.streamid = uuid_gen(); session.notopen = nil; @@ -88,7 +87,7 @@ function stream_callbacks.streamopened(session, attr) end function stream_callbacks.streamclosed(session) - session.log("Received </stream:stream>"); + session.log("debug", "Received </stream:stream>"); session:close(); end @@ -99,6 +98,31 @@ function stream_callbacks.handlestanza(session, stanza) if not stanza.attr.xmlns and stanza.name == "handshake" then stanza.attr.xmlns = xmlns_component; end + if not stanza.attr.xmlns or stanza.attr.xmlns == "jabber:client" then + local from = stanza.attr.from; + if from then + if session.component_validate_from then + local _, domain = jid_split(stanza.attr.from); + if domain ~= session.host then + -- Return error + session.log("warn", "Component sent stanza with missing or invalid 'from' address"); + session:close{ + condition = "invalid-from"; + text = "Component tried to send from address <"..tostring(from) + .."> which is not in domain <"..tostring(session.host)..">"; + }; + return; + end + end + else + stanza.attr.from = session.host; + end + if not stanza.attr.to then + session.log("warn", "Rejecting stanza with no 'to' address"); + session.send(st.error_reply(stanza, "modify", "bad-request", "Components MUST specify a 'to' address on stanzas")); + return; + end + end return core_process_stanza(session, stanza); end @@ -141,51 +165,48 @@ local function session_close(session, reason) end --- Component connlistener -function component_listener.onincoming(conn, data) - local session = sessions[conn]; - if not session then - local _send = conn.write; - session = { type = "component", conn = conn, send = function (data) return _send(conn, tostring(data)); end }; - sessions[conn] = session; - - -- Logging functions -- - - local conn_name = "jcp"..tostring(conn):match("[a-f0-9]+$"); - session.log = logger.init(conn_name); - session.close = session_close; - - session.log("info", "Incoming Jabber component connection"); - - local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "\1"); - session.parser = parser; - +function component_listener.onconnect(conn) + local _send = conn.write; + local session = { type = "component", conn = conn, send = function (data) return _send(conn, tostring(data)); end }; + + -- Logging functions -- + local conn_name = "jcp"..tostring(conn):match("[a-f0-9]+$"); + session.log = logger.init(conn_name); + session.close = session_close; + + session.log("info", "Incoming Jabber component connection"); + + local stream = new_xmpp_stream(session, stream_callbacks); + session.stream = stream; + + session.notopen = true; + + function session.reset_stream() session.notopen = true; - - function session.data(conn, data) - local ok, err = parser:parse(data); - if ok then return; end - log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " "):gsub("[%z\1-\31]", "_")); - session:close("xml-not-well-formed"); - end - - session.dispatch_stanza = stream_callbacks.handlestanza; - + session.stream:reset(); end - if data then - session.data(conn, data); + + function session.data(conn, data) + local ok, err = stream:feed(data); + if ok then return; end + log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " "):gsub("[%z\1-\31]", "_")); + session:close("not-well-formed"); end -end + session.dispatch_stanza = stream_callbacks.handlestanza; + + sessions[conn] = session; +end +function component_listener.onincoming(conn, data) + local session = sessions[conn]; + session.data(conn, data); +end function component_listener.ondisconnect(conn, err) local session = sessions[conn]; if session then (session.log or log)("info", "component disconnected: %s (%s)", tostring(session.host), tostring(err)); - if session.host then - log("debug", "Deregistering component"); - cm_deregister_component(session.host); - hosts[session.host].connected = nil; - end - sessions[conn] = nil; + if session.on_destroy then session:on_destroy(err); end + sessions[conn] = nil; for k in pairs(session) do if k ~= "log" and k ~= "close" then session[k] = nil; diff --git a/net/xmppserver_listener.lua b/net/xmppserver_listener.lua index d1272edb..3af0b962 100644 --- a/net/xmppserver_listener.lua +++ b/net/xmppserver_listener.lua @@ -7,11 +7,17 @@ -- +local tostring = tostring; +local type = type; +local xpcall = xpcall; +local s_format = string.format; +local traceback = debug.traceback; local logger = require "logger"; local log = logger.init("xmppserver_listener"); -local lxp = require "lxp" -local init_xmlhandlers = require "core.xmlhandlers" +local st = require "util.stanza"; +local connlisteners_register = require "net.connlisteners".register; +local new_xmpp_stream = require "util.xmppstream".new; local s2s_new_incoming = require "core.s2smanager".new_incoming; local s2s_streamopened = require "core.s2smanager".streamopened; local s2s_streamclosed = require "core.s2smanager".streamclosed; @@ -27,7 +33,7 @@ function stream_callbacks.error(session, error, data) session:close("invalid-namespace"); elseif error == "parse-error" then session.log("debug", "Server-to-server XML parse error: %s", tostring(error)); - session:close("xml-not-well-formed"); + session:close("not-well-formed"); elseif error == "stream-error" then local condition, text = "undefined-condition"; for child in data:children() do @@ -48,48 +54,22 @@ function stream_callbacks.error(session, error, data) end end -local function handleerr(err) log("error", "Traceback[s2s]: %s: %s", tostring(err), debug.traceback()); end -function stream_callbacks.handlestanza(a, b) - if b.attr.xmlns == "jabber:client" then --COMPAT: Prosody pre-0.6.2 may send jabber:client - b.attr.xmlns = nil; +local function handleerr(err) log("error", "Traceback[s2s]: %s: %s", tostring(err), traceback()); end +function stream_callbacks.handlestanza(session, stanza) + if stanza.attr.xmlns == "jabber:client" then --COMPAT: Prosody pre-0.6.2 may send jabber:client + stanza.attr.xmlns = nil; + end + stanza = session.filter("stanzas/in", stanza); + if stanza then + return xpcall(function () return core_process_stanza(session, stanza) end, handleerr); end - xpcall(function () core_process_stanza(a, b) end, handleerr); end -local connlisteners_register = require "net.connlisteners".register; - -local t_insert = table.insert; -local t_concat = table.concat; -local t_concatall = function (t, sep) local tt = {}; for _, s in ipairs(t) do t_insert(tt, tostring(s)); end return t_concat(tt, sep); end -local m_random = math.random; -local format = string.format; -local sessionmanager = require "core.sessionmanager"; -local sm_new_session, sm_destroy_session = sessionmanager.new_session, sessionmanager.destroy_session; -local st = require "util.stanza"; - local sessions = {}; local xmppserver = { default_port = 5269, default_mode = "*a" }; -- These are session methods -- -local function session_reset_stream(session) - -- Reset stream - local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "\1"); - session.parser = parser; - - session.notopen = true; - - function session.data(conn, data) - local ok, err = parser:parse(data); - if ok then return; end - (session.log or log)("warn", "Received invalid XML: %s", data); - (session.log or log)("warn", "Problem was: %s", err); - session:close("xml-not-well-formed"); - end - - return true; -end - local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; local default_stream_attr = { ["xmlns:stream"] = "http://etherx.jabber.org/streams", xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; local function session_close(session, reason, remote_reason) @@ -132,29 +112,55 @@ end -- End of session methods -- -function xmppserver.onincoming(conn, data) - local session = sessions[conn]; - if not session then - session = s2s_new_incoming(conn); - sessions[conn] = session; +local function initialize_session(session) + local stream = new_xmpp_stream(session, stream_callbacks); + session.stream = stream; + + session.notopen = true; + + function session.reset_stream() + session.notopen = true; + session.stream:reset(); + end + + local filter = session.filter; + function session.data(data) + data = filter("bytes/in", data); + if data then + local ok, err = stream:feed(data); + if ok then return; end + (session.log or log)("warn", "Received invalid XML: %s", data); + (session.log or log)("warn", "Problem was: %s", err); + session:close("not-well-formed"); + end + end - -- Logging functions -- + session.close = session_close; + local handlestanza = stream_callbacks.handlestanza; + function session.dispatch_stanza(session, stanza) + return handlestanza(session, stanza); + end +end - +function xmppserver.onconnect(conn) + if not sessions[conn] then -- May be an existing outgoing session + local session = s2s_new_incoming(conn); + sessions[conn] = session; + + -- Logging functions -- local conn_name = "s2sin"..tostring(conn):match("[a-f0-9]+$"); session.log = logger.init(conn_name); session.log("info", "Incoming s2s connection"); - session.reset_stream = session_reset_stream; - session.close = session_close; - - session_reset_stream(session); -- Initialise, ready for use - - session.dispatch_stanza = stream_callbacks.handlestanza; + initialize_session(session); end - if data then - session.data(conn, data); +end + +function xmppserver.onincoming(conn, data) + local session = sessions[conn]; + if session then + session.data(data); end end @@ -162,9 +168,9 @@ function xmppserver.onstatus(conn, status) if status == "ssl-handshake-complete" then local session = sessions[conn]; if session and session.direction == "outgoing" then - local format, to_host, from_host = string.format, session.to_host, session.from_host; + local to_host, from_host = session.to_host, session.from_host; session.log("debug", "Sending stream header..."); - session.sends2s(format([[<stream:stream xmlns='jabber:server' xmlns:db='jabber:server:dialback' xmlns:stream='http://etherx.jabber.org/streams' from='%s' to='%s' version='1.0'>]], from_host, to_host)); + session.sends2s(s_format([[<stream:stream xmlns='jabber:server' xmlns:db='jabber:server:dialback' xmlns:stream='http://etherx.jabber.org/streams' from='%s' to='%s' version='1.0'>]], from_host, to_host)); end end end @@ -190,12 +196,7 @@ function xmppserver.register_outgoing(conn, session) session.direction = "outgoing"; sessions[conn] = session; - session.reset_stream = session_reset_stream; - session.close = session_close; - session_reset_stream(session); -- Initialise, ready for use - - --local function handleerr(err) print("Traceback:", err, debug.traceback()); end - --session.stanza_dispatch = function (stanza) return select(2, xpcall(function () return core_process_stanza(session, stanza); end, handleerr)); end + initialize_session(session); end connlisteners_register("xmppserver", xmppserver); |