diff options
Diffstat (limited to 'net/server_event.lua')
-rw-r--r-- | net/server_event.lua | 285 |
1 files changed, 113 insertions, 172 deletions
diff --git a/net/server_event.lua b/net/server_event.lua index d505825d..70a6dc37 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -11,6 +11,7 @@ -- when using luasec, there are 4 cases of timeout errors: wantread or wantwrite during reading or writing --]] +-- luacheck: ignore 212/self 431/err 211/ret local SCRIPT_NAME = "server_event.lua" local SCRIPT_VERSION = "0.05" @@ -32,27 +33,32 @@ local cfg = { DEBUG = true, -- show debug messages } -local function use(x) return rawget(_G, x); end -local ipairs = use "ipairs" -local string = use "string" -local select = use "select" -local require = use "require" -local tostring = use "tostring" -local coroutine = use "coroutine" -local setmetatable = use "setmetatable" +local pairs = pairs +local select = select +local require = require +local tostring = tostring +local setmetatable = setmetatable local t_insert = table.insert local t_concat = table.concat +local s_sub = string.sub -local ssl = use "ssl" -local socket = use "socket" or require "socket" +local coroutine_wrap = coroutine.wrap +local coroutine_yield = coroutine.yield + +local has_luasec, ssl = pcall ( require , "ssl" ) +local socket = require "socket" +local levent = require "luaevent.core" + +local socket_gettime = socket.gettime +local getaddrinfo = socket.dns.getaddrinfo local log = require ("util.logger").init("socket") local function debug(...) return log("debug", ("%s "):rep(select('#', ...)), ...) end -local vdebug = debug; +-- local vdebug = debug; local bitor = ( function( ) -- thx Rici Lake local hasbit = function( x, p ) @@ -72,62 +78,25 @@ local bitor = ( function( ) -- thx Rici Lake end end )( ) -local event = require "luaevent.core" -local base = event.new( ) -local EV_READ = event.EV_READ -local EV_WRITE = event.EV_WRITE -local EV_TIMEOUT = event.EV_TIMEOUT -local EV_SIGNAL = event.EV_SIGNAL +local base = levent.new( ) +local addevent = base.addevent +local EV_READ = levent.EV_READ +local EV_WRITE = levent.EV_WRITE +local EV_TIMEOUT = levent.EV_TIMEOUT +local EV_SIGNAL = levent.EV_SIGNAL local EV_READWRITE = bitor( EV_READ, EV_WRITE ) -local interfacelist = ( function( ) -- holds the interfaces for sockets - local array = { } - local len = 0 - return function( method, arg ) - if "add" == method then - len = len + 1 - array[ len ] = arg - arg:_position( len ) - return len - elseif "delete" == method then - if len <= 0 then - return nil, "array is already empty" - end - local position = arg:_position() -- get position in array - if position ~= len then - local interface = array[ len ] -- get last interface - array[ position ] = interface -- copy it into free position - array[ len ] = nil -- free last position - interface:_position( position ) -- set new position in array - else -- free last position - array[ len ] = nil - end - len = len - 1 - return len - else - return array - end - end -end )( ) +local interfacelist = { } -- Client interface methods -local interface_mt -do - interface_mt = {}; interface_mt.__index = interface_mt; - - local addevent = base.addevent - local coroutine_wrap, coroutine_yield = coroutine.wrap,coroutine.yield - +local interface_mt = {}; interface_mt.__index = interface_mt; + -- Private methods - function interface_mt:_position(new_position) - self.position = new_position or self.position - return self.position; - end function interface_mt:_close() return self:_destroy(); end - + function interface_mt:_start_connection(plainssl) -- should be called from addclient local callback = function( event ) if EV_TIMEOUT == event then -- timeout during connection @@ -136,7 +105,7 @@ do self:_close() debug( "new connection failed. id:", self.id, "error:", self.fatalerror ) else - if plainssl and ssl then -- start ssl session + if plainssl and has_luasec then -- start ssl session self:starttls(self._sslctx, true) else -- normal connection self:_start_session(true) @@ -188,8 +157,7 @@ do return false end self.conn:settimeout( 0 ) -- set non blocking - local handshakecallback = coroutine_wrap( - function( event ) + local handshakecallback = coroutine_wrap(function( event ) local _, err local attempt = 0 local maxattempt = cfg.MAX_HANDSHAKE_ATTEMPTS @@ -265,15 +233,15 @@ do self.eventread, self.eventclose = nil, nil self.interface, self.readcallback = nil, nil end - interfacelist( "delete", self ) + interfacelist[ self ] = nil return true end - + function interface_mt:_lock(nointerface, noreading, nowriting) -- lock or unlock this interface or events self.nointerface, self.noreading, self.nowriting = nointerface, noreading, nowriting return nointerface, noreading, nowriting end - + --TODO: Deprecate function interface_mt:lock_read(switch) if switch then @@ -301,7 +269,7 @@ do end return self._connections end - + -- Public methods function interface_mt:write(data) if self.nowriting then return nil, "locked" end @@ -344,27 +312,27 @@ do return true end end - + function interface_mt:socket() return self.conn end - + function interface_mt:server() return self._server or self; end - + function interface_mt:port() return self._port end - + function interface_mt:serverport() return self._serverport end - + function interface_mt:ip() return self._ip end - + function interface_mt:ssl() return self._usingssl end @@ -373,15 +341,15 @@ do function interface_mt:type() return self._type or "client" end - + function interface_mt:connections() return self._connections end - + function interface_mt:address() return self.addr end - + function interface_mt:set_sslctx(sslctx) self._sslctx = sslctx; if sslctx then @@ -397,11 +365,11 @@ do end return self._pattern; end - - function interface_mt:set_send(new_send) + +function interface_mt:set_send(new_send) -- luacheck: ignore 212 -- No-op, we always use the underlying connection's send end - + function interface_mt:starttls(sslctx, call_onconnect) debug( "try to start ssl at client id:", self.id ) local err @@ -430,22 +398,22 @@ do self.starttls = false; return true end - + function interface_mt:setoption(option, value) if self.conn.setoption then return self.conn:setoption(option, value); end return false, "setoption not implemented"; end - + function interface_mt:setlistener(listener) self:ondetach(); -- Notify listener that it is no longer responsible for this connection - self.onconnect, self.ondisconnect, self.onincoming, - self.ontimeout, self.onstatus, self.ondetach - = listener.onconnect, listener.ondisconnect, listener.onincoming, - listener.ontimeout, listener.onstatus, listener.ondetach; + self.onconnect, self.ondisconnect, self.onincoming, self.ontimeout, + self.onreadtimeout, self.onstatus, self.ondetach + = listener.onconnect, listener.ondisconnect, listener.onincoming, listener.ontimeout, + listener.onreadtimeout, listener.onstatus, listener.ondetach; end - + -- Stub handlers function interface_mt:onconnect() end @@ -455,22 +423,22 @@ do end function interface_mt:ontimeout() end +function interface_mt:onreadtimeout() + self.fatalerror = "timeout during receiving" + debug( "connection failed:", self.fatalerror ) + self:_close() + self.eventread = nil +end function interface_mt:ondrain() end function interface_mt:ondetach() end function interface_mt:onstatus() end -end -- End of client interface methods -local handleclient; -do - local string_sub = string.sub -- caching table lookups - local addevent = base.addevent - local socket_gettime = socket.gettime - function handleclient( client, ip, port, server, pattern, listener, sslctx ) -- creates an client interface +local function handleclient( client, ip, port, server, pattern, listener, sslctx ) -- creates an client interface --vdebug("creating client interfacce...") local interface = { type = "client"; @@ -484,6 +452,7 @@ do ondisconnect = listener.ondisconnect; -- will be called when client disconnects onincoming = listener.onincoming; -- will be called when client sends data ontimeout = listener.ontimeout; -- called when fatal socket timeout occurs + onreadtimeout = listener.onreadtimeout; -- called when socket inactivity timeout occurs ondrain = listener.ondrain; -- called when writebuffer is empty ondetach = listener.ondetach; -- called when disassociating this listener from this connection onstatus = listener.onstatus; -- called for status changes (e.g. of SSL/TLS) @@ -499,14 +468,14 @@ do noreading = false, nowriting = false; -- locks of the read/writecallback startsslcallback = false; -- starting handshake callback position = false; -- position of client in interfacelist - + -- Properties _ip = ip, _port = port, _server = server, _pattern = pattern, _serverport = (server and server:port() or nil), _sslctx = sslctx; -- parameters _usingssl = false; -- client is using ssl; } - if not ssl then interface.starttls = false; end + if not has_luasec then interface.starttls = false; end interface.id = tostring(interface):match("%x+$"); interface.writecallback = function( event ) -- called on write events --vdebug( "new client write event, id/ip/port:", interface, ip, port ) @@ -552,7 +521,7 @@ do return -1 elseif byte and (err == "timeout" or err == "wantwrite") then -- want write again --vdebug( "writebuffer is not empty:", err ) - interface.writebuffer[1] = string_sub( interface.writebuffer[1], byte + 1, interface.writebufferlen ) -- new buffer + interface.writebuffer[1] = s_sub( interface.writebuffer[1], byte + 1, interface.writebufferlen ) -- new buffer interface.writebufferlen = interface.writebufferlen - byte if "wantread" == err then -- happens only with luasec local callback = function( ) @@ -575,7 +544,7 @@ do end end end - + interface.readcallback = function( event ) -- called on read events --vdebug( "new client read event, id/ip/port:", tostring(interface.id), tostring(ip), tostring(port) ) if interface.noreading or interface.fatalerror then -- leave this event @@ -583,13 +552,9 @@ do interface.eventread = nil return -1 end - if EV_TIMEOUT == event then -- took too long to get some data from client -> disconnect - interface.fatalerror = "timeout during receiving" - debug( "connection failed:", interface.fatalerror ) - interface:_close() - interface.eventread = nil - return -1 - else -- can read + if EV_TIMEOUT == event and interface:onreadtimeout() ~= true then + return -1 -- took too long to get some data from client -> disconnect + end if interface._usingssl then -- handle luasec if interface.eventwritetimeout then -- ok, in the past writecallback was regged local ret = interface.writecallback( ) -- call it @@ -638,22 +603,19 @@ do end return EV_READ, cfg.READ_TIMEOUT end - end client:settimeout( 0 ) -- set non blocking setmetatable(interface, interface_mt) - interfacelist( "add", interface ) -- add to interfacelist + interfacelist[ interface ] = true -- add to interfacelist return interface end -end -local handleserver -do - function handleserver( server, addr, port, pattern, listener, sslctx ) -- creates an server interface +local function handleserver( server, addr, port, pattern, listener, sslctx ) -- creates an server interface debug "creating server interface..." local interface = { _connections = 0; - + + type = "server"; conn = server; onconnect = listener.onconnect; -- will be called when new client connected eventread = false; -- read event handler @@ -661,7 +623,7 @@ do readcallback = false; -- read event callback fatalerror = false; -- error message nointerface = true; -- lock/unlock parameter - + _ip = addr, _port = port, _pattern = pattern, _sslctx = sslctx; } @@ -694,92 +656,77 @@ do interface._connections = interface._connections + 1 -- increase connection count local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx ) --vdebug( "client id:", clientinterface, "startssl:", startssl ) - if ssl and sslctx then + if has_luasec and sslctx then clientinterface:starttls(sslctx, true) else clientinterface:_start_session( true ) end debug( "accepted incoming client connection from:", client_ip or "<unknown IP>", client_port or "<unknown port>", "to", port or "<unknown port>"); - + client, err = server:accept() -- try to accept again end return EV_READ end - + server:settimeout( 0 ) setmetatable(interface, interface_mt) - interfacelist( "add", interface ) + interfacelist[ interface ] = true interface:_start_session() return interface end -end -local addserver = ( function( ) - return function( addr, port, listener, pattern, sslcfg, startssl ) -- TODO: check arguments - --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslcfg or "nil", startssl or "nil") +local function addserver( addr, port, listener, pattern, sslctx, startssl ) -- TODO: check arguments + --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslctx or "nil", startssl or "nil") + if sslctx and not has_luasec then + debug "fatal error: luasec not found" + return nil, "luasec not found" +end local server, err = socket.bind( addr, port, cfg.ACCEPT_QUEUE ) -- create server socket if not server then debug( "creating server socket on "..addr.." port "..port.." failed:", err ) return nil, err end - local sslctx - if sslcfg then - if not ssl then - debug "fatal error: luasec not found" - return nil, "luasec not found" - end - sslctx, err = sslcfg - if err then - debug( "error while creating new ssl context for server socket:", err ) - return nil, err - end - end local interface = handleserver( server, addr, port, pattern, listener, sslctx, startssl ) -- new server handler debug( "new server created with id:", tostring(interface)) return interface end -end )( ) -local addclient, wrapclient -do - function wrapclient( client, ip, port, listeners, pattern, sslctx ) +local function wrapclient( client, ip, port, listeners, pattern, sslctx ) local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx ) interface:_start_connection(sslctx) return interface, client --function handleclient( client, ip, port, server, pattern, listener, _, sslctx ) -- creates an client interface end - - function addclient( addr, serverport, listener, pattern, localaddr, localport, sslcfg, startssl ) - local client, err = socket.tcp() -- creating new socket - if not client then - debug( "cannot create socket:", err ) - return nil, err + +local function addclient( addr, serverport, listener, pattern, sslctx, typ ) + if sslctx and not has_luasec then + debug "need luasec, but not available" + return nil, "luasec not found" end - client:settimeout( 0 ) -- set nonblocking - if localaddr then - local res, err = client:bind( localaddr, localport, -1 ) - if not res then - debug( "cannot bind client:", err ) - return nil, err + if not typ then + local addrinfo, err = getaddrinfo(addr) + if not addrinfo then return nil, err end + if addrinfo[1] and addrinfo[1].family == "inet6" then + typ = "tcp6" + else + typ = "tcp" end end - local sslctx - if sslcfg then -- handle ssl/new context - if not ssl then - debug "need luasec, but not available" - return nil, "luasec not found" + local create = socket[typ] + if type( create ) ~= "function" then + return nil, "invalid socket type" end - sslctx, err = sslcfg - if err then - debug( "cannot create new ssl context:", err ) + local client, err = create() -- creating new socket + if not client then + debug( "cannot create socket:", err ) return nil, err end - end + client:settimeout( 0 ) -- set nonblocking local res, err = client:connect( addr, serverport ) -- connect if res or ( err == "timeout" ) then local ip, port = client:getsockname( ) - local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx, startssl ) - interface:_start_connection( startssl ) + local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx ) + interface:_start_connection( sslctx ) debug( "new connection id:", interface.id ) return interface, err else @@ -787,23 +734,18 @@ do return nil, err end end -end - -local loop = function( ) -- starts the event loop +local function loop( ) -- starts the event loop base:loop( ) return "quitting"; end -local newevent = ( function( ) - local add = base.addevent - return function( ... ) - return add( base, ... ) +local function newevent( ... ) + return addevent( base, ... ) end -end )( ) -local closeallservers = function( arg ) - for _, item in ipairs( interfacelist( ) ) do +local function closeallservers ( arg ) + for item in pairs( interfacelist ) do if item.type == "server" then item:close( arg ) end @@ -826,7 +768,7 @@ end -- being garbage-collected local signal_events = {}; -- [signal_num] -> event object local function hook_signal(signal_num, handler) - local function _handler(event) + local function _handler() local ret = handler(); if ret ~= false then -- Continue handling this signal? return EV_SIGNAL; -- Yes @@ -839,14 +781,14 @@ end local function link(sender, receiver, buffersize) local sender_locked; - + function receiver:ondrain() if sender_locked then sender:resume(); sender_locked = nil; end end - + function sender:onincoming(data) receiver:write(data); if receiver.writebufferlen >= buffersize then @@ -858,12 +800,11 @@ local function link(sender, receiver, buffersize) end return { - cfg = cfg, base = base, loop = loop, link = link, - event = event, + event = levent, event_base = base, addevent = newevent, addserver = addserver, |