diff options
Diffstat (limited to 'net/server_select.lua')
-rw-r--r-- | net/server_select.lua | 105 |
1 files changed, 74 insertions, 31 deletions
diff --git a/net/server_select.lua b/net/server_select.lua index c50a6ce1..3503c30d 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -1,7 +1,7 @@ --- +-- -- server.lua by blastbeat of the luadch project -- Re-used here under the MIT/X Consortium License --- +-- -- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain -- @@ -38,7 +38,6 @@ local coroutine = use "coroutine" --// lua lib methods //-- -local os_difftime = os.difftime local math_min = math.min local math_huge = math.huge local table_concat = table.concat @@ -48,13 +47,14 @@ local coroutine_yield = coroutine.yield --// extern libs //-- -local luasec = use "ssl" +local has_luasec, luasec = pcall ( require , "ssl" ) local luasocket = use "socket" or require "socket" local luasocket_gettime = luasocket.gettime +local getaddrinfo = luasocket.dns.getaddrinfo --// extern lib methods //-- -local ssl_wrap = ( luasec and luasec.wrap ) +local ssl_wrap = ( has_luasec and luasec.wrap ) local socket_bind = luasocket.bind local socket_sleep = luasocket.sleep local socket_select = luasocket.select @@ -149,7 +149,7 @@ _accepretry = 10 -- seconds to wait until the next attempt of a full server to a _maxsendlen = 51000 * 1024 -- max len of send buffer _maxreadlen = 25000 * 1024 -- max len of read buffer -_checkinterval = 1200000 -- interval in secs to check idle clients +_checkinterval = 30 -- interval in secs to check idle clients _sendtimeout = 60000 -- allowed send idle time in secs _readtimeout = 6 * 60 * 60 -- allowed read idle time in secs @@ -295,6 +295,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport local status = listeners.onstatus local disconnect = listeners.ondisconnect local drain = listeners.ondrain + local onreadtimeout = listeners.onreadtimeout; local detach = listeners.ondetach local bufferqueue = { } -- buffer array @@ -324,6 +325,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.disconnect = function( ) return disconnect end + handler.onreadtimeout = onreadtimeout; + handler.setlistener = function( self, listeners ) if detach then detach(self) -- Notify listener that it is no longer responsible for this connection @@ -332,6 +335,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport disconnect = listeners.ondisconnect status = listeners.onstatus drain = listeners.ondrain + handler.onreadtimeout = listeners.onreadtimeout detach = listeners.ondetach end handler.getstats = function( ) @@ -404,6 +408,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport out_put "server.lua: closed client handler and removed socket from list" return true end + handler.server = function ( ) + return server + end handler.ip = function( ) return ip end @@ -575,6 +582,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _ = status and status( handler, "ssl-handshake-complete" ) if self.autostart_ssl and listeners.onconnect then listeners.onconnect(self); + if bufferqueuelen ~= 0 then + _sendlistlen = addsocket(_sendlist, client, _sendlistlen) + end end _readlistlen = addsocket(_readlist, client, _readlistlen) return true @@ -592,13 +602,14 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport coroutine_yield( ) -- handshake not finished end end - out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") ) - _ = handler and handler:force_close("ssl handshake failed") + err = "ssl handshake error: " .. ( err or "handshake too long" ); + out_put( "server.lua: ", err ); + _ = handler and handler:force_close(err) return false, err -- handshake failed end ) end - if luasec then + if has_luasec then handler.starttls = function( self, _sslctx) if _sslctx then handler:set_sslctx(_sslctx); @@ -624,7 +635,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport shutdown = id _socketlist[ socket ] = handler _readlistlen = addsocket(_readlist, socket, _readlistlen) - + -- remove traces of the old socket _readlistlen = removesocket( _readlist, oldsocket, _readlistlen ) _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen ) @@ -651,7 +662,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _socketlist[ socket ] = handler _readlistlen = addsocket(_readlist, socket, _readlistlen) - if sslctx and luasec then + if sslctx and has_luasec then out_put "server.lua: auto-starting ssl negotiation..." handler.autostart_ssl = true; local ok, err = handler:starttls(sslctx); @@ -712,7 +723,7 @@ local function link(sender, receiver, buffersize) sender_locked = nil; end end - + local _readbuffer = sender.readbuffer; function sender.readbuffer() _readbuffer(); @@ -727,22 +738,23 @@ end ----------------------------------// PUBLIC //-- addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server + addr = addr or "*" local err if type( listeners ) ~= "table" then err = "invalid listener table" - end - if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then + elseif type ( addr ) ~= "string" then + err = "invalid address" + elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then err = "invalid port" elseif _server[ addr..":"..port ] then err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist" - elseif sslctx and not luasec then + elseif sslctx and not has_luasec then err = "luasec not found" end if err then out_error( "server.lua, [", addr, "]:", port, ": ", err ) return nil, err end - addr = addr or "*" local server, err = socket_bind( addr, port, _tcpbacklog ) if err then out_error( "server.lua, [", addr, "]:", port, ": ", err ) @@ -878,21 +890,22 @@ loop = function(once) -- this is the main loop of the program _currenttime = luasocket_gettime( ) -- Check for socket timeouts - local difftime = os_difftime( _currenttime - _starttime ) - if difftime > _checkinterval then + if _currenttime - _starttime > _checkinterval then _starttime = _currenttime for handler, timestamp in pairs( _writetimes ) do - if os_difftime( _currenttime - timestamp ) > _sendtimeout then - --_writetimes[ handler ] = nil + if _currenttime - timestamp > _sendtimeout then handler.disconnect( )( handler, "send timeout" ) handler:force_close() -- forced disconnect end end for handler, timestamp in pairs( _readtimes ) do - if os_difftime( _currenttime - timestamp ) > _readtimeout then - --_readtimes[ handler ] = nil - handler.disconnect( )( handler, "read timeout" ) - handler:close( ) -- forced disconnect? + if _currenttime - timestamp > _readtimeout then + if not(handler.onreadtimeout) or handler:onreadtimeout() ~= true then + handler.disconnect( )( handler, "read timeout" ) + handler:close( ) -- forced disconnect? + else + _readtimes[ handler ] = _currenttime -- reset timer + end end end end @@ -920,6 +933,7 @@ loop = function(once) -- this is the main loop of the program socket_sleep( _sleeptime ) until quitting; if once and quitting == "once" then quitting = nil; return; end + closeall(); return "quitting" end @@ -952,17 +966,46 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx return handler, socket end -local addclient = function( address, port, listeners, pattern, sslctx ) - local client, err = luasocket.tcp( ) +local addclient = function( address, port, listeners, pattern, sslctx, typ ) + local err + if type( listeners ) ~= "table" then + err = "invalid listener table" + elseif type ( address ) ~= "string" then + err = "invalid address" + elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then + err = "invalid port" + elseif sslctx and not has_luasec then + err = "luasec not found" + end + if not typ then + local addrinfo, err = getaddrinfo(address) + if not addrinfo then return nil, err end + if addrinfo[1] and addrinfo[1].family == "inet6" then + typ = "tcp6" + else + typ = "tcp" + end + end + local create = luasocket[typ] + if type( create ) ~= "function" then + err = "invalid socket type" + end + + if err then + out_error( "server.lua, addclient: ", err ) + return nil, err + end + + local client, err = create( ) if err then return nil, err end client:settimeout( 0 ) - _, err = client:connect( address, port ) - if err then -- try again - local handler = wrapclient( client, address, port, listeners ) + local ok, err = client:connect( address, port ) + if ok or err == "timeout" then + return wrapclient( client, address, port, listeners, pattern, sslctx ) else - wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx ) + return nil, err end end @@ -992,7 +1035,7 @@ return { addclient = addclient, wrapclient = wrapclient, - + loop = loop, link = link, step = step, |