aboutsummaryrefslogtreecommitdiffstats
path: root/net/server_epoll.lua
diff options
context:
space:
mode:
Diffstat (limited to 'net/server_epoll.lua')
-rw-r--r--net/server_epoll.lua786
1 files changed, 786 insertions, 0 deletions
diff --git a/net/server_epoll.lua b/net/server_epoll.lua
new file mode 100644
index 00000000..e0189179
--- /dev/null
+++ b/net/server_epoll.lua
@@ -0,0 +1,786 @@
+-- Prosody IM
+-- Copyright (C) 2016-2018 Kim Alvefur
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+
+local t_sort = table.sort;
+local t_insert = table.insert;
+local t_remove = table.remove;
+local t_concat = table.concat;
+local setmetatable = setmetatable;
+local tostring = tostring;
+local pcall = pcall;
+local type = type;
+local next = next;
+local pairs = pairs;
+local log = require "util.logger".init("server_epoll");
+local socket = require "socket";
+local luasec = require "ssl";
+local gettime = require "util.time".now;
+local createtable = require "util.table".create;
+local inet = require "util.net";
+local inet_pton = inet.pton;
+local _SOCKETINVALID = socket._SOCKETINVALID or -1;
+
+local poll = assert(require "util.poll".new());
+
+local _ENV = nil;
+-- luacheck: std none
+
+local default_config = { __index = {
+ read_timeout = 14 * 60;
+ write_timeout = 7;
+ tcp_backlog = 128;
+ accept_retry_interval = 10;
+ read_retry_delay = 1e-06;
+ read_size = 8192;
+ connect_timeout = 20;
+ handshake_timeout = 60;
+ max_wait = 86400;
+ min_wait = 1e-06;
+}};
+local cfg = default_config.__index;
+
+local fds = createtable(10, 0); -- FD -> conn
+
+-- Timer and scheduling --
+
+local timers = {};
+
+local function noop() end
+local function closetimer(t)
+ t[1] = 0;
+ t[2] = noop;
+end
+
+-- Set to true when timers have changed
+local resort_timers = false;
+
+-- Add absolute timer
+local function at(time, f)
+ local timer = { time, f, close = closetimer };
+ t_insert(timers, timer);
+ resort_timers = true;
+ return timer;
+end
+
+-- Add relative timer
+local function addtimer(timeout, f)
+ return at(gettime() + timeout, f);
+end
+
+-- Run callbacks of expired timers
+-- Return time until next timeout
+local function runtimers(next_delay, min_wait)
+ -- Any timers at all?
+ if not timers[1] then
+ return next_delay;
+ end
+
+ if resort_timers then
+ -- Sort earliest timers to the end
+ t_sort(timers, function (a, b) return a[1] > b[1]; end);
+ resort_timers = false;
+ end
+
+ -- Iterate from the end and remove completed timers
+ for i = #timers, 1, -1 do
+ local timer = timers[i];
+ local t, f = timer[1], timer[2];
+ -- Get time for every iteration to increase accuracy
+ local now = gettime();
+ if t > now then
+ -- This timer should not fire yet
+ local diff = t - now;
+ if diff < next_delay then
+ next_delay = diff;
+ end
+ break;
+ end
+ local new_timeout = f(now);
+ if new_timeout then
+ -- Schedule for 'delay' from the time actually scheduled,
+ -- not from now, in order to prevent timer drift.
+ timer[1] = t + new_timeout;
+ resort_timers = true;
+ else
+ t_remove(timers, i);
+ end
+ end
+
+ if resort_timers or next_delay < min_wait then
+ -- Timers may be added from within a timer callback.
+ -- Those would not be considered for next_delay,
+ -- and we might sleep for too long, so instead
+ -- we return a shorter timeout so we can
+ -- properly sort all new timers.
+ next_delay = min_wait;
+ end
+
+ return next_delay;
+end
+
+-- Socket handler interface
+
+local interface = {};
+local interface_mt = { __index = interface };
+
+function interface_mt:__tostring()
+ if self.sockname and self.peername then
+ return ("FD %d (%s, %d, %s, %d)"):format(self:getfd(), self.peername, self.peerport, self.sockname, self.sockport);
+ elseif self.sockname or self.peername then
+ return ("FD %d (%s, %d)"):format(self:getfd(), self.sockname or self.peername, self.sockport or self.peerport);
+ end
+ return ("FD %d"):format(self:getfd());
+end
+
+-- Replace the listener and tell the old one
+function interface:setlistener(listeners, data)
+ self:on("detach");
+ self.listeners = listeners;
+ self:on("attach", data);
+end
+
+-- Call a listener callback
+function interface:on(what, ...)
+ if not self.listeners then
+ log("error", "%s has no listeners", self);
+ return;
+ end
+ local listener = self.listeners["on"..what];
+ if not listener then
+ -- log("debug", "Missing listener 'on%s'", what); -- uncomment for development and debugging
+ return;
+ end
+ local ok, err = pcall(listener, self, ...);
+ if not ok then
+ log("error", "Error calling on%s: %s", what, err);
+ end
+ return err;
+end
+
+-- Return the file descriptor number
+function interface:getfd()
+ if self.conn then
+ return self.conn:getfd();
+ end
+ return _SOCKETINVALID;
+end
+
+function interface:server()
+ return self._server or self;
+end
+
+-- Get IP address
+function interface:ip()
+ return self.peername or self.sockname;
+end
+
+-- Get a port number, doesn't matter which
+function interface:port()
+ return self.sockport or self.peerport;
+end
+
+-- Get local port number
+function interface:clientport()
+ return self.sockport;
+end
+
+-- Get remote port
+function interface:serverport()
+ if self.sockport then
+ return self.sockport;
+ elseif self._server then
+ self._server:port();
+ end
+end
+
+-- Return underlying socket
+function interface:socket()
+ return self.conn;
+end
+
+function interface:set_mode(new_mode)
+ self.read_size = new_mode;
+end
+
+function interface:setoption(k, v)
+ -- LuaSec doesn't expose setoption :(
+ if self.conn.setoption then
+ self.conn:setoption(k, v);
+ end
+end
+
+-- Timeout for detecting dead or idle sockets
+function interface:setreadtimeout(t)
+ if t == false then
+ if self._readtimeout then
+ self._readtimeout:close();
+ self._readtimeout = nil;
+ end
+ return
+ end
+ t = t or cfg.read_timeout;
+ if self._readtimeout then
+ self._readtimeout[1] = gettime() + t;
+ resort_timers = true;
+ else
+ self._readtimeout = addtimer(t, function ()
+ if self:on("readtimeout") then
+ return cfg.read_timeout;
+ else
+ self:on("disconnect", "read timeout");
+ self:destroy();
+ end
+ end);
+ end
+end
+
+-- Timeout for detecting dead sockets
+function interface:setwritetimeout(t)
+ if t == false then
+ if self._writetimeout then
+ self._writetimeout:close();
+ self._writetimeout = nil;
+ end
+ return
+ end
+ t = t or cfg.write_timeout;
+ if self._writetimeout then
+ self._writetimeout[1] = gettime() + t;
+ resort_timers = true;
+ else
+ self._writetimeout = addtimer(t, function ()
+ self:on("disconnect", "write timeout");
+ self:destroy();
+ end);
+ end
+end
+
+function interface:add(r, w)
+ local fd = self:getfd();
+ if fd < 0 then
+ return nil, "invalid fd";
+ end
+ if r == nil then r = self._wantread; end
+ if w == nil then w = self._wantwrite; end
+ local ok, err, errno = poll:add(fd, r, w);
+ if not ok then
+ log("error", "Could not register %s: %s(%d)", self, err, errno);
+ return ok, err;
+ end
+ self._wantread, self._wantwrite = r, w;
+ fds[fd] = self;
+ log("debug", "Watching %s", self);
+ return true;
+end
+
+function interface:set(r, w)
+ local fd = self:getfd();
+ if fd < 0 then
+ return nil, "invalid fd";
+ end
+ if r == nil then r = self._wantread; end
+ if w == nil then w = self._wantwrite; end
+ local ok, err, errno = poll:set(fd, r, w);
+ if not ok then
+ log("error", "Could not update poller state %s: %s(%d)", self, err, errno);
+ return ok, err;
+ end
+ self._wantread, self._wantwrite = r, w;
+ return true;
+end
+
+function interface:del()
+ local fd = self:getfd();
+ if fd < 0 then
+ return nil, "invalid fd";
+ end
+ if fds[fd] ~= self then
+ return nil, "unregistered fd";
+ end
+ local ok, err, errno = poll:del(fd);
+ if not ok then
+ log("error", "Could not unregister %s: %s(%d)", self, err, errno);
+ return ok, err;
+ end
+ self._wantread, self._wantwrite = nil, nil;
+ fds[fd] = nil;
+ log("debug", "Unwatched %s", self);
+ return true;
+end
+
+function interface:setflags(r, w)
+ if not(self._wantread or self._wantwrite) then
+ if not(r or w) then
+ return true; -- no change
+ end
+ return self:add(r, w);
+ end
+ if not(r or w) then
+ return self:del();
+ end
+ return self:set(r, w);
+end
+
+-- Called when socket is readable
+function interface:onreadable()
+ local data, err, partial = self.conn:receive(self.read_size or cfg.read_size);
+ if data then
+ self:onconnect();
+ self:on("incoming", data);
+ else
+ if partial and partial ~= "" then
+ self:onconnect();
+ self:on("incoming", partial, err);
+ end
+ if err == "wantread" then
+ self:set(true, nil);
+ elseif err == "wantwrite" then
+ self:set(nil, true);
+ elseif err ~= "timeout" then
+ self:on("disconnect", err);
+ self:destroy()
+ return;
+ end
+ end
+ if not self.conn then return; end
+ if self.conn:dirty() then
+ self:setreadtimeout(false);
+ self:pausefor(cfg.read_retry_delay);
+ else
+ self:setreadtimeout();
+ end
+end
+
+-- Called when socket is writable
+function interface:onwritable()
+ self:onconnect();
+ if not self.conn then return; end -- could have been closed in onconnect
+ local buffer = self.writebuffer;
+ local data = t_concat(buffer);
+ local ok, err, partial = self.conn:send(data);
+ if ok then
+ self:set(nil, false);
+ for i = #buffer, 1, -1 do
+ buffer[i] = nil;
+ end
+ self:setwritetimeout(false);
+ self:ondrain(); -- Be aware of writes in ondrain
+ return;
+ elseif partial then
+ buffer[1] = data:sub(partial+1);
+ for i = #buffer, 2, -1 do
+ buffer[i] = nil;
+ end
+ self:setwritetimeout();
+ end
+ if err == "wantwrite" or err == "timeout" then
+ self:set(nil, true);
+ elseif err == "wantread" then
+ self:set(true, nil);
+ elseif err ~= "timeout" then
+ self:on("disconnect", err);
+ self:destroy();
+ end
+end
+
+-- The write buffer has been successfully emptied
+function interface:ondrain()
+ return self:on("drain");
+end
+
+-- Add data to write buffer and set flag for wanting to write
+function interface:write(data)
+ local buffer = self.writebuffer;
+ if buffer then
+ t_insert(buffer, data);
+ else
+ self.writebuffer = { data };
+ end
+ self:setwritetimeout();
+ self:set(nil, true);
+ return #data;
+end
+interface.send = interface.write;
+
+-- Close, possibly after writing is done
+function interface:close()
+ if self.writebuffer and self.writebuffer[1] then
+ self:set(false, true); -- Flush final buffer contents
+ self.write, self.send = noop, noop; -- No more writing
+ log("debug", "Close %s after writing", self);
+ self.ondrain = interface.close;
+ else
+ log("debug", "Close %s now", self);
+ self.write, self.send = noop, noop;
+ self.close = noop;
+ self:on("disconnect");
+ self:destroy();
+ end
+end
+
+function interface:destroy()
+ self:del();
+ self:setwritetimeout(false);
+ self:setreadtimeout(false);
+ self.onreadable = noop;
+ self.onwritable = noop;
+ self.destroy = noop;
+ self.close = noop;
+ self.on = noop;
+ self.conn:close();
+ self.conn = nil;
+end
+
+function interface:ssl()
+ return self._tls;
+end
+
+function interface:starttls(tls_ctx)
+ if tls_ctx then self.tls_ctx = tls_ctx; end
+ self.starttls = false;
+ if self.writebuffer and self.writebuffer[1] then
+ log("debug", "Start TLS on %s after write", self);
+ self.ondrain = interface.starttls;
+ self:set(nil, true); -- make sure wantwrite is set
+ else
+ if self.ondrain == interface.starttls then
+ self.ondrain = nil;
+ end
+ self.onwritable = interface.tlshandskake;
+ self.onreadable = interface.tlshandskake;
+ self:set(true, true);
+ log("debug", "Prepare to start TLS on %s", self);
+ end
+end
+
+function interface:tlshandskake()
+ self:setwritetimeout(false);
+ self:setreadtimeout(false);
+ if not self._tls then
+ self._tls = true;
+ log("debug", "Start TLS on %s now", self);
+ self:del();
+ local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx);
+ if not ok then
+ log("error", "Failed to initialize TLS: %s", conn);
+ conn, err = ok, conn;
+ end
+ if not conn then
+ self:on("disconnect", err);
+ self:destroy();
+ return conn, err;
+ end
+ conn:settimeout(0);
+ self.conn = conn;
+ self:on("starttls");
+ self.ondrain = nil;
+ self.onwritable = interface.tlshandskake;
+ self.onreadable = interface.tlshandskake;
+ return self:init();
+ end
+ local ok, err = self.conn:dohandshake();
+ if ok then
+ log("debug", "TLS handshake on %s complete", self);
+ self.onwritable = nil;
+ self.onreadable = nil;
+ self:on("status", "ssl-handshake-complete");
+ self:setwritetimeout();
+ self:set(true, true);
+ elseif err == "wantread" then
+ log("debug", "TLS handshake on %s to wait until readable", self);
+ self:set(true, false);
+ self:setreadtimeout(cfg.handshake_timeout);
+ elseif err == "wantwrite" then
+ log("debug", "TLS handshake on %s to wait until writable", self);
+ self:set(false, true);
+ self:setwritetimeout(cfg.handshake_timeout);
+ else
+ log("debug", "TLS handshake error on %s: %s", self, err);
+ self:on("disconnect", err);
+ self:destroy();
+ end
+end
+
+local function wrapsocket(client, server, read_size, listeners, tls_ctx) -- luasocket object -> interface object
+ client:settimeout(0);
+ local conn = setmetatable({
+ conn = client;
+ _server = server;
+ created = gettime();
+ listeners = listeners;
+ read_size = read_size or (server and server.read_size);
+ writebuffer = {};
+ tls_ctx = tls_ctx or (server and server.tls_ctx);
+ tls_direct = server and server.tls_direct;
+ }, interface_mt);
+
+ conn:updatenames();
+ return conn;
+end
+
+function interface:updatenames()
+ local conn = self.conn;
+ local ok, peername, peerport = pcall(conn.getpeername, conn);
+ if ok then
+ self.peername, self.peerport = peername, peerport;
+ end
+ local ok, sockname, sockport = pcall(conn.getsockname, conn);
+ if ok then
+ self.sockname, self.sockport = sockname, sockport;
+ end
+end
+
+-- A server interface has new incoming connections waiting
+-- This replaces the onreadable callback
+function interface:onacceptable()
+ local conn, err = self.conn:accept();
+ if not conn then
+ log("debug", "Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval);
+ self:pausefor(cfg.accept_retry_interval);
+ return;
+ end
+ local client = wrapsocket(conn, self, nil, self.listeners);
+ log("debug", "New connection %s", tostring(client));
+ client:init();
+ if self.tls_direct then
+ client:starttls(self.tls_ctx);
+ end
+end
+
+-- Initialization
+function interface:init()
+ self:setwritetimeout();
+ return self:add(true, true);
+end
+
+function interface:pause()
+ return self:set(false);
+end
+
+function interface:resume()
+ return self:set(true);
+end
+
+-- Pause connection for some time
+function interface:pausefor(t)
+ if self._pausefor then
+ self._pausefor:close();
+ end
+ if t == false then return; end
+ self:set(false);
+ self._pausefor = addtimer(t, function ()
+ self._pausefor = nil;
+ if self.conn and self.conn:dirty() then
+ self:onreadable();
+ end
+ self:set(true);
+ end);
+end
+
+-- Connected!
+function interface:onconnect()
+ if self.conn and not self.peername and self.conn.getpeername then
+ self.peername, self.peerport = self.conn:getpeername();
+ end
+ self.onconnect = noop;
+ self:on("connect");
+end
+
+local function addserver(addr, port, listeners, read_size, tls_ctx)
+ local conn, err = socket.bind(addr, port, cfg.tcp_backlog);
+ if not conn then return conn, err; end
+ conn:settimeout(0);
+ local server = setmetatable({
+ conn = conn;
+ created = gettime();
+ listeners = listeners;
+ read_size = read_size;
+ onreadable = interface.onacceptable;
+ tls_ctx = tls_ctx;
+ tls_direct = tls_ctx and true or false;
+ sockname = addr;
+ sockport = port;
+ }, interface_mt);
+ server:add(true, false);
+ return server;
+end
+
+-- COMPAT
+local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx)
+ local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx);
+ if not client.peername then
+ client.peername, client.peerport = addr, port;
+ end
+ local ok, err = client:init();
+ if not ok then return ok, err; end
+ if tls_ctx then
+ client:starttls(tls_ctx);
+ end
+ return client;
+end
+
+-- New outgoing TCP connection
+local function addclient(addr, port, listeners, read_size, tls_ctx, typ)
+ local create;
+ if not typ then
+ local n = inet_pton(addr);
+ if not n then return nil, "invalid-ip"; end
+ if #n == 16 then
+ typ = "tcp6";
+ else
+ typ = "tcp4";
+ end
+ end
+ if typ then
+ create = socket[typ];
+ end
+ if type(create) ~= "function" then
+ return nil, "invalid socket type";
+ end
+ local conn, err = create();
+ local ok, err = conn:settimeout(0);
+ if not ok then return ok, err; end
+ local ok, err = conn:setpeername(addr, port);
+ if not ok and err ~= "timeout" then return ok, err; end
+ local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx)
+ local ok, err = client:init();
+ if not ok then return ok, err; end
+ if tls_ctx then
+ client:starttls(tls_ctx);
+ end
+ return client, conn;
+end
+
+local function watchfd(fd, onreadable, onwritable)
+ local conn = setmetatable({
+ conn = fd;
+ onreadable = onreadable;
+ onwritable = onwritable;
+ close = function (self)
+ self:del();
+ end
+ }, interface_mt);
+ if type(fd) == "number" then
+ conn.getfd = function ()
+ return fd;
+ end;
+ -- Otherwise it'll need to be something LuaSocket-compatible
+ end
+ conn:add(onreadable, onwritable);
+ return conn;
+end;
+
+-- Dump all data from one connection into another
+local function link(from, to)
+ from.listeners = setmetatable({
+ onincoming = function (_, data)
+ from:pause();
+ to:write(data);
+ end,
+ }, {__index=from.listeners});
+ to.listeners = setmetatable({
+ ondrain = function ()
+ from:resume();
+ end,
+ }, {__index=to.listeners});
+ from:set(true, nil);
+ to:set(nil, true);
+end
+
+-- COMPAT
+-- net.adns calls this but then replaces :send so this can be a noop
+function interface:set_send(new_send) -- luacheck: ignore 212
+end
+
+-- Close all connections and servers
+local function closeall()
+ for fd, conn in pairs(fds) do -- luacheck: ignore 213/fd
+ conn:close();
+ end
+end
+
+local quitting = nil;
+
+-- Signal main loop about shutdown via above upvalue
+local function setquitting(quit)
+ if quit then
+ quitting = "quitting";
+ closeall();
+ else
+ quitting = nil;
+ end
+end
+
+-- Main loop
+local function loop(once)
+ repeat
+ local t = runtimers(cfg.max_wait, cfg.min_wait);
+ local fd, r, w = poll:wait(t);
+ if fd then
+ local conn = fds[fd];
+ if conn then
+ if r then
+ conn:onreadable();
+ end
+ if w then
+ conn:onwritable();
+ end
+ else
+ log("debug", "Removing unknown fd %d", fd);
+ poll:del(fd);
+ end
+ elseif r ~= "timeout" then
+ log("debug", "epoll_wait error: %s[%d]", r, w);
+ end
+ until once or (quitting and next(fds) == nil);
+ return quitting;
+end
+
+return {
+ get_backend = function () return "epoll"; end;
+ addserver = addserver;
+ addclient = addclient;
+ add_task = addtimer;
+ at = at;
+ loop = loop;
+ closeall = closeall;
+ setquitting = setquitting;
+ wrapclient = wrapclient;
+ watchfd = watchfd;
+ link = link;
+ set_config = function (newconfig)
+ cfg = setmetatable(newconfig, default_config);
+ end;
+
+ -- libevent emulation
+ event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 };
+ addevent = function (fd, mode, callback)
+ local function onevent(self)
+ local ret = self:callback();
+ if ret == -1 then
+ self:set(false, false);
+ elseif ret then
+ self:set(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
+ end
+ end
+
+ local conn = setmetatable({
+ getfd = function () return fd; end;
+ callback = callback;
+ onreadable = onevent;
+ onwritable = onevent;
+ close = function (self)
+ self:del();
+ fds[fd] = nil;
+ end;
+ }, interface_mt);
+ local ok, err = conn:add(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
+ if not ok then return ok, err; end
+ return conn;
+ end;
+};