diff options
author | Matthew Wild <mwild1@gmail.com> | 2009-01-29 17:54:37 +0000 |
---|---|---|
committer | Matthew Wild <mwild1@gmail.com> | 2009-01-29 17:54:37 +0000 |
commit | 4dc046b0afa381ac22982bd1d4265775832c4553 (patch) | |
tree | 808a534fe9e3bc2cfadb1ea6bb7bacea05277631 /net/server.lua | |
parent | 8fc233d7f5e6a5bee04d6b26a761ac2f9f88e57d (diff) | |
parent | 75a4fe601018e84510f1526faa9d155a9299e528 (diff) | |
download | prosody-4dc046b0afa381ac22982bd1d4265775832c4553.tar.gz prosody-4dc046b0afa381ac22982bd1d4265775832c4553.zip |
Automated merge with http://waqas.ath.cx:8000/
Diffstat (limited to 'net/server.lua')
-rw-r--r-- | net/server.lua | 1710 |
1 files changed, 823 insertions, 887 deletions
diff --git a/net/server.lua b/net/server.lua index 7416276f..b5d84dd9 100644 --- a/net/server.lua +++ b/net/server.lua @@ -1,887 +1,823 @@ ---[[ - - server.lua by blastbeat of the luadch project - - re-used here under the MIT/X Consortium License - - Modifications (C) 2008 Matthew Wild, Waqas Hussain -]]-- - -----------------------------------// DECLARATION //-- - ---// constants //-- - -local STAT_UNIT = 1 / ( 1024 * 1024 ) -- mb - ---// lua functions //-- - -local function use( what ) return _G[ what ] end - -local type = use "type" -local pairs = use "pairs" -local ipairs = use "ipairs" -local tostring = use "tostring" -local collectgarbage = use "collectgarbage" - ---// lua libs //-- - -local table = use "table" -local coroutine = use "coroutine" - ---// lua lib methods //-- - -local table_concat = table.concat -local table_remove = table.remove -local string_sub = use'string'.sub -local coroutine_wrap = coroutine.wrap -local coroutine_yield = coroutine.yield -local print = print; - -local log = require "util.logger".init("server"); -local out_put = function () end --print; -local out_error = function (...) log("error", table_concat({...}, " ")); end - ---// extern libs //-- - -local luasec = select(2, pcall(require, "ssl")) -local luasocket = require "socket" - ---// extern lib methods //-- - -local ssl_wrap = ( luasec and luasec.wrap ) -local socket_bind = luasocket.bind -local socket_select = luasocket.select -local ssl_newcontext = ( luasec and luasec.newcontext ) - ---// functions //-- - -local loop -local stats -local addtimer -local closeall -local addserver -local firetimer -local closesocket -local removesocket -local wrapserver -local wraptcpclient -local wrapsslclient - ---// tables //-- - -local listener -local readlist -local writelist -local socketlist -local timelistener - ---// simple data types //-- - -local _ -local readlen = 0 -- length of readlist -local writelen = 0 -- lenght of writelist - -local sendstat= 0 -local receivestat = 0 - -----------------------------------// DEFINITION //-- - -listener = { } -- key = port, value = table -readlist = { } -- array with sockets to read from -writelist = { } -- arrary with sockets to write to -socketlist = { } -- key = socket, value = wrapped socket -timelistener = { } - -stats = function( ) - return receivestat, sendstat -end - -wrapserver = function( listener, socket, ip, serverport, mode, sslctx, wrapper_function ) -- this function wraps a server - - local dispatch, disconnect = listener.listener, listener.disconnect -- dangerous - - local wrapclient, err - - out_put("Starting a new server on "..tostring(serverport).." with ssl: "..tostring(sslctx)); - if sslctx then - if not ssl_newcontext then - return nil, "luasec not found" - end - if type( sslctx ) ~= "table" then - out_error "server.lua: wrong server sslctx" - return nil, "wrong server sslctx" - end - sslctx, err = ssl_newcontext( sslctx ) - if not sslctx then - err = err or "wrong sslctx parameters" - out_error( "server.lua: ", err ) - return nil, err - end - end - - if wrapper_function then - wrapclient = wrapper_function - elseif sslctx then - wrapclient = wrapsslclient - else - wrapclient = wraptcpclient - end - - local accept = socket.accept - local close = socket.close - - --// public methods of the object //-- - - local handler = { } - - handler.shutdown = function( ) end - - --[[handler.listener = function( data, err ) - return ondata( handler, data, err ) - end]] - handler.ssl = function( ) - return sslctx and true or false - end - handler.close = function( closed ) - _ = not closed and close( socket ) - writelen = removesocket( writelist, socket, writelen ) - readlen = removesocket( readlist, socket, readlen ) - socketlist[ socket ] = nil - handler = nil - end - handler.ip = function( ) - return ip - end - handler.serverport = function( ) - return serverport - end - handler.socket = function( ) - return socket - end - handler.receivedata = function( ) - local client, err = accept( socket ) -- try to accept - if client then - local ip, clientport = client:getpeername( ) - client:settimeout( 0 ) - local handler, client, err = wrapclient( listener, client, ip, serverport, clientport, mode, sslctx ) -- wrap new client socket - if err then -- error while wrapping ssl socket - return false - end - out_put( "server.lua: accepted new client connection from ", ip, ":", clientport ) - return dispatch( handler ) - elseif err then -- maybe timeout or something else - out_put( "server.lua: error with new client connection: ", err ) - return false - end - end - return handler -end - -wrapsslclient = function( listener, socket, ip, serverport, clientport, mode, sslctx ) -- this function wraps a ssl cleint - - local dispatch, disconnect = listener.listener, listener.disconnect - - --// transform socket to ssl object //-- - - local err - socket, err = ssl_wrap( socket, sslctx ) -- wrap socket - if err then - out_put( "server.lua: ssl error: ", err ) - return nil, nil, err -- fatal error - end - socket:settimeout( 0 ) - - --// private closures of the object //-- - - local writequeue = { } -- buffer for messages to send - - local eol, fatal_send_error, wants_closing - - local sstat, rstat = 0, 0 - - --// local import of socket methods //-- - - local send = socket.send - local receive = socket.receive - local close = socket.close - --local shutdown = socket.shutdown - - --// public methods of the object //-- - - local handler = { } - - handler.getstats = function( ) - return rstat, sstat - end - - handler.listener = function( data, err ) - return listener( handler, data, err ) - end - handler.ssl = function( ) - return true - end - handler.send = function( _, data, i, j ) - return send( socket, data, i, j ) - end - handler.receive = function( pattern, prefix ) - return receive( socket, pattern, prefix ) - end - handler.shutdown = function( pattern ) - --return shutdown( socket, pattern ) - end - handler.close = function( closed ) - if eol and not fatal_send_error then - -- There is data in the buffer, and we haven't experienced - -- an error trying to send yet, so we'll flush the buffer now - handler._dispatchdata(); - if eol then - -- and there is *still* data in the buffer - -- we'll give up for now, and close later - wants_closing = true; - return; - end - end - close( socket ) - writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen - readlen = removesocket( readlist, socket, readlen ) - socketlist[ socket ] = nil - out_put "server.lua: closed handler and removed socket from list" - end - handler.ip = function( ) - return ip - end - handler.serverport = function( ) - return serverport - end - handler.clientport = function( ) - return clientport - end - - handler.write = function( data ) - if not eol then - writelen = writelen + 1 - writelist[ writelen ] = socket - eol = 0 - end - eol = eol + 1 - writequeue[ eol ] = data - end - handler.writequeue = function( ) - return writequeue - end - handler.socket = function( ) - return socket - end - handler.mode = function( ) - return mode - end - handler._receivedata = function( ) - local data, err, part = receive( socket, mode ) -- receive data in "mode" - if not err or ( err == "timeout" or err == "wantread" ) then -- received something - local data = data or part or "" - local count = #data * STAT_UNIT - rstat = rstat + count - receivestat = receivestat + count - --out_put( "server.lua: read data '", data, "', error: ", err ) - return dispatch( handler, data, err ) - else -- connections was closed or fatal error - out_put( "server.lua: client ", ip, ":", clientport, " error: ", err ) - handler.close( ) - disconnect( handler, err ) - writequeue = nil - handler = nil - return false - end - end - handler._dispatchdata = function( ) -- this function writes data to handlers - local buffer = table_concat( writequeue, "", 1, eol ) - local succ, err, byte = send( socket, buffer ) - local count = ( succ or 0 ) * STAT_UNIT - sstat = sstat + count - sendstat = sendstat + count - out_put( "server.lua: sended '", buffer, "', bytes: ", succ, ", error: ", err, ", part: ", byte, ", to: ", ip, ":", clientport ) - if succ then -- sending succesful - --writequeue = { } - eol = nil - writelen = removesocket( writelist, socket, writelen ) -- delete socket from writelist - if wants_closing then - handler.close(); - end - return true - elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write - buffer = string_sub( buffer, byte + 1, -1 ) -- new buffer - writequeue[ 1 ] = buffer -- insert new buffer in queue - eol = 1 - return true - else -- connection was closed during sending or fatal error - fatal_send_error = true; - out_put( "server.lua: client ", ip, ":", clientport, " error: ", err ) - handler.close( ) - disconnect( handler, err ) - writequeue = nil - handler = nil - return false - end - end - - -- // COMPAT // -- - - handler.getIp = handler.ip - handler.getPort = handler.clientport - - --// handshake //-- - - local wrote - - handler.handshake = coroutine_wrap( function( client ) - local err - for i = 1, 10 do -- 10 handshake attemps - _, err = client:dohandshake( ) - if not err then - out_put( "server.lua: ssl handshake done" ) - writelen = ( wrote and removesocket( writelist, socket, writelen ) ) or writelen - handler.receivedata = handler._receivedata -- when handshake is done, replace the handshake function with regular functions - handler.dispatchdata = handler._dispatchdata - return dispatch( handler ) - else - out_put( "server.lua: error during ssl handshake: ", err ) - if err == "wantwrite" then - if wrote == nil then - writelen = writelen + 1 - writelist[ writelen ] = client - wrote = true - end - end - coroutine_yield( handler, nil, err ) -- handshake not finished - end - end - _ = err ~= "closed" and close( socket ) - handler.close( ) - disconnect( handler, err ) - writequeue = nil - handler = nil - return false -- handshake failed - end - ) - handler.receivedata = handler.handshake - handler.dispatchdata = handler.handshake - - handler.handshake( socket ) -- do handshake - - socketlist[ socket ] = handler - readlen = readlen + 1 - readlist[ readlen ] = socket - - return handler, socket -end - -wraptlsclient = function( listener, socket, ip, serverport, clientport, mode, sslctx ) -- this function wraps a tls cleint - - local dispatch, disconnect = listener.listener, listener.disconnect - - --// transform socket to ssl object //-- - - local err - - socket:settimeout( 0 ) - --// private closures of the object //-- - - local writequeue = { } -- buffer for messages to send - - local eol, fatal_send_error, wants_closing - - local sstat, rstat = 0, 0 - - --// local import of socket methods //-- - - local send = socket.send - local receive = socket.receive - local close = socket.close - --local shutdown = socket.shutdown - - --// public methods of the object //-- - - local handler = { } - - handler.getstats = function( ) - return rstat, sstat - end - - handler.listener = function( data, err ) - return listener( handler, data, err ) - end - handler.ssl = function( ) - return false - end - handler.send = function( _, data, i, j ) - return send( socket, data, i, j ) - end - handler.receive = function( pattern, prefix ) - return receive( socket, pattern, prefix ) - end - handler.shutdown = function( pattern ) - --return shutdown( socket, pattern ) - end - handler.close = function( closed ) - if eol and not fatal_send_error then - -- There is data in the buffer, and we haven't experienced - -- an error trying to send yet, so we'll flush the buffer now - handler._dispatchdata(); - if eol then - -- and there is *still* data in the buffer - -- we'll give up for now, and close later - wants_closing = true; - return; - end - end - close( socket ) - writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen - readlen = removesocket( readlist, socket, readlen ) - socketlist[ socket ] = nil - out_put "server.lua: closed handler and removed socket from list" - end - handler.ip = function( ) - return ip - end - handler.serverport = function( ) - return serverport - end - handler.clientport = function( ) - return clientport - end - - handler.write = function( data ) - if not eol then - writelen = writelen + 1 - writelist[ writelen ] = socket - eol = 0 - end - eol = eol + 1 - writequeue[ eol ] = data - end - handler.writequeue = function( ) - return writequeue - end - handler.socket = function( ) - return socket - end - handler.mode = function( ) - return mode - end - handler._receivedata = function( ) - local data, err, part = receive( socket, mode ) -- receive data in "mode" - if not err or ( err == "timeout" or err == "wantread" ) then -- received something - local data = data or part or "" - local count = #data * STAT_UNIT - rstat = rstat + count - receivestat = receivestat + count - --out_put( "server.lua: read data '", data, "', error: ", err ) - return dispatch( handler, data, err ) - else -- connections was closed or fatal error - out_put( "server.lua: client ", ip, ":", clientport, " error: ", err ) - handler.close( ) - disconnect( handler, err ) - writequeue = nil - handler = nil - return false - end - end - handler._dispatchdata = function( ) -- this function writes data to handlers - local buffer = table_concat( writequeue, "", 1, eol ) - local succ, err, byte = send( socket, buffer ) - local count = ( succ or 0 ) * STAT_UNIT - sstat = sstat + count - sendstat = sendstat + count - out_put( "server.lua: sended '", buffer, "', bytes: ", succ, ", error: ", err, ", part: ", byte, ", to: ", ip, ":", clientport ) - if succ then -- sending succesful - --writequeue = { } - eol = nil - writelen = removesocket( writelist, socket, writelen ) -- delete socket from writelist - if handler.need_tls then - out_put("server.lua: connection is ready for tls handshake"); - handler.starttls(true); - end - if wants_closing then - handler.close(); - end - return true - elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write - buffer = string_sub( buffer, byte + 1, -1 ) -- new buffer - writequeue[ 1 ] = buffer -- insert new buffer in queue - eol = 1 - return true - else -- connection was closed during sending or fatal error - fatal_send_error = true; -- :( - out_put( "server.lua: client ", ip, ":", clientport, " error: ", err ) - handler.close( ) - disconnect( handler, err ) - writequeue = nil - handler = nil - return false - end - end - - handler.receivedata, handler.dispatchdata = handler._receivedata, handler._dispatchdata; - -- // COMPAT // -- - - handler.getIp = handler.ip - handler.getPort = handler.clientport - - --// handshake //-- - - local wrote, read - - handler.starttls = function (now) - if not now then out_put("server.lua: we need to do tls, but delaying until later"); handler.need_tls = true; return; end - out_put( "server.lua: attempting to start tls on "..tostring(socket) ) - local oldsocket = socket; - socket, err = ssl_wrap( socket, sslctx ) -- wrap socket - out_put("sslwrapped socket is "..tostring(socket)); - if err then - out_put( "server.lua: ssl error: ", err ) - return nil, nil, err -- fatal error - end - socket:settimeout(0); - - -- Add the new socket to our system - socketlist[ socket ] = handler - readlen = readlen + 1 - readlist[ readlen ] = socket - - -- Remove traces of the old socket - readlen = removesocket( readlist, oldsocket, readlen ) - socketlist [ oldsocket ] = nil; - - send = socket.send - receive = socket.receive - close = socket.close - handler.ssl = function( ) - return true - end - handler.send = function( _, data, i, j ) - return send( socket, data, i, j ) - end - handler.receive = function( pattern, prefix ) - return receive( socket, pattern, prefix ) - end - - handler.starttls = nil; - handler.need_tls = nil - - handler.handshake = coroutine_wrap( function( client ) - local err - for i = 1, 10 do -- 10 handshake attemps - _, err = client:dohandshake( ) - if not err then - out_put( "server.lua: ssl handshake done" ) - writelen = ( wrote and removesocket( writelist, socket, writelen ) ) or writelen - handler.receivedata = handler._receivedata -- when handshake is done, replace the handshake function with regular functions - handler.dispatchdata = handler._dispatchdata; - return true; - else - out_put( "server.lua: error during ssl handshake: ", err ) - if err == "wantwrite" then - if wrote == nil then - writelen = writelen + 1 - writelist[ writelen ] = client - wrote = true - end - end - coroutine_yield( handler, nil, err ) -- handshake not finished - end - end - _ = err ~= "closed" and close( socket ) - handler.close( ) - disconnect( handler, err ) - writequeue = nil - handler = nil - return false -- handshake failed - end - ) - handler.receivedata = handler.handshake - handler.dispatchdata = handler.handshake - - handler.handshake( socket ) -- do handshake - end - socketlist[ socket ] = handler - readlen = readlen + 1 - readlist[ readlen ] = socket - - return handler, socket -end - -wraptcpclient = function( listener, socket, ip, serverport, clientport, mode ) -- this function wraps a socket - - local dispatch, disconnect = listener.listener, listener.disconnect - - --// private closures of the object //-- - - local writequeue = { } -- list for messages to send - - local eol, fatal_send_error, wants_closing - - socket:settimeout(0); - - local rstat, sstat = 0, 0 - - --// local import of socket methods //-- - - local send = socket.send - local receive = socket.receive - local close = socket.close - local shutdown = socket.shutdown - - --// public methods of the object //-- - - local handler = { } - - handler.getstats = function( ) - return rstat, sstat - end - - handler.listener = function( data, err ) - return listener( handler, data, err ) - end - handler.ssl = function( ) - return false - end - handler.send = function( _, data, i, j ) - return send( socket, data, i, j ) - end - handler.receive = function( pattern, prefix ) - return receive( socket, pattern, prefix ) - end - handler.shutdown = function( pattern ) - return shutdown( socket, pattern ) - end - handler.close = function( closed ) - if eol and not fatal_send_error then - -- There is data in the buffer, and we haven't experienced - -- an error trying to send yet, so we'll flush the buffer now - handler.dispatchdata(); - if eol then - -- and there is *still* data in the buffer - -- we'll give up for now, and close later - wants_closing = true; - return; - end - end - _ = not closed and shutdown( socket ) - _ = not closed and close( socket ) - writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen - readlen = removesocket( readlist, socket, readlen ) - socketlist[ socket ] = nil - out_put "server.lua: closed handler and removed socket from list" - end - handler.ip = function( ) - return ip - end - handler.serverport = function( ) - return serverport - end - handler.clientport = function( ) - return clientport - end - handler.write = function( data ) - if not eol then - writelen = writelen + 1 - writelist[ writelen ] = socket - eol = 0 - end - eol = eol + 1 - writequeue[ eol ] = data - end - handler.writequeue = function( ) - return writequeue - end - handler.socket = function( ) - return socket - end - handler.mode = function( ) - return mode - end - - handler.receivedata = function( ) - local data, err, part = receive( socket, mode ) -- receive data in "mode" - if not err or ( err == "timeout" or err == "wantread" ) then -- received something - local data = data or part or "" - local count = #data * STAT_UNIT - rstat = rstat + count - receivestat = receivestat + count - --out_put( "server.lua: read data '", data, "', error: ", err ) - return dispatch( handler, data, err ) - else -- connections was closed or fatal error - out_put( "server.lua: client ", ip, ":", clientport, " error: ", err ) - handler.close( ) - disconnect( handler, err ) - writequeue = nil - handler = nil - return false - end - end - - handler.dispatchdata = function( ) -- this function writes data to handlers - local buffer = table_concat( writequeue, "", 1, eol ) - local succ, err, byte = send( socket, buffer ) - local count = ( succ or 0 ) * STAT_UNIT - sstat = sstat + count - sendstat = sendstat + count - out_put( "server.lua: sended '", buffer, "', bytes: ", succ, ", error: ", err, ", part: ", byte, ", to: ", ip, ":", clientport ) - if succ then -- sending succesful - --writequeue = { } - eol = nil - writelen = removesocket( writelist, socket, writelen ) -- delete socket from writelist - if wants_closing then - handler.close(); - end - return true - elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write - buffer = string_sub( buffer, byte + 1, -1 ) -- new buffer - writequeue[ 1 ] = buffer -- insert new buffer in queue - eol = 1 - return true - else -- connection was closed during sending or fatal error - fatal_send_error = true; -- :'-( - out_put( "server.lua: client ", ip, ":", clientport, " error: ", err ) - handler.close( ) - disconnect( handler, err ) - writequeue = nil - handler = nil - return false - end - end - - -- // COMPAT // -- - - handler.getIp = handler.ip - handler.getPort = handler.clientport - - socketlist[ socket ] = handler - readlen = readlen + 1 - readlist[ readlen ] = socket - - return handler, socket -end - -addtimer = function( listener ) - timelistener[ #timelistener + 1 ] = listener -end - -firetimer = function( listener ) - for i, listener in ipairs( timelistener ) do - listener( ) - end -end - -addserver = function( listeners, port, addr, mode, sslctx, wrapper_function ) -- this function provides a way for other scripts to reg a server - local err - if type( listeners ) ~= "table" then - err = "invalid listener table" - else - for name, func in pairs( listeners ) do - if type( func ) ~= "function" then - --err = "invalid listener function" - break - end - end - end - if not type( port ) == "number" or not ( port >= 0 and port <= 65535 ) then - err = "invalid port" - elseif listener[ port ] then - err= "listeners on port '" .. port .. "' already exist" - elseif sslctx and not luasec then - err = "luasec not found" - end - if err then - out_error( "server.lua: ", err ) - return nil, err - end - addr = addr or "*" - local server, err = socket_bind( addr, port ) - if err then - out_error( addr..":"..port.." -", err ) - return nil, err - end - local handler, err = wrapserver( listeners, server, addr, port, mode, sslctx, wrapper_function ) -- wrap new server socket - if not handler then - server:close( ) - return nil, err - end - server:settimeout( 0 ) - readlen = readlen + 1 - readlist[ readlen ] = server - listener[ port ] = listeners - socketlist[ server ] = handler - out_put( "server.lua: new server listener on ", addr, ":", port ) - return true -end - -removesocket = function( tbl, socket, len ) -- this function removes sockets from a list - for i, target in ipairs( tbl ) do - if target == socket then - len = len - 1 - table_remove( tbl, i ) - return len - end - end - return len -end - -closeall = function( ) - for sock, handler in pairs( socketlist ) do - handler.shutdown( ) - handler.close( ) - socketlist[ sock ] = nil - end - writelist, readlist, socketlist = { }, { }, { } -end - -closesocket = function( socket ) - writelen = removesocket( writelist, socket, writelen ) - readlen = removesocket( readlist, socket, readlen ) - socketlist[ socket ] = nil - socket:close( ) -end - -loop = function( ) -- this is the main loop of the program - --signal_set( "hub", "run" ) - repeat - local read, write, err = socket_select( readlist, writelist, 1 ) -- 1 sec timeout, nice for timers - for i, socket in ipairs( write ) do -- send data waiting in writequeues - local handler = socketlist[ socket ] - if handler then - handler.dispatchdata( ) - else - closesocket( socket ) - out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen - end - end - for i, socket in ipairs( read ) do -- receive data - local handler = socketlist[ socket ] - if handler then - handler.receivedata( ) - else - closesocket( socket ) - out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen - end - end - firetimer( ) - until false - return -end - -----------------------------------// BEGIN //-- - -----------------------------------// PUBLIC INTERFACE //-- - -return { - - add = addserver, - loop = loop, - stats = stats, - closeall = closeall, - addtimer = addtimer, - wraptcpclient = wraptcpclient, - wrapsslclient = wrapsslclient, - wraptlsclient = wraptlsclient, -} +--[[
+
+ server.lua by blastbeat
+
+ - this script contains the server loop of the program
+ - other scripts can reg a server here
+
+]]--
+
+-- // wrapping luadch stuff // --
+
+local use = function( what )
+ return _G[ what ]
+end
+local clean = function( tbl )
+ for i, k in pairs( tbl ) do
+ tbl[ i ] = nil
+ end
+end
+
+local log, table_concat = require ("util.logger").init("socket"), table.concat;
+local out_put = function (...) return log("debug", table_concat{...}); end
+local out_error = function (...) return log("warn", table_concat{...}); end
+local mem_free = collectgarbage
+
+----------------------------------// DECLARATION //--
+
+--// constants //--
+
+local STAT_UNIT = 1 -- byte
+
+--// lua functions //--
+
+local type = use "type"
+local pairs = use "pairs"
+local ipairs = use "ipairs"
+local tostring = use "tostring"
+local collectgarbage = use "collectgarbage"
+
+--// lua libs //--
+
+local os = use "os"
+local table = use "table"
+local string = use "string"
+local coroutine = use "coroutine"
+
+--// lua lib methods //--
+
+local os_time = os.time
+local os_difftime = os.difftime
+local table_concat = table.concat
+local table_remove = table.remove
+local string_len = string.len
+local string_sub = string.sub
+local coroutine_wrap = coroutine.wrap
+local coroutine_yield = coroutine.yield
+
+--// extern libs //--
+
+local luasec = select( 2, pcall( require, "ssl" ) )
+local luasocket = require "socket"
+
+--// extern lib methods //--
+
+local ssl_wrap = ( luasec and luasec.wrap )
+local socket_bind = luasocket.bind
+local socket_sleep = luasocket.sleep
+local socket_select = luasocket.select
+local ssl_newcontext = ( luasec and luasec.newcontext )
+
+--// functions //--
+
+local id
+local loop
+local stats
+local idfalse
+local addtimer
+local closeall
+local addserver
+local wrapserver
+local getsettings
+local closesocket
+local removesocket
+local removeserver
+local changetimeout
+local wrapconnection
+local changesettings
+
+--// tables //--
+
+local _server
+local _readlist
+local _timerlist
+local _sendlist
+local _socketlist
+local _closelist
+local _readtimes
+local _writetimes
+
+--// simple data types //--
+
+local _
+local _readlistlen
+local _sendlistlen
+local _timerlistlen
+
+local _sendtraffic
+local _readtraffic
+
+local _selecttimeout
+local _sleeptime
+
+local _starttime
+local _currenttime
+
+local _maxsendlen
+local _maxreadlen
+
+local _checkinterval
+local _sendtimeout
+local _readtimeout
+
+local _cleanqueue
+
+local _timer
+
+local _maxclientsperserver
+
+----------------------------------// DEFINITION //--
+
+_server = { } -- key = port, value = table; list of listening servers
+_readlist = { } -- array with sockets to read from
+_sendlist = { } -- arrary with sockets to write to
+_timerlist = { } -- array of timer functions
+_socketlist = { } -- key = socket, value = wrapped socket (handlers)
+_readtimes = { } -- key = handler, value = timestamp of last data reading
+_writetimes = { } -- key = handler, value = timestamp of last data writing/sending
+_closelist = { } -- handlers to close
+
+_readlistlen = 0 -- length of readlist
+_sendlistlen = 0 -- length of sendlist
+_timerlistlen = 0 -- lenght of timerlist
+
+_sendtraffic = 0 -- some stats
+_readtraffic = 0
+
+_selecttimeout = 1 -- timeout of socket.select
+_sleeptime = 0 -- time to wait at the end of every loop
+
+_maxsendlen = 51000 * 1024 -- max len of send buffer
+_maxreadlen = 25000 * 1024 -- max len of read buffer
+
+_checkinterval = 1200000 -- interval in secs to check idle clients
+_sendtimeout = 60000 -- allowed send idle time in secs
+_readtimeout = 6 * 60 * 60 -- allowed read idle time in secs
+
+_cleanqueue = false -- clean bufferqueue after using
+
+_maxclientsperserver = 1000
+
+----------------------------------// PRIVATE //--
+
+wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxconnections, startssl ) -- this function wraps a server
+
+ maxconnections = maxconnections or _maxclientsperserver
+
+ local connections = 0
+
+ local dispatch, disconnect = listeners.incoming or listeners.listener, listeners.disconnect
+
+ local err
+
+ local ssl = false
+
+ if sslctx then
+ if not ssl_newcontext then
+ return nil, "luasec not found"
+ end
+ if type( sslctx ) ~= "table" then
+ out_error "server.lua: wrong server sslctx"
+ return nil, "wrong server sslctx"
+ end
+ sslctx, err = ssl_newcontext( sslctx )
+ if not sslctx then
+ err = err or "wrong sslctx parameters"
+ out_error( "server.lua: ", err )
+ return nil, err
+ end
+ ssl = true
+ else
+ out_put("server.lua: ", "ssl not enabled on ", serverport);
+ end
+
+ local accept = socket.accept
+
+ --// public methods of the object //--
+
+ local handler = { }
+
+ handler.shutdown = function( ) end
+
+ handler.ssl = function( )
+ return ssl
+ end
+ handler.remove = function( )
+ connections = connections - 1
+ end
+ handler.close = function( )
+ for _, handler in pairs( _socketlist ) do
+ if handler.serverport == serverport then
+ handler.disconnect( handler, "server closed" )
+ handler.close( true )
+ end
+ end
+ socket:close( )
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
+ _readlistlen = removesocket( _readlist, socket, _readlistlen )
+ _socketlist[ socket ] = nil
+ handler = nil
+ socket = nil
+ mem_free( )
+ out_put "server.lua: closed server handler and removed sockets from list"
+ end
+ handler.ip = function( )
+ return ip
+ end
+ handler.serverport = function( )
+ return serverport
+ end
+ handler.socket = function( )
+ return socket
+ end
+ handler.readbuffer = function( )
+ if connections > maxconnections then
+ out_put( "server.lua: refused new client connection: server full" )
+ return false
+ end
+ local client, err = accept( socket ) -- try to accept
+ if client then
+ local ip, clientport = client:getpeername( )
+ client:settimeout( 0 )
+ local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx, startssl ) -- wrap new client socket
+ if err then -- error while wrapping ssl socket
+ return false
+ end
+ connections = connections + 1
+ out_put( "server.lua: accepted new client connection from ", ip, ":", clientport, " to ", serverport)
+ return dispatch( handler )
+ elseif err then -- maybe timeout or something else
+ out_put( "server.lua: error with new client connection: ", err )
+ return false
+ end
+ end
+ return handler
+end
+
+wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, startssl ) -- this function wraps a client to a handler object
+
+ socket:settimeout( 0 )
+
+ --// local import of socket methods //--
+
+ local send
+ local receive
+ local shutdown
+
+ --// private closures of the object //--
+
+ local ssl
+
+ local dispatch = listeners.incoming or listeners.listener
+ local disconnect = listeners.disconnect
+
+ local bufferqueue = { } -- buffer array
+ local bufferqueuelen = 0 -- end of buffer array
+
+ local toclose
+ local fatalerror
+ local needtls
+
+ local bufferlen = 0
+
+ local noread = false
+ local nosend = false
+
+ local sendtraffic, readtraffic = 0, 0
+
+ local maxsendlen = _maxsendlen
+ local maxreadlen = _maxreadlen
+
+ --// public methods of the object //--
+
+ local handler = bufferqueue -- saves a table ^_^
+
+ handler.dispatch = function( )
+ return dispatch
+ end
+ handler.disconnect = function( )
+ return disconnect
+ end
+ handler.setlistener = function( listeners )
+ dispatch = listeners.incoming
+ disconnect = listeners.disconnect
+ end
+ handler.getstats = function( )
+ return readtraffic, sendtraffic
+ end
+ handler.ssl = function( )
+ return ssl
+ end
+ handler.send = function( _, data, i, j )
+ return send( socket, data, i, j )
+ end
+ handler.receive = function( pattern, prefix )
+ return receive( socket, pattern, prefix )
+ end
+ handler.shutdown = function( pattern )
+ return shutdown( socket, pattern )
+ end
+ handler.close = function( forced )
+ _readlistlen = removesocket( _readlist, socket, _readlistlen )
+ _readtimes[ handler ] = nil
+ if bufferqueuelen ~= 0 then
+ if not ( forced or fatalerror ) then
+ handler.sendbuffer( )
+ if bufferqueuelen ~= 0 then -- try again...
+ handler.write = nil -- ... but no further writing allowed
+ toclose = true
+ return false
+ end
+ else
+ send( socket, table_concat( bufferqueue, "", 1, bufferqueuelen ), 1, bufferlen ) -- forced send
+ end
+ end
+ shutdown( socket )
+ socket:close( )
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
+ _socketlist[ socket ] = nil
+ _writetimes[ handler ] = nil
+ _closelist[ handler ] = nil
+ handler = nil
+ socket = nil
+ mem_free( )
+ if server then
+ server.remove( )
+ end
+ out_put "server.lua: closed client handler and removed socket from list"
+ return true
+ end
+ handler.ip = function( )
+ return ip
+ end
+ handler.serverport = function( )
+ return serverport
+ end
+ handler.clientport = function( )
+ return clientport
+ end
+ local write = function( data )
+ bufferlen = bufferlen + string_len( data )
+ if bufferlen > maxsendlen then
+ _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle
+ handler.write = idfalse -- dont write anymore
+ return false
+ elseif not _sendlist[ socket ] then
+ _sendlistlen = _sendlistlen + 1
+ _sendlist[ _sendlistlen ] = socket
+ _sendlist[ socket ] = _sendlistlen
+ end
+ bufferqueuelen = bufferqueuelen + 1
+ bufferqueue[ bufferqueuelen ] = data
+ _writetimes[ handler ] = _writetimes[ handler ] or _currenttime
+ return true
+ end
+ handler.write = write
+ handler.bufferqueue = function( )
+ return bufferqueue
+ end
+ handler.socket = function( )
+ return socket
+ end
+ handler.pattern = function( new )
+ pattern = new or pattern
+ return pattern
+ end
+ handler.bufferlen = function( readlen, sendlen )
+ maxsendlen = sendlen or maxsendlen
+ maxreadlen = readlen or maxreadlen
+ return maxreadlen, maxsendlen
+ end
+ handler.lock = function( switch )
+ if switch == true then
+ handler.write = idfalse
+ local tmp = _sendlistlen
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
+ _writetimes[ handler ] = nil
+ if _sendlistlen ~= tmp then
+ nosend = true
+ end
+ tmp = _readlistlen
+ _readlistlen = removesocket( _readlist, socket, _readlistlen )
+ _readtimes[ handler ] = nil
+ if _readlistlen ~= tmp then
+ noread = true
+ end
+ elseif switch == false then
+ handler.write = write
+ if noread then
+ noread = false
+ _readlistlen = _readlistlen + 1
+ _readlist[ socket ] = _readlistlen
+ _readlist[ _readlistlen ] = socket
+ _readtimes[ handler ] = _currenttime
+ end
+ if nosend then
+ nosend = false
+ write( "" )
+ end
+ end
+ return noread, nosend
+ end
+ local _readbuffer = function( ) -- this function reads data
+ local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern"
+ if not err or ( err == "timeout" or err == "wantread" ) then -- received something
+ local buffer = buffer or part or ""
+ local len = string_len( buffer )
+ if len > maxreadlen then
+ disconnect( handler, "receive buffer exceeded" )
+ handler.close( true )
+ return false
+ end
+ local count = len * STAT_UNIT
+ readtraffic = readtraffic + count
+ _readtraffic = _readtraffic + count
+ _readtimes[ handler ] = _currenttime
+ --out_put( "server.lua: read data '", buffer, "', error: ", err )
+ return dispatch( handler, buffer, err )
+ else -- connections was closed or fatal error
+ out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )
+ fatalerror = true
+ disconnect( handler, err )
+ handler.close( )
+ return false
+ end
+ end
+ local _sendbuffer = function( ) -- this function sends data
+ local buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )
+ local succ, err, byte = send( socket, buffer, 1, bufferlen )
+ local count = ( succ or byte or 0 ) * STAT_UNIT
+ sendtraffic = sendtraffic + count
+ _sendtraffic = _sendtraffic + count
+ _ = _cleanqueue and clean( bufferqueue )
+ --out_put( "server.lua: sended '", buffer, "', bytes: ", succ, ", error: ", err, ", part: ", byte, ", to: ", ip, ":", clientport )
+ if succ then -- sending succesful
+ bufferqueuelen = 0
+ bufferlen = 0
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist
+ _ = needtls and handler.starttls(true)
+ _ = toclose and handler.close( )
+ _writetimes[ handler ] = nil
+ return true
+ elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write
+ buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer
+ bufferqueue[ 1 ] = buffer -- insert new buffer in queue
+ bufferqueuelen = 1
+ bufferlen = bufferlen - byte
+ _writetimes[ handler ] = _currenttime
+ return true
+ else -- connection was closed during sending or fatal error
+ out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )
+ fatalerror = true
+ disconnect( handler, err )
+ handler.close( )
+ return false
+ end
+ end
+
+ if sslctx then -- ssl?
+ ssl = true
+ local wrote
+ local handshake = coroutine_wrap( function( client ) -- create handshake coroutine
+ local err
+ for i = 1, 10 do -- 10 handshake attemps
+ _, err = client:dohandshake( )
+ if not err then
+ --out_put( "server.lua: ssl handshake done" )
+ _sendlistlen = ( wrote and removesocket( _sendlist, socket, _sendlistlen ) ) or _sendlistlen
+ handler.readbuffer = _readbuffer -- when handshake is done, replace the handshake function with regular functions
+ handler.sendbuffer = _sendbuffer
+ --return dispatch( handler )
+ return true
+ else
+ out_put( "server.lua: error during ssl handshake: ", err )
+ if err == "wantwrite" and not wrote then
+ _sendlistlen = _sendlistlen + 1
+ _sendlist[ _sendlistlen ] = client
+ wrote = true
+ end
+ --coroutine_yield( handler, nil, err ) -- handshake not finished
+ coroutine_yield( )
+ end
+ end
+ disconnect( handler, "max handshake attemps exceeded" )
+ handler.close( true ) -- forced disconnect
+ return false -- handshake failed
+ end
+ )
+ if startssl then -- ssl now?
+ --out_put("server.lua: ", "starting ssl handshake")
+ local err
+ socket, err = ssl_wrap( socket, sslctx ) -- wrap socket
+ if err then
+ out_put( "server.lua: ssl error: ", err )
+ mem_free( )
+ return nil, nil, err -- fatal error
+ end
+ socket:settimeout( 0 )
+ handler.readbuffer = handshake
+ handler.sendbuffer = handshake
+ handshake( socket ) -- do handshake
+ else
+ handler.starttls = function( now )
+ if not now then
+ --out_put "server.lua: we need to do tls, but delaying until later"
+ needtls = true
+ return
+ end
+ --out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
+ local oldsocket, err = socket
+ socket, err = ssl_wrap( socket, sslctx ) -- wrap socket
+ --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) )
+ if err then
+ out_put( "server.lua: error while starting tls on client: ", err )
+ return nil, err -- fatal error
+ end
+
+ socket:settimeout( 0 )
+
+ -- add the new socket to our system
+
+ send = socket.send
+ receive = socket.receive
+ shutdown = id
+
+ _socketlist[ socket ] = handler
+ _readlistlen = _readlistlen + 1
+ _readlist[ _readlistlen ] = socket
+ _readlist[ socket ] = _readlistlen
+
+ -- remove traces of the old socket
+
+ _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
+ _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
+ _socketlist[ oldsocket ] = nil
+
+ handler.starttls = nil
+ needtls = nil
+
+ handler.receivedata = handler.handshake
+ handler.dispatchdata = handler.handshake
+ handshake( socket ) -- do handshake
+ end
+ handler.readbuffer = _readbuffer
+ handler.sendbuffer = _sendbuffer
+ end
+ else -- normal connection
+ ssl = false
+ handler.readbuffer = _readbuffer
+ handler.sendbuffer = _sendbuffer
+ end
+
+ send = socket.send
+ receive = socket.receive
+ shutdown = ( ssl and id ) or socket.shutdown
+
+ _socketlist[ socket ] = handler
+ _readlistlen = _readlistlen + 1
+ _readlist[ _readlistlen ] = socket
+ _readlist[ socket ] = _readlistlen
+
+ return handler, socket
+end
+
+id = function( )
+end
+
+idfalse = function( )
+ return false
+end
+
+removesocket = function( list, socket, len ) -- this function removes sockets from a list ( copied from copas )
+ local pos = list[ socket ]
+ if pos then
+ list[ socket ] = nil
+ local last = list[ len ]
+ list[ len ] = nil
+ if last ~= socket then
+ list[ last ] = pos
+ list[ pos ] = last
+ end
+ return len - 1
+ end
+ return len
+end
+
+closesocket = function( socket )
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
+ _readlistlen = removesocket( _readlist, socket, _readlistlen )
+ _socketlist[ socket ] = nil
+ socket:close( )
+ mem_free( )
+end
+
+----------------------------------// PUBLIC //--
+
+addserver = function( listeners, port, addr, pattern, sslctx, maxconnections, startssl ) -- this function provides a way for other scripts to reg a server
+ local err
+ --out_put("server.lua: autossl on ", port, " is ", startssl)
+ if type( listeners ) ~= "table" then
+ err = "invalid listener table"
+ end
+ if not type( port ) == "number" or not ( port >= 0 and port <= 65535 ) then
+ err = "invalid port"
+ elseif _server[ port ] then
+ err = "listeners on port '" .. port .. "' already exist"
+ elseif sslctx and not luasec then
+ err = "luasec not found"
+ end
+ if err then
+ out_error( "server.lua: ", err )
+ return nil, err
+ end
+ addr = addr or "*"
+ local server, err = socket_bind( addr, port )
+ if err then
+ out_error( "server.lua: ", err )
+ return nil, err
+ end
+ local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, maxconnections, startssl ) -- wrap new server socket
+ if not handler then
+ server:close( )
+ return nil, err
+ end
+ server:settimeout( 0 )
+ _readlistlen = _readlistlen + 1
+ _readlist[ _readlistlen ] = server
+ _server[ port ] = handler
+ _socketlist[ server ] = handler
+ out_put( "server.lua: new server listener on '", addr, ":", port, "'" )
+ return handler
+end
+
+removeserver = function( port )
+ local handler = _server[ port ]
+ if not handler then
+ return nil, "no server found on port '" .. tostring( port ) "'"
+ end
+ handler.close( )
+ return true
+end
+
+closeall = function( )
+ for _, handler in pairs( _socketlist ) do
+ handler.close( )
+ _socketlist[ _ ] = nil
+ end
+ _readlistlen = 0
+ _sendlistlen = 0
+ _timerlistlen = 0
+ _server = { }
+ _readlist = { }
+ _sendlist = { }
+ _timerlist = { }
+ _socketlist = { }
+ mem_free( )
+end
+
+getsettings = function( )
+ return _selecttimeout, _sleeptime, _maxsendlen, _maxreadlen, _checkinterval, _sendtimeout, _readtimeout, _cleanqueue, _maxclientsperserver
+end
+
+changesettings = function( new )
+ if type( new ) ~= "table" then
+ return nil, "invalid settings table"
+ end
+ _selecttimeout = tonumber( new.timeout ) or _selecttimeout
+ _sleeptime = tonumber( new.sleeptime ) or _sleeptime
+ _maxsendlen = tonumber( new.maxsendlen ) or _maxsendlen
+ _maxreadlen = tonumber( new.maxreadlen ) or _maxreadlen
+ _checkinterval = tonumber( new.checkinterval ) or _checkinterval
+ _sendtimeout = tonumber( new.sendtimeout ) or _sendtimeout
+ _readtimeout = tonumber( new.readtimeout ) or _readtimeout
+ _cleanqueue = new.cleanqueue
+ _maxclientsperserver = new._maxclientsperserver or _maxclientsperserver
+ return true
+end
+
+addtimer = function( listener )
+ if type( listener ) ~= "function" then
+ return nil, "invalid listener function"
+ end
+ _timerlistlen = _timerlistlen + 1
+ _timerlist[ _timerlistlen ] = listener
+ return true
+end
+
+stats = function( )
+ return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen
+end
+
+loop = function( ) -- this is the main loop of the program
+ while true do
+ local read, write, err = socket_select( _readlist, _sendlist, _selecttimeout )
+ for i, socket in ipairs( write ) do -- send data waiting in writequeues
+ local handler = _socketlist[ socket ]
+ if handler then
+ handler.sendbuffer( )
+ else
+ closesocket( socket )
+ out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen
+ end
+ end
+ for i, socket in ipairs( read ) do -- receive data
+ local handler = _socketlist[ socket ]
+ if handler then
+ handler.readbuffer( )
+ else
+ closesocket( socket )
+ out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen
+ end
+ end
+ for handler, err in pairs( _closelist ) do
+ handler.disconnect( )( handler, err )
+ handler.close( true ) -- forced disconnect
+ end
+ clean( _closelist )
+ _currenttime = os_time( )
+ if os_difftime( _currenttime - _timer ) >= 1 then
+ for i = 1, _timerlistlen do
+ _timerlist[ i ]( ) -- fire timers
+ end
+ _timer = _currenttime
+ end
+ socket_sleep( _sleeptime ) -- wait some time
+ --collectgarbage( )
+ end
+end
+
+--// EXPERIMENTAL //--
+
+local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx, startssl )
+ local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, startssl )
+ _socketlist[ socket ] = handler
+ _sendlistlen = _sendlistlen + 1
+ _sendlist[ _sendlistlen ] = socket
+ _sendlist[ socket ] = _sendlistlen
+ return handler, socket
+end
+
+local addclient = function( address, port, listeners, pattern, sslctx, startssl )
+ local client, err = socket.tcp( )
+ if err then
+ return nil, err
+ end
+ client:settimeout( 0 )
+ _, err = client:connect( address, port )
+ if err then -- try again
+ local handler = wrapclient( client, address, port, listeners )
+ else
+ wrapconnection( server, listeners, socket, address, port, "clientport", pattern, sslctx, startssl )
+ end
+end
+
+--// EXPERIMENTAL //--
+
+----------------------------------// BEGIN //--
+
+use "setmetatable" ( _socketlist, { __mode = "k" } )
+use "setmetatable" ( _readtimes, { __mode = "k" } )
+use "setmetatable" ( _writetimes, { __mode = "k" } )
+
+_timer = os_time( )
+_starttime = os_time( )
+
+addtimer( function( )
+ local difftime = os_difftime( _currenttime - _starttime )
+ if difftime > _checkinterval then
+ _starttime = _currenttime
+ for handler, timestamp in pairs( _writetimes ) do
+ if os_difftime( _currenttime - timestamp ) > _sendtimeout then
+ --_writetimes[ handler ] = nil
+ handler.disconnect( )( handler, "send timeout" )
+ handler.close( true ) -- forced disconnect
+ end
+ end
+ for handler, timestamp in pairs( _readtimes ) do
+ if os_difftime( _currenttime - timestamp ) > _readtimeout then
+ --_readtimes[ handler ] = nil
+ handler.disconnect( )( handler, "read timeout" )
+ handler.close( ) -- forced disconnect?
+ end
+ end
+ end
+ end
+)
+
+----------------------------------// PUBLIC INTERFACE //--
+
+return {
+
+ addclient = addclient,
+ wrapclient = wrapclient,
+
+ loop = loop,
+ stats = stats,
+ closeall = closeall,
+ addtimer = addtimer,
+ addserver = addserver,
+ getsettings = getsettings,
+ removeserver = removeserver,
+ changesettings = changesettings,
+
+}
|