From aaa28d9ab043d4f6b0f07f659f2ce5747db88486 Mon Sep 17 00:00:00 2001 From: Kim Alvefur Date: Wed, 16 May 2018 23:57:09 +0200 Subject: net.server_epoll: Use util.poll --- net/server_epoll.lua | 143 ++++++++++++++++++++++++++------------------------- 1 file changed, 74 insertions(+), 69 deletions(-) diff --git a/net/server_epoll.lua b/net/server_epoll.lua index d556bf37..2bc7c658 100644 --- a/net/server_epoll.lua +++ b/net/server_epoll.lua @@ -5,8 +5,6 @@ -- COPYING file in the source package for more information. -- --- server_epoll --- Server backend based on https://luarocks.org/modules/zash/lua-epoll local t_sort = table.sort; local t_insert = table.insert; @@ -19,14 +17,13 @@ local type = type; local next = next; local pairs = pairs; local log = require "util.logger".init("server_epoll"); -local epoll = require "epoll"; local socket = require "socket"; local luasec = require "ssl"; local gettime = require "util.time".now; local createtable = require "util.table".create; local _SOCKETINVALID = socket._SOCKETINVALID or -1; -assert(socket.tcp6 and socket.tcp4, "Incompatible LuaSocket version"); +local poll = require "util.poll".new(); local _ENV = nil; -- luacheck: std none @@ -260,48 +257,56 @@ function interface:setwritetimeout(t) end end --- lua-epoll flag for currently requested poll state -function interface:flags() - if self._wantread then - if self._wantwrite then - return "rw"; - end - return "r"; - elseif self._wantwrite then - return "w"; +function interface:add(r, w) + local fd = self:getfd(); + if fd < 0 then + return nil, "invalid fd"; + end + if r == nil then r = self._wantread; end + if w == nil then w = self._wantwrite; end + local ok, err = poll:add(fd, r, w); + if not ok then + log("error", "Could not register %s: %s", self, err); + return ok, err; end + self._wantread, self._wantwrite = r, w; + fds[fd] = self; + log("debug", "Registered %s", self); + return true; end --- Add or remove sockets or modify epoll flags -function interface:setflags(r, w) - if r ~= nil then self._wantread = r; end - if w ~= nil then self._wantwrite = w; end - local flags = self:flags(); - local currentflags = self._flags; - if flags == currentflags then - return true; +function interface:set(r, w) + local fd = self:getfd(); + if fd < 0 then + return nil, "invalid fd"; end + if r == nil then r = self._wantread; end + if w == nil then w = self._wantwrite; end + local ok, err = poll:set(fd, r, w); + if not ok then + log("error", "Could not update poller state %s: %s", self, err); + return ok, err; + end + self._wantread, self._wantwrite = r, w; + return true; +end + +function interface:del() local fd = self:getfd(); if fd < 0 then - self._wantread, self._wantwrite = nil, nil; return nil, "invalid fd"; end - local op = "mod"; - if not flags then - op = "del"; - elseif not currentflags then - op = "add"; - end - local ok, err = epoll.ctl(op, fd, flags); --- log("debug", "epoll_ctl(%q, %d, %q) -> %s" .. (err and ", %q" or ""), --- op, fd, flags or "", tostring(ok), err); - if not ok then return ok, err end - if op == "add" then - fds[fd] = self; - elseif op == "del" then - fds[fd] = nil; - end - self._flags = flags; + if fds[fd] ~= self then + return nil, "unregistered fd"; + end + local ok, err = poll:del(fd); + if not ok then + log("error", "Could not unregister %s: %s", self, err); + return ok, err; + end + self._wantread, self._wantwrite = nil, nil; + fds[fd] = nil; + log("debug", "Unregistered %s", self); return true; end @@ -317,9 +322,9 @@ function interface:onreadable() self:on("incoming", partial, err); end if err == "wantread" then - self:setflags(true, nil); + self:set(true, nil); elseif err == "wantwrite" then - self:setflags(nil, true); + self:set(nil, true); elseif err ~= "timeout" then self:on("disconnect", err); self:destroy() @@ -343,7 +348,7 @@ function interface:onwritable() local data = t_concat(buffer); local ok, err, partial = self.conn:send(data); if ok then - self:setflags(nil, false); + self:set(nil, false); for i = #buffer, 1, -1 do buffer[i] = nil; end @@ -358,9 +363,9 @@ function interface:onwritable() self:setwritetimeout(); end if err == "wantwrite" or err == "timeout" then - self:setflags(nil, true); + self:set(nil, true); elseif err == "wantread" then - self:setflags(true, nil); + self:set(true, nil); elseif err ~= "timeout" then self:on("disconnect", err); self:destroy(); @@ -381,7 +386,7 @@ function interface:write(data) self.writebuffer = { data }; end self:setwritetimeout(); - self:setflags(nil, true); + self:set(nil, true); return #data; end interface.send = interface.write; @@ -389,7 +394,7 @@ interface.send = interface.write; -- Close, possibly after writing is done function interface:close() if self.writebuffer and self.writebuffer[1] then - self:setflags(false, true); -- Flush final buffer contents + self:set(false, true); -- Flush final buffer contents self.write, self.send = noop, noop; -- No more writing log("debug", "Close %s after writing", self); self.ondrain = interface.close; @@ -403,7 +408,7 @@ function interface:close() end function interface:destroy() - self:setflags(false, false); + self:del(); self:setwritetimeout(false); self:setreadtimeout(false); self.onreadable = noop; @@ -425,10 +430,10 @@ function interface:starttls(tls_ctx) log("debug", "Start TLS on %s after write", self); self.ondrain = interface.starttls; self.starttls = false; - self:setflags(nil, true); -- make sure wantwrite is set + self:set(nil, true); -- make sure wantwrite is set else log("debug", "Start TLS on %s now", self); - self:setflags(false, false); + self:del(); local conn, err = luasec.wrap(self.conn, tls_ctx or self.tls_ctx); if not conn then self:on("disconnect", err); @@ -440,8 +445,7 @@ function interface:starttls(tls_ctx) self.ondrain = nil; self.onwritable = interface.tlshandskake; self.onreadable = interface.tlshandskake; - self:setflags(true, true); - self:setwritetimeout(cfg.handshake_timeout); + return self:init(); end end @@ -455,14 +459,15 @@ function interface:tlshandskake() self.onreadable = nil; self._tls = true; self:on("status", "ssl-handshake-complete"); - self:init(); + self:setwritetimeout(); + self:set(true, true); elseif err == "wantread" then log("debug", "TLS handshake on %s to wait until readable", self); - self:setflags(true, false); + self:set(true, false); self:setreadtimeout(cfg.handshake_timeout); elseif err == "wantwrite" then log("debug", "TLS handshake on %s to wait until writable", self); - self:setflags(false, true); + self:set(false, true); self:setwritetimeout(cfg.handshake_timeout); else log("debug", "TLS handshake error on %s: %s", self, err); @@ -513,15 +518,15 @@ end -- Initialization function interface:init() self:setwritetimeout(); - return self:setflags(true, true); + return self:add(true, true); end function interface:pause() - return self:setflags(false); + return self:set(false); end function interface:resume() - return self:setflags(true); + return self:set(true); end -- Pause connection for some time @@ -530,13 +535,13 @@ function interface:pausefor(t) self._pausefor:close(); end if t == false then return; end - self:setflags(false); + self:set(false); self._pausefor = addtimer(t, function () self._pausefor = nil; if self.conn and self.conn:dirty() then self:onreadable(); end - self:setflags(true); + self:set(true); end); end @@ -564,7 +569,7 @@ local function addserver(addr, port, listeners, read_size, tls_ctx) sockname = addr; sockport = port; }, interface_mt); - server:setflags(true, false); + server:add(true, false); return server; end @@ -603,7 +608,7 @@ local function watchfd(fd, onreadable, onwriteable) onreadable = onreadable; onwriteable = onwriteable; close = function (self) - self:setflags(false, false); + self:del(); end }, interface_mt); if type(fd) == "number" then @@ -612,7 +617,7 @@ local function watchfd(fd, onreadable, onwriteable) end; -- Otherwise it'll need to be something LuaSocket-compatible end - conn:setflags(onreadable, onwriteable); + conn:add(onreadable, onwriteable); return conn; end; @@ -629,8 +634,8 @@ local function link(from, to) from:resume(); end, }, {__index=to.listeners}); - from:setflags(true, nil); - to:setflags(nil, true); + from:set(true, nil); + to:set(nil, true); end -- XXX What uses this? @@ -662,7 +667,7 @@ end local function loop(once) repeat local t = runtimers(cfg.max_wait, cfg.min_wait); - local fd, r, w = epoll.wait(t); + local fd, r, w = poll:wait(t); if fd then local conn = fds[fd]; if conn then @@ -674,7 +679,7 @@ local function loop(once) end else log("debug", "Removing unknown fd %d", fd); - epoll.ctl("del", fd); + poll:del(fd); end elseif r ~= "timeout" then log("debug", "epoll_wait error: %s", tostring(r)); @@ -705,9 +710,9 @@ return { local function onevent(self) local ret = self:callback(); if ret == -1 then - self:setflags(false, false); + self:set(false, false); elseif ret then - self:setflags(mode == "r" or mode == "rw", mode == "w" or mode == "rw"); + self:set(mode == "r" or mode == "rw", mode == "w" or mode == "rw"); end end @@ -717,11 +722,11 @@ return { onreadable = onevent; onwritable = onevent; close = function (self) - self:setflags(false, false); + self:del(); fds[fd] = nil; end; }, interface_mt); - local ok, err = conn:setflags(mode == "r" or mode == "rw", mode == "w" or mode == "rw"); + local ok, err = conn:add(mode == "r" or mode == "rw", mode == "w" or mode == "rw"); if not ok then return ok, err; end return conn; end; -- cgit v1.2.3