From 80c0a2c47045c59218f185c4d878ff09eef7da0e Mon Sep 17 00:00:00 2001 From: Matthew Wild Date: Sat, 20 Aug 2011 15:06:14 -0400 Subject: net.server_select: Merge straight-SSL and starttls code paths, also fixes onconnect being called before handshake completion for straight-SSL --- net/server_select.lua | 126 +++++++++++++++++++++++--------------------------- 1 file changed, 57 insertions(+), 69 deletions(-) (limited to 'net') diff --git a/net/server_select.lua b/net/server_select.lua index 9a0cfa6e..ad854b25 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -525,6 +525,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.readbuffer = _readbuffer -- when handshake is done, replace the handshake function with regular functions handler.sendbuffer = _sendbuffer _ = status and status( handler, "ssl-handshake-complete" ) + if self.autostart_ssl and listeners.onconnect then + listeners.onconnect(self); + end _readlistlen = addsocket(_readlist, client, _readlistlen) return true else @@ -549,74 +552,56 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport ) end if luasec then - if sslctx then -- ssl? - handler:set_sslctx(sslctx); - out_put("server.lua: ", "starting ssl handshake") - local err - socket, err = ssl_wrap( socket, sslctx ) -- wrap socket - if err then - out_put( "server.lua: ssl error: ", tostring(err) ) - --mem_free( ) - return nil, nil, err -- fatal error + handler.starttls = function( self, _sslctx) + if _sslctx then + handler:set_sslctx(_sslctx); end - socket:settimeout( 0 ) - handler.readbuffer = handshake - handler.sendbuffer = handshake - handshake( socket ) -- do handshake + if bufferqueuelen > 0 then + out_put "server.lua: we need to do tls, but delaying until send buffer empty" + needtls = true + return + end + out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) + local oldsocket, err = socket + socket, err = ssl_wrap( socket, sslctx ) -- wrap socket if not socket then - return nil, nil, "ssl handshake failed"; + out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) + return nil, err -- fatal error end - else - local sslctx; - handler.starttls = function( self, _sslctx) - if _sslctx then - sslctx = _sslctx; - handler:set_sslctx(sslctx); - end - if bufferqueuelen > 0 then - out_put "server.lua: we need to do tls, but delaying until send buffer empty" - needtls = true - return - end - out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) - local oldsocket, err = socket - socket, err = ssl_wrap( socket, sslctx ) -- wrap socket - --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) ) - if err then - out_put( "server.lua: error while starting tls on client: ", tostring(err) ) - return nil, err -- fatal error - end - socket:settimeout( 0 ) - - -- add the new socket to our system - - send = socket.send - receive = socket.receive - shutdown = id - - _socketlist[ socket ] = handler - _readlistlen = addsocket(_readlist, socket, _readlistlen) - - -- remove traces of the old socket + socket:settimeout( 0 ) - _readlistlen = removesocket( _readlist, oldsocket, _readlistlen ) - _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen ) - _socketlist[ oldsocket ] = nil + -- add the new socket to our system + send = socket.send + receive = socket.receive + shutdown = id + _socketlist[ socket ] = handler + _readlistlen = addsocket(_readlist, socket, _readlistlen) + + -- remove traces of the old socket + _readlistlen = removesocket( _readlist, oldsocket, _readlistlen ) + _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen ) + _socketlist[ oldsocket ] = nil - handler.starttls = nil - needtls = nil + handler.starttls = nil + needtls = nil - -- Secure now - ssl = true + -- Secure now (if handshake fails connection will close) + ssl = true - handler.readbuffer = handshake - handler.sendbuffer = handshake - handshake( socket ) -- do handshake - end - handler.readbuffer = _readbuffer - handler.sendbuffer = _sendbuffer + handler.readbuffer = handshake + handler.sendbuffer = handshake + handshake( socket ) -- do handshake + end + handler.readbuffer = _readbuffer + handler.sendbuffer = _sendbuffer + + if sslctx then + out_put "server.lua: auto-starting ssl negotiation..." + handler.autostart_ssl = true; + handler:starttls(sslctx); end + else handler.readbuffer = _readbuffer handler.sendbuffer = _sendbuffer @@ -857,16 +842,19 @@ end local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx ) local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) _socketlist[ socket ] = handler - _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) - if listeners.onconnect then - -- When socket is writeable, call onconnect - local _sendbuffer = handler.sendbuffer; - handler.sendbuffer = function () - handler.sendbuffer = _sendbuffer; - listeners.onconnect(handler); - -- If there was data with the incoming packet, handle it now. - if #handler:bufferqueue() > 0 then - return _sendbuffer(); + if not sslctx then + _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) + if listeners.onconnect then + -- When socket is writeable, call onconnect + local _sendbuffer = handler.sendbuffer; + handler.sendbuffer = function () + handler.sendbuffer = _sendbuffer; + listeners.onconnect(handler); + -- If there was data with the incoming packet, handle it now. + if #handler:bufferqueue() > 0 then + return _sendbuffer(); + end + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) end end end -- cgit v1.2.3