aboutsummaryrefslogtreecommitdiffstats
path: root/net/server_epoll.lua
diff options
context:
space:
mode:
Diffstat (limited to 'net/server_epoll.lua')
-rw-r--r--net/server_epoll.lua310
1 files changed, 249 insertions, 61 deletions
diff --git a/net/server_epoll.lua b/net/server_epoll.lua
index fa275d71..b4477375 100644
--- a/net/server_epoll.lua
+++ b/net/server_epoll.lua
@@ -6,8 +6,6 @@
--
-local t_insert = table.insert;
-local t_concat = table.concat;
local setmetatable = setmetatable;
local pcall = pcall;
local type = type;
@@ -15,24 +13,59 @@ local next = next;
local pairs = pairs;
local ipairs = ipairs;
local traceback = debug.traceback;
-local logger = require "util.logger";
+local logger = require "prosody.util.logger";
local log = logger.init("server_epoll");
local socket = require "socket";
-local luasec = require "ssl";
-local realtime = require "util.time".now;
-local monotonic = require "util.time".monotonic;
-local indexedbheap = require "util.indexedbheap";
-local createtable = require "util.table".create;
-local inet = require "util.net";
+local realtime = require "prosody.util.time".now;
+local monotonic = require "prosody.util.time".monotonic;
+local indexedbheap = require "prosody.util.indexedbheap";
+local createtable = require "prosody.util.table".create;
+local dbuffer = require "prosody.util.dbuffer";
+local inet = require "prosody.util.net";
local inet_pton = inet.pton;
local _SOCKETINVALID = socket._SOCKETINVALID or -1;
-local new_id = require "util.id".short;
-local xpcall = require "util.xpcall".xpcall;
+local new_id = require "prosody.util.id".short;
+local xpcall = require "prosody.util.xpcall".xpcall;
+local sslconfig = require "prosody.util.sslconfig";
+local tls_impl = require "prosody.net.tls_luasec";
+local have_signal, signal = pcall(require, "prosody.util.signal");
-local poller = require "util.poll"
+local poller = require "prosody.util.poll"
local EEXIST = poller.EEXIST;
local ENOENT = poller.ENOENT;
+-- systemd socket activation
+local SD_LISTEN_FDS_START = 3;
+local SD_LISTEN_FDS = tonumber(os.getenv("LISTEN_FDS")) or 0;
+
+local inherited_sockets = setmetatable({}, {
+ __index = function(t, k)
+ local serv_mt = debug.getregistry()["tcp{server}"];
+ for i = 1, SD_LISTEN_FDS do
+ local serv = socket.tcp();
+ if serv:getfd() ~= _SOCKETINVALID then
+ -- If LuaSocket allocated a FD for then we can't really close it and it would leak.
+ log("error", "LuaSocket not compatible with socket activation. Upgrade LuaSocket or disable socket activation.");
+ setmetatable(t, nil);
+ break
+ end
+ serv:setfd(SD_LISTEN_FDS_START + i - 1);
+ debug.setmetatable(serv, serv_mt);
+ serv:settimeout(0);
+ local ip, port = serv:getsockname();
+ t[ip .. ":" .. port] = serv;
+ if ip == "0.0.0.0" then
+ -- LuaSocket treats '*' as an alias for '0.0.0.0'
+ t["*:" .. port] = serv;
+ end
+ end
+
+ -- Disable lazy-loading mechanism once performed
+ setmetatable(t, nil);
+ return t[k];
+ end;
+});
+
local poll = assert(poller.new());
local _ENV = nil;
@@ -60,6 +93,15 @@ local default_config = { __index = {
-- Size of chunks to read from sockets
read_size = 8192;
+ -- Maximum size of send buffer, after which additional data is rejected
+ max_send_buffer_size = 32*1024*1024;
+
+ -- How many chunks (immutable strings) to keep in the send buffer
+ send_buffer_chunks = nil;
+
+ -- Maximum amount of data to send at once (to the TCP buffers), default based on /proc/sys/net/ipv4/tcp_wmem
+ max_send_chunk = 4*1024*1024;
+
-- Timeout used during between steps in TLS handshakes
ssl_handshake_timeout = 60;
@@ -91,6 +133,12 @@ local default_config = { __index = {
--- How long to wait after getting the shutdown signal before forcefully tearing down every socket
shutdown_deadline = 5;
+
+ -- TCP Fast Open
+ tcp_fastopen = false;
+
+ -- Defer accept until incoming data is available
+ tcp_defer_accept = false;
}};
local cfg = default_config.__index;
@@ -393,6 +441,9 @@ function interface:set(r, w)
end
if r == nil then r = self._wantread; end
if w == nil then w = self._wantwrite; end
+ if r == self._wantread and w == self._wantwrite then
+ return true
+ end
local ok, err, errno = poll:set(fd, r, w);
if not ok then
self:debug("Could not update poller state: %s(%d)", err, errno);
@@ -457,7 +508,8 @@ function interface:onreadable()
end
if err == "closed" and self._connected then
self:debug("Connection closed by remote");
- self:close(err);
+ self:on("disconnect", err);
+ self:destroy();
return;
elseif err ~= "timeout" then
self:debug("Read error, closing (%s)", err);
@@ -490,26 +542,21 @@ function interface:onwritable()
self:onconnect();
if not self.conn then return nil, "no-conn"; end -- could have been closed in onconnect
self:on("predrain");
- local buffer = self.writebuffer;
- local data = buffer or "";
- if type(buffer) == "table" then
- if buffer[3] then
- data = t_concat(data);
- elseif buffer[2] then
- data = buffer[1] .. buffer[2];
- else
- data = buffer[1] or "";
- end
- end
+ local buffer = self.writebuffer or "";
+ -- Naming things ... s/data/slice/ ?
+ local data = buffer:sub(1, cfg.max_send_chunk);
local ok, err, partial = self.conn:send(data);
self._writable = ok;
- if ok then
+ if ok and #data < #buffer then
+ -- Sent the whole 'data' but there's more in the buffer
+ ok, err, partial = nil, "timeout", ok;
+ end
+ self:debug("Sent %d out of %d buffered bytes", ok and #data or partial or 0, #buffer);
+ if ok then -- all the data we had was sent successfully
self:set(nil, false);
if cfg.keep_buffers and type(buffer) == "table" then
- for i = #buffer, 1, -1 do
- buffer[i] = nil;
- end
- else
+ buffer:discard(ok);
+ else -- string or don't keep buffers
self.writebuffer = nil;
end
self._writing = nil;
@@ -517,14 +564,10 @@ function interface:onwritable()
self:ondrain(); -- Be aware of writes in ondrain
return ok;
elseif partial then
- self:debug("Sent %d out of %d buffered bytes", partial, #data);
- if cfg.keep_buffers and type(buffer) == "table" then
- buffer[1] = data:sub(partial+1);
- for i = #buffer, 2, -1 do
- buffer[i] = nil;
- end
+ if type(buffer) == "table" then
+ buffer:discard(partial);
else
- self.writebuffer = data:sub(partial+1);
+ self.writebuffer = data:sub(partial + 1);
end
self:set(nil, true);
self:setwritetimeout();
@@ -552,13 +595,51 @@ end
-- Add data to write buffer and set flag for wanting to write
function interface:write(data)
local buffer = self.writebuffer;
- if type(buffer) == "table" then
- t_insert(buffer, data);
- elseif type(buffer) == "string" then
- self:noise("Allocating buffer!")
- self.writebuffer = { buffer, data };
- elseif buffer == nil then
+ -- (nil) -> save string
+ -- (string) -> convert to buffer (3 tables!)
+ -- (buffer) -> write to buffer
+ if not buffer then
self.writebuffer = data;
+ elseif type(buffer) == "string" then
+ local prev_buffer = buffer;
+ buffer = dbuffer.new(cfg.max_send_buffer_size, cfg.send_buffer_chunks);
+ self.writebuffer = buffer;
+ if prev_buffer then
+ -- TODO refactor, there's 3 copies of these lines
+ if not buffer:write(prev_buffer) then
+ if self._write_lock then
+ return false;
+ end
+ -- Try to flush buffer to make room
+ self:onwritable();
+ if not buffer:write(prev_buffer) then
+ self:on("disconnect", "no space left in buffer");
+ self:destroy();
+ return false;
+ end
+ end
+ end
+ if not buffer:write(data) then
+ if self._write_lock then
+ return false;
+ end
+ self:onwritable();
+ if not buffer:write(data) then
+ self:on("disconnect", "no space left in buffer");
+ self:destroy();
+ return false;
+ end
+ end
+ elseif not buffer:write(data) then
+ if self._write_lock then
+ return false;
+ end
+ self:onwritable();
+ if not buffer:write(data) then
+ self:on("disconnect", "no space left in buffer");
+ self:destroy();
+ return false;
+ end
end
if not self._write_lock and not self._writing then
if self._writable and cfg.opportunistic_writes and not self._opportunistic_write then
@@ -576,7 +657,7 @@ interface.send = interface.write;
-- Close, possibly after writing is done
function interface:close()
- if self._connected and self.writebuffer and (self.writebuffer[1] or type(self.writebuffer) == "string") then
+ if self.writebuffer and #self.writebuffer ~= 0 then
self._connected = false;
self:set(false, true); -- Flush final buffer contents
self:setreadtimeout(false);
@@ -614,10 +695,51 @@ function interface:set_sslctx(sslctx)
self._sslctx = sslctx;
end
+function interface:sslctx()
+ return self.tls_ctx
+end
+
+function interface:ssl_info()
+ local sock = self.conn;
+ if not sock then return nil, "not-connected" end
+ if not sock.info then return nil, "not-implemented"; end
+ return sock:info();
+end
+
+function interface:ssl_peercertificate()
+ local sock = self.conn;
+ if not sock then return nil, "not-connected" end
+ if not sock.getpeercertificate then return nil, "not-implemented"; end
+ return sock:getpeercertificate();
+end
+
+function interface:ssl_peerverification()
+ local sock = self.conn;
+ if not sock then return nil, "not-connected" end
+ if not sock.getpeerverification then return nil, { { "Chain verification not supported" } }; end
+ return sock:getpeerverification();
+end
+
+function interface:ssl_peerfinished()
+ local sock = self.conn;
+ if not sock then return nil, "not-connected" end
+ if not sock.getpeerfinished then return nil, "not-implemented"; end
+ return sock:getpeerfinished();
+end
+
+function interface:ssl_exportkeyingmaterial(label, len, context)
+ local sock = self.conn;
+ if not sock then return nil, "not-connected" end
+ if sock.exportkeyingmaterial then
+ return sock:exportkeyingmaterial(label, len, context);
+ end
+end
+
+
function interface:starttls(tls_ctx)
if tls_ctx then self.tls_ctx = tls_ctx; end
self.starttls = false;
- if self.writebuffer and (self.writebuffer[1] or type(self.writebuffer) == "string") then
+ if self.writebuffer and #self.writebuffer ~= 0 then
self:debug("Start TLS after write");
self.ondrain = interface.starttls;
self:set(nil, true); -- make sure wantwrite is set
@@ -641,11 +763,7 @@ function interface:inittls(tls_ctx, now)
self.starttls = false;
self:debug("Starting TLS now");
self:updatenames(); -- Can't getpeer/sockname after wrap()
- local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx);
- if not ok then
- conn, err = ok, conn;
- self:debug("Failed to initialize TLS: %s", err);
- end
+ local conn, err = self.tls_ctx:wrap(self.conn);
if not conn then
self:on("disconnect", err);
self:destroy();
@@ -656,8 +774,8 @@ function interface:inittls(tls_ctx, now)
if conn.sni then
if self.servername then
conn:sni(self.servername);
- elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
- conn:sni(self._server.hosts, true);
+ elseif next(self.tls_ctx._sni_contexts) ~= nil then
+ conn:sni(self.tls_ctx._sni_contexts, true);
end
end
if self.extra and self.extra.tlsa and conn.settlsa then
@@ -741,7 +859,6 @@ local function wrapsocket(client, server, read_size, listeners, tls_ctx, extra)
end
end
- conn:updatenames();
return conn;
end
@@ -767,6 +884,7 @@ function interface:onacceptable()
return;
end
local client = wrapsocket(conn, self, nil, self.listeners);
+ client:updatenames();
client:debug("New connection %s on server %s", client, self);
client:defaultoptions();
client._writable = cfg.opportunistic_writes;
@@ -855,7 +973,7 @@ function interface:resume_writes()
end
self:noise("Resume writes");
self._write_lock = nil;
- if self.writebuffer and (self.writebuffer[1] or type(self.writebuffer) == "string") then
+ if self.writebuffer and #self.writebuffer ~= 0 then
self:setwritetimeout();
self:set(nil, true);
end
@@ -885,11 +1003,25 @@ local function wrapserver(conn, addr, port, listeners, config)
log = logger.init(("serv%s"):format(new_id()));
}, interface_mt);
server:debug("Server %s created", server);
+ if cfg.tcp_fastopen then
+ server:setoption("tcp-fastopen", cfg.tcp_fastopen);
+ end
+ if type(cfg.tcp_defer_accept) == "number" then
+ server:setoption("tcp-defer-accept", cfg.tcp_defer_accept);
+ end
server:add(true, false);
return server;
end
local function listen(addr, port, listeners, config)
+ local inherited = inherited_sockets[addr .. ":" .. port];
+ if inherited then
+ local conn = wrapserver(inherited, addr, port, listeners, config);
+ -- sockets created by systemd must not be :close() since we may not have
+ -- privileges to create them
+ conn.destroy = interface.del;
+ return conn;
+ end
local conn, err = socket.bind(addr, port, cfg.tcp_backlog);
if not conn then return conn, err; end
conn:settimeout(0);
@@ -908,6 +1040,7 @@ end
-- COMPAT
local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx, extra)
local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra);
+ client:updatenames();
if not client.peername then
client.peername, client.peerport = addr, port;
end
@@ -941,9 +1074,13 @@ local function addclient(addr, port, listeners, read_size, tls_ctx, typ, extra)
if not conn then return conn, err; end
local ok, err = conn:settimeout(0);
if not ok then return ok, err; end
+ local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra)
+ if cfg.tcp_fastopen then
+ client:setoption("tcp-fastopen-connect", 1);
+ end
local ok, err = conn:setpeername(addr, port);
if not ok and err ~= "timeout" then return ok, err; end
- local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra)
+ client:updatenames();
local ok, err = client:init();
if not client.peername then
-- otherwise not set until connected
@@ -1032,12 +1169,38 @@ local function setquitting(quit)
end
end
+local function loop_once()
+ runtimers(); -- Ignore return value because we only do this once
+ local fd, r, w = poll:wait(0);
+ if fd then
+ local conn = fds[fd];
+ if conn then
+ if r then
+ conn:onreadable();
+ end
+ if w then
+ conn:onwritable();
+ end
+ else
+ log("debug", "Removing unknown fd %d", fd);
+ poll:del(fd);
+ end
+ else
+ return fd, r;
+ end
+end
+
-- Main loop
local function loop(once)
- repeat
- local t = runtimers(cfg.max_wait, cfg.min_wait);
+ if once then
+ return loop_once();
+ end
+
+ local t = 0;
+ while not quitting do
local fd, r, w = poll:wait(t);
- while fd do
+ if fd then
+ t = 0;
local conn = fds[fd];
if conn then
if r then
@@ -1050,15 +1213,35 @@ local function loop(once)
log("debug", "Removing unknown fd %d", fd);
poll:del(fd);
end
- fd, r, w = poll:wait(0);
- end
- if r ~= "timeout" and r ~= "signal" then
+ elseif r == "timeout" then
+ t = runtimers(cfg.max_wait, cfg.min_wait);
+ elseif r ~= "signal" then
log("debug", "epoll_wait error: %s[%d]", r, w);
end
- until once or (quitting and next(fds) == nil);
+ end
return quitting;
end
+local hook_signal;
+if have_signal and signal.signalfd then
+ local function dispatch(self)
+ return self:on("signal", self.conn:read());
+ end
+
+ function hook_signal(signum, cb)
+ local sigfd = signal.signalfd(signum);
+ if not sigfd then
+ log("error", "Could not hook signal %d", signum);
+ return nil, "failed";
+ end
+ local watch = watchfd(sigfd, dispatch);
+ watch.listeners = { onsignal = cb };
+ watch.close = nil; -- revert to default
+ watch:noise("Signal handler %d ready", signum);
+ return watch;
+ end
+end
+
return {
get_backend = function () return "epoll"; end;
addserver = addserver;
@@ -1084,6 +1267,11 @@ return {
set_config = function (newconfig)
cfg = setmetatable(newconfig, default_config);
end;
+ hook_signal = hook_signal;
+
+ tls_builder = function(basedir)
+ return sslconfig._new(tls_impl.new_context, basedir)
+ end,
-- libevent emulation
event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 };