aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/cqueues.lua74
-rw-r--r--net/server.lua106
-rw-r--r--net/server_event.lua44
-rw-r--r--net/server_select.lua99
4 files changed, 225 insertions, 98 deletions
diff --git a/net/cqueues.lua b/net/cqueues.lua
new file mode 100644
index 00000000..8c4c756f
--- /dev/null
+++ b/net/cqueues.lua
@@ -0,0 +1,74 @@
+-- Prosody IM
+-- Copyright (C) 2014 Daurnimator
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+-- This module allows you to use cqueues with a net.server mainloop
+--
+
+local server = require "net.server";
+local cqueues = require "cqueues";
+assert(cqueues.VERSION >= 20150113, "cqueues newer than 20150113 required")
+
+-- Create a single top level cqueue
+local cq;
+
+if server.cq then -- server provides cqueues object
+ cq = server.cq;
+elseif server.get_backend() == "select" and server._addtimer then -- server_select
+ cq = cqueues.new();
+ local function step()
+ assert(cq:loop(0));
+ end
+
+ -- Use wrapclient (as wrapconnection isn't exported) to get server_select to watch cq fd
+ local handler = server.wrapclient({
+ getfd = function() return cq:pollfd(); end;
+ settimeout = function() end; -- Method just needs to exist
+ close = function() end; -- Need close method for 'closeall'
+ }, nil, nil, {});
+
+ -- Only need to listen for readable; cqueues handles everything under the hood
+ -- readbuffer is called when `select` notes an fd as readable
+ handler.readbuffer = step;
+
+ -- Use server_select low lever timer facility,
+ -- this callback gets called *every* time there is a timeout in the main loop
+ server._addtimer(function(current_time)
+ -- This may end up in extra step()'s, but cqueues handles it for us.
+ step();
+ return cq:timeout();
+ end);
+elseif server.event and server.base then -- server_event
+ cq = cqueues.new();
+ -- Only need to listen for readable; cqueues handles everything under the hood
+ local EV_READ = server.event.EV_READ;
+ -- Convert a cqueues timeout to an acceptable timeout for luaevent
+ local function luaevent_safe_timeout(cq)
+ local t = cq:timeout();
+ -- if you give luaevent 0 or nil, it re-uses the previous timeout.
+ if t == 0 then
+ t = 0.000001; -- 1 microsecond is the smallest that works (goes into a `struct timeval`)
+ elseif t == nil then -- pick something big if we don't have one
+ t = 0x7FFFFFFF; -- largest 32bit int
+ end
+ return t
+ end
+ local event_handle;
+ event_handle = server.base:addevent(cq:pollfd(), EV_READ, function(e)
+ -- Need to reference event_handle or this callback will get collected
+ -- This creates a circular reference that can only be broken if event_handle is manually :close()'d
+ local _ = event_handle;
+ -- Run as many cqueues things as possible (with a timeout of 0)
+ -- If an error is thrown, it will break the libevent loop; but prosody resumes after logging a top level error
+ assert(cq:loop(0));
+ return EV_READ, luaevent_safe_timeout(cq);
+ end, luaevent_safe_timeout(cq));
+else
+ error "NYI"
+end
+
+return {
+ cq = cq;
+}
diff --git a/net/server.lua b/net/server.lua
index 2a0b89ae..a753a19c 100644
--- a/net/server.lua
+++ b/net/server.lua
@@ -6,25 +6,69 @@
-- COPYING file in the source package for more information.
--
-local use_luaevent = prosody and require "core.configmanager".get("*", "use_libevent");
+local server_type = prosody and require "core.configmanager".get("*", "network_backend") or "select";
+if prosody and require "core.configmanager".get("*", "use_libevent") then
+ server_type = "event";
+end
-if use_luaevent then
- use_luaevent = pcall(require, "luaevent.core");
- if not use_luaevent then
+if server_type == "event" then
+ if not pcall(require, "luaevent.core") then
log("error", "libevent not found, falling back to select()");
+ server_type = "select"
end
end
local server;
-
-if use_luaevent then
+local set_config;
+if server_type == "event" then
server = require "net.server_event";
- -- Overwrite signal.signal() because we need to ask libevent to
- -- handle them instead
- local ok, signal = pcall(require, "util.signal");
- if ok and signal then
- local _signal_signal = signal.signal;
+ local defaults = {};
+ for k,v in pairs(server.cfg) do
+ defaults[k] = v;
+ end
+ function set_config(settings)
+ local event_settings = {
+ ACCEPT_DELAY = settings.event_accept_retry_interval;
+ ACCEPT_QUEUE = settings.tcp_backlog;
+ CLEAR_DELAY = settings.event_clear_interval;
+ CONNECT_TIMEOUT = settings.connect_timeout;
+ DEBUG = settings.debug;
+ HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout;
+ MAX_CONNECTIONS = settings.max_connections;
+ MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips;
+ MAX_READ_LENGTH = settings.max_receive_buffer_size;
+ MAX_SEND_LENGTH = settings.max_send_buffer_size;
+ READ_TIMEOUT = settings.read_timeout;
+ WRITE_TIMEOUT = settings.send_timeout;
+ };
+
+ for k,default in pairs(defaults) do
+ server.cfg[k] = event_settings[k] or default;
+ end
+ end
+elseif server_type == "select" then
+ server = require "net.server_select";
+
+ local defaults = {};
+ for k,v in pairs(server.getsettings()) do
+ defaults[k] = v;
+ end
+ function set_config(settings)
+ local select_settings = {};
+ for k,default in pairs(defaults) do
+ select_settings[k] = settings[k] or default;
+ end
+ server.changesettings(select_settings);
+ end
+else
+ error("Unsupported server type")
+end
+
+-- If server.hook_signal exists, replace signal.signal()
+local has_signal, signal = pcall(require, "util.signal");
+if has_signal then
+ if server.hook_signal then
function signal.signal(signal_id, handler)
if type(signal_id) == "string" then
signal_id = signal[signal_id:upper()];
@@ -34,46 +78,22 @@ if use_luaevent then
end
return server.hook_signal(signal_id, handler);
end
+ else
+ server.hook_signal = signal.signal;
end
else
- use_luaevent = false;
- server = require "net.server_select";
+ if not server.hook_signal then
+ server.hook_signal = function()
+ return false, "signal hooking not supported"
+ end
+ end
end
if prosody then
local config_get = require "core.configmanager".get;
- local defaults = {};
- for k,v in pairs(server.cfg or server.getsettings()) do
- defaults[k] = v;
- end
local function load_config()
local settings = config_get("*", "network_settings") or {};
- if use_luaevent then
- local event_settings = {
- ACCEPT_DELAY = settings.event_accept_retry_interval;
- ACCEPT_QUEUE = settings.tcp_backlog;
- CLEAR_DELAY = settings.event_clear_interval;
- CONNECT_TIMEOUT = settings.connect_timeout;
- DEBUG = settings.debug;
- HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout;
- MAX_CONNECTIONS = settings.max_connections;
- MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips;
- MAX_READ_LENGTH = settings.max_receive_buffer_size;
- MAX_SEND_LENGTH = settings.max_send_buffer_size;
- READ_TIMEOUT = settings.read_timeout;
- WRITE_TIMEOUT = settings.send_timeout;
- };
-
- for k,default in pairs(defaults) do
- server.cfg[k] = event_settings[k] or default;
- end
- else
- local select_settings = {};
- for k,default in pairs(defaults) do
- select_settings[k] = settings[k] or default;
- end
- server.changesettings(select_settings);
- end
+ return set_config(settings);
end
load_config();
prosody.events.add_handler("config-reloaded", load_config);
diff --git a/net/server_event.lua b/net/server_event.lua
index e5705ab5..67758664 100644
--- a/net/server_event.lua
+++ b/net/server_event.lua
@@ -97,7 +97,7 @@ function interface_mt:_close()
return self:_destroy();
end
-function interface_mt:_start_connection(plainssl) -- should be called from addclient
+function interface_mt:_start_connection(plainssl) -- called from wrapclient
local callback = function( event )
if EV_TIMEOUT == event then -- timeout during connection
self.fatalerror = "connection timeout"
@@ -238,8 +238,8 @@ function interface_mt:_destroy() -- close this interface + events and call last
end
function interface_mt:_lock(nointerface, noreading, nowriting) -- lock or unlock this interface or events
- self.nointerface, self.noreading, self.nowriting = nointerface, noreading, nowriting
- return nointerface, noreading, nowriting
+ self.nointerface, self.noreading, self.nowriting = nointerface, noreading, nowriting
+ return nointerface, noreading, nowriting
end
--TODO: Deprecate
@@ -702,16 +702,14 @@ local function addclient( addr, serverport, listener, pattern, sslctx, typ )
debug "need luasec, but not available"
return nil, "luasec not found"
end
- if not typ then
+ if getaddrinfo and not typ then
local addrinfo, err = getaddrinfo(addr)
if not addrinfo then return nil, err end
if addrinfo[1] and addrinfo[1].family == "inet6" then
typ = "tcp6"
- else
- typ = "tcp"
end
end
- local create = socket[typ]
+ local create = socket[typ or "tcp"]
if type( create ) ~= "function" then
return nil, "invalid socket type"
end
@@ -722,10 +720,11 @@ local function addclient( addr, serverport, listener, pattern, sslctx, typ )
end
client:settimeout( 0 ) -- set nonblocking
local res, err = client:connect( addr, serverport ) -- connect
- if res or ( err == "timeout" ) then
- local ip, port = client:getsockname( )
- local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx )
- interface:_start_connection( sslctx )
+ if res or ( err == "timeout" or err == "Operation already in progress" ) then
+ if client.getsockname then
+ addr = client:getsockname( )
+ end
+ local interface = wrapclient( client, addr, serverport, listener, pattern, sslctx )
debug( "new connection id:", interface.id )
return interface, err
else
@@ -741,7 +740,7 @@ end
local function newevent( ... )
return addevent( base, ... )
-end
+ end
local function closeallservers ( arg )
for item in pairs( interfacelist ) do
@@ -753,9 +752,9 @@ end
local function setquitting(yes)
if yes then
- -- Quit now
- closeallservers();
- base:loopexit();
+ -- Quit now
+ closeallservers();
+ base:loopexit();
end
end
@@ -798,6 +797,20 @@ local function link(sender, receiver, buffersize)
sender:set_mode("*a");
end
+local function add_task(delay, callback)
+ local event_handle;
+ event_handle = base:addevent(nil, 0, function ()
+ local ret = callback(socket_gettime());
+ if ret then
+ return 0, ret;
+ elseif event_handle then
+ return -1;
+ end
+ end
+ , delay);
+ return event_handle;
+end
+
return {
cfg = cfg,
base = base,
@@ -813,6 +826,7 @@ return {
closeall = closeallservers,
get_backend = get_backend,
hook_signal = hook_signal,
+ add_task = add_task,
__NAME = SCRIPT_NAME,
__DATE = LAST_MODIFIED,
diff --git a/net/server_select.lua b/net/server_select.lua
index d9826c02..4df6e12d 100644
--- a/net/server_select.lua
+++ b/net/server_select.lua
@@ -31,17 +31,16 @@ local tostring = use "tostring"
--// lua libs //--
-local os = use "os"
local table = use "table"
local string = use "string"
local coroutine = use "coroutine"
--// lua lib methods //--
-local os_difftime = os.difftime
local math_min = math.min
local math_huge = math.huge
local table_concat = table.concat
+local table_insert = table.insert
local string_sub = string.sub
local coroutine_wrap = coroutine.wrap
local coroutine_yield = coroutine.yield
@@ -57,7 +56,6 @@ local getaddrinfo = luasocket.dns.getaddrinfo
local ssl_wrap = ( has_luasec and luasec.wrap )
local socket_bind = luasocket.bind
-local socket_sleep = luasocket.sleep
local socket_select = luasocket.select
--// functions //--
@@ -101,7 +99,6 @@ local _sendtraffic
local _readtraffic
local _selecttimeout
-local _sleeptime
local _tcpbacklog
local _starttime
@@ -114,8 +111,6 @@ local _checkinterval
local _sendtimeout
local _readtimeout
-local _timer
-
local _maxselectlen
local _maxfd
@@ -140,7 +135,6 @@ _sendtraffic = 0 -- some stats
_readtraffic = 0
_selecttimeout = 1 -- timeout of socket.select
-_sleeptime = 0 -- time to wait at the end of every loop
_tcpbacklog = 128 -- some kind of hint to the OS
_maxsendlen = 51000 * 1024 -- max len of send buffer
@@ -292,7 +286,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
local bufferqueuelen = 0 -- end of buffer array
local toclose
- local fatalerror
local needtls
local bufferlen = 0
@@ -504,7 +497,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
return dispatch( handler, buffer, err )
else -- connections was closed or fatal error
out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
- fatalerror = true
_ = handler and handler:force_close( err )
return false
end
@@ -544,7 +536,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
return true
else -- connection was closed during sending or fatal error
out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) )
- fatalerror = true
_ = handler and handler:force_close( err )
return false
end
@@ -793,7 +784,6 @@ end
getsettings = function( )
return {
select_timeout = _selecttimeout;
- select_sleep_time = _sleeptime;
tcp_backlog = _tcpbacklog;
max_send_buffer_size = _maxsendlen;
max_receive_buffer_size = _maxreadlen;
@@ -811,7 +801,6 @@ changesettings = function( new )
return nil, "invalid settings table"
end
_selecttimeout = tonumber( new.select_timeout ) or _selecttimeout
- _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime
_maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen
_maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen
_checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval
@@ -833,6 +822,49 @@ addtimer = function( listener )
return true
end
+local add_task do
+ local data = {};
+ local new_data = {};
+
+ function add_task(delay, callback)
+ local current_time = luasocket_gettime();
+ delay = delay + current_time;
+ if delay >= current_time then
+ table_insert(new_data, {delay, callback});
+ else
+ local r = callback(current_time);
+ if r and type(r) == "number" then
+ return add_task(r, callback);
+ end
+ end
+ end
+
+ addtimer(function(current_time)
+ if #new_data > 0 then
+ for _, d in pairs(new_data) do
+ table_insert(data, d);
+ end
+ new_data = {};
+ end
+
+ local next_time = math_huge;
+ for i, d in pairs(data) do
+ local t, callback = d[1], d[2];
+ if t <= current_time then
+ data[i] = nil;
+ local r = callback(current_time);
+ if type(r) == "number" then
+ add_task(r, callback);
+ next_time = math_min(next_time, r);
+ end
+ else
+ next_time = math_min(next_time, t - current_time);
+ end
+ end
+ return next_time;
+ end);
+end
+
stats = function( )
return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen
end
@@ -846,8 +878,15 @@ end
loop = function(once) -- this is the main loop of the program
if quitting then return "quitting"; end
if once then quitting = "once"; end
- local next_timer_time = math_huge;
+ _currenttime = luasocket_gettime( )
repeat
+ -- Fire timers
+ local next_timer_time = math_huge;
+ for i = 1, _timerlistlen do
+ local t = _timerlist[ i ]( _currenttime ) -- fire timers
+ if t then next_timer_time = math_min(next_timer_time, t); end
+ end
+
local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) )
for i, socket in ipairs( write ) do -- send data waiting in writequeues
local handler = _socketlist[ socket ]
@@ -875,17 +914,16 @@ loop = function(once) -- this is the main loop of the program
_currenttime = luasocket_gettime( )
-- Check for socket timeouts
- local difftime = os_difftime( _currenttime - _starttime )
- if difftime > _checkinterval then
+ if _currenttime - _starttime > _checkinterval then
_starttime = _currenttime
for handler, timestamp in pairs( _writetimes ) do
- if os_difftime( _currenttime - timestamp ) > _sendtimeout then
+ if _currenttime - timestamp > _sendtimeout then
handler.disconnect( )( handler, "send timeout" )
handler:force_close() -- forced disconnect
end
end
for handler, timestamp in pairs( _readtimes ) do
- if os_difftime( _currenttime - timestamp ) > _readtimeout then
+ if _currenttime - timestamp > _readtimeout then
if not(handler.onreadtimeout) or handler:onreadtimeout() ~= true then
handler.disconnect( )( handler, "read timeout" )
handler:close( ) -- forced disconnect?
@@ -895,21 +933,6 @@ loop = function(once) -- this is the main loop of the program
end
end
end
-
- -- Fire timers
- if _currenttime - _timer >= math_min(next_timer_time, 1) then
- next_timer_time = math_huge;
- for i = 1, _timerlistlen do
- local t = _timerlist[ i ]( _currenttime ) -- fire timers
- if t then next_timer_time = math_min(next_timer_time, t); end
- end
- _timer = _currenttime
- else
- next_timer_time = next_timer_time - (_currenttime - _timer);
- end
-
- -- wait some time (0 by default)
- socket_sleep( _sleeptime )
until quitting;
if once and quitting == "once" then quitting = nil; return; end
closeall();
@@ -956,16 +979,14 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ )
elseif sslctx and not has_luasec then
err = "luasec not found"
end
- if not typ then
+ if getaddrinfo and not typ then
local addrinfo, err = getaddrinfo(address)
if not addrinfo then return nil, err end
if addrinfo[1] and addrinfo[1].family == "inet6" then
typ = "tcp6"
- else
- typ = "tcp"
end
end
- local create = luasocket[typ]
+ local create = luasocket[typ or "tcp"]
if type( create ) ~= "function" then
err = "invalid socket type"
end
@@ -981,22 +1002,19 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ )
end
client:settimeout( 0 )
local ok, err = client:connect( address, port )
- if ok or err == "timeout" then
+ if ok or err == "timeout" or err == "Operation already in progress" then
return wrapclient( client, address, port, listeners, pattern, sslctx )
else
return nil, err
end
end
---// EXPERIMENTAL //--
-
----------------------------------// BEGIN //--
use "setmetatable" ( _socketlist, { __mode = "k" } )
use "setmetatable" ( _readtimes, { __mode = "k" } )
use "setmetatable" ( _writetimes, { __mode = "k" } )
-_timer = luasocket_gettime( )
_starttime = luasocket_gettime( )
local function setlogger(new_logger)
@@ -1011,6 +1029,7 @@ end
return {
_addtimer = addtimer,
+ add_task = add_task;
addclient = addclient,
wrapclient = wrapclient,