aboutsummaryrefslogtreecommitdiffstats
path: root/net/server_epoll.lua
diff options
context:
space:
mode:
authorKim Alvefur <zash@zash.se>2018-05-16 23:57:09 +0200
committerKim Alvefur <zash@zash.se>2018-05-16 23:57:09 +0200
commitfa6dce05f891682dc5c10a6e2da4f0d4eb65ec64 (patch)
tree4a78979ef92b7a1935e9efe9b6dda2fa6d3eb658 /net/server_epoll.lua
parentfa4507823faa2887120b24c897456abf42ed3a6e (diff)
downloadprosody-fa6dce05f891682dc5c10a6e2da4f0d4eb65ec64.tar.gz
prosody-fa6dce05f891682dc5c10a6e2da4f0d4eb65ec64.zip
net.server_epoll: Use util.poll
Diffstat (limited to 'net/server_epoll.lua')
-rw-r--r--net/server_epoll.lua143
1 files changed, 74 insertions, 69 deletions
diff --git a/net/server_epoll.lua b/net/server_epoll.lua
index d556bf37..2bc7c658 100644
--- a/net/server_epoll.lua
+++ b/net/server_epoll.lua
@@ -5,8 +5,6 @@
-- COPYING file in the source package for more information.
--
--- server_epoll
--- Server backend based on https://luarocks.org/modules/zash/lua-epoll
local t_sort = table.sort;
local t_insert = table.insert;
@@ -19,14 +17,13 @@ local type = type;
local next = next;
local pairs = pairs;
local log = require "util.logger".init("server_epoll");
-local epoll = require "epoll";
local socket = require "socket";
local luasec = require "ssl";
local gettime = require "util.time".now;
local createtable = require "util.table".create;
local _SOCKETINVALID = socket._SOCKETINVALID or -1;
-assert(socket.tcp6 and socket.tcp4, "Incompatible LuaSocket version");
+local poll = require "util.poll".new();
local _ENV = nil;
-- luacheck: std none
@@ -260,48 +257,56 @@ function interface:setwritetimeout(t)
end
end
--- lua-epoll flag for currently requested poll state
-function interface:flags()
- if self._wantread then
- if self._wantwrite then
- return "rw";
- end
- return "r";
- elseif self._wantwrite then
- return "w";
+function interface:add(r, w)
+ local fd = self:getfd();
+ if fd < 0 then
+ return nil, "invalid fd";
+ end
+ if r == nil then r = self._wantread; end
+ if w == nil then w = self._wantwrite; end
+ local ok, err = poll:add(fd, r, w);
+ if not ok then
+ log("error", "Could not register %s: %s", self, err);
+ return ok, err;
end
+ self._wantread, self._wantwrite = r, w;
+ fds[fd] = self;
+ log("debug", "Registered %s", self);
+ return true;
end
--- Add or remove sockets or modify epoll flags
-function interface:setflags(r, w)
- if r ~= nil then self._wantread = r; end
- if w ~= nil then self._wantwrite = w; end
- local flags = self:flags();
- local currentflags = self._flags;
- if flags == currentflags then
- return true;
+function interface:set(r, w)
+ local fd = self:getfd();
+ if fd < 0 then
+ return nil, "invalid fd";
end
+ if r == nil then r = self._wantread; end
+ if w == nil then w = self._wantwrite; end
+ local ok, err = poll:set(fd, r, w);
+ if not ok then
+ log("error", "Could not update poller state %s: %s", self, err);
+ return ok, err;
+ end
+ self._wantread, self._wantwrite = r, w;
+ return true;
+end
+
+function interface:del()
local fd = self:getfd();
if fd < 0 then
- self._wantread, self._wantwrite = nil, nil;
return nil, "invalid fd";
end
- local op = "mod";
- if not flags then
- op = "del";
- elseif not currentflags then
- op = "add";
- end
- local ok, err = epoll.ctl(op, fd, flags);
--- log("debug", "epoll_ctl(%q, %d, %q) -> %s" .. (err and ", %q" or ""),
--- op, fd, flags or "", tostring(ok), err);
- if not ok then return ok, err end
- if op == "add" then
- fds[fd] = self;
- elseif op == "del" then
- fds[fd] = nil;
- end
- self._flags = flags;
+ if fds[fd] ~= self then
+ return nil, "unregistered fd";
+ end
+ local ok, err = poll:del(fd);
+ if not ok then
+ log("error", "Could not unregister %s: %s", self, err);
+ return ok, err;
+ end
+ self._wantread, self._wantwrite = nil, nil;
+ fds[fd] = nil;
+ log("debug", "Unregistered %s", self);
return true;
end
@@ -317,9 +322,9 @@ function interface:onreadable()
self:on("incoming", partial, err);
end
if err == "wantread" then
- self:setflags(true, nil);
+ self:set(true, nil);
elseif err == "wantwrite" then
- self:setflags(nil, true);
+ self:set(nil, true);
elseif err ~= "timeout" then
self:on("disconnect", err);
self:destroy()
@@ -343,7 +348,7 @@ function interface:onwritable()
local data = t_concat(buffer);
local ok, err, partial = self.conn:send(data);
if ok then
- self:setflags(nil, false);
+ self:set(nil, false);
for i = #buffer, 1, -1 do
buffer[i] = nil;
end
@@ -358,9 +363,9 @@ function interface:onwritable()
self:setwritetimeout();
end
if err == "wantwrite" or err == "timeout" then
- self:setflags(nil, true);
+ self:set(nil, true);
elseif err == "wantread" then
- self:setflags(true, nil);
+ self:set(true, nil);
elseif err ~= "timeout" then
self:on("disconnect", err);
self:destroy();
@@ -381,7 +386,7 @@ function interface:write(data)
self.writebuffer = { data };
end
self:setwritetimeout();
- self:setflags(nil, true);
+ self:set(nil, true);
return #data;
end
interface.send = interface.write;
@@ -389,7 +394,7 @@ interface.send = interface.write;
-- Close, possibly after writing is done
function interface:close()
if self.writebuffer and self.writebuffer[1] then
- self:setflags(false, true); -- Flush final buffer contents
+ self:set(false, true); -- Flush final buffer contents
self.write, self.send = noop, noop; -- No more writing
log("debug", "Close %s after writing", self);
self.ondrain = interface.close;
@@ -403,7 +408,7 @@ function interface:close()
end
function interface:destroy()
- self:setflags(false, false);
+ self:del();
self:setwritetimeout(false);
self:setreadtimeout(false);
self.onreadable = noop;
@@ -425,10 +430,10 @@ function interface:starttls(tls_ctx)
log("debug", "Start TLS on %s after write", self);
self.ondrain = interface.starttls;
self.starttls = false;
- self:setflags(nil, true); -- make sure wantwrite is set
+ self:set(nil, true); -- make sure wantwrite is set
else
log("debug", "Start TLS on %s now", self);
- self:setflags(false, false);
+ self:del();
local conn, err = luasec.wrap(self.conn, tls_ctx or self.tls_ctx);
if not conn then
self:on("disconnect", err);
@@ -440,8 +445,7 @@ function interface:starttls(tls_ctx)
self.ondrain = nil;
self.onwritable = interface.tlshandskake;
self.onreadable = interface.tlshandskake;
- self:setflags(true, true);
- self:setwritetimeout(cfg.handshake_timeout);
+ return self:init();
end
end
@@ -455,14 +459,15 @@ function interface:tlshandskake()
self.onreadable = nil;
self._tls = true;
self:on("status", "ssl-handshake-complete");
- self:init();
+ self:setwritetimeout();
+ self:set(true, true);
elseif err == "wantread" then
log("debug", "TLS handshake on %s to wait until readable", self);
- self:setflags(true, false);
+ self:set(true, false);
self:setreadtimeout(cfg.handshake_timeout);
elseif err == "wantwrite" then
log("debug", "TLS handshake on %s to wait until writable", self);
- self:setflags(false, true);
+ self:set(false, true);
self:setwritetimeout(cfg.handshake_timeout);
else
log("debug", "TLS handshake error on %s: %s", self, err);
@@ -513,15 +518,15 @@ end
-- Initialization
function interface:init()
self:setwritetimeout();
- return self:setflags(true, true);
+ return self:add(true, true);
end
function interface:pause()
- return self:setflags(false);
+ return self:set(false);
end
function interface:resume()
- return self:setflags(true);
+ return self:set(true);
end
-- Pause connection for some time
@@ -530,13 +535,13 @@ function interface:pausefor(t)
self._pausefor:close();
end
if t == false then return; end
- self:setflags(false);
+ self:set(false);
self._pausefor = addtimer(t, function ()
self._pausefor = nil;
if self.conn and self.conn:dirty() then
self:onreadable();
end
- self:setflags(true);
+ self:set(true);
end);
end
@@ -564,7 +569,7 @@ local function addserver(addr, port, listeners, read_size, tls_ctx)
sockname = addr;
sockport = port;
}, interface_mt);
- server:setflags(true, false);
+ server:add(true, false);
return server;
end
@@ -603,7 +608,7 @@ local function watchfd(fd, onreadable, onwriteable)
onreadable = onreadable;
onwriteable = onwriteable;
close = function (self)
- self:setflags(false, false);
+ self:del();
end
}, interface_mt);
if type(fd) == "number" then
@@ -612,7 +617,7 @@ local function watchfd(fd, onreadable, onwriteable)
end;
-- Otherwise it'll need to be something LuaSocket-compatible
end
- conn:setflags(onreadable, onwriteable);
+ conn:add(onreadable, onwriteable);
return conn;
end;
@@ -629,8 +634,8 @@ local function link(from, to)
from:resume();
end,
}, {__index=to.listeners});
- from:setflags(true, nil);
- to:setflags(nil, true);
+ from:set(true, nil);
+ to:set(nil, true);
end
-- XXX What uses this?
@@ -662,7 +667,7 @@ end
local function loop(once)
repeat
local t = runtimers(cfg.max_wait, cfg.min_wait);
- local fd, r, w = epoll.wait(t);
+ local fd, r, w = poll:wait(t);
if fd then
local conn = fds[fd];
if conn then
@@ -674,7 +679,7 @@ local function loop(once)
end
else
log("debug", "Removing unknown fd %d", fd);
- epoll.ctl("del", fd);
+ poll:del(fd);
end
elseif r ~= "timeout" then
log("debug", "epoll_wait error: %s", tostring(r));
@@ -705,9 +710,9 @@ return {
local function onevent(self)
local ret = self:callback();
if ret == -1 then
- self:setflags(false, false);
+ self:set(false, false);
elseif ret then
- self:setflags(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
+ self:set(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
end
end
@@ -717,11 +722,11 @@ return {
onreadable = onevent;
onwritable = onevent;
close = function (self)
- self:setflags(false, false);
+ self:del();
fds[fd] = nil;
end;
}, interface_mt);
- local ok, err = conn:setflags(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
+ local ok, err = conn:add(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
if not ok then return ok, err; end
return conn;
end;