aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/connect.lua78
-rw-r--r--net/connlisteners.lua2
-rw-r--r--net/cqueues.lua74
-rw-r--r--net/dns.lua13
-rw-r--r--net/http.lua21
-rw-r--r--net/httpserver.lua2
-rw-r--r--net/resolvers/basic.lua60
-rw-r--r--net/resolvers/manual.lua25
-rw-r--r--net/server.lua114
-rw-r--r--net/server_epoll.lua709
-rw-r--r--net/server_event.lua28
-rw-r--r--net/server_select.lua103
-rw-r--r--net/websocket/frames.lua4
13 files changed, 1132 insertions, 101 deletions
diff --git a/net/connect.lua b/net/connect.lua
new file mode 100644
index 00000000..675116e2
--- /dev/null
+++ b/net/connect.lua
@@ -0,0 +1,78 @@
+local server = require "net.server";
+local log = require "util.logger".init("net.connect");
+local new_id = require "util.id".short;
+
+local pending_connection_methods = {};
+local pending_connection_mt = {
+ __name = "pending_connection";
+ __index = pending_connection_methods;
+ __tostring = function (p)
+ return "<pending connection "..p.id.." to "..tostring(p.target_resolver.hostname)..">";
+ end;
+};
+
+function pending_connection_methods:log(level, message, ...)
+ log(level, "[pending connection %s] "..message, self.id, ...);
+end
+
+-- pending_connections_map[conn] = pending_connection
+local pending_connections_map = {};
+
+local pending_connection_listeners = {};
+
+local function attempt_connection(p)
+ p:log("debug", "Checking for targets...");
+ if p.conn then
+ pending_connections_map[p.conn] = nil;
+ p.conn = nil;
+ end
+ p.target_resolver:next(function (conn_type, ip, port, extra)
+ p:log("debug", "Next target to try is %s:%d", ip, port);
+ local conn = assert(server.addclient(ip, port, pending_connection_listeners, p.options.pattern or "*a", p.options.sslctx, conn_type, extra));
+ p.conn = conn;
+ pending_connections_map[conn] = p;
+ end);
+end
+
+function pending_connection_listeners.onconnect(conn)
+ local p = pending_connections_map[conn];
+ if not p then
+ log("warn", "Successful connection, but unexpected! Closing.");
+ conn:close();
+ return;
+ end
+ pending_connections_map[conn] = nil;
+ p:log("debug", "Successfully connected");
+ if p.listeners.onattach then
+ p.listeners.onattach(conn, p.data);
+ end
+ conn:setlistener(p.listeners);
+ return p.listeners.onconnect(conn);
+end
+
+function pending_connection_listeners.ondisconnect(conn, reason)
+ local p = pending_connections_map[conn];
+ if not p then
+ log("warn", "Failed connection, but unexpected!");
+ return;
+ end
+ p:log("debug", "Connection attempt failed");
+ attempt_connection(p);
+end
+
+local function connect(target_resolver, listeners, options, data)
+ local p = setmetatable({
+ id = new_id();
+ target_resolver = target_resolver;
+ listeners = assert(listeners);
+ options = options or {};
+ cb = cb;
+ }, pending_connection_mt);
+
+ p:log("debug", "Starting connection process");
+ attempt_connection(p);
+end
+
+return {
+ connect = connect;
+};
diff --git a/net/connlisteners.lua b/net/connlisteners.lua
index 000bfa63..38cf8f08 100644
--- a/net/connlisteners.lua
+++ b/net/connlisteners.lua
@@ -5,7 +5,7 @@ local traceback = debug.traceback;
local _ENV = nil;
local function fail()
- log("error", "Attempt to use legacy connlisteners API. For more info see http://prosody.im/doc/developers/network");
+ log("error", "Attempt to use legacy connlisteners API. For more info see https://prosody.im/doc/developers/network");
log("error", "Legacy connlisteners API usage, %s", traceback("", 2));
end
diff --git a/net/cqueues.lua b/net/cqueues.lua
new file mode 100644
index 00000000..8c4c756f
--- /dev/null
+++ b/net/cqueues.lua
@@ -0,0 +1,74 @@
+-- Prosody IM
+-- Copyright (C) 2014 Daurnimator
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+-- This module allows you to use cqueues with a net.server mainloop
+--
+
+local server = require "net.server";
+local cqueues = require "cqueues";
+assert(cqueues.VERSION >= 20150113, "cqueues newer than 20150113 required")
+
+-- Create a single top level cqueue
+local cq;
+
+if server.cq then -- server provides cqueues object
+ cq = server.cq;
+elseif server.get_backend() == "select" and server._addtimer then -- server_select
+ cq = cqueues.new();
+ local function step()
+ assert(cq:loop(0));
+ end
+
+ -- Use wrapclient (as wrapconnection isn't exported) to get server_select to watch cq fd
+ local handler = server.wrapclient({
+ getfd = function() return cq:pollfd(); end;
+ settimeout = function() end; -- Method just needs to exist
+ close = function() end; -- Need close method for 'closeall'
+ }, nil, nil, {});
+
+ -- Only need to listen for readable; cqueues handles everything under the hood
+ -- readbuffer is called when `select` notes an fd as readable
+ handler.readbuffer = step;
+
+ -- Use server_select low lever timer facility,
+ -- this callback gets called *every* time there is a timeout in the main loop
+ server._addtimer(function(current_time)
+ -- This may end up in extra step()'s, but cqueues handles it for us.
+ step();
+ return cq:timeout();
+ end);
+elseif server.event and server.base then -- server_event
+ cq = cqueues.new();
+ -- Only need to listen for readable; cqueues handles everything under the hood
+ local EV_READ = server.event.EV_READ;
+ -- Convert a cqueues timeout to an acceptable timeout for luaevent
+ local function luaevent_safe_timeout(cq)
+ local t = cq:timeout();
+ -- if you give luaevent 0 or nil, it re-uses the previous timeout.
+ if t == 0 then
+ t = 0.000001; -- 1 microsecond is the smallest that works (goes into a `struct timeval`)
+ elseif t == nil then -- pick something big if we don't have one
+ t = 0x7FFFFFFF; -- largest 32bit int
+ end
+ return t
+ end
+ local event_handle;
+ event_handle = server.base:addevent(cq:pollfd(), EV_READ, function(e)
+ -- Need to reference event_handle or this callback will get collected
+ -- This creates a circular reference that can only be broken if event_handle is manually :close()'d
+ local _ = event_handle;
+ -- Run as many cqueues things as possible (with a timeout of 0)
+ -- If an error is thrown, it will break the libevent loop; but prosody resumes after logging a top level error
+ assert(cq:loop(0));
+ return EV_READ, luaevent_safe_timeout(cq);
+ end, luaevent_safe_timeout(cq));
+else
+ error "NYI"
+end
+
+return {
+ cq = cq;
+}
diff --git a/net/dns.lua b/net/dns.lua
index eba2b5a0..e6749025 100644
--- a/net/dns.lua
+++ b/net/dns.lua
@@ -15,6 +15,7 @@
local socket = require "socket";
local timer = require "util.timer";
local new_ip = require "util.ip".new_ip;
+local have_util_net, util_net = pcall(require, "util.net");
local _, windows = pcall(require, "util.windows");
local is_windows = (_ and windows) or os.getenv("WINDIR");
@@ -382,6 +383,12 @@ function resolver:A(rr) -- - - - - - - - - - - - - - - - - - - - - - - - A
rr.a = string.format('%i.%i.%i.%i', b1, b2, b3, b4);
end
+if have_util_net and util_net.ntop then
+ function resolver:A(rr)
+ rr.a = util_net.ntop(self:sub(4));
+ end
+end
+
function resolver:AAAA(rr)
local addr = {};
for _ = 1, rr.rdlength, 2 do
@@ -402,6 +409,12 @@ function resolver:AAAA(rr)
rr.aaaa = addr:gsub(zeros[1], "::", 1):gsub("^0::", "::"):gsub("::0$", "::");
end
+if have_util_net and util_net.ntop then
+ function resolver:AAAA(rr)
+ rr.aaaa = util_net.ntop(self:sub(16));
+ end
+end
+
function resolver:CNAME(rr) -- - - - - - - - - - - - - - - - - - - - CNAME
rr.cname = self:name();
end
diff --git a/net/http.lua b/net/http.lua
index 8364a104..1e2854a4 100644
--- a/net/http.lua
+++ b/net/http.lua
@@ -100,9 +100,10 @@ function listener.ondetach(conn)
end
local function destroy_request(request)
- if request.conn then
+ local conn = request.conn;
+ if conn then
request.conn = nil;
- request.handler:close()
+ conn:close()
end
end
@@ -221,14 +222,14 @@ local function request(self, u, ex, callback)
sslctx = ex and ex.sslctx or self.options and self.options.sslctx;
end
- local handler, conn = server.addclient(host, port_number, listener, "*a", sslctx)
- if not handler then
- self.events.fire_event("request-connection-error", { http = self, request = req, url = u, err = conn });
- callback(conn, 0, req);
- return nil, conn;
+ local conn, ret = server.addclient(host, port_number, listener, "*a", sslctx)
+ if not conn then
+ self.events.fire_event("request-connection-error", { http = self, request = req, url = u, err = ret });
+ callback(ret, 0, req);
+ return nil, ret;
end
- req.handler, req.conn = handler, conn
- req.write = function (...) return req.handler:write(...); end
+ req.conn = conn
+ req.write = function (...) return req.conn:write(...); end
req.callback = function (content, code, response, request)
do
@@ -243,7 +244,7 @@ local function request(self, u, ex, callback)
req.reader = request_reader;
req.state = "status";
- requests[req.handler] = req;
+ requests[req.conn] = req;
self.events.fire_event("request", { http = self, request = req, url = u });
return req;
diff --git a/net/httpserver.lua b/net/httpserver.lua
index 6e2e31b9..56561306 100644
--- a/net/httpserver.lua
+++ b/net/httpserver.lua
@@ -5,7 +5,7 @@ local traceback = debug.traceback;
local _ENV = nil;
function fail()
- log("error", "Attempt to use legacy HTTP API. For more info see http://prosody.im/doc/developers/legacy_http");
+ log("error", "Attempt to use legacy HTTP API. For more info see https://prosody.im/doc/developers/legacy_http");
log("error", "Legacy HTTP API usage, %s", traceback("", 2));
end
diff --git a/net/resolvers/basic.lua b/net/resolvers/basic.lua
new file mode 100644
index 00000000..792e6d32
--- /dev/null
+++ b/net/resolvers/basic.lua
@@ -0,0 +1,60 @@
+local adns = require "net.adns";
+
+local methods = {};
+local resolver_mt = { __index = methods };
+
+-- Find the next target to connect to, and
+-- pass it to cb()
+function methods:next(cb)
+ if self.targets then
+ if #self.targets == 0 then
+ cb(nil);
+ return;
+ end
+ local next_target = table.remove(self.targets, 1);
+ cb(unpack(next_target, 1, 4));
+ return;
+ end
+
+ local targets = {};
+ local n = 2;
+ local function ready()
+ n = n - 1;
+ if n > 0 then return; end
+ self.targets = targets;
+ self:next(cb);
+ end
+
+ -- Resolve DNS to target list
+ local dns_resolver = adns.resolver();
+ dns_resolver:lookup(function (answer)
+ if answer then
+ for _, record in ipairs(answer) do
+ table.insert(targets, { self.conn_type, record.a, self.port, self.extra });
+ end
+ end
+ ready();
+ end, self.hostname, "A", "IN");
+
+ dns_resolver:lookup(function (answer)
+ if answer then
+ for _, record in ipairs(answer) do
+ table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra });
+ end
+ end
+ ready();
+ end, self.hostname, "AAAA", "IN");
+end
+
+local function new(hostname, port, conn_type, extra)
+ return setmetatable({
+ hostname = hostname;
+ port = port;
+ conn_type = conn_type or "tcp";
+ extra = extra;
+ }, resolver_mt);
+end
+
+return {
+ new = new;
+};
diff --git a/net/resolvers/manual.lua b/net/resolvers/manual.lua
new file mode 100644
index 00000000..c0d4e5d5
--- /dev/null
+++ b/net/resolvers/manual.lua
@@ -0,0 +1,25 @@
+local methods = {};
+local resolver_mt = { __index = methods };
+
+-- Find the next target to connect to, and
+-- pass it to cb()
+function methods:next(cb)
+ if #self.targets == 0 then
+ cb(nil);
+ return;
+ end
+ local next_target = table.remove(self.targets, 1);
+ cb(unpack(next_target, 1, 4));
+end
+
+local function new(targets, conn_type, extra)
+ return setmetatable({
+ conn_type = conn_type;
+ extra = extra;
+ targets = targets or {};
+ }, resolver_mt);
+end
+
+return {
+ new = new;
+};
diff --git a/net/server.lua b/net/server.lua
index 41e180fa..8b6fbc0b 100644
--- a/net/server.lua
+++ b/net/server.lua
@@ -6,25 +6,75 @@
-- COPYING file in the source package for more information.
--
-local use_luaevent = prosody and require "core.configmanager".get("*", "use_libevent");
+local server_type = prosody and require "core.configmanager".get("*", "network_backend") or "select";
+if prosody and require "core.configmanager".get("*", "use_libevent") then
+ server_type = "event";
+end
-if use_luaevent then
- use_luaevent = pcall(require, "luaevent.core");
- if not use_luaevent then
+if server_type == "event" then
+ if not pcall(require, "luaevent.core") then
log("error", "libevent not found, falling back to select()");
+ server_type = "select"
end
end
local server;
-
-if use_luaevent then
+local set_config;
+if server_type == "event" then
server = require "net.server_event";
- -- Overwrite signal.signal() because we need to ask libevent to
- -- handle them instead
- local ok, signal = pcall(require, "util.signal");
- if ok and signal then
- local _signal_signal = signal.signal;
+ local defaults = {};
+ for k,v in pairs(server.cfg) do
+ defaults[k] = v;
+ end
+ function set_config(settings)
+ local event_settings = {
+ ACCEPT_DELAY = settings.accept_retry_interval;
+ ACCEPT_QUEUE = settings.tcp_backlog;
+ CLEAR_DELAY = settings.event_clear_interval;
+ CONNECT_TIMEOUT = settings.connect_timeout;
+ DEBUG = settings.debug;
+ HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout;
+ MAX_CONNECTIONS = settings.max_connections;
+ MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips;
+ MAX_READ_LENGTH = settings.max_receive_buffer_size;
+ MAX_SEND_LENGTH = settings.max_send_buffer_size;
+ READ_TIMEOUT = settings.read_timeout;
+ WRITE_TIMEOUT = settings.send_timeout;
+ };
+
+ for k,default in pairs(defaults) do
+ server.cfg[k] = event_settings[k] or default;
+ end
+ end
+elseif server_type == "select" then
+ server = require "net.server_select";
+
+ local defaults = {};
+ for k,v in pairs(server.getsettings()) do
+ defaults[k] = v;
+ end
+ function set_config(settings)
+ local select_settings = {};
+ for k,default in pairs(defaults) do
+ select_settings[k] = settings[k] or default;
+ end
+ server.changesettings(select_settings);
+ end
+else
+ server = require("net.server_"..server_type);
+ set_config = server.set_config;
+ if not server.get_backend then
+ function server.get_backend()
+ return server_type;
+ end
+ end
+end
+
+-- If server.hook_signal exists, replace signal.signal()
+local has_signal, signal = pcall(require, "util.signal");
+if has_signal then
+ if server.hook_signal then
function signal.signal(signal_id, handler)
if type(signal_id) == "string" then
signal_id = signal[signal_id:upper()];
@@ -34,46 +84,22 @@ if use_luaevent then
end
return server.hook_signal(signal_id, handler);
end
+ else
+ server.hook_signal = signal.signal;
end
else
- use_luaevent = false;
- server = require "net.server_select";
+ if not server.hook_signal then
+ server.hook_signal = function()
+ return false, "signal hooking not supported"
+ end
+ end
end
-if prosody then
+if prosody and set_config then
local config_get = require "core.configmanager".get;
- local defaults = {};
- for k,v in pairs(server.cfg or server.getsettings()) do
- defaults[k] = v;
- end
local function load_config()
local settings = config_get("*", "network_settings") or {};
- if use_luaevent then
- local event_settings = {
- ACCEPT_DELAY = settings.accept_retry_interval;
- ACCEPT_QUEUE = settings.tcp_backlog;
- CLEAR_DELAY = settings.event_clear_interval;
- CONNECT_TIMEOUT = settings.connect_timeout;
- DEBUG = settings.debug;
- HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout;
- MAX_CONNECTIONS = settings.max_connections;
- MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips;
- MAX_READ_LENGTH = settings.max_receive_buffer_size;
- MAX_SEND_LENGTH = settings.max_send_buffer_size;
- READ_TIMEOUT = settings.read_timeout;
- WRITE_TIMEOUT = settings.send_timeout;
- };
-
- for k,default in pairs(defaults) do
- server.cfg[k] = event_settings[k] or default;
- end
- else
- local select_settings = {};
- for k,default in pairs(defaults) do
- select_settings[k] = settings[k] or default;
- end
- server.changesettings(select_settings);
- end
+ return set_config(settings);
end
load_config();
prosody.events.add_handler("config-reloaded", load_config);
diff --git a/net/server_epoll.lua b/net/server_epoll.lua
new file mode 100644
index 00000000..6da05f0e
--- /dev/null
+++ b/net/server_epoll.lua
@@ -0,0 +1,709 @@
+-- Prosody IM
+-- Copyright (C) 2016 Kim Alvefur
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+-- server_epoll
+-- Server backend based on https://luarocks.org/modules/zash/lua-epoll
+
+local t_sort = table.sort;
+local t_insert = table.insert;
+local t_remove = table.remove;
+local t_concat = table.concat;
+local setmetatable = setmetatable;
+local tostring = tostring;
+local pcall = pcall;
+local next = next;
+local pairs = pairs;
+local log = require "util.logger".init("server_epoll");
+local epoll = require "epoll";
+local socket = require "socket";
+local luasec = require "ssl";
+local gettime = require "util.time".now;
+local createtable = require "util.table".create;
+local _SOCKETINVALID = socket._SOCKETINVALID or -1;
+
+assert(socket.tcp6 and socket.tcp4, "Incompatible LuaSocket version");
+
+local _ENV = nil;
+
+local default_config = { __index = {
+ read_timeout = 900;
+ write_timeout = 7;
+ tcp_backlog = 128;
+ accept_retry_interval = 10;
+ read_retry_delay = 1e-06;
+ connect_timeout = 20;
+ handshake_timeout = 60;
+ max_wait = 86400;
+ min_wait = 1e-06;
+}};
+local cfg = default_config.__index;
+
+local fds = createtable(10, 0); -- FD -> conn
+
+-- Timer and scheduling --
+
+local timers = {};
+
+local function noop() end
+local function closetimer(t)
+ t[1] = 0;
+ t[2] = noop;
+end
+
+-- Set to true when timers have changed
+local resort_timers = false;
+
+-- Add absolute timer
+local function at(time, f)
+ local timer = { time, f, close = closetimer };
+ t_insert(timers, timer);
+ resort_timers = true;
+ return timer;
+end
+
+-- Add relative timer
+local function addtimer(timeout, f)
+ return at(gettime() + timeout, f);
+end
+
+-- Run callbacks of expired timers
+-- Return time until next timeout
+local function runtimers(next_delay, min_wait)
+ -- Any timers at all?
+ if not timers[1] then
+ return next_delay;
+ end
+
+ if resort_timers then
+ -- Sort earliest timers to the end
+ t_sort(timers, function (a, b) return a[1] > b[1]; end);
+ resort_timers = false;
+ end
+
+ -- Iterate from the end and remove completed timers
+ for i = #timers, 1, -1 do
+ local timer = timers[i];
+ local t, f = timer[1], timer[2];
+ -- Get time for every iteration to increase accuracy
+ local now = gettime();
+ if t > now then
+ -- This timer should not fire yet
+ local diff = t - now;
+ if diff < next_delay then
+ next_delay = diff;
+ end
+ break;
+ end
+ local new_timeout = f(now);
+ if new_timeout then
+ -- Schedule for 'delay' from the time actually scheduled,
+ -- not from now, in order to prevent timer drift.
+ timer[1] = t + new_timeout;
+ resort_timers = true;
+ else
+ t_remove(timers, i);
+ end
+ end
+
+ if resort_timers or next_delay < min_wait then
+ -- Timers may be added from within a timer callback.
+ -- Those would not be considered for next_delay,
+ -- and we might sleep for too long, so instead
+ -- we return a shorter timeout so we can
+ -- properly sort all new timers.
+ next_delay = min_wait;
+ end
+
+ return next_delay;
+end
+
+-- Socket handler interface
+
+local interface = {};
+local interface_mt = { __index = interface };
+
+function interface_mt:__tostring()
+ if self.sockname and self.peername then
+ return ("FD %d (%s, %d, %s, %d)"):format(self:getfd(), self.peername, self.peerport, self.sockname, self.sockport);
+ elseif self.sockname or self.peername then
+ return ("FD %d (%s, %d)"):format(self:getfd(), self.sockname or self.peername, self.sockport or self.peerport);
+ end
+ return ("%s FD %d"):format(tostring(self.conn), self:getfd());
+end
+
+-- Replace the listener and tell the old one
+function interface:setlistener(listeners)
+ self:on("detach");
+ self.listeners = listeners;
+end
+
+-- Call a listener callback
+function interface:on(what, ...)
+ if not self.listeners then
+ log("error", "%s has no listeners", self);
+ return;
+ end
+ local listener = self.listeners["on"..what];
+ if not listener then
+ -- log("debug", "Missing listener 'on%s'", what); -- uncomment for development and debugging
+ return;
+ end
+ local ok, err = pcall(listener, self, ...);
+ if not ok then
+ log("error", "Error calling on%s: %s", what, err);
+ end
+ return err;
+end
+
+-- Return the file descriptor number
+function interface:getfd()
+ if self.conn then
+ return self.conn:getfd();
+ end
+ return _SOCKETINVALID;
+end
+
+function interface:server()
+ return self._server or self;
+end
+
+-- Get IP address
+function interface:ip()
+ return self.peername or self.sockname;
+end
+
+-- Get a port number, doesn't matter which
+function interface:port()
+ return self.sockport or self.peerport;
+end
+
+-- Get local port number
+function interface:clientport()
+ return self.sockport;
+end
+
+-- Get remote port
+function interface:serverport()
+ if self.sockport then
+ return self.sockport;
+ elseif self._server then
+ self._server:port();
+ end
+end
+
+-- Return underlying socket
+function interface:socket()
+ return self.conn;
+end
+
+function interface:set_mode(new_mode)
+ self._pattern = new_mode;
+end
+
+function interface:setoption(k, v)
+ -- LuaSec doesn't expose setoption :(
+ if self.conn.setoption then
+ self.conn:setoption(k, v);
+ end
+end
+
+-- Timeout for detecting dead or idle sockets
+function interface:setreadtimeout(t)
+ if t == false then
+ if self._readtimeout then
+ self._readtimeout:close();
+ self._readtimeout = nil;
+ end
+ return
+ end
+ t = t or cfg.read_timeout;
+ if self._readtimeout then
+ self._readtimeout[1] = gettime() + t;
+ resort_timers = true;
+ else
+ self._readtimeout = addtimer(t, function ()
+ if self:on("readtimeout") then
+ return cfg.read_timeout;
+ else
+ self:on("disconnect", "read timeout");
+ self:destroy();
+ end
+ end);
+ end
+end
+
+-- Timeout for detecting dead sockets
+function interface:setwritetimeout(t)
+ if t == false then
+ if self._writetimeout then
+ self._writetimeout:close();
+ self._writetimeout = nil;
+ end
+ return
+ end
+ t = t or cfg.write_timeout;
+ if self._writetimeout then
+ self._writetimeout[1] = gettime() + t;
+ resort_timers = true;
+ else
+ self._writetimeout = addtimer(t, function ()
+ self:on("disconnect", "write timeout");
+ self:destroy();
+ end);
+ end
+end
+
+-- lua-epoll flag for currently requested poll state
+function interface:flags()
+ if self._wantread then
+ if self._wantwrite then
+ return "rw";
+ end
+ return "r";
+ elseif self._wantwrite then
+ return "w";
+ end
+end
+
+-- Add or remove sockets or modify epoll flags
+function interface:setflags(r, w)
+ if r ~= nil then self._wantread = r; end
+ if w ~= nil then self._wantwrite = w; end
+ local flags = self:flags();
+ local currentflags = self._flags;
+ if flags == currentflags then
+ return true;
+ end
+ local fd = self:getfd();
+ if fd < 0 then
+ self._wantread, self._wantwrite = nil, nil;
+ return nil, "invalid fd";
+ end
+ local op = "mod";
+ if not flags then
+ op = "del";
+ elseif not currentflags then
+ op = "add";
+ end
+ local ok, err = epoll.ctl(op, fd, flags);
+-- log("debug", "epoll_ctl(%q, %d, %q) -> %s" .. (err and ", %q" or ""),
+-- op, fd, flags or "", tostring(ok), err);
+ if not ok then return ok, err end
+ if op == "add" then
+ fds[fd] = self;
+ elseif op == "del" then
+ fds[fd] = nil;
+ end
+ self._flags = flags;
+ return true;
+end
+
+-- Called when socket is readable
+function interface:onreadable()
+ local data, err, partial = self.conn:receive(self._pattern);
+ if data then
+ self:on("incoming", data);
+ else
+ if partial then
+ self:on("incoming", partial, err);
+ end
+ if err == "wantread" then
+ self:setflags(true, nil);
+ elseif err == "wantwrite" then
+ self:setflags(nil, true);
+ elseif err ~= "timeout" then
+ self:on("disconnect", err);
+ self:destroy()
+ return;
+ end
+ end
+ if not self.conn then return; end
+ if self.conn:dirty() then
+ self:setreadtimeout(false);
+ self:pausefor(cfg.read_retry_delay);
+ else
+ self:setreadtimeout();
+ end
+end
+
+-- Called when socket is writable
+function interface:onwriteable()
+ local buffer = self.writebuffer;
+ local data = t_concat(buffer);
+ local ok, err, partial = self.conn:send(data);
+ if ok then
+ for i = #buffer, 1, -1 do
+ buffer[i] = nil;
+ end
+ self:setflags(nil, false);
+ self:setwritetimeout(false);
+ self:ondrain(); -- Be aware of writes in ondrain
+ return;
+ end
+ if partial then
+ buffer[1] = data:sub(partial+1);
+ for i = #buffer, 2, -1 do
+ buffer[i] = nil;
+ end
+ self:setwritetimeout();
+ end
+ if err == "wantwrite" or err == "timeout" then
+ self:setflags(nil, true);
+ elseif err == "wantread" then
+ self:setflags(true, nil);
+ elseif err ~= "timeout" then
+ self:on("disconnect", err);
+ self:destroy();
+ end
+end
+
+-- The write buffer has been successfully emptied
+function interface:ondrain()
+ return self:on("drain");
+end
+
+-- Add data to write buffer and set flag for wanting to write
+function interface:write(data)
+ local buffer = self.writebuffer;
+ if buffer then
+ t_insert(buffer, data);
+ else
+ self.writebuffer = { data };
+ end
+ self:setwritetimeout();
+ self:setflags(nil, true);
+ return #data;
+end
+interface.send = interface.write;
+
+-- Close, possibly after writing is done
+function interface:close()
+ if self.writebuffer and self.writebuffer[1] then
+ self:setflags(false, true); -- Flush final buffer contents
+ self.write, self.send = noop, noop; -- No more writing
+ log("debug", "Close %s after writing", tostring(self));
+ self.ondrain = interface.close;
+ else
+ log("debug", "Close %s now", tostring(self));
+ self.write, self.send = noop, noop;
+ self.close = noop;
+ self:on("disconnect");
+ self:destroy();
+ end
+end
+
+function interface:destroy()
+ self:setflags(false, false);
+ self:setwritetimeout(false);
+ self:setreadtimeout(false);
+ self.onreadable = noop;
+ self.onwriteable = noop;
+ self.destroy = noop;
+ self.close = noop;
+ self.on = noop;
+ self.conn:close();
+ self.conn = nil;
+end
+
+function interface:ssl()
+ return self._tls;
+end
+
+function interface:starttls(ctx)
+ if ctx then self.tls = ctx; end
+ if self.writebuffer and self.writebuffer[1] then
+ log("debug", "Start TLS on %s after write", tostring(self));
+ self.ondrain = interface.starttls;
+ self.starttls = false;
+ self:setflags(nil, true); -- make sure wantwrite is set
+ else
+ log("debug", "Start TLS on %s now", tostring(self));
+ self:setflags(false, false);
+ local conn, err = luasec.wrap(self.conn, ctx or self.tls);
+ if not conn then
+ self:on("disconnect", err);
+ self:destroy();
+ return conn, err;
+ end
+ conn:settimeout(0);
+ self.conn = conn;
+ self.ondrain = nil;
+ self.onwriteable = interface.tlshandskake;
+ self.onreadable = interface.tlshandskake;
+ self:setflags(true, true);
+ self:setwritetimeout(cfg.handshake_timeout);
+ end
+end
+
+function interface:tlshandskake()
+ self:setwritetimeout(false);
+ self:setreadtimeout(false);
+ local ok, err = self.conn:dohandshake();
+ if ok then
+ log("debug", "TLS handshake on %s complete", tostring(self));
+ self.onwriteable = nil;
+ self.onreadable = nil;
+ self._tls = true;
+ self:on("status", "ssl-handshake-complete");
+ self:init();
+ elseif err == "wantread" then
+ log("debug", "TLS handshake on %s to wait until readable", tostring(self));
+ self:setflags(true, false);
+ self:setreadtimeout(cfg.handshake_timeout);
+ elseif err == "wantwrite" then
+ log("debug", "TLS handshake on %s to wait until writable", tostring(self));
+ self:setflags(false, true);
+ self:setwritetimeout(cfg.handshake_timeout);
+ else
+ log("debug", "TLS handshake error on %s: %s", tostring(self), err);
+ self:on("disconnect", err);
+ self:destroy();
+ end
+end
+
+local function wrapsocket(client, server, pattern, listeners, tls) -- luasocket object -> interface object
+ client:settimeout(0);
+ local conn = setmetatable({
+ conn = client;
+ _server = server;
+ created = gettime();
+ listeners = listeners;
+ _pattern = pattern or (server and server._pattern);
+ writebuffer = {};
+ tls = tls;
+ }, interface_mt);
+
+ if client.getpeername then
+ conn.peername, conn.peerport = client:getpeername();
+ end
+ if client.getsockname then
+ conn.sockname, conn.sockport = client:getsockname();
+ end
+ return conn;
+end
+
+-- A server interface has new incoming connections waiting
+-- This replaces the onreadable callback
+function interface:onacceptable()
+ local conn, err = self.conn:accept();
+ if not conn then
+ log("debug", "Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval);
+ self:pausefor(cfg.accept_retry_interval);
+ return;
+ end
+ local client = wrapsocket(conn, self, nil, self.listeners, self.tls);
+ log("debug", "New connection %s", tostring(client));
+ client:init();
+end
+
+-- Initialization
+function interface:init()
+ if self.tls and not self._tls then
+ return self:starttls();
+ else
+ self.onwriteable = interface.onfirstwritable;
+ self.onreadable = interface.onfirstreadable;
+ self:setwritetimeout();
+ return self:setflags(true, true);
+ end
+end
+
+function interface:pause()
+ return self:setflags(false);
+end
+
+function interface:resume()
+ return self:setflags(true);
+end
+
+-- Pause connection for some time
+function interface:pausefor(t)
+ if self._pausefor then
+ self._pausefor:close();
+ end
+ if t == false then return; end
+ self:setflags(false);
+ self._pausefor = addtimer(t, function ()
+ self._pausefor = nil;
+ if self.conn and self.conn:dirty() then
+ self:onreadable();
+ end
+ self:setflags(true);
+ end);
+end
+
+-- Connected!
+function interface:onconnect()
+ self:setflags(true, false);
+ if not self._connected then
+ self._connected = true;
+ self:on("connect");
+ end
+end
+
+function interface:onfirstwritable()
+ self.onreadable = nil;
+ self.onwriteable = nil;
+ self:onconnect();
+ return self:onwriteable();
+end
+
+function interface:onfirstreadable()
+ self.onreadable = nil;
+ self.onwriteable = nil;
+ self:onconnect();
+ return self:onreadable();
+end
+
+local function addserver(addr, port, listeners, pattern, tls)
+ local conn, err = socket.bind(addr, port, cfg.tcp_backlog);
+ if not conn then return conn, err; end
+ conn:settimeout(0);
+ local server = setmetatable({
+ conn = conn;
+ created = gettime();
+ listeners = listeners;
+ _pattern = pattern;
+ onreadable = interface.onacceptable;
+ tls = tls;
+ sockname = addr;
+ sockport = port;
+ }, interface_mt);
+ server:setflags(true, false);
+ return server;
+end
+
+-- COMPAT
+local function wrapclient(conn, addr, port, listeners, pattern, tls)
+ local client = wrapsocket(conn, nil, pattern, listeners, tls);
+ if not client.peername then
+ client.peername, client.peerport = addr, port;
+ end
+ client:init();
+ return client;
+end
+
+-- New outgoing TCP connection
+local function addclient(addr, port, listeners, pattern, tls)
+ local conn, err = socket.tcp();
+ if not conn then return conn, err; end
+ conn:settimeout(0);
+ conn:connect(addr, port);
+ local client = wrapsocket(conn, nil, pattern, listeners, tls)
+ client:init();
+ return client, conn;
+end
+
+-- Dump all data from one connection into another
+local function link(from, to)
+ from.listeners = setmetatable({
+ onincoming = function (_, data)
+ from:pause();
+ to:write(data);
+ end,
+ }, {__index=from.listeners});
+ to.listeners = setmetatable({
+ ondrain = function ()
+ from:resume();
+ end,
+ }, {__index=to.listeners});
+ from:setflags(true, nil);
+ to:setflags(nil, true);
+end
+
+-- XXX What uses this?
+-- net.adns
+function interface:set_send(new_send)
+ self.send = new_send;
+end
+
+-- Close all connections and servers
+local function closeall()
+ for fd, conn in pairs(fds) do -- luacheck: ignore 213/fd
+ conn:close();
+ end
+end
+
+local quitting = nil;
+
+-- Signal main loop about shutdown via above upvalue
+local function setquitting(quit)
+ if quit then
+ quitting = "quitting";
+ closeall();
+ else
+ quitting = nil;
+ end
+end
+
+-- Main loop
+local function loop(once)
+ repeat
+ local t = runtimers(cfg.max_wait, cfg.min_wait);
+ local fd, r, w = epoll.wait(t);
+ if fd then
+ local conn = fds[fd];
+ if conn then
+ if r then
+ conn:onreadable();
+ end
+ if w then
+ conn:onwriteable();
+ end
+ else
+ log("debug", "Removing unknown fd %d", fd);
+ epoll.ctl("del", fd);
+ end
+ elseif r ~= "timeout" then
+ log("debug", "epoll_wait error: %s", tostring(r));
+ end
+ until once or (quitting and next(fds) == nil);
+ return quitting;
+end
+
+return {
+ get_backend = function () return "epoll"; end;
+ addserver = addserver;
+ addclient = addclient;
+ add_task = addtimer;
+ at = at;
+ loop = loop;
+ closeall = closeall;
+ setquitting = setquitting;
+ wrapclient = wrapclient;
+ link = link;
+ set_config = function (newconfig)
+ cfg = setmetatable(newconfig, default_config);
+ end;
+
+ -- libevent emulation
+ event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 };
+ addevent = function (fd, mode, callback)
+ local function onevent(self)
+ local ret = self:callback();
+ if ret == -1 then
+ self:setflags(false, false);
+ elseif ret then
+ self:setflags(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
+ end
+ end
+
+ local conn = setmetatable({
+ getfd = function () return fd; end;
+ callback = callback;
+ onreadable = onevent;
+ onwriteable = onevent;
+ close = function (self)
+ self:setflags(false, false);
+ fds[fd] = nil;
+ end;
+ }, interface_mt);
+ local ok, err = conn:setflags(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
+ if not ok then return ok, err; end
+ return conn;
+ end;
+};
diff --git a/net/server_event.lua b/net/server_event.lua
index 3a907349..42b757d4 100644
--- a/net/server_event.lua
+++ b/net/server_event.lua
@@ -106,6 +106,12 @@ function interface_mt:_start_connection(plainssl) -- called from wrapclient
self:_close()
debug( "new connection failed. id:", self.id, "error:", self.fatalerror )
else
+ if EV_READWRITE == event then
+ if self.readcallback(event) == -1 then
+ -- Fatal error occurred
+ return -1;
+ end
+ end
if plainssl and has_luasec then -- start ssl session
self:starttls(self._sslctx, true)
else -- normal connection
@@ -116,7 +122,7 @@ function interface_mt:_start_connection(plainssl) -- called from wrapclient
self.eventconnect = nil
return -1
end
- self.eventconnect = addevent( base, self.conn, EV_WRITE, callback, cfg.CONNECT_TIMEOUT )
+ self.eventconnect = addevent( base, self.conn, EV_READWRITE, callback, cfg.CONNECT_TIMEOUT )
return true
end
function interface_mt:_start_session(call_onconnect) -- new session, for example after startssl
@@ -223,7 +229,8 @@ function interface_mt:_destroy() -- close this interface + events and call last
_ = self.eventsession and self.eventsession:close( )
_ = self.eventwritetimeout and self.eventwritetimeout:close( )
_ = self.eventreadtimeout and self.eventreadtimeout:close( )
- _ = self.ondisconnect and self:ondisconnect( self.fatalerror ~= "client to close" and self.fatalerror) -- call ondisconnect listener (wont be the case if handshake failed on connect)
+ -- call ondisconnect listener (wont be the case if handshake failed on connect)
+ _ = self.ondisconnect and self:ondisconnect( self.fatalerror ~= "client to close" and self.fatalerror)
_ = self.conn and self.conn:close( ) -- close connection
_ = self._server and self._server:counter(-1);
self.eventread, self.eventwrite = nil, nil
@@ -773,7 +780,7 @@ local function setquitting(yes)
end
local function get_backend()
- return base:method();
+ return "libevent " .. base:method();
end
-- We need to hold onto the events to stop them
@@ -811,6 +818,20 @@ local function link(sender, receiver, buffersize)
sender:set_mode("*a");
end
+local function add_task(delay, callback)
+ local event_handle;
+ event_handle = base:addevent(nil, 0, function ()
+ local ret = callback(socket_gettime());
+ if ret then
+ return 0, ret;
+ elseif event_handle then
+ return -1;
+ end
+ end
+ , delay);
+ return event_handle;
+end
+
return {
cfg = cfg,
base = base,
@@ -826,6 +847,7 @@ return {
closeall = closeallservers,
get_backend = get_backend,
hook_signal = hook_signal,
+ add_task = add_task,
__NAME = SCRIPT_NAME,
__DATE = LAST_MODIFIED,
diff --git a/net/server_select.lua b/net/server_select.lua
index 12aef9d8..31c6306f 100644
--- a/net/server_select.lua
+++ b/net/server_select.lua
@@ -40,6 +40,7 @@ local coroutine = use "coroutine"
local math_min = math.min
local math_huge = math.huge
local table_concat = table.concat
+local table_insert = table.insert
local string_sub = string.sub
local coroutine_wrap = coroutine.wrap
local coroutine_yield = coroutine.yield
@@ -55,7 +56,6 @@ local getaddrinfo = luasocket.dns.getaddrinfo
local ssl_wrap = ( has_luasec and luasec.wrap )
local socket_bind = luasocket.bind
-local socket_sleep = luasocket.sleep
local socket_select = luasocket.select
--// functions //--
@@ -100,7 +100,6 @@ local _sendtraffic
local _readtraffic
local _selecttimeout
-local _sleeptime
local _tcpbacklog
local _accepretry
@@ -114,8 +113,6 @@ local _checkinterval
local _sendtimeout
local _readtimeout
-local _timer
-
local _maxselectlen
local _maxfd
@@ -141,7 +138,6 @@ _sendtraffic = 0 -- some stats
_readtraffic = 0
_selecttimeout = 1 -- timeout of socket.select
-_sleeptime = 0 -- time to wait at the end of every loop
_tcpbacklog = 128 -- some kind of hint to the OS
_accepretry = 10 -- seconds to wait until the next attempt of a full server to accept
@@ -301,7 +297,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
local bufferqueuelen = 0 -- end of buffer array
local toclose
- local fatalerror
local needtls
local bufferlen = 0
@@ -517,7 +512,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
return dispatch( handler, buffer, err )
else -- connections was closed or fatal error
out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
- fatalerror = true
_ = handler and handler:force_close( err )
return false
end
@@ -557,7 +551,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
return true
else -- connection was closed during sending or fatal error
out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) )
- fatalerror = true
_ = handler and handler:force_close( err )
return false
end
@@ -806,7 +799,6 @@ end
getsettings = function( )
return {
select_timeout = _selecttimeout;
- select_sleep_time = _sleeptime;
tcp_backlog = _tcpbacklog;
max_send_buffer_size = _maxsendlen;
max_receive_buffer_size = _maxreadlen;
@@ -825,7 +817,6 @@ changesettings = function( new )
return nil, "invalid settings table"
end
_selecttimeout = tonumber( new.select_timeout ) or _selecttimeout
- _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime
_maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen
_maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen
_checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval
@@ -848,6 +839,49 @@ addtimer = function( listener )
return true
end
+local add_task do
+ local data = {};
+ local new_data = {};
+
+ function add_task(delay, callback)
+ local current_time = luasocket_gettime();
+ delay = delay + current_time;
+ if delay >= current_time then
+ table_insert(new_data, {delay, callback});
+ else
+ local r = callback(current_time);
+ if r and type(r) == "number" then
+ return add_task(r, callback);
+ end
+ end
+ end
+
+ addtimer(function(current_time)
+ if #new_data > 0 then
+ for _, d in pairs(new_data) do
+ table_insert(data, d);
+ end
+ new_data = {};
+ end
+
+ local next_time = math_huge;
+ for i, d in pairs(data) do
+ local t, callback = d[1], d[2];
+ if t <= current_time then
+ data[i] = nil;
+ local r = callback(current_time);
+ if type(r) == "number" then
+ add_task(r, callback);
+ next_time = math_min(next_time, r);
+ end
+ else
+ next_time = math_min(next_time, t - current_time);
+ end
+ end
+ return next_time;
+ end);
+end
+
stats = function( )
return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen
end
@@ -861,25 +895,32 @@ end
loop = function(once) -- this is the main loop of the program
if quitting then return "quitting"; end
if once then quitting = "once"; end
- local next_timer_time = math_huge;
+ _currenttime = luasocket_gettime( )
repeat
+ -- Fire timers
+ local next_timer_time = math_huge;
+ for i = 1, _timerlistlen do
+ local t = _timerlist[ i ]( _currenttime ) -- fire timers
+ if t then next_timer_time = math_min(next_timer_time, t); end
+ end
+
local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) )
- for _, socket in ipairs( write ) do -- send data waiting in writequeues
+ for _, socket in ipairs( read ) do -- receive data
local handler = _socketlist[ socket ]
if handler then
- handler.sendbuffer( )
+ handler.readbuffer( )
else
closesocket( socket )
- out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen
+ out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen
end
end
- for _, socket in ipairs( read ) do -- receive data
+ for _, socket in ipairs( write ) do -- send data waiting in writequeues
local handler = _socketlist[ socket ]
if handler then
- handler.readbuffer( )
+ handler.sendbuffer( )
else
closesocket( socket )
- out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen
+ out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen
end
end
for handler, err in pairs( _closelist ) do
@@ -910,27 +951,12 @@ loop = function(once) -- this is the main loop of the program
end
end
- -- Fire timers
- if _currenttime - _timer >= math_min(next_timer_time, 1) then
- next_timer_time = math_huge;
- for i = 1, _timerlistlen do
- local t = _timerlist[ i ]( _currenttime ) -- fire timers
- if t then next_timer_time = math_min(next_timer_time, t); end
- end
- _timer = _currenttime
- else
- next_timer_time = next_timer_time - (_currenttime - _timer);
- end
-
for server, paused_time in pairs( _fullservers ) do
if _currenttime - paused_time > _accepretry then
_fullservers[ server ] = nil;
server.resume();
end
end
-
- -- wait some time (0 by default)
- socket_sleep( _sleeptime )
until quitting;
if once and quitting == "once" then quitting = nil; return; end
closeall();
@@ -952,6 +978,7 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx
if not handler then return nil, err end
_socketlist[ socket ] = handler
if not sslctx then
+ _readlistlen = addsocket(_readlist, socket, _readlistlen)
_sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
if listeners.onconnect then
-- When socket is writeable, call onconnect
@@ -977,16 +1004,14 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ )
elseif sslctx and not has_luasec then
err = "luasec not found"
end
- if not typ then
+ if getaddrinfo and not typ then
local addrinfo, err = getaddrinfo(address)
if not addrinfo then return nil, err end
if addrinfo[1] and addrinfo[1].family == "inet6" then
typ = "tcp6"
- else
- typ = "tcp"
end
end
- local create = luasocket[typ]
+ local create = luasocket[typ or "tcp"]
if type( create ) ~= "function" then
err = "invalid socket type"
end
@@ -1002,22 +1027,19 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ )
end
client:settimeout( 0 )
local ok, err = client:connect( address, port )
- if ok or err == "timeout" then
+ if ok or err == "timeout" or err == "Operation already in progress" then
return wrapclient( client, address, port, listeners, pattern, sslctx )
else
return nil, err
end
end
---// EXPERIMENTAL //--
-
----------------------------------// BEGIN //--
use "setmetatable" ( _socketlist, { __mode = "k" } )
use "setmetatable" ( _readtimes, { __mode = "k" } )
use "setmetatable" ( _writetimes, { __mode = "k" } )
-_timer = luasocket_gettime( )
_starttime = luasocket_gettime( )
local function setlogger(new_logger)
@@ -1032,6 +1054,7 @@ end
return {
_addtimer = addtimer,
+ add_task = add_task;
addclient = addclient,
wrapclient = wrapclient,
diff --git a/net/websocket/frames.lua b/net/websocket/frames.lua
index 5fe96d45..385f64ce 100644
--- a/net/websocket/frames.lua
+++ b/net/websocket/frames.lua
@@ -21,8 +21,8 @@ local t_concat = table.concat;
local s_byte = string.byte;
local s_char= string.char;
local s_sub = string.sub;
-local s_pack = string.pack;
-local s_unpack = string.unpack;
+local s_pack = string.pack; -- luacheck: ignore 143
+local s_unpack = string.unpack; -- luacheck: ignore 143
if not s_pack and softreq"struct" then
s_pack = softreq"struct".pack;