diff options
Diffstat (limited to 'net/server_epoll.lua')
-rw-r--r-- | net/server_epoll.lua | 245 |
1 files changed, 187 insertions, 58 deletions
diff --git a/net/server_epoll.lua b/net/server_epoll.lua index 0c03ae15..d44149f3 100644 --- a/net/server_epoll.lua +++ b/net/server_epoll.lua @@ -9,12 +9,12 @@ local t_insert = table.insert; local t_concat = table.concat; local setmetatable = setmetatable; -local tostring = tostring; local pcall = pcall; local type = type; local next = next; local pairs = pairs; -local log = require "util.logger".init("server_epoll"); +local logger = require "util.logger"; +local log = logger.init("server_epoll"); local socket = require "socket"; local luasec = require "ssl"; local gettime = require "util.time".now; @@ -23,6 +23,7 @@ local createtable = require "util.table".create; local inet = require "util.net"; local inet_pton = inet.pton; local _SOCKETINVALID = socket._SOCKETINVALID or -1; +local new_id = require "util.id".medium; local poller = require "util.poll" local EEXIST = poller.EEXIST; @@ -38,7 +39,10 @@ local default_config = { __index = { read_timeout = 14 * 60; -- How long to wait for a socket to become writable after queuing data to send - send_timeout = 60; + send_timeout = 180; + + -- How long to wait for a socket to become writable after creation + connect_timeout = 20; -- Some number possibly influencing how many pending connections can be accepted tcp_backlog = 128; @@ -58,6 +62,13 @@ local default_config = { __index = { -- Maximum and minimum amount of time to sleep waiting for events (adjusted for pending timers) max_wait = 86400; min_wait = 1e-06; + + -- EXPERIMENTAL + -- Whether to kill connections in case of callback errors. + fatal_errors = false; + + -- Attempt writes instantly + opportunistic_writes = false; }}; local cfg = default_config.__index; @@ -102,7 +113,7 @@ local function runtimers(next_delay, min_wait) if peek > now then next_delay = peek - now; break; - end + end local _, timer, id = timers:pop(); local ok, ret = pcall(timer[2], now); @@ -110,10 +121,10 @@ local function runtimers(next_delay, min_wait) local next_time = now+ret; timer[1] = next_time; timers:insert(timer, next_time); - end + end peek = timers:peek(); - end + end if peek == nil then return next_delay; end @@ -138,6 +149,15 @@ function interface_mt:__tostring() return ("FD %d"):format(self:getfd()); end +interface.log = log; +function interface:debug(msg, ...) --luacheck: ignore 212/self + self.log("debug", msg, ...); +end + +function interface:error(msg, ...) --luacheck: ignore 212/self + self.log("error", msg, ...); +end + -- Replace the listener and tell the old one function interface:setlistener(listeners, data) self:on("detach"); @@ -148,21 +168,32 @@ end -- Call a listener callback function interface:on(what, ...) if not self.listeners then - log("error", "%s has no listeners", self); + self:debug("Interface is missing listener callbacks"); return; end local listener = self.listeners["on"..what]; if not listener then - -- log("debug", "Missing listener 'on%s'", what); -- uncomment for development and debugging + -- self:debug("Missing listener 'on%s'", what); -- uncomment for development and debugging return; end local ok, err = pcall(listener, self, ...); if not ok then - log("error", "Error calling on%s: %s", what, err); + if cfg.fatal_errors then + self:debug("Closing due to error calling on%s: %s", what, err); + self:destroy(); + else + self:debug("Error calling on%s: %s", what, err); + end + return nil, err; end return err; end +-- Allow this one to be overridden +function interface:onincoming(...) + return self:on("incoming", ...); +end + -- Return the file descriptor number function interface:getfd() if self.conn then @@ -230,8 +261,10 @@ function interface:setreadtimeout(t) else self._readtimeout = addtimer(t, function () if self:on("readtimeout") then + self:debug("Read timeout handled"); return cfg.read_timeout; else + self:debug("Read timeout not handled, disconnecting"); self:on("disconnect", "read timeout"); self:destroy(); end @@ -253,6 +286,7 @@ function interface:setwritetimeout(t) self._writetimeout:reschedule(gettime() + t); else self._writetimeout = addtimer(t, function () + self:debug("Write timeout"); self:on("disconnect", "write timeout"); self:destroy(); end); @@ -269,15 +303,15 @@ function interface:add(r, w) local ok, err, errno = poll:add(fd, r, w); if not ok then if errno == EEXIST then - log("debug", "%s already registered!", self); + self:debug("FD already registered in poller! (EEXIST)"); return self:set(r, w); -- So try to change its flags end - log("error", "Could not register %s: %s(%d)", self, err, errno); + self:debug("Could not register in poller: %s(%d)", err, errno); return ok, err; end self._wantread, self._wantwrite = r, w; fds[fd] = self; - log("debug", "Watching %s", self); + self:debug("Registered in poller"); return true; end @@ -290,7 +324,7 @@ function interface:set(r, w) if w == nil then w = self._wantwrite; end local ok, err, errno = poll:set(fd, r, w); if not ok then - log("error", "Could not update poller state %s: %s(%d)", self, err, errno); + self:debug("Could not update poller state: %s(%d)", err, errno); return ok, err; end self._wantread, self._wantwrite = r, w; @@ -307,12 +341,12 @@ function interface:del() end local ok, err, errno = poll:del(fd); if not ok and errno ~= ENOENT then - log("error", "Could not unregister %s: %s(%d)", self, err, errno); + self:debug("Could not unregister: %s(%d)", err, errno); return ok, err; end self._wantread, self._wantwrite = nil, nil; fds[fd] = nil; - log("debug", "Unwatched %s", self); + self:debug("Unregistered from poller"); return true; end @@ -334,7 +368,7 @@ function interface:onreadable() local data, err, partial = self.conn:receive(self.read_size or cfg.read_size); if data then self:onconnect(); - self:on("incoming", data); + self:onincoming(data); else if err == "wantread" then self:set(true, nil); @@ -345,7 +379,7 @@ function interface:onreadable() end if partial and partial ~= "" then self:onconnect(); - self:on("incoming", partial, err); + self:onincoming(partial, err); end if err ~= "timeout" then self:on("disconnect", err); @@ -354,6 +388,14 @@ function interface:onreadable() end end if not self.conn then return; end + if self._limit and (data or partial) then + local cost = self._limit * #(data or partial); + if cost > cfg.min_wait then + self:setreadtimeout(false); + self:pausefor(cost); + return; + end + end if self._wantread and self.conn:dirty() then self:setreadtimeout(false); self:pausefor(cfg.read_retry_delay); @@ -378,10 +420,12 @@ function interface:onwritable() self:ondrain(); -- Be aware of writes in ondrain return; elseif partial then + self:debug("Sent %d out of %d buffered bytes", partial, #data); buffer[1] = data:sub(partial+1); for i = #buffer, 2, -1 do buffer[i] = nil; end + self:set(nil, true); self:setwritetimeout(); end if err == "wantwrite" or err == "timeout" then @@ -407,8 +451,14 @@ function interface:write(data) else self.writebuffer = { data }; end - self:setwritetimeout(); - self:set(nil, true); + if not self._write_lock then + if cfg.opportunistic_writes then + self:onwritable(); + return #data; + end + self:setwritetimeout(); + self:set(nil, true); + end return #data; end interface.send = interface.write; @@ -418,10 +468,10 @@ function interface:close() if self.writebuffer and self.writebuffer[1] then self:set(false, true); -- Flush final buffer contents self.write, self.send = noop, noop; -- No more writing - log("debug", "Close %s after writing", self); + self:debug("Close after writing remaining buffered data"); self.ondrain = interface.close; else - log("debug", "Close %s now", self); + self:debug("Closing now"); self.write, self.send = noop, noop; self.close = noop; self:on("disconnect"); @@ -450,7 +500,7 @@ function interface:starttls(tls_ctx) if tls_ctx then self.tls_ctx = tls_ctx; end self.starttls = false; if self.writebuffer and self.writebuffer[1] then - log("debug", "Start TLS on %s after write", self); + self:debug("Start TLS after write"); self.ondrain = interface.starttls; self:set(nil, true); -- make sure wantwrite is set else @@ -460,7 +510,7 @@ function interface:starttls(tls_ctx) self.onwritable = interface.tlshandskake; self.onreadable = interface.tlshandskake; self:set(true, true); - log("debug", "Prepare to start TLS on %s", self); + self:debug("Prepared to start TLS"); end end @@ -469,12 +519,13 @@ function interface:tlshandskake() self:setreadtimeout(false); if not self._tls then self._tls = true; - log("debug", "Start TLS on %s now", self); + self:debug("Starting TLS now"); self:del(); + self:updatenames(); -- Can't getpeer/sockname after wrap() local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx); if not ok then conn, err = ok, conn; - log("error", "Failed to initialize TLS: %s", err); + self:debug("Failed to initialize TLS: %s", err); end if not conn then self:on("disconnect", err); @@ -483,6 +534,13 @@ function interface:tlshandskake() end conn:settimeout(0); self.conn = conn; + if conn.sni then + if self.servername then + conn:sni(self.servername); + elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then + conn:sni(self._server.hosts, true); + end + end self:on("starttls"); self.ondrain = nil; self.onwritable = interface.tlshandskake; @@ -491,29 +549,35 @@ function interface:tlshandskake() end local ok, err = self.conn:dohandshake(); if ok then - log("debug", "TLS handshake on %s complete", self); + local info = self.conn.info and self.conn:info(); + if type(info) == "table" then + self:debug("TLS handshake complete (%s with %s)", info.protocol, info.cipher); + else + self:debug("TLS handshake complete"); + end self.onwritable = nil; self.onreadable = nil; self:on("status", "ssl-handshake-complete"); self:setwritetimeout(); self:set(true, true); elseif err == "wantread" then - log("debug", "TLS handshake on %s to wait until readable", self); + self:debug("TLS handshake to wait until readable"); self:set(true, false); self:setreadtimeout(cfg.ssl_handshake_timeout); elseif err == "wantwrite" then - log("debug", "TLS handshake on %s to wait until writable", self); + self:debug("TLS handshake to wait until writable"); self:set(false, true); self:setwritetimeout(cfg.ssl_handshake_timeout); else - log("debug", "TLS handshake error on %s: %s", self, err); + self:debug("TLS handshake error: %s", err); self:on("disconnect", err); self:destroy(); end end -local function wrapsocket(client, server, read_size, listeners, tls_ctx) -- luasocket object -> interface object +local function wrapsocket(client, server, read_size, listeners, tls_ctx, extra) -- luasocket object -> interface object client:settimeout(0); + local conn_id = ("conn%s"):format(new_id()); local conn = setmetatable({ conn = client; _server = server; @@ -523,8 +587,17 @@ local function wrapsocket(client, server, read_size, listeners, tls_ctx) -- luas writebuffer = {}; tls_ctx = tls_ctx or (server and server.tls_ctx); tls_direct = server and server.tls_direct; + id = conn_id; + log = logger.init(conn_id); + extra = extra; }, interface_mt); + if extra then + if extra.servername then + conn.servername = extra.servername; + end + end + conn:updatenames(); return conn; end @@ -532,11 +605,11 @@ end function interface:updatenames() local conn = self.conn; local ok, peername, peerport = pcall(conn.getpeername, conn); - if ok then + if ok and peername then self.peername, self.peerport = peername, peerport; end local ok, sockname, sockport = pcall(conn.getsockname, conn); - if ok then + if ok and sockname then self.sockname, self.sockport = sockname, sockport; end end @@ -546,34 +619,39 @@ end function interface:onacceptable() local conn, err = self.conn:accept(); if not conn then - log("debug", "Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval); + self:debug("Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval); self:pausefor(cfg.accept_retry_interval); return; end local client = wrapsocket(conn, self, nil, self.listeners); - log("debug", "New connection %s", tostring(client)); + client:debug("New connection %s on server %s", client, self); client:init(); if self.tls_direct then client:starttls(self.tls_ctx); + else + client:onconnect(); end end -- Initialization function interface:init() - self:setwritetimeout(); + self:setwritetimeout(cfg.connect_timeout); return self:add(true, true); end function interface:pause() + self:debug("Pause reading"); return self:set(false); end function interface:resume() + self:debug("Resume reading"); return self:set(true); end -- Pause connection for some time function interface:pausefor(t) + self:debug("Pause for %fs", t); if self._pausefor then self._pausefor:close(); end @@ -588,16 +666,45 @@ function interface:pausefor(t) end); end +function interface:setlimit(Bps) + if Bps > 0 then + self._limit = 1/Bps; + else + self._limit = nil; + end +end + +function interface:pause_writes() + if self._write_lock then + return + end + self:debug("Pause writes"); + self._write_lock = true; + self:setwritetimeout(false); + self:set(nil, false); +end + +function interface:resume_writes() + if not self._write_lock then + return + end + self:debug("Resume writes"); + self._write_lock = nil; + if self.writebuffer[1] then + self:setwritetimeout(); + self:set(nil, true); + end +end + -- Connected! function interface:onconnect() - if self.conn and not self.peername and self.conn.getpeername then - self.peername, self.peerport = self.conn:getpeername(); - end + self:updatenames(); + self:debug("Connected (%s)", self); self.onconnect = noop; self:on("connect"); end -local function addserver(addr, port, listeners, read_size, tls_ctx) +local function listen(addr, port, listeners, config) local conn, err = socket.bind(addr, port, cfg.tcp_backlog); if not conn then return conn, err; end conn:settimeout(0); @@ -605,20 +712,32 @@ local function addserver(addr, port, listeners, read_size, tls_ctx) conn = conn; created = gettime(); listeners = listeners; - read_size = read_size; + read_size = config and config.read_size; onreadable = interface.onacceptable; - tls_ctx = tls_ctx; - tls_direct = tls_ctx and true or false; + tls_ctx = config and config.tls_ctx; + tls_direct = config and config.tls_direct; + hosts = config and config.sni_hosts; sockname = addr; sockport = port; + log = logger.init(("serv%s"):format(new_id())); }, interface_mt); + server:debug("Server %s created", server); server:add(true, false); return server; end -- COMPAT -local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx) - local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx); +local function addserver(addr, port, listeners, read_size, tls_ctx) + return listen(addr, port, listeners, { + read_size = read_size; + tls_ctx = tls_ctx; + tls_direct = tls_ctx and true or false; + }); +end + +-- COMPAT +local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx, extra) + local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra); if not client.peername then client.peername, client.peerport = addr, port; end @@ -631,7 +750,7 @@ local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx) end -- New outgoing TCP connection -local function addclient(addr, port, listeners, read_size, tls_ctx, typ) +local function addclient(addr, port, listeners, read_size, tls_ctx, typ, extra) local create; if not typ then local n = inet_pton(addr); @@ -649,13 +768,19 @@ local function addclient(addr, port, listeners, read_size, tls_ctx, typ) return nil, "invalid socket type"; end local conn, err = create(); + if not conn then return conn, err; end local ok, err = conn:settimeout(0); if not ok then return ok, err; end local ok, err = conn:setpeername(addr, port); if not ok and err ~= "timeout" then return ok, err; end - local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx) + local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra) local ok, err = client:init(); + if not client.peername then + -- otherwise not set until connected + client.peername, client.peerport = addr, port; + end if not ok then return ok, err; end + client:debug("Client %s created", client); if tls_ctx then client:starttls(tls_ctx); end @@ -677,23 +802,23 @@ local function watchfd(fd, onreadable, onwritable) end; -- Otherwise it'll need to be something LuaSocket-compatible end + conn.id = new_id(); + conn.log = logger.init(("fdwatch%s"):format(conn.id)); conn:add(onreadable, onwritable); return conn; end; -- Dump all data from one connection into another -local function link(from, to) - from.listeners = setmetatable({ - onincoming = function (_, data) - from:pause(); - to:write(data); - end, - }, {__index=from.listeners}); - to.listeners = setmetatable({ - ondrain = function () - from:resume(); - end, - }, {__index=to.listeners}); +local function link(from, to, read_size) + from:debug("Linking to %s", to.id); + function from:onincoming(data) + self:pause(); + to:write(data); + end + function to:ondrain() -- luacheck: ignore 212/self + from:resume(); + end + from:set_mode(read_size); from:set(true, nil); to:set(nil, true); end @@ -752,6 +877,7 @@ return { addserver = addserver; addclient = addclient; add_task = addtimer; + listen = listen; at = at; loop = loop; closeall = closeall; @@ -766,6 +892,7 @@ return { -- libevent emulation event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 }; addevent = function (fd, mode, callback) + log("warn", "Using deprecated libevent emulation, please update code to use watchfd API instead"); local function onevent(self) local ret = self:callback(); if ret == -1 then @@ -785,6 +912,8 @@ return { fds[fd] = nil; end; }, interface_mt); + conn.id = conn:getfd(); + conn.log = logger.init(("fdwatch%d"):format(conn.id)); local ok, err = conn:add(mode == "r" or mode == "rw", mode == "w" or mode == "rw"); if not ok then return ok, err; end return conn; |