aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/adns.lua28
-rw-r--r--net/connlisteners.lua8
-rw-r--r--net/dns.lua183
-rw-r--r--net/http.lua123
-rw-r--r--net/httpserver.lua127
-rw-r--r--net/multiplex_listener.lua4
-rw-r--r--net/server.lua2
-rw-r--r--net/server_event.lua43
-rw-r--r--net/server_select.lua78
-rw-r--r--net/xmppclient_listener.lua115
-rw-r--r--net/xmppcomponent_listener.lua119
-rw-r--r--net/xmppserver_listener.lua121
12 files changed, 491 insertions, 460 deletions
diff --git a/net/adns.lua b/net/adns.lua
index 88d4b4b3..cd69a627 100644
--- a/net/adns.lua
+++ b/net/adns.lua
@@ -26,22 +26,26 @@ function lookup(handler, qname, qtype, qclass)
return;
end
log("debug", "Records for %s not in cache, sending query (%s)...", qname, tostring(coroutine.running()));
- dns.query(qname, qtype, qclass);
- coroutine.yield({ qclass or "IN", qtype or "A", qname, coroutine.running()}); -- Wait for reply
- log("debug", "Reply for %s (%s)", qname, tostring(coroutine.running()));
- local ok, err = pcall(handler, dns.peek(qname, qtype, qclass));
+ local ok, err = dns.query(qname, qtype, qclass);
+ if ok then
+ coroutine.yield({ qclass or "IN", qtype or "A", qname, coroutine.running()}); -- Wait for reply
+ log("debug", "Reply for %s (%s)", qname, tostring(coroutine.running()));
+ end
+ if ok then
+ ok, err = pcall(handler, dns.peek(qname, qtype, qclass));
+ else
+ log("error", "Error sending DNS query: %s", err);
+ ok, err = pcall(handler, nil, err);
+ end
if not ok then
log("error", "Error in DNS response handler: %s", tostring(err));
end
end)(dns.peek(qname, qtype, qclass));
end
-function cancel(handle, call_handler)
+function cancel(handle, call_handler, reason)
log("warn", "Cancelling DNS lookup for %s", tostring(handle[3]));
- dns.cancel(handle);
- if call_handler then
- coroutine.resume(handle[4]);
- end
+ dns.cancel(handle[1], handle[2], handle[3], handle[4], call_handler);
end
function new_async_socket(sock, resolver)
@@ -74,7 +78,11 @@ function new_async_socket(sock, resolver)
handler.setpeername = function (_, ...) peername = (...); local ret = sock:setpeername(...); _:set_send(dummy_send); return ret; end
handler.connect = function (_, ...) return sock:connect(...) end
--handler.send = function (_, data) _:write(data); return _.sendbuffer and _.sendbuffer(); end
- handler.send = function (_, data) return sock:send(data); end
+ handler.send = function (_, data)
+ local getpeername = sock.getpeername;
+ log("debug", "Sending DNS query to %s", (getpeername and getpeername(sock)) or "<unconnected>");
+ return sock:send(data);
+ end
return handler;
end
diff --git a/net/connlisteners.lua b/net/connlisteners.lua
index 93dce8b3..7da25c62 100644
--- a/net/connlisteners.lua
+++ b/net/connlisteners.lua
@@ -13,8 +13,10 @@ local server = require "net.server";
local log = require "util.logger".init("connlisteners");
local tostring = tostring;
-local dofile, pcall, error =
- dofile, pcall, error
+local dofile, xpcall, error =
+ dofile, xpcall, error
+
+local debug_traceback = debug.traceback;
module "connlisteners"
@@ -37,7 +39,7 @@ end
function get(name)
local h = listeners[name];
if not h then
- local ok, ret = pcall(dofile, listeners_dir..name:gsub("[^%w%-]", "_").."_listener.lua");
+ local ok, ret = xpcall(function() dofile(listeners_dir..name:gsub("[^%w%-]", "_").."_listener.lua") end, debug_traceback);
if not ok then
log("error", "Error while loading listener '%s': %s", tostring(name), tostring(ret));
return nil, ret;
diff --git a/net/dns.lua b/net/dns.lua
index c0de97fd..c905f56c 100644
--- a/net/dns.lua
+++ b/net/dns.lua
@@ -2,8 +2,6 @@
-- This file is included with Prosody IM. It has modifications,
-- which are hereby placed in the public domain.
--- public domain 20080404 lua@ztact.com
-
-- todo: quick (default) header generation
-- todo: nxdomain, error handling
@@ -15,18 +13,61 @@
local socket = require "socket";
-local ztact = require "util.ztact";
+local timer = require "util.timer";
+
local _, windows = pcall(require, "util.windows");
local is_windows = (_ and windows) or os.getenv("WINDIR");
local coroutine, io, math, string, table =
coroutine, io, math, string, table;
-local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack =
- ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack;
+local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type=
+ ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type;
+local ztact = { -- public domain 20080404 lua@ztact.com
+ get = function(parent, ...)
+ local len = select('#', ...);
+ for i=1,len do
+ parent = parent[select(i, ...)];
+ if parent == nil then break; end
+ end
+ return parent;
+ end;
+ set = function(parent, ...)
+ local len = select('#', ...);
+ local key, value = select(len-1, ...);
+ local cutpoint, cutkey;
+
+ for i=1,len-2 do
+ local key = select (i, ...)
+ local child = parent[key]
+
+ if value == nil then
+ if child == nil then
+ return;
+ elseif next(child, next(child)) then
+ cutpoint = nil; cutkey = nil;
+ elseif cutpoint == nil then
+ cutpoint = parent; cutkey = key;
+ end
+ elseif child == nil then
+ child = {};
+ parent[key] = child;
+ end
+ parent = child
+ end
+
+ if value == nil and cutpoint then
+ cutpoint[cutkey] = nil;
+ else
+ parent[key] = value;
+ return value;
+ end
+ end;
+};
local get, set = ztact.get, ztact.set;
+local default_timeout = 15;
-------------------------------------------------- module dns
module('dns')
@@ -115,32 +156,31 @@ end
local resolver = {};
resolver.__index = resolver;
+resolver.timeout = default_timeout;
-local SRV_tostring;
-
+local function default_rr_tostring(rr)
+ local rr_val = rr.type and rr[rr.type:lower()];
+ if type(rr_val) ~= "string" then
+ return "<UNKNOWN RDATA TYPE>";
+ end
+ return rr_val;
+end
+
+local special_tostrings = {
+ LOC = resolver.LOC_tostring;
+ MX = function (rr)
+ return string.format('%2i %s', rr.pref, rr.mx);
+ end;
+ SRV = function (rr)
+ local s = rr.srv;
+ return string.format('%5d %5d %5d %s', s.priority, s.weight, s.port, s.target);
+ end;
+};
local rr_metatable = {}; -- - - - - - - - - - - - - - - - - - - rr_metatable
function rr_metatable.__tostring(rr)
- local s0 = string.format('%2s %-5s %6i %-28s', rr.class, rr.type, rr.ttl, rr.name);
- local s1 = '';
- if rr.type == 'A' then
- s1 = ' '..rr.a;
- elseif rr.type == 'MX' then
- s1 = string.format(' %2i %s', rr.pref, rr.mx);
- elseif rr.type == 'CNAME' then
- s1 = ' '..rr.cname;
- elseif rr.type == 'LOC' then
- s1 = ' '..resolver.LOC_tostring(rr);
- elseif rr.type == 'NS' then
- s1 = ' '..rr.ns;
- elseif rr.type == 'SRV' then
- s1 = ' '..SRV_tostring(rr);
- elseif rr.type == 'TXT' then
- s1 = ' '..rr.txt;
- else
- s1 = ' <UNKNOWN RDATA TYPE>';
- end
- return s0..s1;
+ local rr_string = (special_tostrings[rr.type] or default_rr_tostring)(rr);
+ return string.format('%2s %-5s %6i %-28s %s', rr.class, rr.type, rr.ttl, rr.name, rr_string);
end
@@ -434,13 +474,10 @@ function resolver:SRV(rr) -- - - - - - - - - - - - - - - - - - - - - - SRV
rr.srv.target = self:name();
end
-
-function SRV_tostring(rr) -- - - - - - - - - - - - - - - - - - SRV_tostring
- local s = rr.srv;
- return string.format( '%5d %5d %5d %s', s.priority, s.weight, s.port, s.target );
+function resolver:PTR(rr)
+ rr.ptr = self:name();
end
-
function resolver:TXT(rr) -- - - - - - - - - - - - - - - - - - - - - - TXT
rr.txt = self:sub (rr.rdlength);
end
@@ -524,7 +561,7 @@ end
function resolver:adddefaultnameservers() -- - - - - adddefaultnameservers
if is_windows then
- if windows then
+ if windows and windows.get_nameservers then
for _, server in ipairs(windows.get_nameservers()) do
self:addnameserver(server);
end
@@ -562,7 +599,11 @@ function resolver:getsocket(servernum) -- - - - - - - - - - - - - getsocket
local sock = self.socket[servernum];
if sock then return sock; end
- sock = socket.udp();
+ local err;
+ sock, err = socket.udp();
+ if not sock then
+ return nil, err;
+ end
if self.socket_wrapper then sock = self.socket_wrapper(sock, self); end
sock:settimeout(0);
-- todo: attempt to use a random port, fallback to 0
@@ -667,18 +708,44 @@ function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query
retry = socket.gettime() + self.delays[1]
};
- -- remember the query
+ -- remember the query
self.active[id] = self.active[id] or {};
self.active[id][question] = o;
- -- remember which coroutine wants the answer
+ -- remember which coroutine wants the answer
local co = coroutine.running();
if co then
set(self.wanted, qclass, qtype, qname, co, true);
--set(self.yielded, co, qclass, qtype, qname, true);
end
- self:getsocket (o.server):send (o.packet)
+ local conn, err = self:getsocket(o.server)
+ if not conn then
+ return nil, err;
+ end
+ conn:send (o.packet)
+
+ if timer and self.timeout then
+ local num_servers = #self.server;
+ local i = 1;
+ timer.add_task(self.timeout, function ()
+ if get(self.wanted, qclass, qtype, qname, co) then
+ if i < num_servers then
+ i = i + 1;
+ self:servfail(conn);
+ o.server = self.best_server;
+ conn, err = self:getsocket(o.server);
+ if conn then
+ conn:send(o.packet);
+ return self.timeout;
+ end
+ end
+ -- Tried everything, failed
+ self:cancel(qclass, qtype, qname, co, true);
+ end
+ end)
+ end
+ return true;
end
function resolver:servfail(sock)
@@ -710,7 +777,7 @@ function resolver:servfail(sock)
end
end
end
-
+
if num == self.best_server then
self.best_server = self.best_server + 1;
if self.best_server > #self.server then
@@ -720,6 +787,10 @@ function resolver:servfail(sock)
end
end
+function resolver:settimeout(seconds)
+ self.timeout = seconds;
+end
+
function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive
--print('receive'); print(self.socket);
self.time = socket.gettime();
@@ -769,11 +840,11 @@ function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive
end
-function resolver:feed(sock, packet)
+function resolver:feed(sock, packet, force)
--print('receive'); print(self.socket);
self.time = socket.gettime();
- local response = self:decode(packet);
+ local response = self:decode(packet, force);
if response and self.active[response.header.id]
and self.active[response.header.id][response.question.raw] then
--print('received response');
@@ -806,10 +877,13 @@ function resolver:feed(sock, packet)
return response;
end
-function resolver:cancel(data)
- local cos = get(self.wanted, unpack(data, 1, 3));
+function resolver:cancel(qclass, qtype, qname, co, call_handler)
+ local cos = get(self.wanted, qclass, qtype, qname);
if cos then
- cos[data[4]] = nil;
+ if call_handler then
+ coroutine.resume(co);
+ end
+ cos[co] = nil;
end
end
@@ -852,12 +926,12 @@ end
function resolver:lookup(qname, qtype, qclass) -- - - - - - - - - - lookup
self:query (qname, qtype, qclass)
while self:pulse() do
- local recvt = {}
- for i, s in ipairs(self.socket) do
- recvt[i] = s
- end
- socket.select(recvt, nil, 4)
- end
+ local recvt = {}
+ for i, s in ipairs(self.socket) do
+ recvt[i] = s
+ end
+ socket.select(recvt, nil, 4)
+ end
--print(self.cache);
return self:peek(qname, qtype, qclass);
end
@@ -866,6 +940,9 @@ function resolver:lookupex(handler, qname, qtype, qclass) -- - - - - - - - -
return self:peek(qname, qtype, qclass) or self:query(qname, qtype, qclass);
end
+function resolver:tohostname(ip)
+ return dns.lookup(ip:gsub("(%d+)%.(%d+)%.(%d+)%.(%d+)", "%4.%3.%2.%1.in-addr.arpa."), "PTR");
+end
--print ---------------------------------------------------------------- print
@@ -941,6 +1018,10 @@ function dns.lookup(...) -- - - - - - - - - - - - - - - - - - - - - lookup
return _resolver:lookup(...);
end
+function dns.tohostname(...)
+ return _resolver:tohostname(...);
+end
+
function dns.purge(...) -- - - - - - - - - - - - - - - - - - - - - - purge
return _resolver:purge(...);
end
@@ -961,6 +1042,10 @@ function dns.cancel(...) -- - - - - - - - - - - - - - - - - - - - - - cancel
return _resolver:cancel(...);
end
+function dns.settimeout(...)
+ return _resolver:settimeout(...);
+end
+
function dns.socket_wrapper_set(...) -- - - - - - - - - socket_wrapper_set
return _resolver:socket_wrapper_set(...);
end
diff --git a/net/http.lua b/net/http.lua
index 0634d773..6c8e0a68 100644
--- a/net/http.lua
+++ b/net/http.lua
@@ -10,6 +10,7 @@
local socket = require "socket"
local mime = require "mime"
local url = require "socket.url"
+local httpstream_new = require "util.httpstream".new;
local server = require "net.server"
@@ -17,8 +18,9 @@ local connlisteners_get = require "net.connlisteners".get;
local listener = connlisteners_get("httpclient") or error("No httpclient listener!");
local t_insert, t_concat = table.insert, table.concat;
-local tonumber, tostring, pairs, xpcall, select, debug_traceback, char, format =
- tonumber, tostring, pairs, xpcall, select, debug.traceback, string.char, string.format;
+local pairs, ipairs = pairs, ipairs;
+local tonumber, tostring, xpcall, select, debug_traceback, char, format =
+ tonumber, tostring, xpcall, select, debug.traceback, string.char, string.format;
local log = require "util.logger".init("http");
@@ -27,107 +29,46 @@ module "http"
function urlencode(s) return s and (s:gsub("%W", function (c) return format("%%%02x", c:byte()); end)); end
function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end)); end
-local function expectbody(reqt, code)
- if reqt.method == "HEAD" then return nil end
- if code == 204 or code == 304 or code == 301 then return nil end
- if code >= 100 and code < 200 then return nil end
- return 1
+local function _formencodepart(s)
+ return s and (s:gsub("%W", function (c)
+ if c ~= " " then
+ return format("%%%02x", c:byte());
+ else
+ return "+";
+ end
+ end));
+end
+function formencode(form)
+ local result = {};
+ for _, field in ipairs(form) do
+ t_insert(result, _formencodepart(field.name).."=".._formencodepart(field.value));
+ end
+ return t_concat(result, "&");
end
local function request_reader(request, data, startpos)
- if not data then
- if request.body then
- log("debug", "Connection closed, but we have data, calling callback...");
- request.callback(t_concat(request.body), request.code, request);
- elseif request.state ~= "completed" then
- -- Error.. connection was closed prematurely
- request.callback("connection-closed", 0, request);
- return;
- end
- destroy_request(request);
- request.body = nil;
- request.state = "completed";
- return;
- end
- if request.state == "body" and request.state ~= "completed" then
- log("debug", "Reading body...")
- if not request.body then request.body = {}; request.havebodylength, request.bodylength = 0, tonumber(request.responseheaders["content-length"]); end
- if startpos then
- data = data:sub(startpos, -1)
- end
- t_insert(request.body, data);
- if request.bodylength then
- request.havebodylength = request.havebodylength + #data;
- if request.havebodylength >= request.bodylength then
- -- We have the body
- log("debug", "Have full body, calling callback");
- if request.callback then
- request.callback(t_concat(request.body), request.code, request);
- end
- request.body = nil;
- request.state = "completed";
- else
- log("debug", "Have "..request.havebodylength.." bytes out of "..request.bodylength);
- end
- end
- elseif request.state == "headers" then
- log("debug", "Reading headers...")
- local pos = startpos;
- local headers, headers_complete = request.responseheaders;
- if not headers then
- headers = {};
- request.responseheaders = headers;
- end
- for line in data:sub(startpos, -1):gmatch("(.-)\r\n") do
- startpos = startpos + #line + 2;
- local k, v = line:match("(%S+): (.+)");
- if k and v then
- headers[k:lower()] = v;
- --log("debug", "Header: "..k:lower().." = "..v);
- elseif #line == 0 then
- headers_complete = true;
- break;
- else
- log("warn", "Unhandled header line: "..line);
+ if not request.parser then
+ local function success_cb(r)
+ if request.callback then
+ for k,v in pairs(r) do request[k] = v; end
+ request.callback(r.body, r.code, request);
+ request.callback = nil;
end
- end
- if not headers_complete then return; end
- -- Reached the end of the headers
- if not expectbody(request, request.code) then
- request.callback(nil, request.code, request);
- return;
- end
- request.state = "body";
- if #data > startpos then
- return request_reader(request, data, startpos);
- end
- elseif request.state == "status" then
- log("debug", "Reading status...")
- local http, code, text, linelen = data:match("^HTTP/(%S+) (%d+) (.-)\r\n()", startpos);
- code = tonumber(code);
- if not code then
- log("warn", "Invalid HTTP status line, telling callback then closing");
- local ret = request.callback("invalid-status-line", 0, request);
destroy_request(request);
- return ret;
end
-
- request.code, request.responseversion = code, http;
-
- if request.onlystatus then
+ local function error_cb(r)
if request.callback then
- request.callback(nil, code, request);
+ request.callback(r or "connection-closed", 0, request);
+ request.callback = nil;
end
destroy_request(request);
- return;
end
-
- request.state = "headers";
-
- if #data > linelen then
- return request_reader(request, data, linelen);
+ local function options_cb()
+ return request;
end
+ request.parser = httpstream_new(success_cb, error_cb, "client", options_cb);
end
+ request.parser:feed(data);
end
local function handleerr(err) log("error", "Traceback[http]: %s: %s", tostring(err), debug_traceback()); end
diff --git a/net/httpserver.lua b/net/httpserver.lua
index 59ddbb12..74f61c56 100644
--- a/net/httpserver.lua
+++ b/net/httpserver.lua
@@ -7,19 +7,20 @@
--
-local socket = require "socket"
local server = require "net.server"
local url_parse = require "socket.url".parse;
+local httpstream_new = require "util.httpstream".new;
local connlisteners_start = require "net.connlisteners".start;
local connlisteners_get = require "net.connlisteners".get;
local listener;
local t_insert, t_concat = table.insert, table.concat;
-local s_match, s_gmatch = string.match, string.gmatch;
local tonumber, tostring, pairs, ipairs, type = tonumber, tostring, pairs, ipairs, type;
+local xpcall = xpcall;
+local debug_traceback = debug.traceback;
-local urlencode = function (s) return s and (s:gsub("%W", function (c) return string.format("%%%02x", c:byte()); end)); end
+local urlencode = function (s) return s and (s:gsub("%W", function (c) return ("%%%02x"):format(c:byte()); end)); end
local log = require "util.logger".init("httpserver");
@@ -29,10 +30,6 @@ module "httpserver"
local default_handler;
-local function expectbody(reqt)
- return reqt.method == "POST";
-end
-
local function send_response(request, response)
-- Write status line
local resp;
@@ -87,6 +84,22 @@ local function call_callback(request, err)
callback = (request.server and request.server.handlers[base]) or default_handler;
end
if callback then
+ local _callback = callback;
+ function callback(method, body, request)
+ local ok, result = xpcall(function() return _callback(method, body, request) end, debug_traceback);
+ if ok then return result; end
+ log("error", "Error in HTTP server handler: %s", result);
+ -- TODO: When we support pipelining, request.destroyed
+ -- won't be the right flag - we just want to see if there
+ -- has been a response to this request yet.
+ if not request.destroyed then
+ return {
+ status = "500 Internal Server Error";
+ headers = { ["Content-Type"] = "text/plain" };
+ body = "There was an error processing your request. See the error log for more details.";
+ };
+ end
+ end
if err then
log("debug", "Request error: "..err);
if not callback(nil, err, request) then
@@ -114,94 +127,21 @@ local function call_callback(request, err)
end
local function request_reader(request, data, startpos)
- if not data then
- if request.body then
- call_callback(request);
- else
- -- Error.. connection was closed prematurely
- call_callback(request, "connection-closed");
- end
- -- Here we force a destroy... the connection is gone, so we can't reply later
- destroy_request(request);
- return;
- end
- if request.state == "body" then
- log("debug", "Reading body...")
- if not request.body then request.body = {}; request.havebodylength, request.bodylength = 0, tonumber(request.headers["content-length"]); end
- if startpos then
- data = data:sub(startpos, -1)
- end
- t_insert(request.body, data);
- if request.bodylength then
- request.havebodylength = request.havebodylength + #data;
- if request.havebodylength >= request.bodylength then
- -- We have the body
- call_callback(request);
- end
- end
- elseif request.state == "headers" then
- log("debug", "Reading headers...")
- local pos = startpos;
- local headers, headers_complete = request.headers;
- if not headers then
- headers = {};
- request.headers = headers;
- end
-
- for line in data:gmatch("(.-)\r\n") do
- startpos = (startpos or 1) + #line + 2;
- local k, v = line:match("(%S+): (.+)");
- if k and v then
- headers[k:lower()] = v;
- --log("debug", "Header: '"..k:lower().."' = '"..v.."'");
- elseif #line == 0 then
- headers_complete = true;
- break;
- else
- log("debug", "Unhandled header line: "..line);
- end
- end
-
- if not headers_complete then return; end
-
- if not expectbody(request) then
+ if not request.parser then
+ local function success_cb(r)
+ for k,v in pairs(r) do request[k] = v; end
+ request.url = url_parse(request.path);
+ request.url.path = request.url.path and request.url.path:gsub("%%(%x%x)", function(x) return x.char(tonumber(x, 16)) end);
+ request.body = { request.body };
call_callback(request);
- return;
- end
-
- -- Reached the end of the headers
- request.state = "body";
- if #data > startpos then
- return request_reader(request, data:sub(startpos, -1));
- end
- elseif request.state == "request" then
- log("debug", "Reading request line...")
- local method, path, http, linelen = data:match("^(%S+) (%S+) HTTP/(%S+)\r\n()", startpos);
- if not method then
- log("warn", "Invalid HTTP status line, telling callback then closing");
- local ret = call_callback(request, "invalid-status-line");
- request:destroy();
- return ret;
end
-
- request.method, request.path, request.httpversion = method, path, http;
-
- request.url = url_parse(request.path);
-
- log("debug", method.." request for "..tostring(request.path) .. " on port "..request.handler:serverport());
-
- if request.onlystatus then
- if not call_callback(request) then
- return;
- end
- end
-
- request.state = "headers";
-
- if #data > linelen then
- return request_reader(request, data:sub(linelen, -1));
+ local function error_cb(r)
+ call_callback(request, r or "connection-closed");
+ destroy_request(request);
end
+ request.parser = httpstream_new(success_cb, error_cb);
end
+ request.parser:feed(data);
end
-- The default handler for requests
@@ -263,6 +203,7 @@ function new_from_config(ports, handle_request, default_options)
log("warn", "Old syntax of httpserver.new_from_config being used to register %s", handle_request);
handle_request, default_options = default_options, { base = handle_request };
end
+ ports = ports or {5280};
for _, options in ipairs(ports) do
local port = default_options.port or 5280;
local base = default_options.base;
@@ -285,8 +226,8 @@ function new_from_config(ports, handle_request, default_options)
ssl.options = "no_sslv2";
end
- new{ port = port, interface = interface,
- base = base, handler = handle_request,
+ new{ port = port, interface = interface,
+ base = base, handler = handle_request,
ssl = ssl, type = (ssl and "ssl") or "tcp" };
end
end
diff --git a/net/multiplex_listener.lua b/net/multiplex_listener.lua
index bf193ad8..b515ccce 100644
--- a/net/multiplex_listener.lua
+++ b/net/multiplex_listener.lua
@@ -19,6 +19,8 @@ function server.onincoming(conn, data)
if buf:match("^[a-zA-Z]") then
local listener = httpserver_listener;
conn:setlistener(listener);
+ local onconnect = listener.onconnect;
+ if onconnect then onconnect(conn) end
listener.onincoming(conn, buf);
elseif buf:match(">") then
local listener;
@@ -31,6 +33,8 @@ function server.onincoming(conn, data)
listener = xmppclient_listener;
end
conn:setlistener(listener);
+ local onconnect = listener.onconnect;
+ if onconnect then onconnect(conn) end
listener.onincoming(conn, buf);
elseif #buf > 1024 then
conn:close();
diff --git a/net/server.lua b/net/server.lua
index e0d4b85a..1c1a63a4 100644
--- a/net/server.lua
+++ b/net/server.lua
@@ -6,7 +6,7 @@
-- COPYING file in the source package for more information.
--
-local use_luaevent = require "core.configmanager".get("*", "core", "use_libevent");
+local use_luaevent = prosody and require "core.configmanager".get("*", "core", "use_libevent");
if use_luaevent then
use_luaevent = pcall(require, "luaevent.core");
diff --git a/net/server_event.lua b/net/server_event.lua
index 0331e793..528305d3 100644
--- a/net/server_event.lua
+++ b/net/server_event.lua
@@ -143,9 +143,9 @@ do
debug( "new connection failed. id:", self.id, "error:", self.fatalerror )
else
if plainssl and ssl then -- start ssl session
- self:starttls()
+ self:starttls(nil, true)
else -- normal connection
- self:_start_session( self.listener.onconnect )
+ self:_start_session(true)
end
debug( "new connection established. id:", self.id )
end
@@ -155,13 +155,15 @@ do
self.eventconnect = addevent( base, self.conn, EV_WRITE, callback, cfg.CONNECT_TIMEOUT )
return true
end
- function interface_mt:_start_session(onconnect) -- new session, for example after startssl
+ function interface_mt:_start_session(call_onconnect) -- new session, for example after startssl
if self.type == "client" then
local callback = function( )
self:_lock( false, false, false )
--vdebug( "start listening on client socket with id:", self.id )
self.eventread = addevent( base, self.conn, EV_READ, self.readcallback, cfg.READ_TIMEOUT ); -- register callback
- self:onconnect()
+ if call_onconnect then
+ self:onconnect()
+ end
self.eventsession = nil
return -1
end
@@ -173,7 +175,7 @@ do
end
return true
end
- function interface_mt:_start_ssl(arg) -- old socket will be destroyed, therefore we have to close read/write events first
+ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed, therefore we have to close read/write events first
--vdebug( "starting ssl session with client id:", self.id )
local _
_ = self.eventread and self.eventread:close( ) -- close events; this must be called outside of the event callbacks!
@@ -184,7 +186,7 @@ do
if err then
self.fatalerror = err
self.conn = nil -- cannot be used anymore
- if "onconnect" == arg then
+ if call_onconnect then
self.ondisconnect = nil -- dont call this when client isnt really connected
end
self:_close()
@@ -211,28 +213,25 @@ do
self.send = self.conn.send -- caching table lookups with new client object
self.receive = self.conn.receive
local onsomething
- if "onconnect" == arg then -- trigger listener
- onsomething = self.onconnect
- else
- onsomething = self.onsslconnection
+ if not call_onconnect then -- trigger listener
+ self:onstatus("ssl-handshake-complete");
end
- self:_start_session( onsomething )
+ self:_start_session( call_onconnect )
debug( "ssl handshake done" )
- self:onstatus("ssl-handshake-complete");
self.eventhandshake = nil
return -1
end
- debug( "error during ssl handshake:", err )
if err == "wantwrite" then
event = EV_WRITE
elseif err == "wantread" then
event = EV_READ
else
+ debug( "ssl handshake error:", err )
self.fatalerror = err
end
end
if self.fatalerror then
- if "onconnect" == arg then
+ if call_onconnect then
self.ondisconnect = nil -- dont call this when client isnt really connected
end
self:_close()
@@ -362,6 +361,10 @@ do
end
end
+ function interface_mt:socket()
+ return self.conn
+ end
+
function interface_mt:server()
return self._server or self;
end
@@ -414,7 +417,7 @@ do
-- No-op, we always use the underlying connection's send
end
- function interface_mt:starttls(sslctx)
+ function interface_mt:starttls(sslctx, call_onconnect)
debug( "try to start ssl at client id:", self.id )
local err
self._sslctx = sslctx;
@@ -428,7 +431,7 @@ do
self._usingssl = true
self.startsslcallback = function( ) -- we have to start the handshake outside of a read/write event
self.startsslcallback = nil
- self:_start_ssl();
+ self:_start_ssl(call_onconnect);
self.eventstarthandshake = nil
return -1
end
@@ -468,7 +471,6 @@ do
function interface_mt:ondrain()
end
function interface_mt:onstatus()
- debug("server.lua: Dummy onstatus()")
end
end
@@ -700,9 +702,9 @@ do
local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, nil, sslctx )
--vdebug( "client id:", clientinterface, "startssl:", startssl )
if ssl and sslctx then
- clientinterface:starttls(sslctx)
+ clientinterface:starttls(sslctx, true)
else
- clientinterface:_start_session( clientinterface.onconnect )
+ clientinterface:_start_session( true )
end
debug( "accepted incoming client connection from:", client_ip or "<unknown IP>", client_port or "<unknown port>", "to", port or "<unknown port>");
@@ -724,7 +726,7 @@ local addserver = ( function( )
--vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslcfg or "nil", startssl or "nil")
local server, err = socket.bind( addr, port, cfg.ACCEPT_QUEUE ) -- create server socket
if not server then
- debug( "creating server socket failed because:", err )
+ debug( "creating server socket on "..addr.." port "..port.." failed:", err )
return nil, err
end
local sslctx
@@ -846,7 +848,6 @@ function hook_signal(signal_num, handler)
end
local function link(sender, receiver, buffersize)
- sender:set_mode(buffersize);
local sender_locked;
function receiver:ondrain()
diff --git a/net/server_select.lua b/net/server_select.lua
index 298e560a..c3777a5f 100644
--- a/net/server_select.lua
+++ b/net/server_select.lua
@@ -32,6 +32,7 @@ local STAT_UNIT = 1 -- byte
local type = use "type"
local pairs = use "pairs"
local ipairs = use "ipairs"
+local tonumber = use "tonumber"
local tostring = use "tostring"
local collectgarbage = use "collectgarbage"
@@ -44,8 +45,9 @@ local coroutine = use "coroutine"
--// lua lib methods //--
-local os_time = os.time
local os_difftime = os.difftime
+local math_min = math.min
+local math_huge = math.huge
local table_concat = table.concat
local table_remove = table.remove
local string_len = string.len
@@ -57,6 +59,7 @@ local coroutine_yield = coroutine.yield
local luasec = use "ssl"
local luasocket = use "socket" or require "socket"
+local luasocket_gettime = luasocket.gettime
--// extern lib methods //--
@@ -74,6 +77,7 @@ local stats
local idfalse
local addtimer
local closeall
+local addsocket
local addserver
local getserver
local wrapserver
@@ -125,6 +129,8 @@ local _timer
local _maxclientsperserver
+local _maxsslhandshake
+
----------------------------------// DEFINITION //--
_server = { } -- key = port, value = table; list of listening servers
@@ -167,7 +173,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco
local connections = 0
- local dispatch, disconnect = listeners.onincoming, listeners.ondisconnect
+ local dispatch, disconnect = listeners.onconnect or listeners.onincoming, listeners.ondisconnect
local accept = socket.accept
@@ -483,7 +489,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
if drain then
drain(handler)
end
- _ = needtls and handler:starttls(nil, true)
+ _ = needtls and handler:starttls(nil)
_ = toclose and handler:close( )
return true
elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write
@@ -524,7 +530,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
_readlistlen = addsocket(_readlist, client, _readlistlen)
return true
else
- out_put( "server.lua: error during ssl handshake: ", tostring(err) )
if err == "wantwrite" and not wrote then
_sendlistlen = addsocket(_sendlist, client, _sendlistlen)
wrote = true
@@ -532,6 +537,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
_readlistlen = addsocket(_readlist, client, _readlistlen)
read = true
else
+ out_put( "server.lua: ssl handshake error: ", tostring(err) )
break;
end
--coroutine_yield( handler, nil, err ) -- handshake not finished
@@ -564,13 +570,13 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
end
else
local sslctx;
- handler.starttls = function( self, _sslctx, now )
+ handler.starttls = function( self, _sslctx)
if _sslctx then
sslctx = _sslctx;
handler:set_sslctx(sslctx);
end
- if not now then
- out_put "server.lua: we need to do tls, but delaying until later"
+ if bufferqueuelen > 0 then
+ out_put "server.lua: we need to do tls, but delaying until send buffer empty"
needtls = true
return
end
@@ -623,16 +629,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
_socketlist[ socket ] = handler
_readlistlen = addsocket(_readlist, socket, _readlistlen)
- if listeners.onconnect then
- _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
- handler.sendbuffer = function ()
- listeners.onconnect(handler);
- handler.sendbuffer = _sendbuffer;
- if bufferqueuelen > 0 then
- return _sendbuffer();
- end
- end
- end
return handler, socket
end
@@ -676,7 +672,6 @@ closesocket = function( socket )
end
local function link(sender, receiver, buffersize)
- sender:set_mode(buffersize);
local sender_locked;
local _sendbuffer = receiver.sendbuffer;
function receiver.sendbuffer()
@@ -798,16 +793,18 @@ stats = function( )
return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen
end
-local dontstop = true; -- thinking about tomorrow, ...
+local quitting;
setquitting = function (quit)
- dontstop = not quit;
- return;
+ quitting = not not quit;
end
-loop = function( ) -- this is the main loop of the program
- while dontstop do
- local read, write, err = socket_select( _readlist, _sendlist, _selecttimeout )
+loop = function(once) -- this is the main loop of the program
+ if quitting then return "quitting"; end
+ if once then quitting = "once"; end
+ local next_timer_time = math_huge;
+ repeat
+ local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) )
for i, socket in ipairs( write ) do -- send data waiting in writequeues
local handler = _socketlist[ socket ]
if handler then
@@ -831,19 +828,28 @@ loop = function( ) -- this is the main loop of the program
handler:close( true ) -- forced disconnect
end
clean( _closelist )
- _currenttime = os_time( )
- if os_difftime( _currenttime - _timer ) >= 1 then
+ _currenttime = luasocket_gettime( )
+ if _currenttime - _timer >= math_min(next_timer_time, 1) then
+ next_timer_time = math_huge;
for i = 1, _timerlistlen do
- _timerlist[ i ]( _currenttime ) -- fire timers
+ local t = _timerlist[ i ]( _currenttime ) -- fire timers
+ if t then next_timer_time = math_min(next_timer_time, t); end
end
_timer = _currenttime
+ else
+ next_timer_time = next_timer_time - (_currenttime - _timer);
end
socket_sleep( _sleeptime ) -- wait some time
--collectgarbage( )
- end
+ until quitting;
+ if once and quitting == "once" then quitting = nil; return; end
return "quitting"
end
+step = function ()
+ return loop(true);
+end
+
local function get_backend()
return "select";
end
@@ -854,6 +860,18 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx
local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
_socketlist[ socket ] = handler
_sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
+ if listeners.onconnect then
+ -- When socket is writeable, call onconnect
+ local _sendbuffer = handler.sendbuffer;
+ handler.sendbuffer = function ()
+ handler.sendbuffer = _sendbuffer;
+ listeners.onconnect(handler);
+ -- If there was data with the incoming packet, handle it now.
+ if #handler:bufferqueue() > 0 then
+ return _sendbuffer();
+ end
+ end
+ end
return handler, socket
end
@@ -879,8 +897,8 @@ use "setmetatable" ( _socketlist, { __mode = "k" } )
use "setmetatable" ( _readtimes, { __mode = "k" } )
use "setmetatable" ( _writetimes, { __mode = "k" } )
-_timer = os_time( )
-_starttime = os_time( )
+_timer = luasocket_gettime( )
+_starttime = luasocket_gettime( )
addtimer( function( )
local difftime = os_difftime( _currenttime - _starttime )
diff --git a/net/xmppclient_listener.lua b/net/xmppclient_listener.lua
index 94daa2b2..4cc90cbf 100644
--- a/net/xmppclient_listener.lua
+++ b/net/xmppclient_listener.lua
@@ -10,22 +10,19 @@
local logger = require "logger";
local log = logger.init("xmppclient_listener");
-local lxp = require "lxp"
-local init_xmlhandlers = require "core.xmlhandlers"
-local sm_new_session = require "core.sessionmanager".new_session;
+local new_xmpp_stream = require "util.xmppstream".new;
local connlisteners_register = require "net.connlisteners".register;
-local t_insert = table.insert;
-local t_concat = table.concat;
-local t_concatall = function (t, sep) local tt = {}; for _, s in ipairs(t) do t_insert(tt, tostring(s)); end return t_concat(tt, sep); end
-local m_random = math.random;
-local format = string.format;
local sessionmanager = require "core.sessionmanager";
local sm_new_session, sm_destroy_session = sessionmanager.new_session, sessionmanager.destroy_session;
local sm_streamopened = sessionmanager.streamopened;
local sm_streamclosed = sessionmanager.streamclosed;
local st = require "util.stanza";
+local xpcall = xpcall;
+local tostring = tostring;
+local type = type;
+local traceback = debug.traceback;
local config = require "core.configmanager";
local opt_keepalives = config.get("*", "core", "tcp_keepalives");
@@ -41,7 +38,7 @@ function stream_callbacks.error(session, error, data)
session:close("invalid-namespace");
elseif error == "parse-error" then
(session.log or log)("debug", "Client XML parse error: %s", tostring(data));
- session:close("xml-not-well-formed");
+ session:close("not-well-formed");
elseif error == "stream-error" then
local condition, text = "undefined-condition";
for child in data:children() do
@@ -62,9 +59,12 @@ function stream_callbacks.error(session, error, data)
end
end
-local function handleerr(err) log("error", "Traceback[c2s]: %s: %s", tostring(err), debug.traceback()); end
-function stream_callbacks.handlestanza(a, b)
- xpcall(function () core_process_stanza(a, b) end, handleerr);
+local function handleerr(err) log("error", "Traceback[c2s]: %s: %s", tostring(err), traceback()); end
+function stream_callbacks.handlestanza(session, stanza)
+ stanza = session.filter("stanzas/in", stanza);
+ if stanza then
+ return xpcall(function () return core_process_stanza(session, stanza) end, handleerr);
+ end
end
local sessions = {};
@@ -72,23 +72,6 @@ local xmppclient = { default_port = 5222, default_mode = "*a" };
-- These are session methods --
-local function session_reset_stream(session)
- -- Reset stream
- local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "\1");
- session.parser = parser;
-
- session.notopen = true;
-
- function session.data(conn, data)
- local ok, err = parser:parse(data);
- if ok then return; end
- log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " "):gsub("[%z\1-\31]", "_"));
- session:close("xml-not-well-formed");
- end
-
- return true;
-end
-
local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'};
local default_stream_attr = { ["xmlns:stream"] = "http://etherx.jabber.org/streams", xmlns = stream_callbacks.default_ns, version = "1.0", id = "" };
local function session_close(session, reason)
@@ -128,32 +111,54 @@ end
-- End of session methods --
-function xmppclient.onincoming(conn, data)
- local session = sessions[conn];
- if not session then
- session = sm_new_session(conn);
- sessions[conn] = session;
-
- session.log("info", "Client connected");
-
- -- Client is using legacy SSL (otherwise mod_tls sets this flag)
- if conn:ssl() then
- session.secure = true;
- end
-
- if opt_keepalives ~= nil then
- conn:setoption("keepalive", opt_keepalives);
+function xmppclient.onconnect(conn)
+ local session = sm_new_session(conn);
+ sessions[conn] = session;
+
+ session.log("info", "Client connected");
+
+ -- Client is using legacy SSL (otherwise mod_tls sets this flag)
+ if conn:ssl() then
+ session.secure = true;
+ end
+
+ if opt_keepalives ~= nil then
+ conn:setoption("keepalive", opt_keepalives);
+ end
+
+ session.close = session_close;
+
+ local stream = new_xmpp_stream(session, stream_callbacks);
+ session.stream = stream;
+
+ session.notopen = true;
+
+ function session.reset_stream()
+ session.notopen = true;
+ session.stream:reset();
+ end
+
+ local filter = session.filter;
+ function session.data(data)
+ data = filter("bytes/in", data);
+ if data then
+ local ok, err = stream:feed(data);
+ if ok then return; end
+ log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " "):gsub("[%z\1-\31]", "_"));
+ session:close("not-well-formed");
end
-
- session.reset_stream = session_reset_stream;
- session.close = session_close;
-
- session_reset_stream(session); -- Initialise, ready for use
-
- session.dispatch_stanza = stream_callbacks.handlestanza;
end
- if data then
- session.data(conn, data);
+
+ local handlestanza = stream_callbacks.handlestanza;
+ function session.dispatch_stanza(session, stanza)
+ return handlestanza(session, stanza);
+ end
+end
+
+function xmppclient.onincoming(conn, data)
+ local session = sessions[conn];
+ if session then
+ session.data(data);
end
end
@@ -167,4 +172,8 @@ function xmppclient.ondisconnect(conn, err)
end
end
+function xmppclient.associate_session(conn, session)
+ sessions[conn] = session;
+end
+
connlisteners_register("xmppclient", xmppclient);
diff --git a/net/xmppcomponent_listener.lua b/net/xmppcomponent_listener.lua
index b87f7c96..90293559 100644
--- a/net/xmppcomponent_listener.lua
+++ b/net/xmppcomponent_listener.lua
@@ -10,17 +10,19 @@
local hosts = _G.hosts;
local t_concat = table.concat;
+local tostring = tostring;
+local type = type;
+local pairs = pairs;
local lxp = require "lxp";
local logger = require "util.logger";
local config = require "core.configmanager";
local connlisteners = require "net.connlisteners";
-local cm_register_component = require "core.componentmanager".register_component;
-local cm_deregister_component = require "core.componentmanager".deregister_component;
local uuid_gen = require "util.uuid".generate;
+local jid_split = require "util.jid".split;
local sha1 = require "util.hashes".sha1;
local st = require "util.stanza";
-local init_xmlhandlers = require "core.xmlhandlers";
+local new_xmpp_stream = require "util.xmppstream".new;
local sessions = {};
@@ -30,7 +32,7 @@ local component_listener = { default_port = 5347; default_mode = "*a"; default_i
local xmlns_component = 'jabber:component:accept';
---- Callbacks/data for xmlhandlers to handle streams for us ---
+--- Callbacks/data for xmppstream to handle streams for us ---
local stream_callbacks = { default_ns = xmlns_component };
@@ -43,7 +45,7 @@ function stream_callbacks.error(session, error, data, data2)
session:close("invalid-namespace");
elseif error == "parse-error" then
session.log("warn", "External component %s XML parse error: %s", tostring(session.host), tostring(data));
- session:close("xml-not-well-formed");
+ session:close("not-well-formed");
elseif error == "stream-error" then
local condition, text = "undefined-condition";
for child in data:children() do
@@ -66,19 +68,16 @@ end
function stream_callbacks.streamopened(session, attr)
if config.get(attr.to, "core", "component_module") ~= "component" then
- -- Trying to act as a component domain which
+ -- Trying to act as a component domain which
-- hasn't been configured
session:close{ condition = "host-unknown", text = tostring(attr.to).." does not match any configured external components" };
return;
end
- -- Store the original host (this is used for config, etc.)
- session.user = attr.to;
- -- Set the host for future reference
- session.host = config.get(attr.to, "core", "component_address") or attr.to;
- -- Note that we don't create the internal component
+ -- Note that we don't create the internal component
-- until after the external component auths successfully
+ session.host = attr.to;
session.streamid = uuid_gen();
session.notopen = nil;
@@ -88,7 +87,7 @@ function stream_callbacks.streamopened(session, attr)
end
function stream_callbacks.streamclosed(session)
- session.log("Received </stream:stream>");
+ session.log("debug", "Received </stream:stream>");
session:close();
end
@@ -99,6 +98,31 @@ function stream_callbacks.handlestanza(session, stanza)
if not stanza.attr.xmlns and stanza.name == "handshake" then
stanza.attr.xmlns = xmlns_component;
end
+ if not stanza.attr.xmlns or stanza.attr.xmlns == "jabber:client" then
+ local from = stanza.attr.from;
+ if from then
+ if session.component_validate_from then
+ local _, domain = jid_split(stanza.attr.from);
+ if domain ~= session.host then
+ -- Return error
+ session.log("warn", "Component sent stanza with missing or invalid 'from' address");
+ session:close{
+ condition = "invalid-from";
+ text = "Component tried to send from address <"..tostring(from)
+ .."> which is not in domain <"..tostring(session.host)..">";
+ };
+ return;
+ end
+ end
+ else
+ stanza.attr.from = session.host;
+ end
+ if not stanza.attr.to then
+ session.log("warn", "Rejecting stanza with no 'to' address");
+ session.send(st.error_reply(stanza, "modify", "bad-request", "Components MUST specify a 'to' address on stanzas"));
+ return;
+ end
+ end
return core_process_stanza(session, stanza);
end
@@ -141,51 +165,48 @@ local function session_close(session, reason)
end
--- Component connlistener
-function component_listener.onincoming(conn, data)
- local session = sessions[conn];
- if not session then
- local _send = conn.write;
- session = { type = "component", conn = conn, send = function (data) return _send(conn, tostring(data)); end };
- sessions[conn] = session;
-
- -- Logging functions --
-
- local conn_name = "jcp"..tostring(conn):match("[a-f0-9]+$");
- session.log = logger.init(conn_name);
- session.close = session_close;
-
- session.log("info", "Incoming Jabber component connection");
-
- local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "\1");
- session.parser = parser;
-
+function component_listener.onconnect(conn)
+ local _send = conn.write;
+ local session = { type = "component", conn = conn, send = function (data) return _send(conn, tostring(data)); end };
+
+ -- Logging functions --
+ local conn_name = "jcp"..tostring(conn):match("[a-f0-9]+$");
+ session.log = logger.init(conn_name);
+ session.close = session_close;
+
+ session.log("info", "Incoming Jabber component connection");
+
+ local stream = new_xmpp_stream(session, stream_callbacks);
+ session.stream = stream;
+
+ session.notopen = true;
+
+ function session.reset_stream()
session.notopen = true;
-
- function session.data(conn, data)
- local ok, err = parser:parse(data);
- if ok then return; end
- log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " "):gsub("[%z\1-\31]", "_"));
- session:close("xml-not-well-formed");
- end
-
- session.dispatch_stanza = stream_callbacks.handlestanza;
-
+ session.stream:reset();
end
- if data then
- session.data(conn, data);
+
+ function session.data(conn, data)
+ local ok, err = stream:feed(data);
+ if ok then return; end
+ log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " "):gsub("[%z\1-\31]", "_"));
+ session:close("not-well-formed");
end
-end
+ session.dispatch_stanza = stream_callbacks.handlestanza;
+
+ sessions[conn] = session;
+end
+function component_listener.onincoming(conn, data)
+ local session = sessions[conn];
+ session.data(conn, data);
+end
function component_listener.ondisconnect(conn, err)
local session = sessions[conn];
if session then
(session.log or log)("info", "component disconnected: %s (%s)", tostring(session.host), tostring(err));
- if session.host then
- log("debug", "Deregistering component");
- cm_deregister_component(session.host);
- hosts[session.host].connected = nil;
- end
- sessions[conn] = nil;
+ if session.on_destroy then session:on_destroy(err); end
+ sessions[conn] = nil;
for k in pairs(session) do
if k ~= "log" and k ~= "close" then
session[k] = nil;
diff --git a/net/xmppserver_listener.lua b/net/xmppserver_listener.lua
index d1272edb..3af0b962 100644
--- a/net/xmppserver_listener.lua
+++ b/net/xmppserver_listener.lua
@@ -7,11 +7,17 @@
--
+local tostring = tostring;
+local type = type;
+local xpcall = xpcall;
+local s_format = string.format;
+local traceback = debug.traceback;
local logger = require "logger";
local log = logger.init("xmppserver_listener");
-local lxp = require "lxp"
-local init_xmlhandlers = require "core.xmlhandlers"
+local st = require "util.stanza";
+local connlisteners_register = require "net.connlisteners".register;
+local new_xmpp_stream = require "util.xmppstream".new;
local s2s_new_incoming = require "core.s2smanager".new_incoming;
local s2s_streamopened = require "core.s2smanager".streamopened;
local s2s_streamclosed = require "core.s2smanager".streamclosed;
@@ -27,7 +33,7 @@ function stream_callbacks.error(session, error, data)
session:close("invalid-namespace");
elseif error == "parse-error" then
session.log("debug", "Server-to-server XML parse error: %s", tostring(error));
- session:close("xml-not-well-formed");
+ session:close("not-well-formed");
elseif error == "stream-error" then
local condition, text = "undefined-condition";
for child in data:children() do
@@ -48,48 +54,22 @@ function stream_callbacks.error(session, error, data)
end
end
-local function handleerr(err) log("error", "Traceback[s2s]: %s: %s", tostring(err), debug.traceback()); end
-function stream_callbacks.handlestanza(a, b)
- if b.attr.xmlns == "jabber:client" then --COMPAT: Prosody pre-0.6.2 may send jabber:client
- b.attr.xmlns = nil;
+local function handleerr(err) log("error", "Traceback[s2s]: %s: %s", tostring(err), traceback()); end
+function stream_callbacks.handlestanza(session, stanza)
+ if stanza.attr.xmlns == "jabber:client" then --COMPAT: Prosody pre-0.6.2 may send jabber:client
+ stanza.attr.xmlns = nil;
+ end
+ stanza = session.filter("stanzas/in", stanza);
+ if stanza then
+ return xpcall(function () return core_process_stanza(session, stanza) end, handleerr);
end
- xpcall(function () core_process_stanza(a, b) end, handleerr);
end
-local connlisteners_register = require "net.connlisteners".register;
-
-local t_insert = table.insert;
-local t_concat = table.concat;
-local t_concatall = function (t, sep) local tt = {}; for _, s in ipairs(t) do t_insert(tt, tostring(s)); end return t_concat(tt, sep); end
-local m_random = math.random;
-local format = string.format;
-local sessionmanager = require "core.sessionmanager";
-local sm_new_session, sm_destroy_session = sessionmanager.new_session, sessionmanager.destroy_session;
-local st = require "util.stanza";
-
local sessions = {};
local xmppserver = { default_port = 5269, default_mode = "*a" };
-- These are session methods --
-local function session_reset_stream(session)
- -- Reset stream
- local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "\1");
- session.parser = parser;
-
- session.notopen = true;
-
- function session.data(conn, data)
- local ok, err = parser:parse(data);
- if ok then return; end
- (session.log or log)("warn", "Received invalid XML: %s", data);
- (session.log or log)("warn", "Problem was: %s", err);
- session:close("xml-not-well-formed");
- end
-
- return true;
-end
-
local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'};
local default_stream_attr = { ["xmlns:stream"] = "http://etherx.jabber.org/streams", xmlns = stream_callbacks.default_ns, version = "1.0", id = "" };
local function session_close(session, reason, remote_reason)
@@ -132,29 +112,55 @@ end
-- End of session methods --
-function xmppserver.onincoming(conn, data)
- local session = sessions[conn];
- if not session then
- session = s2s_new_incoming(conn);
- sessions[conn] = session;
+local function initialize_session(session)
+ local stream = new_xmpp_stream(session, stream_callbacks);
+ session.stream = stream;
+
+ session.notopen = true;
+
+ function session.reset_stream()
+ session.notopen = true;
+ session.stream:reset();
+ end
+
+ local filter = session.filter;
+ function session.data(data)
+ data = filter("bytes/in", data);
+ if data then
+ local ok, err = stream:feed(data);
+ if ok then return; end
+ (session.log or log)("warn", "Received invalid XML: %s", data);
+ (session.log or log)("warn", "Problem was: %s", err);
+ session:close("not-well-formed");
+ end
+ end
- -- Logging functions --
+ session.close = session_close;
+ local handlestanza = stream_callbacks.handlestanza;
+ function session.dispatch_stanza(session, stanza)
+ return handlestanza(session, stanza);
+ end
+end
-
+function xmppserver.onconnect(conn)
+ if not sessions[conn] then -- May be an existing outgoing session
+ local session = s2s_new_incoming(conn);
+ sessions[conn] = session;
+
+ -- Logging functions --
local conn_name = "s2sin"..tostring(conn):match("[a-f0-9]+$");
session.log = logger.init(conn_name);
session.log("info", "Incoming s2s connection");
- session.reset_stream = session_reset_stream;
- session.close = session_close;
-
- session_reset_stream(session); -- Initialise, ready for use
-
- session.dispatch_stanza = stream_callbacks.handlestanza;
+ initialize_session(session);
end
- if data then
- session.data(conn, data);
+end
+
+function xmppserver.onincoming(conn, data)
+ local session = sessions[conn];
+ if session then
+ session.data(data);
end
end
@@ -162,9 +168,9 @@ function xmppserver.onstatus(conn, status)
if status == "ssl-handshake-complete" then
local session = sessions[conn];
if session and session.direction == "outgoing" then
- local format, to_host, from_host = string.format, session.to_host, session.from_host;
+ local to_host, from_host = session.to_host, session.from_host;
session.log("debug", "Sending stream header...");
- session.sends2s(format([[<stream:stream xmlns='jabber:server' xmlns:db='jabber:server:dialback' xmlns:stream='http://etherx.jabber.org/streams' from='%s' to='%s' version='1.0'>]], from_host, to_host));
+ session.sends2s(s_format([[<stream:stream xmlns='jabber:server' xmlns:db='jabber:server:dialback' xmlns:stream='http://etherx.jabber.org/streams' from='%s' to='%s' version='1.0'>]], from_host, to_host));
end
end
end
@@ -190,12 +196,7 @@ function xmppserver.register_outgoing(conn, session)
session.direction = "outgoing";
sessions[conn] = session;
- session.reset_stream = session_reset_stream;
- session.close = session_close;
- session_reset_stream(session); -- Initialise, ready for use
-
- --local function handleerr(err) print("Traceback:", err, debug.traceback()); end
- --session.stanza_dispatch = function (stanza) return select(2, xpcall(function () return core_process_stanza(session, stanza); end, handleerr)); end
+ initialize_session(session);
end
connlisteners_register("xmppserver", xmppserver);