diff options
Diffstat (limited to 'net/server_select.lua')
-rw-r--r-- | net/server_select.lua | 105 |
1 files changed, 68 insertions, 37 deletions
diff --git a/net/server_select.lua b/net/server_select.lua index 1a40a6d3..4b156409 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -68,6 +68,7 @@ local idfalse local closeall local addsocket local addserver +local listen local addtimer local getserver local wrapserver @@ -157,7 +158,7 @@ _maxsslhandshake = 30 -- max handshake round-trips ----------------------------------// PRIVATE //-- -wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd +wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, ssldirect ) -- this function wraps a server -- FIXME Make sure FD < _maxfd if socket:getfd() >= _maxfd then out_error("server.lua: Disallowed FD number: "..socket:getfd()) @@ -183,6 +184,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t handler.sslctx = function( ) return sslctx end + handler.hosts = {} -- sni handler.remove = function( ) connections = connections - 1 if handler then @@ -244,13 +246,13 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t local client, err = accept( socket ) -- try to accept if client then local ip, clientport = client:getpeername( ) - local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket + local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx, ssldirect ) -- wrap new client socket if err then -- error while wrapping ssl socket return false end connections = connections + 1 out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport)) - if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes + if dispatch and not ssldirect then -- SSL connections will notify onconnect when handshake completes return dispatch( handler ); end return; @@ -264,7 +266,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t return handler end -wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object +wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, ssldirect ) -- this function wraps a client to a handler object if socket:getfd() >= _maxfd then out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent @@ -424,9 +426,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport bufferlen = bufferlen + #data if bufferlen > maxsendlen then _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle - handler.write = idfalse -- don't write anymore return false - elseif socket and not _sendlist[ socket ] then + elseif not nosend and socket and not _sendlist[ socket ] then _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) end bufferqueuelen = bufferqueuelen + 1 @@ -456,49 +457,57 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport maxreadlen = readlen or maxreadlen return bufferlen, maxreadlen, maxsendlen end - --TODO: Deprecate handler.lock_read = function (self, switch) + out_error( "server.lua, lock_read() is deprecated, use pause() and resume()" ) if switch == true then - local tmp = _readlistlen - _readlistlen = removesocket( _readlist, socket, _readlistlen ) - _readtimes[ handler ] = nil - if _readlistlen ~= tmp then - noread = true - end + return self:pause() elseif switch == false then - if noread then - noread = false - _readlistlen = addsocket(_readlist, socket, _readlistlen) - _readtimes[ handler ] = _currenttime - end + return self:resume() end return noread end handler.pause = function (self) - return self:lock_read(true); + local tmp = _readlistlen + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _readtimes[ handler ] = nil + if _readlistlen ~= tmp then + noread = true + end + return noread; end handler.resume = function (self) - return self:lock_read(false); + if noread then + noread = false + _readlistlen = addsocket(_readlist, socket, _readlistlen) + _readtimes[ handler ] = _currenttime + end + return noread; end handler.lock = function( self, switch ) - handler.lock_read (switch) + out_error( "server.lua, lock() is deprecated" ) + handler.lock_read (self, switch) if switch == true then - handler.write = idfalse - local tmp = _sendlistlen - _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) - _writetimes[ handler ] = nil - if _sendlistlen ~= tmp then - nosend = true - end + handler.pause_writes (self) elseif switch == false then - handler.write = write - if nosend then - nosend = false - write( "" ) - end + handler.resume_writes (self) end return noread, nosend end + handler.pause_writes = function (self) + local tmp = _sendlistlen + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + _writetimes[ handler ] = nil + if _sendlistlen ~= tmp then + nosend = true + end + end + handler.resume_writes = function (self) + if nosend then + nosend = false + write( "" ) + end + end + local _readbuffer = function( ) -- this function reads data local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern" if not err or (err == "wantread" or err == "timeout") then -- received something @@ -619,11 +628,20 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) local oldsocket, err = socket socket, err = ssl_wrap( socket, sslctx ) -- wrap socket + if not socket then out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) return nil, err -- fatal error end + if socket.sni then + if self.servername then + socket:sni(self.servername); + elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then + socket:sni(self.server().hosts, true); + end + end + socket:settimeout( 0 ) -- add the new socket to our system @@ -659,7 +677,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _socketlist[ socket ] = handler _readlistlen = addsocket(_readlist, socket, _readlistlen) - if sslctx and has_luasec then + if sslctx and ssldirect and has_luasec then out_put "server.lua: auto-starting ssl negotiation..." handler.autostart_ssl = true; local ok, err = handler:starttls(sslctx); @@ -734,9 +752,13 @@ end ----------------------------------// PUBLIC //-- -addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server +listen = function ( addr, port, listeners, config ) addr = addr or "*" + config = config or {} local err + local sslctx = config.tls_ctx; + local ssldirect = config.tls_direct; + local pattern = config.read_size; if type( listeners ) ~= "table" then err = "invalid listener table" elseif type ( addr ) ~= "string" then @@ -757,7 +779,7 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function out_error( "server.lua, [", addr, "]:", port, ": ", err ) return nil, err end - local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket + local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, ssldirect ) -- wrap new server socket if not handler then server:close( ) return nil, err @@ -770,6 +792,14 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function return handler end +addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server + return listen(addr, port, listeners, { + read_size = pattern; + tls_ctx = sslctx; + tls_direct = sslctx and true or false; + }); +end + getserver = function ( addr, port ) return _server[ addr..":"..port ]; end @@ -978,7 +1008,7 @@ end --// EXPERIMENTAL //-- local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx ) - local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) + local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, sslctx) if not handler then return nil, err end _socketlist[ socket ] = handler if not sslctx then @@ -1114,6 +1144,7 @@ return { stats = stats, closeall = closeall, addserver = addserver, + listen = listen, getserver = getserver, setlogger = setlogger, getsettings = getsettings, |