diff options
author | Matthew Wild <mwild1@gmail.com> | 2015-03-27 22:24:57 +0000 |
---|---|---|
committer | Matthew Wild <mwild1@gmail.com> | 2015-03-27 22:24:57 +0000 |
commit | 315e3b3b937c22224c4530a194420f1b2fd77316 (patch) | |
tree | f7d0930c79494da535857eaf4f2a976bb052fa35 /net/server_select.lua | |
parent | 2f531bc782386106af9c8f297594a13c81ae660c (diff) | |
parent | 847f4204accf62ae15f623698c4842b08440a2cd (diff) | |
download | prosody-315e3b3b937c22224c4530a194420f1b2fd77316.tar.gz prosody-315e3b3b937c22224c4530a194420f1b2fd77316.zip |
Merge 0.10->trunk
Diffstat (limited to 'net/server_select.lua')
-rw-r--r-- | net/server_select.lua | 146 |
1 files changed, 98 insertions, 48 deletions
diff --git a/net/server_select.lua b/net/server_select.lua index 486e953b..35dcb5a7 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -31,32 +31,31 @@ local tostring = use "tostring" --// lua libs //-- -local os = use "os" local table = use "table" local string = use "string" local coroutine = use "coroutine" --// lua lib methods //-- -local os_difftime = os.difftime local math_min = math.min local math_huge = math.huge local table_concat = table.concat +local table_insert = table.insert local string_sub = string.sub local coroutine_wrap = coroutine.wrap local coroutine_yield = coroutine.yield --// extern libs //-- -local luasec = use "ssl" +local has_luasec, luasec = pcall ( require , "ssl" ) local luasocket = use "socket" or require "socket" local luasocket_gettime = luasocket.gettime +local getaddrinfo = luasocket.dns.getaddrinfo --// extern lib methods //-- -local ssl_wrap = ( luasec and luasec.wrap ) +local ssl_wrap = ( has_luasec and luasec.wrap ) local socket_bind = luasocket.bind -local socket_sleep = luasocket.sleep local socket_select = luasocket.select --// functions //-- @@ -100,7 +99,6 @@ local _sendtraffic local _readtraffic local _selecttimeout -local _sleeptime local _tcpbacklog local _starttime @@ -113,8 +111,6 @@ local _checkinterval local _sendtimeout local _readtimeout -local _timer - local _maxselectlen local _maxfd @@ -139,7 +135,6 @@ _sendtraffic = 0 -- some stats _readtraffic = 0 _selecttimeout = 1 -- timeout of socket.select -_sleeptime = 0 -- time to wait at the end of every loop _tcpbacklog = 128 -- some kind of hint to the OS _maxsendlen = 51000 * 1024 -- max len of send buffer @@ -291,7 +286,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport local bufferqueuelen = 0 -- end of buffer array local toclose - local fatalerror local needtls local bufferlen = 0 @@ -503,7 +497,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport return dispatch( handler, buffer, err ) else -- connections was closed or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) ) - fatalerror = true _ = handler and handler:force_close( err ) return false end @@ -543,7 +536,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport return true else -- connection was closed during sending or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) ) - fatalerror = true _ = handler and handler:force_close( err ) return false end @@ -594,7 +586,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end ) end - if luasec then + if has_luasec then handler.starttls = function( self, _sslctx) if _sslctx then handler:set_sslctx(_sslctx); @@ -647,7 +639,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _socketlist[ socket ] = handler _readlistlen = addsocket(_readlist, socket, _readlistlen) - if sslctx and luasec then + if sslctx and has_luasec then out_put "server.lua: auto-starting ssl negotiation..." handler.autostart_ssl = true; local ok, err = handler:starttls(sslctx); @@ -723,22 +715,23 @@ end ----------------------------------// PUBLIC //-- addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server + addr = addr or "*" local err if type( listeners ) ~= "table" then err = "invalid listener table" - end - if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then + elseif type ( addr ) ~= "string" then + err = "invalid address" + elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then err = "invalid port" elseif _server[ addr..":"..port ] then err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist" - elseif sslctx and not luasec then + elseif sslctx and not has_luasec then err = "luasec not found" end if err then out_error( "server.lua, [", addr, "]:", port, ": ", err ) return nil, err end - addr = addr or "*" local server, err = socket_bind( addr, port, _tcpbacklog ) if err then out_error( "server.lua, [", addr, "]:", port, ": ", err ) @@ -790,7 +783,6 @@ end getsettings = function( ) return { select_timeout = _selecttimeout; - select_sleep_time = _sleeptime; tcp_backlog = _tcpbacklog; max_send_buffer_size = _maxsendlen; max_receive_buffer_size = _maxreadlen; @@ -808,7 +800,6 @@ changesettings = function( new ) return nil, "invalid settings table" end _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout - _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval @@ -830,6 +821,49 @@ addtimer = function( listener ) return true end +local add_task do + local data = {}; + local new_data = {}; + + function add_task(delay, callback) + local current_time = luasocket_gettime(); + delay = delay + current_time; + if delay >= current_time then + table_insert(new_data, {delay, callback}); + else + local r = callback(current_time); + if r and type(r) == "number" then + return add_task(r, callback); + end + end + end + + addtimer(function(current_time) + if #new_data > 0 then + for _, d in pairs(new_data) do + table_insert(data, d); + end + new_data = {}; + end + + local next_time = math_huge; + for i, d in pairs(data) do + local t, callback = d[1], d[2]; + if t <= current_time then + data[i] = nil; + local r = callback(current_time); + if type(r) == "number" then + add_task(r, callback); + next_time = math_min(next_time, r); + end + else + next_time = math_min(next_time, t - current_time); + end + end + return next_time; + end); +end + stats = function( ) return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen end @@ -843,8 +877,15 @@ end loop = function(once) -- this is the main loop of the program if quitting then return "quitting"; end if once then quitting = "once"; end - local next_timer_time = math_huge; + _currenttime = luasocket_gettime( ) repeat + -- Fire timers + local next_timer_time = math_huge; + for i = 1, _timerlistlen do + local t = _timerlist[ i ]( _currenttime ) -- fire timers + if t then next_timer_time = math_min(next_timer_time, t); end + end + local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) ) for i, socket in ipairs( write ) do -- send data waiting in writequeues local handler = _socketlist[ socket ] @@ -872,17 +913,16 @@ loop = function(once) -- this is the main loop of the program _currenttime = luasocket_gettime( ) -- Check for socket timeouts - local difftime = os_difftime( _currenttime - _starttime ) - if difftime > _checkinterval then + if _currenttime - _starttime > _checkinterval then _starttime = _currenttime for handler, timestamp in pairs( _writetimes ) do - if os_difftime( _currenttime - timestamp ) > _sendtimeout then + if _currenttime - timestamp > _sendtimeout then handler.disconnect( )( handler, "send timeout" ) handler:force_close() -- forced disconnect end end for handler, timestamp in pairs( _readtimes ) do - if os_difftime( _currenttime - timestamp ) > _readtimeout then + if _currenttime - timestamp > _readtimeout then if not(handler.onreadtimeout) or handler:onreadtimeout() ~= true then handler.disconnect( )( handler, "read timeout" ) handler:close( ) -- forced disconnect? @@ -892,21 +932,6 @@ loop = function(once) -- this is the main loop of the program end end end - - -- Fire timers - if _currenttime - _timer >= math_min(next_timer_time, 1) then - next_timer_time = math_huge; - for i = 1, _timerlistlen do - local t = _timerlist[ i ]( _currenttime ) -- fire timers - if t then next_timer_time = math_min(next_timer_time, t); end - end - _timer = _currenttime - else - next_timer_time = next_timer_time - (_currenttime - _timer); - end - - -- wait some time (0 by default) - socket_sleep( _sleeptime ) until quitting; if once and quitting == "once" then quitting = nil; return; end return "quitting" @@ -941,29 +966,53 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx return handler, socket end -local addclient = function( address, port, listeners, pattern, sslctx ) - local client, err = luasocket.tcp( ) +local addclient = function( address, port, listeners, pattern, sslctx, typ ) + local err + if type( listeners ) ~= "table" then + err = "invalid listener table" + elseif type ( address ) ~= "string" then + err = "invalid address" + elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then + err = "invalid port" + elseif sslctx and not has_luasec then + err = "luasec not found" + end + if getaddrinfo and not typ then + local addrinfo, err = getaddrinfo(address) + if not addrinfo then return nil, err end + if addrinfo[1] and addrinfo[1].family == "inet6" then + typ = "tcp6" + end + end + local create = luasocket[typ or "tcp"] + if type( create ) ~= "function" then + err = "invalid socket type" + end + + if err then + out_error( "server.lua, addclient: ", err ) + return nil, err + end + + local client, err = create( ) if err then return nil, err end client:settimeout( 0 ) - _, err = client:connect( address, port ) - if err then -- try again + local ok, err = client:connect( address, port ) + if ok or err == "timeout" or err == "Operation already in progress" then return wrapclient( client, address, port, listeners, pattern, sslctx ) else - return wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx ) + return nil, err end end ---// EXPERIMENTAL //-- - ----------------------------------// BEGIN //-- use "setmetatable" ( _socketlist, { __mode = "k" } ) use "setmetatable" ( _readtimes, { __mode = "k" } ) use "setmetatable" ( _writetimes, { __mode = "k" } ) -_timer = luasocket_gettime( ) _starttime = luasocket_gettime( ) local function setlogger(new_logger) @@ -978,6 +1027,7 @@ end return { _addtimer = addtimer, + add_task = add_task; addclient = addclient, wrapclient = wrapclient, |