diff options
Diffstat (limited to 'net/server_select.lua')
-rw-r--r-- | net/server_select.lua | 156 |
1 files changed, 89 insertions, 67 deletions
diff --git a/net/server_select.lua b/net/server_select.lua index 0852d444..7eb330a8 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -10,11 +10,6 @@ local use = function( what ) return _G[ what ] end -local clean = function( tbl ) - for i, k in pairs( tbl ) do - tbl[ i ] = nil - end -end local log, table_concat = require ("util.logger").init("socket"), table.concat; local out_put = function (...) return log("debug", table_concat{...}); end @@ -47,7 +42,6 @@ local os_difftime = os.difftime local math_min = math.min local math_huge = math.huge local table_concat = table.concat -local string_len = string.len local string_sub = string.sub local coroutine_wrap = coroutine.wrap local coroutine_yield = coroutine.yield @@ -107,6 +101,7 @@ local _readtraffic local _selecttimeout local _sleeptime +local _tcpbacklog local _starttime local _currenttime @@ -118,11 +113,10 @@ local _checkinterval local _sendtimeout local _readtimeout -local _cleanqueue - local _timer -local _maxclientsperserver +local _maxselectlen +local _maxfd local _maxsslhandshake @@ -146,6 +140,7 @@ _readtraffic = 0 _selecttimeout = 1 -- timeout of socket.select _sleeptime = 0 -- time to wait at the end of every loop +_tcpbacklog = 128 -- some kind of hint to the OS _maxsendlen = 51000 * 1024 -- max len of send buffer _maxreadlen = 25000 * 1024 -- max len of read buffer @@ -154,17 +149,21 @@ _checkinterval = 1200000 -- interval in secs to check idle clients _sendtimeout = 60000 -- allowed send idle time in secs _readtimeout = 6 * 60 * 60 -- allowed read idle time in secs -_cleanqueue = false -- clean bufferqueue after using - -_maxclientsperserver = 1000 +local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows +_maxfd = luasocket._SETSIZE or (is_windows and math.huge) or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows +_maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows _maxsslhandshake = 30 -- max handshake round-trips ----------------------------------// PRIVATE //-- -wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxconnections ) -- this function wraps a server +wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd - maxconnections = maxconnections or _maxclientsperserver + if socket:getfd() >= _maxfd then + out_error("server.lua: Disallowed FD number: "..socket:getfd()) + socket:close() + return nil, "fd-too-large" + end local connections = 0 @@ -201,20 +200,23 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco --mem_free( ) out_put "server.lua: closed server handler and removed sockets from list" end - handler.pause = function() + handler.pause = function( hard ) if not handler.paused then - socket:close( ) - _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) _readlistlen = removesocket( _readlist, socket, _readlistlen ) - _socketlist[ socket ] = nil - socket = nil; + if hard then + _socketlist[ socket ] = nil + socket:close( ) + socket = nil; + end handler.paused = true; end end - handler.resume = function() + handler.resume = function( ) if handler.paused then - socket = socket_bind( ip, serverport ); - socket:settimeout( 0 ) + if not socket then + socket = socket_bind( ip, serverport, _tcpbacklog ); + socket:settimeout( 0 ) + end _readlistlen = addsocket(_readlist, socket, _readlistlen) _socketlist[ socket ] = handler handler.paused = false; @@ -230,7 +232,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco return socket end handler.readbuffer = function( ) - if connections > maxconnections then + if _readlistlen >= _maxselectlen or _sendlistlen >= _maxselectlen then handler.pause( ) out_put( "server.lua: refused new client connection: server full" ) return false @@ -244,7 +246,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco end connections = connections + 1 out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport)) - if dispatch then + if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes return dispatch( handler ); end return; @@ -258,6 +260,12 @@ end wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object + if socket:getfd() >= _maxfd then + out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent + socket:close( ) -- Should we send some kind of error here? + server.pause( ) + return nil, nil, "fd-too-large" + end socket:settimeout( 0 ) --// local import of socket methods //-- @@ -335,9 +343,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.force_close = function ( self, err ) if bufferqueuelen ~= 0 then out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport)) - for i = bufferqueuelen, 1, -1 do - bufferqueue[i] = nil; - end bufferqueuelen = 0; end return self:close(err); @@ -391,7 +396,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport return clientport end local write = function( self, data ) - bufferlen = bufferlen + string_len( data ) + bufferlen = bufferlen + #data if bufferlen > maxsendlen then _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle handler.write = idfalse -- dont write anymore @@ -473,7 +478,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern" if not err or (err == "wantread" or err == "timeout") then -- received something local buffer = buffer or part or "" - local len = string_len( buffer ) + local len = #buffer if len > maxreadlen then handler:close( "receive buffer exceeded" ) return false @@ -499,7 +504,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport count = ( succ or byte or 0 ) * STAT_UNIT sendtraffic = sendtraffic + count _sendtraffic = _sendtraffic + count - _ = _cleanqueue and clean( bufferqueue ) + for i = bufferqueuelen,1,-1 do + bufferqueue[ i ] = nil + end --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) ) else succ, err, count = false, "unexpected close", 0; @@ -568,7 +575,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") ) _ = handler and handler:force_close("ssl handshake failed") - return false, err -- handshake failed + return false, err -- handshake failed end ) end @@ -612,7 +619,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.readbuffer = handshake handler.sendbuffer = handshake - return handshake( socket ) -- do handshake + return handshake( socket ) -- do handshake end end @@ -628,10 +635,10 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport if sslctx and luasec then out_put "server.lua: auto-starting ssl negotiation..." handler.autostart_ssl = true; - local ok, err = handler:starttls(sslctx); - if ok == false then - return nil, nil, err - end + local ok, err = handler:starttls(sslctx); + if ok == false then + return nil, nil, err + end end return handler, socket @@ -716,12 +723,12 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function return nil, err end addr = addr or "*" - local server, err = socket_bind( addr, port ) + local server, err = socket_bind( addr, port, _tcpbacklog ) if err then out_error( "server.lua, [", addr, "]:", port, ": ", err ) return nil, err end - local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, _maxclientsperserver ) -- wrap new server socket + local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket if not handler then server:close( ) return nil, err @@ -765,7 +772,19 @@ closeall = function( ) end getsettings = function( ) - return _selecttimeout, _sleeptime, _maxsendlen, _maxreadlen, _checkinterval, _sendtimeout, _readtimeout, _cleanqueue, _maxclientsperserver, _maxsslhandshake + return { + select_timeout = _selecttimeout; + select_sleep_time = _sleeptime; + tcp_backlog = _tcpbacklog; + max_send_buffer_size = _maxsendlen; + max_receive_buffer_size = _maxreadlen; + select_idle_check_interval = _checkinterval; + send_timeout = _sendtimeout; + read_timeout = _readtimeout; + max_connections = _maxselectlen; + max_ssl_handshake_roundtrips = _maxsslhandshake; + highest_allowed_fd = _maxfd; + } end changesettings = function( new ) @@ -777,11 +796,12 @@ changesettings = function( new ) _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval + _tcpbacklog = tonumber( new.tcp_backlog ) or _tcpbacklog _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout _readtimeout = tonumber( new.read_timeout ) or _readtimeout - _cleanqueue = new.select_clean_queue - _maxclientsperserver = new.max_connections or _maxclientsperserver + _maxselectlen = new.max_connections or _maxselectlen _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake + _maxfd = new.highest_allowed_fd or _maxfd return true end @@ -831,9 +851,31 @@ loop = function(once) -- this is the main loop of the program for handler, err in pairs( _closelist ) do handler.disconnect( )( handler, err ) handler:force_close() -- forced disconnect + _closelist[ handler ] = nil; end - clean( _closelist ) _currenttime = luasocket_gettime( ) + + -- Check for socket timeouts + local difftime = os_difftime( _currenttime - _starttime ) + if difftime > _checkinterval then + _starttime = _currenttime + for handler, timestamp in pairs( _writetimes ) do + if os_difftime( _currenttime - timestamp ) > _sendtimeout then + --_writetimes[ handler ] = nil + handler.disconnect( )( handler, "send timeout" ) + handler:force_close() -- forced disconnect + end + end + for handler, timestamp in pairs( _readtimes ) do + if os_difftime( _currenttime - timestamp ) > _readtimeout then + --_readtimes[ handler ] = nil + handler.disconnect( )( handler, "read timeout" ) + handler:close( ) -- forced disconnect? + end + end + end + + -- Fire timers if _currenttime - _timer >= math_min(next_timer_time, 1) then next_timer_time = math_huge; for i = 1, _timerlistlen do @@ -844,8 +886,9 @@ loop = function(once) -- this is the main loop of the program else next_timer_time = next_timer_time - (_currenttime - _timer); end - socket_sleep( _sleeptime ) -- wait some time - --collectgarbage( ) + + -- wait some time (0 by default) + socket_sleep( _sleeptime ) until quitting; if once and quitting == "once" then quitting = nil; return; end return "quitting" @@ -862,7 +905,8 @@ end --// EXPERIMENTAL //-- local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx ) - local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) + local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) + if not handler then return nil, err end _socketlist[ socket ] = handler if not sslctx then _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) @@ -908,28 +952,6 @@ use "setmetatable" ( _writetimes, { __mode = "k" } ) _timer = luasocket_gettime( ) _starttime = luasocket_gettime( ) -addtimer( function( ) - local difftime = os_difftime( _currenttime - _starttime ) - if difftime > _checkinterval then - _starttime = _currenttime - for handler, timestamp in pairs( _writetimes ) do - if os_difftime( _currenttime - timestamp ) > _sendtimeout then - --_writetimes[ handler ] = nil - handler.disconnect( )( handler, "send timeout" ) - handler:force_close() -- forced disconnect - end - end - for handler, timestamp in pairs( _readtimes ) do - if os_difftime( _currenttime - timestamp ) > _readtimeout then - --_readtimes[ handler ] = nil - handler.disconnect( )( handler, "read timeout" ) - handler:close( ) -- forced disconnect? - end - end - end - end -) - local function setlogger(new_logger) local old_logger = log; if new_logger then |