diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/cqueues.lua | 74 | ||||
-rw-r--r-- | net/server.lua | 106 | ||||
-rw-r--r-- | net/server_event.lua | 41 | ||||
-rw-r--r-- | net/server_select.lua | 99 |
4 files changed, 225 insertions, 95 deletions
diff --git a/net/cqueues.lua b/net/cqueues.lua new file mode 100644 index 00000000..8c4c756f --- /dev/null +++ b/net/cqueues.lua @@ -0,0 +1,74 @@ +-- Prosody IM +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- +-- This module allows you to use cqueues with a net.server mainloop +-- + +local server = require "net.server"; +local cqueues = require "cqueues"; +assert(cqueues.VERSION >= 20150113, "cqueues newer than 20150113 required") + +-- Create a single top level cqueue +local cq; + +if server.cq then -- server provides cqueues object + cq = server.cq; +elseif server.get_backend() == "select" and server._addtimer then -- server_select + cq = cqueues.new(); + local function step() + assert(cq:loop(0)); + end + + -- Use wrapclient (as wrapconnection isn't exported) to get server_select to watch cq fd + local handler = server.wrapclient({ + getfd = function() return cq:pollfd(); end; + settimeout = function() end; -- Method just needs to exist + close = function() end; -- Need close method for 'closeall' + }, nil, nil, {}); + + -- Only need to listen for readable; cqueues handles everything under the hood + -- readbuffer is called when `select` notes an fd as readable + handler.readbuffer = step; + + -- Use server_select low lever timer facility, + -- this callback gets called *every* time there is a timeout in the main loop + server._addtimer(function(current_time) + -- This may end up in extra step()'s, but cqueues handles it for us. + step(); + return cq:timeout(); + end); +elseif server.event and server.base then -- server_event + cq = cqueues.new(); + -- Only need to listen for readable; cqueues handles everything under the hood + local EV_READ = server.event.EV_READ; + -- Convert a cqueues timeout to an acceptable timeout for luaevent + local function luaevent_safe_timeout(cq) + local t = cq:timeout(); + -- if you give luaevent 0 or nil, it re-uses the previous timeout. + if t == 0 then + t = 0.000001; -- 1 microsecond is the smallest that works (goes into a `struct timeval`) + elseif t == nil then -- pick something big if we don't have one + t = 0x7FFFFFFF; -- largest 32bit int + end + return t + end + local event_handle; + event_handle = server.base:addevent(cq:pollfd(), EV_READ, function(e) + -- Need to reference event_handle or this callback will get collected + -- This creates a circular reference that can only be broken if event_handle is manually :close()'d + local _ = event_handle; + -- Run as many cqueues things as possible (with a timeout of 0) + -- If an error is thrown, it will break the libevent loop; but prosody resumes after logging a top level error + assert(cq:loop(0)); + return EV_READ, luaevent_safe_timeout(cq); + end, luaevent_safe_timeout(cq)); +else + error "NYI" +end + +return { + cq = cq; +} diff --git a/net/server.lua b/net/server.lua index 2a0b89ae..a753a19c 100644 --- a/net/server.lua +++ b/net/server.lua @@ -6,25 +6,69 @@ -- COPYING file in the source package for more information. -- -local use_luaevent = prosody and require "core.configmanager".get("*", "use_libevent"); +local server_type = prosody and require "core.configmanager".get("*", "network_backend") or "select"; +if prosody and require "core.configmanager".get("*", "use_libevent") then + server_type = "event"; +end -if use_luaevent then - use_luaevent = pcall(require, "luaevent.core"); - if not use_luaevent then +if server_type == "event" then + if not pcall(require, "luaevent.core") then log("error", "libevent not found, falling back to select()"); + server_type = "select" end end local server; - -if use_luaevent then +local set_config; +if server_type == "event" then server = require "net.server_event"; - -- Overwrite signal.signal() because we need to ask libevent to - -- handle them instead - local ok, signal = pcall(require, "util.signal"); - if ok and signal then - local _signal_signal = signal.signal; + local defaults = {}; + for k,v in pairs(server.cfg) do + defaults[k] = v; + end + function set_config(settings) + local event_settings = { + ACCEPT_DELAY = settings.event_accept_retry_interval; + ACCEPT_QUEUE = settings.tcp_backlog; + CLEAR_DELAY = settings.event_clear_interval; + CONNECT_TIMEOUT = settings.connect_timeout; + DEBUG = settings.debug; + HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout; + MAX_CONNECTIONS = settings.max_connections; + MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips; + MAX_READ_LENGTH = settings.max_receive_buffer_size; + MAX_SEND_LENGTH = settings.max_send_buffer_size; + READ_TIMEOUT = settings.read_timeout; + WRITE_TIMEOUT = settings.send_timeout; + }; + + for k,default in pairs(defaults) do + server.cfg[k] = event_settings[k] or default; + end + end +elseif server_type == "select" then + server = require "net.server_select"; + + local defaults = {}; + for k,v in pairs(server.getsettings()) do + defaults[k] = v; + end + function set_config(settings) + local select_settings = {}; + for k,default in pairs(defaults) do + select_settings[k] = settings[k] or default; + end + server.changesettings(select_settings); + end +else + error("Unsupported server type") +end + +-- If server.hook_signal exists, replace signal.signal() +local has_signal, signal = pcall(require, "util.signal"); +if has_signal then + if server.hook_signal then function signal.signal(signal_id, handler) if type(signal_id) == "string" then signal_id = signal[signal_id:upper()]; @@ -34,46 +78,22 @@ if use_luaevent then end return server.hook_signal(signal_id, handler); end + else + server.hook_signal = signal.signal; end else - use_luaevent = false; - server = require "net.server_select"; + if not server.hook_signal then + server.hook_signal = function() + return false, "signal hooking not supported" + end + end end if prosody then local config_get = require "core.configmanager".get; - local defaults = {}; - for k,v in pairs(server.cfg or server.getsettings()) do - defaults[k] = v; - end local function load_config() local settings = config_get("*", "network_settings") or {}; - if use_luaevent then - local event_settings = { - ACCEPT_DELAY = settings.event_accept_retry_interval; - ACCEPT_QUEUE = settings.tcp_backlog; - CLEAR_DELAY = settings.event_clear_interval; - CONNECT_TIMEOUT = settings.connect_timeout; - DEBUG = settings.debug; - HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout; - MAX_CONNECTIONS = settings.max_connections; - MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips; - MAX_READ_LENGTH = settings.max_receive_buffer_size; - MAX_SEND_LENGTH = settings.max_send_buffer_size; - READ_TIMEOUT = settings.read_timeout; - WRITE_TIMEOUT = settings.send_timeout; - }; - - for k,default in pairs(defaults) do - server.cfg[k] = event_settings[k] or default; - end - else - local select_settings = {}; - for k,default in pairs(defaults) do - select_settings[k] = settings[k] or default; - end - server.changesettings(select_settings); - end + return set_config(settings); end load_config(); prosody.events.add_handler("config-reloaded", load_config); diff --git a/net/server_event.lua b/net/server_event.lua index 6c8a7632..a4cf1146 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -129,7 +129,7 @@ do return self:_destroy(); end - function interface_mt:_start_connection(plainssl) -- should be called from addclient + function interface_mt:_start_connection(plainssl) -- called from wrapclient local callback = function( event ) if EV_TIMEOUT == event then -- timeout during connection self.fatalerror = "connection timeout" @@ -748,30 +748,29 @@ do debug "need luasec, but not available" return nil, "luasec not found" end - if not typ then + if getaddrinfo and not typ then local addrinfo, err = getaddrinfo(addr) if not addrinfo then return nil, err end if addrinfo[1] and addrinfo[1].family == "inet6" then typ = "tcp6" - else - typ = "tcp" end end - local create = socket[typ] + local create = socket[typ or "tcp"] if type( create ) ~= "function" then return nil, "invalid socket type" - end + end local client, err = create() -- creating new socket if not client then debug( "cannot create socket:", err ) - return nil, err - end + return nil, err + end client:settimeout( 0 ) -- set nonblocking local res, err = client:connect( addr, serverport ) -- connect - if res or ( err == "timeout" ) then - local ip, port = client:getsockname( ) - local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx ) - interface:_start_connection( startssl ) + if res or ( err == "timeout" or err == "Operation already in progress" ) then + if client.getsockname then + addr = client:getsockname( ) + end + local interface = wrapclient( client, addr, serverport, listener, pattern, sslctx ) debug( "new connection id:", interface.id ) return interface, err else @@ -849,6 +848,23 @@ local function link(sender, receiver, buffersize) sender:set_mode("*a"); end +local add_task do + local EVENT_LEAVE = (event.core and event.core.LEAVE) or -1; + local socket_gettime = socket.gettime + function add_task(delay, callback) + local event_handle; + event_handle = base:addevent(nil, 0, function () + local ret = callback(socket_gettime()); + if ret then + return 0, ret; + elseif event_handle then + return EVENT_LEAVE; + end + end + , delay); + end +end + return { cfg = cfg, @@ -865,6 +881,7 @@ return { closeall = closeallservers, get_backend = get_backend, hook_signal = hook_signal, + add_task = add_task, __NAME = SCRIPT_NAME, __DATE = LAST_MODIFIED, diff --git a/net/server_select.lua b/net/server_select.lua index 9c5225c6..35dcb5a7 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -31,17 +31,16 @@ local tostring = use "tostring" --// lua libs //-- -local os = use "os" local table = use "table" local string = use "string" local coroutine = use "coroutine" --// lua lib methods //-- -local os_difftime = os.difftime local math_min = math.min local math_huge = math.huge local table_concat = table.concat +local table_insert = table.insert local string_sub = string.sub local coroutine_wrap = coroutine.wrap local coroutine_yield = coroutine.yield @@ -57,7 +56,6 @@ local getaddrinfo = luasocket.dns.getaddrinfo local ssl_wrap = ( has_luasec and luasec.wrap ) local socket_bind = luasocket.bind -local socket_sleep = luasocket.sleep local socket_select = luasocket.select --// functions //-- @@ -101,7 +99,6 @@ local _sendtraffic local _readtraffic local _selecttimeout -local _sleeptime local _tcpbacklog local _starttime @@ -114,8 +111,6 @@ local _checkinterval local _sendtimeout local _readtimeout -local _timer - local _maxselectlen local _maxfd @@ -140,7 +135,6 @@ _sendtraffic = 0 -- some stats _readtraffic = 0 _selecttimeout = 1 -- timeout of socket.select -_sleeptime = 0 -- time to wait at the end of every loop _tcpbacklog = 128 -- some kind of hint to the OS _maxsendlen = 51000 * 1024 -- max len of send buffer @@ -292,7 +286,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport local bufferqueuelen = 0 -- end of buffer array local toclose - local fatalerror local needtls local bufferlen = 0 @@ -504,7 +497,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport return dispatch( handler, buffer, err ) else -- connections was closed or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) ) - fatalerror = true _ = handler and handler:force_close( err ) return false end @@ -544,7 +536,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport return true else -- connection was closed during sending or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) ) - fatalerror = true _ = handler and handler:force_close( err ) return false end @@ -792,7 +783,6 @@ end getsettings = function( ) return { select_timeout = _selecttimeout; - select_sleep_time = _sleeptime; tcp_backlog = _tcpbacklog; max_send_buffer_size = _maxsendlen; max_receive_buffer_size = _maxreadlen; @@ -810,7 +800,6 @@ changesettings = function( new ) return nil, "invalid settings table" end _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout - _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval @@ -832,6 +821,49 @@ addtimer = function( listener ) return true end +local add_task do + local data = {}; + local new_data = {}; + + function add_task(delay, callback) + local current_time = luasocket_gettime(); + delay = delay + current_time; + if delay >= current_time then + table_insert(new_data, {delay, callback}); + else + local r = callback(current_time); + if r and type(r) == "number" then + return add_task(r, callback); + end + end + end + + addtimer(function(current_time) + if #new_data > 0 then + for _, d in pairs(new_data) do + table_insert(data, d); + end + new_data = {}; + end + + local next_time = math_huge; + for i, d in pairs(data) do + local t, callback = d[1], d[2]; + if t <= current_time then + data[i] = nil; + local r = callback(current_time); + if type(r) == "number" then + add_task(r, callback); + next_time = math_min(next_time, r); + end + else + next_time = math_min(next_time, t - current_time); + end + end + return next_time; + end); +end + stats = function( ) return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen end @@ -845,8 +877,15 @@ end loop = function(once) -- this is the main loop of the program if quitting then return "quitting"; end if once then quitting = "once"; end - local next_timer_time = math_huge; + _currenttime = luasocket_gettime( ) repeat + -- Fire timers + local next_timer_time = math_huge; + for i = 1, _timerlistlen do + local t = _timerlist[ i ]( _currenttime ) -- fire timers + if t then next_timer_time = math_min(next_timer_time, t); end + end + local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) ) for i, socket in ipairs( write ) do -- send data waiting in writequeues local handler = _socketlist[ socket ] @@ -874,17 +913,16 @@ loop = function(once) -- this is the main loop of the program _currenttime = luasocket_gettime( ) -- Check for socket timeouts - local difftime = os_difftime( _currenttime - _starttime ) - if difftime > _checkinterval then + if _currenttime - _starttime > _checkinterval then _starttime = _currenttime for handler, timestamp in pairs( _writetimes ) do - if os_difftime( _currenttime - timestamp ) > _sendtimeout then + if _currenttime - timestamp > _sendtimeout then handler.disconnect( )( handler, "send timeout" ) handler:force_close() -- forced disconnect end end for handler, timestamp in pairs( _readtimes ) do - if os_difftime( _currenttime - timestamp ) > _readtimeout then + if _currenttime - timestamp > _readtimeout then if not(handler.onreadtimeout) or handler:onreadtimeout() ~= true then handler.disconnect( )( handler, "read timeout" ) handler:close( ) -- forced disconnect? @@ -894,21 +932,6 @@ loop = function(once) -- this is the main loop of the program end end end - - -- Fire timers - if _currenttime - _timer >= math_min(next_timer_time, 1) then - next_timer_time = math_huge; - for i = 1, _timerlistlen do - local t = _timerlist[ i ]( _currenttime ) -- fire timers - if t then next_timer_time = math_min(next_timer_time, t); end - end - _timer = _currenttime - else - next_timer_time = next_timer_time - (_currenttime - _timer); - end - - -- wait some time (0 by default) - socket_sleep( _sleeptime ) until quitting; if once and quitting == "once" then quitting = nil; return; end return "quitting" @@ -954,16 +977,14 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ ) elseif sslctx and not has_luasec then err = "luasec not found" end - if not typ then + if getaddrinfo and not typ then local addrinfo, err = getaddrinfo(address) if not addrinfo then return nil, err end if addrinfo[1] and addrinfo[1].family == "inet6" then typ = "tcp6" - else - typ = "tcp" end end - local create = luasocket[typ] + local create = luasocket[typ or "tcp"] if type( create ) ~= "function" then err = "invalid socket type" end @@ -979,22 +1000,19 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ ) end client:settimeout( 0 ) local ok, err = client:connect( address, port ) - if ok or err == "timeout" then + if ok or err == "timeout" or err == "Operation already in progress" then return wrapclient( client, address, port, listeners, pattern, sslctx ) else return nil, err end end ---// EXPERIMENTAL //-- - ----------------------------------// BEGIN //-- use "setmetatable" ( _socketlist, { __mode = "k" } ) use "setmetatable" ( _readtimes, { __mode = "k" } ) use "setmetatable" ( _writetimes, { __mode = "k" } ) -_timer = luasocket_gettime( ) _starttime = luasocket_gettime( ) local function setlogger(new_logger) @@ -1009,6 +1027,7 @@ end return { _addtimer = addtimer, + add_task = add_task; addclient = addclient, wrapclient = wrapclient, |