aboutsummaryrefslogtreecommitdiffstats
path: root/net/server_select.lua
diff options
context:
space:
mode:
Diffstat (limited to 'net/server_select.lua')
-rw-r--r--net/server_select.lua146
1 files changed, 98 insertions, 48 deletions
diff --git a/net/server_select.lua b/net/server_select.lua
index 486e953b..35dcb5a7 100644
--- a/net/server_select.lua
+++ b/net/server_select.lua
@@ -31,32 +31,31 @@ local tostring = use "tostring"
--// lua libs //--
-local os = use "os"
local table = use "table"
local string = use "string"
local coroutine = use "coroutine"
--// lua lib methods //--
-local os_difftime = os.difftime
local math_min = math.min
local math_huge = math.huge
local table_concat = table.concat
+local table_insert = table.insert
local string_sub = string.sub
local coroutine_wrap = coroutine.wrap
local coroutine_yield = coroutine.yield
--// extern libs //--
-local luasec = use "ssl"
+local has_luasec, luasec = pcall ( require , "ssl" )
local luasocket = use "socket" or require "socket"
local luasocket_gettime = luasocket.gettime
+local getaddrinfo = luasocket.dns.getaddrinfo
--// extern lib methods //--
-local ssl_wrap = ( luasec and luasec.wrap )
+local ssl_wrap = ( has_luasec and luasec.wrap )
local socket_bind = luasocket.bind
-local socket_sleep = luasocket.sleep
local socket_select = luasocket.select
--// functions //--
@@ -100,7 +99,6 @@ local _sendtraffic
local _readtraffic
local _selecttimeout
-local _sleeptime
local _tcpbacklog
local _starttime
@@ -113,8 +111,6 @@ local _checkinterval
local _sendtimeout
local _readtimeout
-local _timer
-
local _maxselectlen
local _maxfd
@@ -139,7 +135,6 @@ _sendtraffic = 0 -- some stats
_readtraffic = 0
_selecttimeout = 1 -- timeout of socket.select
-_sleeptime = 0 -- time to wait at the end of every loop
_tcpbacklog = 128 -- some kind of hint to the OS
_maxsendlen = 51000 * 1024 -- max len of send buffer
@@ -291,7 +286,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
local bufferqueuelen = 0 -- end of buffer array
local toclose
- local fatalerror
local needtls
local bufferlen = 0
@@ -503,7 +497,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
return dispatch( handler, buffer, err )
else -- connections was closed or fatal error
out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
- fatalerror = true
_ = handler and handler:force_close( err )
return false
end
@@ -543,7 +536,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
return true
else -- connection was closed during sending or fatal error
out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) )
- fatalerror = true
_ = handler and handler:force_close( err )
return false
end
@@ -594,7 +586,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
end
)
end
- if luasec then
+ if has_luasec then
handler.starttls = function( self, _sslctx)
if _sslctx then
handler:set_sslctx(_sslctx);
@@ -647,7 +639,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
_socketlist[ socket ] = handler
_readlistlen = addsocket(_readlist, socket, _readlistlen)
- if sslctx and luasec then
+ if sslctx and has_luasec then
out_put "server.lua: auto-starting ssl negotiation..."
handler.autostart_ssl = true;
local ok, err = handler:starttls(sslctx);
@@ -723,22 +715,23 @@ end
----------------------------------// PUBLIC //--
addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
+ addr = addr or "*"
local err
if type( listeners ) ~= "table" then
err = "invalid listener table"
- end
- if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
+ elseif type ( addr ) ~= "string" then
+ err = "invalid address"
+ elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
err = "invalid port"
elseif _server[ addr..":"..port ] then
err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist"
- elseif sslctx and not luasec then
+ elseif sslctx and not has_luasec then
err = "luasec not found"
end
if err then
out_error( "server.lua, [", addr, "]:", port, ": ", err )
return nil, err
end
- addr = addr or "*"
local server, err = socket_bind( addr, port, _tcpbacklog )
if err then
out_error( "server.lua, [", addr, "]:", port, ": ", err )
@@ -790,7 +783,6 @@ end
getsettings = function( )
return {
select_timeout = _selecttimeout;
- select_sleep_time = _sleeptime;
tcp_backlog = _tcpbacklog;
max_send_buffer_size = _maxsendlen;
max_receive_buffer_size = _maxreadlen;
@@ -808,7 +800,6 @@ changesettings = function( new )
return nil, "invalid settings table"
end
_selecttimeout = tonumber( new.select_timeout ) or _selecttimeout
- _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime
_maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen
_maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen
_checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval
@@ -830,6 +821,49 @@ addtimer = function( listener )
return true
end
+local add_task do
+ local data = {};
+ local new_data = {};
+
+ function add_task(delay, callback)
+ local current_time = luasocket_gettime();
+ delay = delay + current_time;
+ if delay >= current_time then
+ table_insert(new_data, {delay, callback});
+ else
+ local r = callback(current_time);
+ if r and type(r) == "number" then
+ return add_task(r, callback);
+ end
+ end
+ end
+
+ addtimer(function(current_time)
+ if #new_data > 0 then
+ for _, d in pairs(new_data) do
+ table_insert(data, d);
+ end
+ new_data = {};
+ end
+
+ local next_time = math_huge;
+ for i, d in pairs(data) do
+ local t, callback = d[1], d[2];
+ if t <= current_time then
+ data[i] = nil;
+ local r = callback(current_time);
+ if type(r) == "number" then
+ add_task(r, callback);
+ next_time = math_min(next_time, r);
+ end
+ else
+ next_time = math_min(next_time, t - current_time);
+ end
+ end
+ return next_time;
+ end);
+end
+
stats = function( )
return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen
end
@@ -843,8 +877,15 @@ end
loop = function(once) -- this is the main loop of the program
if quitting then return "quitting"; end
if once then quitting = "once"; end
- local next_timer_time = math_huge;
+ _currenttime = luasocket_gettime( )
repeat
+ -- Fire timers
+ local next_timer_time = math_huge;
+ for i = 1, _timerlistlen do
+ local t = _timerlist[ i ]( _currenttime ) -- fire timers
+ if t then next_timer_time = math_min(next_timer_time, t); end
+ end
+
local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) )
for i, socket in ipairs( write ) do -- send data waiting in writequeues
local handler = _socketlist[ socket ]
@@ -872,17 +913,16 @@ loop = function(once) -- this is the main loop of the program
_currenttime = luasocket_gettime( )
-- Check for socket timeouts
- local difftime = os_difftime( _currenttime - _starttime )
- if difftime > _checkinterval then
+ if _currenttime - _starttime > _checkinterval then
_starttime = _currenttime
for handler, timestamp in pairs( _writetimes ) do
- if os_difftime( _currenttime - timestamp ) > _sendtimeout then
+ if _currenttime - timestamp > _sendtimeout then
handler.disconnect( )( handler, "send timeout" )
handler:force_close() -- forced disconnect
end
end
for handler, timestamp in pairs( _readtimes ) do
- if os_difftime( _currenttime - timestamp ) > _readtimeout then
+ if _currenttime - timestamp > _readtimeout then
if not(handler.onreadtimeout) or handler:onreadtimeout() ~= true then
handler.disconnect( )( handler, "read timeout" )
handler:close( ) -- forced disconnect?
@@ -892,21 +932,6 @@ loop = function(once) -- this is the main loop of the program
end
end
end
-
- -- Fire timers
- if _currenttime - _timer >= math_min(next_timer_time, 1) then
- next_timer_time = math_huge;
- for i = 1, _timerlistlen do
- local t = _timerlist[ i ]( _currenttime ) -- fire timers
- if t then next_timer_time = math_min(next_timer_time, t); end
- end
- _timer = _currenttime
- else
- next_timer_time = next_timer_time - (_currenttime - _timer);
- end
-
- -- wait some time (0 by default)
- socket_sleep( _sleeptime )
until quitting;
if once and quitting == "once" then quitting = nil; return; end
return "quitting"
@@ -941,29 +966,53 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx
return handler, socket
end
-local addclient = function( address, port, listeners, pattern, sslctx )
- local client, err = luasocket.tcp( )
+local addclient = function( address, port, listeners, pattern, sslctx, typ )
+ local err
+ if type( listeners ) ~= "table" then
+ err = "invalid listener table"
+ elseif type ( address ) ~= "string" then
+ err = "invalid address"
+ elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
+ err = "invalid port"
+ elseif sslctx and not has_luasec then
+ err = "luasec not found"
+ end
+ if getaddrinfo and not typ then
+ local addrinfo, err = getaddrinfo(address)
+ if not addrinfo then return nil, err end
+ if addrinfo[1] and addrinfo[1].family == "inet6" then
+ typ = "tcp6"
+ end
+ end
+ local create = luasocket[typ or "tcp"]
+ if type( create ) ~= "function" then
+ err = "invalid socket type"
+ end
+
+ if err then
+ out_error( "server.lua, addclient: ", err )
+ return nil, err
+ end
+
+ local client, err = create( )
if err then
return nil, err
end
client:settimeout( 0 )
- _, err = client:connect( address, port )
- if err then -- try again
+ local ok, err = client:connect( address, port )
+ if ok or err == "timeout" or err == "Operation already in progress" then
return wrapclient( client, address, port, listeners, pattern, sslctx )
else
- return wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx )
+ return nil, err
end
end
---// EXPERIMENTAL //--
-
----------------------------------// BEGIN //--
use "setmetatable" ( _socketlist, { __mode = "k" } )
use "setmetatable" ( _readtimes, { __mode = "k" } )
use "setmetatable" ( _writetimes, { __mode = "k" } )
-_timer = luasocket_gettime( )
_starttime = luasocket_gettime( )
local function setlogger(new_logger)
@@ -978,6 +1027,7 @@ end
return {
_addtimer = addtimer,
+ add_task = add_task;
addclient = addclient,
wrapclient = wrapclient,