diff options
-rw-r--r-- | net/server_select.lua | 178 |
1 files changed, 61 insertions, 117 deletions
diff --git a/net/server_select.lua b/net/server_select.lua index 682a155d..24c82022 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -160,7 +160,7 @@ _maxclientsperserver = 1000 _maxsslhandshake = 30 -- max handshake round-trips ----------------------------------// PRIVATE //-- -wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxconnections, startssl ) -- this function wraps a server +wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxconnections ) -- this function wraps a server maxconnections = maxconnections or _maxclientsperserver @@ -168,58 +168,6 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco local dispatch, disconnect = listeners.onincoming, listeners.ondisconnect - local err - - local ssl = false - - if sslctx then - ssl = true - if not ssl_newcontext then - out_error "luasec not found" - ssl = false - end - if type( sslctx ) ~= "table" then - out_error "server.lua: wrong server sslctx" - ssl = false - end - local ctx; - ctx, err = ssl_newcontext( sslctx ) - if not ctx then - err = err or "wrong sslctx parameters" - local file; - file = err:match("^error loading (.-) %("); - if file then - if file == "private key" then - file = sslctx.key or "your private key"; - elseif file == "certificate" then - file = sslctx.certificate or "your certificate file"; - end - local reason = err:match("%((.+)%)$") or "some reason"; - if reason == "Permission denied" then - reason = "Check that the permissions allow Prosody to read this file."; - elseif reason == "No such file or directory" then - reason = "Check that the path is correct, and the file exists."; - elseif reason == "system lib" then - reason = "Previous error (see logs), or other system error."; - else - reason = "Reason: "..tostring(reason or "unknown"):lower(); - end - log("error", "SSL/TLS: Failed to load %s: %s", file, reason); - else - log("error", "SSL/TLS: Error initialising for port %d: %s", serverport, err ); - end - ssl = false - end - sslctx = ctx; - end - if not ssl then - sslctx = false; - if startssl then - log("error", "Failed to listen on port %d due to SSL/TLS to SSL/TLS initialisation errors (see logs)", serverport ) - return nil, "Cannot start ssl, see log for details" - end - end - local accept = socket.accept --// public methods of the object //-- @@ -229,7 +177,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco handler.shutdown = function( ) end handler.ssl = function( ) - return ssl + return sslctx ~= nil end handler.sslctx = function( ) return sslctx @@ -271,7 +219,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco if client then local ip, clientport = client:getpeername( ) client:settimeout( 0 ) - local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx, startssl ) -- wrap new client socket + local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket if err then -- error while wrapping ssl socket return false end @@ -286,7 +234,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco return handler end -wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, startssl ) -- this function wraps a client to a handler object +wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object socket:settimeout( 0 ) @@ -520,7 +468,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport bufferqueuelen = 0 bufferlen = 0 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist - _ = needtls and handler:starttls(true) + _ = needtls and handler:starttls(nil, true) _writetimes[ handler ] = nil _ = toclose and handler.close( ) return true @@ -584,72 +532,69 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end if sslctx then -- ssl? handler:set_sslctx(sslctx); - if startssl then -- ssl now? - --out_put("server.lua: ", "starting ssl handshake") - local err + out_put("server.lua: ", "starting ssl handshake") + local err + socket, err = ssl_wrap( socket, sslctx ) -- wrap socket + if err then + out_put( "server.lua: ssl error: ", tostring(err) ) + --mem_free( ) + return nil, nil, err -- fatal error + end + socket:settimeout( 0 ) + handler.readbuffer = handshake + handler.sendbuffer = handshake + handshake( socket ) -- do handshake + if not socket then + return nil, nil, "ssl handshake failed"; + end + else + local sslctx; + handler.starttls = function( self, _sslctx, now ) + if _sslctx then + sslctx = _sslctx; + handler:set_sslctx(sslctx); + end + if not now then + out_put "server.lua: we need to do tls, but delaying until later" + needtls = true + return + end + out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) + local oldsocket, err = socket socket, err = ssl_wrap( socket, sslctx ) -- wrap socket + --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) ) if err then - out_put( "server.lua: ssl error: ", tostring(err) ) - --mem_free( ) - return nil, nil, err -- fatal error + out_put( "server.lua: error while starting tls on client: ", tostring(err) ) + return nil, err -- fatal error end + socket:settimeout( 0 ) - handler.readbuffer = handshake - handler.sendbuffer = handshake - handshake( socket ) -- do handshake - if not socket then - return nil, nil, "ssl handshake failed"; - end - else - -- We're not automatically doing SSL, so we're not secure (yet) - ssl = false - handler.starttls = function( self, now ) - if not now then - --out_put "server.lua: we need to do tls, but delaying until later" - needtls = true - return - end - --out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) - local oldsocket, err = socket - socket, err = ssl_wrap( socket, sslctx ) -- wrap socket - --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) ) - if err then - out_put( "server.lua: error while starting tls on client: ", tostring(err) ) - return nil, err -- fatal error - end - socket:settimeout( 0 ) + -- add the new socket to our system - -- add the new socket to our system + send = socket.send + receive = socket.receive + shutdown = id - send = socket.send - receive = socket.receive - shutdown = id - - _socketlist[ socket ] = handler - _readlistlen = addsocket(_readlist, socket, _readlistlen) + _socketlist[ socket ] = handler + _readlistlen = addsocket(_readlist, socket, _readlistlen) - -- remove traces of the old socket + -- remove traces of the old socket - _readlistlen = removesocket( _readlist, oldsocket, _readlistlen ) - _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen ) - _socketlist[ oldsocket ] = nil + _readlistlen = removesocket( _readlist, oldsocket, _readlistlen ) + _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen ) + _socketlist[ oldsocket ] = nil - handler.starttls = nil - needtls = nil + handler.starttls = nil + needtls = nil - -- Secure now - ssl = true + -- Secure now + ssl = true - handler.readbuffer = handshake - handler.sendbuffer = handshake - handshake( socket ) -- do handshake - end - handler.readbuffer = _readbuffer - handler.sendbuffer = _sendbuffer + handler.readbuffer = handshake + handler.sendbuffer = handshake + handshake( socket ) -- do handshake end - else -- normal connection - ssl = false handler.readbuffer = _readbuffer handler.sendbuffer = _sendbuffer end @@ -705,9 +650,8 @@ end ----------------------------------// PUBLIC //-- -addserver = function( addr, port, listeners, pattern, sslctx, startssl ) -- this function provides a way for other scripts to reg a server +addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server local err - --out_put("server.lua: autossl on ", port, " is ", startssl) if type( listeners ) ~= "table" then err = "invalid listener table" end @@ -728,7 +672,7 @@ addserver = function( addr, port, listeners, pattern, sslctx, startssl ) -- t out_error( "server.lua, port ", port, ": ", err ) return nil, err end - local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, _maxclientsperserver, startssl ) -- wrap new server socket + local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, _maxclientsperserver ) -- wrap new server socket if not handler then server:close( ) return nil, err @@ -857,14 +801,14 @@ end --// EXPERIMENTAL //-- -local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx, startssl ) - local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, startssl ) +local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx ) + local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) _socketlist[ socket ] = handler _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) return handler, socket end -local addclient = function( address, port, listeners, pattern, sslctx, startssl ) +local addclient = function( address, port, listeners, pattern, sslctx ) local client, err = luasocket.tcp( ) if err then return nil, err @@ -874,7 +818,7 @@ local addclient = function( address, port, listeners, pattern, sslctx, startssl if err then -- try again local handler = wrapclient( client, address, port, listeners ) else - wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx, startssl ) + wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx ) end end |