diff options
Diffstat (limited to 'net/server_select.lua')
-rw-r--r-- | net/server_select.lua | 88 |
1 files changed, 64 insertions, 24 deletions
diff --git a/net/server_select.lua b/net/server_select.lua index 7ac41523..9c5225c6 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -1,7 +1,7 @@ --- +-- -- server.lua by blastbeat of the luadch project -- Re-used here under the MIT/X Consortium License --- +-- -- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain -- @@ -48,13 +48,14 @@ local coroutine_yield = coroutine.yield --// extern libs //-- -local luasec = use "ssl" +local has_luasec, luasec = pcall ( require , "ssl" ) local luasocket = use "socket" or require "socket" local luasocket_gettime = luasocket.gettime +local getaddrinfo = luasocket.dns.getaddrinfo --// extern lib methods //-- -local ssl_wrap = ( luasec and luasec.wrap ) +local ssl_wrap = ( has_luasec and luasec.wrap ) local socket_bind = luasocket.bind local socket_sleep = luasocket.sleep local socket_select = luasocket.select @@ -145,7 +146,7 @@ _tcpbacklog = 128 -- some kind of hint to the OS _maxsendlen = 51000 * 1024 -- max len of send buffer _maxreadlen = 25000 * 1024 -- max len of read buffer -_checkinterval = 1200000 -- interval in secs to check idle clients +_checkinterval = 30 -- interval in secs to check idle clients _sendtimeout = 60000 -- allowed send idle time in secs _readtimeout = 6 * 60 * 60 -- allowed read idle time in secs @@ -284,6 +285,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport local status = listeners.onstatus local disconnect = listeners.ondisconnect local drain = listeners.ondrain + local onreadtimeout = listeners.onreadtimeout; local detach = listeners.ondetach local bufferqueue = { } -- buffer array @@ -313,6 +315,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.disconnect = function( ) return disconnect end + handler.onreadtimeout = onreadtimeout; + handler.setlistener = function( self, listeners ) if detach then detach(self) -- Notify listener that it is no longer responsible for this connection @@ -321,6 +325,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport disconnect = listeners.ondisconnect status = listeners.onstatus drain = listeners.ondrain + handler.onreadtimeout = listeners.onreadtimeout detach = listeners.ondetach end handler.getstats = function( ) @@ -564,6 +569,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _ = status and status( handler, "ssl-handshake-complete" ) if self.autostart_ssl and listeners.onconnect then listeners.onconnect(self); + if bufferqueuelen ~= 0 then + _sendlistlen = addsocket(_sendlist, client, _sendlistlen) + end end _readlistlen = addsocket(_readlist, client, _readlistlen) return true @@ -587,7 +595,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end ) end - if luasec then + if has_luasec then handler.starttls = function( self, _sslctx) if _sslctx then handler:set_sslctx(_sslctx); @@ -613,7 +621,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport shutdown = id _socketlist[ socket ] = handler _readlistlen = addsocket(_readlist, socket, _readlistlen) - + -- remove traces of the old socket _readlistlen = removesocket( _readlist, oldsocket, _readlistlen ) _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen ) @@ -640,7 +648,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _socketlist[ socket ] = handler _readlistlen = addsocket(_readlist, socket, _readlistlen) - if sslctx and luasec then + if sslctx and has_luasec then out_put "server.lua: auto-starting ssl negotiation..." handler.autostart_ssl = true; local ok, err = handler:starttls(sslctx); @@ -701,7 +709,7 @@ local function link(sender, receiver, buffersize) sender_locked = nil; end end - + local _readbuffer = sender.readbuffer; function sender.readbuffer() _readbuffer(); @@ -716,22 +724,23 @@ end ----------------------------------// PUBLIC //-- addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server + addr = addr or "*" local err if type( listeners ) ~= "table" then err = "invalid listener table" - end - if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then + elseif type ( addr ) ~= "string" then + err = "invalid address" + elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then err = "invalid port" elseif _server[ addr..":"..port ] then err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist" - elseif sslctx and not luasec then + elseif sslctx and not has_luasec then err = "luasec not found" end if err then out_error( "server.lua, [", addr, "]:", port, ": ", err ) return nil, err end - addr = addr or "*" local server, err = socket_bind( addr, port, _tcpbacklog ) if err then out_error( "server.lua, [", addr, "]:", port, ": ", err ) @@ -870,16 +879,18 @@ loop = function(once) -- this is the main loop of the program _starttime = _currenttime for handler, timestamp in pairs( _writetimes ) do if os_difftime( _currenttime - timestamp ) > _sendtimeout then - --_writetimes[ handler ] = nil handler.disconnect( )( handler, "send timeout" ) handler:force_close() -- forced disconnect end end for handler, timestamp in pairs( _readtimes ) do if os_difftime( _currenttime - timestamp ) > _readtimeout then - --_readtimes[ handler ] = nil - handler.disconnect( )( handler, "read timeout" ) - handler:close( ) -- forced disconnect? + if not(handler.onreadtimeout) or handler:onreadtimeout() ~= true then + handler.disconnect( )( handler, "read timeout" ) + handler:close( ) -- forced disconnect? + else + _readtimes[ handler ] = _currenttime -- reset timer + end end end end @@ -932,17 +943,46 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx return handler, socket end -local addclient = function( address, port, listeners, pattern, sslctx ) - local client, err = luasocket.tcp( ) +local addclient = function( address, port, listeners, pattern, sslctx, typ ) + local err + if type( listeners ) ~= "table" then + err = "invalid listener table" + elseif type ( address ) ~= "string" then + err = "invalid address" + elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then + err = "invalid port" + elseif sslctx and not has_luasec then + err = "luasec not found" + end + if not typ then + local addrinfo, err = getaddrinfo(address) + if not addrinfo then return nil, err end + if addrinfo[1] and addrinfo[1].family == "inet6" then + typ = "tcp6" + else + typ = "tcp" + end + end + local create = luasocket[typ] + if type( create ) ~= "function" then + err = "invalid socket type" + end + + if err then + out_error( "server.lua, addclient: ", err ) + return nil, err + end + + local client, err = create( ) if err then return nil, err end client:settimeout( 0 ) - _, err = client:connect( address, port ) - if err then -- try again - local handler = wrapclient( client, address, port, listeners ) + local ok, err = client:connect( address, port ) + if ok or err == "timeout" then + return wrapclient( client, address, port, listeners, pattern, sslctx ) else - wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx ) + return nil, err end end @@ -972,7 +1012,7 @@ return { addclient = addclient, wrapclient = wrapclient, - + loop = loop, link = link, step = step, |