From 34e978cc074e534fab4319a5d11f94b5673eec81 Mon Sep 17 00:00:00 2001 From: daurnimator Date: Wed, 18 Dec 2013 17:50:38 -0500 Subject: net/server_select: pcall require ssl (easy to forget to require ssl) --- net/server_select.lua | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'net') diff --git a/net/server_select.lua b/net/server_select.lua index c5e0772f..61078202 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -48,13 +48,13 @@ local coroutine_yield = coroutine.yield --// extern libs //-- -local luasec = use "ssl" +local has_luasec, luasec = pcall ( require , "ssl" ) local luasocket = use "socket" or require "socket" local luasocket_gettime = luasocket.gettime --// extern lib methods //-- -local ssl_wrap = ( luasec and luasec.wrap ) +local ssl_wrap = ( has_luasec and luasec.wrap ) local socket_bind = luasocket.bind local socket_sleep = luasocket.sleep local socket_select = luasocket.select @@ -585,7 +585,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end ) end - if luasec then + if has_luasec then handler.starttls = function( self, _sslctx) if _sslctx then handler:set_sslctx(_sslctx); @@ -638,7 +638,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _socketlist[ socket ] = handler _readlistlen = addsocket(_readlist, socket, _readlistlen) - if sslctx and luasec then + if sslctx and has_luasec then out_put "server.lua: auto-starting ssl negotiation..." handler.autostart_ssl = true; local ok, err = handler:starttls(sslctx); @@ -721,7 +721,7 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function err = "invalid port" elseif _server[ addr..":"..port ] then err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist" - elseif sslctx and not luasec then + elseif sslctx and not has_luasec then err = "luasec not found" end if err then -- cgit v1.2.3 From 927d7917375186ae365d85c6d59f7f9f6a012b57 Mon Sep 17 00:00:00 2001 From: daurnimator Date: Wed, 18 Dec 2013 17:51:27 -0500 Subject: net/server_select: Check arguments to add_server correctly --- net/server_select.lua | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'net') diff --git a/net/server_select.lua b/net/server_select.lua index 61078202..e319e016 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -713,11 +713,13 @@ end ----------------------------------// PUBLIC //-- addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server + addr = addr or "*" local err if type( listeners ) ~= "table" then err = "invalid listener table" - end - if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then + elseif type ( addr ) ~= "string" then + err = "invalid address" + elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then err = "invalid port" elseif _server[ addr..":"..port ] then err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist" @@ -728,7 +730,6 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function out_error( "server.lua, [", addr, "]:", port, ": ", err ) return nil, err end - addr = addr or "*" local server, err = socket_bind( addr, port, _tcpbacklog ) if err then out_error( "server.lua, [", addr, "]:", port, ": ", err ) -- cgit v1.2.3 From 0639f3d4a3fbd107d93ca064a3248c7f18862aad Mon Sep 17 00:00:00 2001 From: daurnimator Date: Wed, 18 Dec 2013 17:52:28 -0500 Subject: net/server_event: add_client should have same arguments no-matter the server backend --- net/server_event.lua | 27 ++++++--------------------- 1 file changed, 6 insertions(+), 21 deletions(-) (limited to 'net') diff --git a/net/server_event.lua b/net/server_event.lua index 59217a0c..82accc99 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -744,36 +744,21 @@ do --function handleclient( client, ip, port, server, pattern, listener, _, sslctx ) -- creates an client interface end - function addclient( addr, serverport, listener, pattern, localaddr, localport, sslcfg, startssl ) + function addclient( addr, serverport, listener, pattern, sslctx ) + if sslctx and not ssl then + debug "need luasec, but not available" + return nil, "luasec not found" + end local client, err = socket.tcp() -- creating new socket if not client then debug( "cannot create socket:", err ) return nil, err end client:settimeout( 0 ) -- set nonblocking - if localaddr then - local res, err = client:bind( localaddr, localport, -1 ) - if not res then - debug( "cannot bind client:", err ) - return nil, err - end - end - local sslctx - if sslcfg then -- handle ssl/new context - if not ssl then - debug "need luasec, but not available" - return nil, "luasec not found" - end - sslctx, err = sslcfg - if err then - debug( "cannot create new ssl context:", err ) - return nil, err - end - end local res, err = client:connect( addr, serverport ) -- connect if res or ( err == "timeout" ) then local ip, port = client:getsockname( ) - local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx, startssl ) + local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx ) interface:_start_connection( startssl ) debug( "new connection id:", interface.id ) return interface, err -- cgit v1.2.3 From 98bd53004dab9e535058cc46cf019610df7ced0f Mon Sep 17 00:00:00 2001 From: daurnimator Date: Wed, 18 Dec 2013 17:54:31 -0500 Subject: net/server_select: addclient: Check for failure correctly; remove wrapconnection call on failure --- net/server_select.lua | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'net') diff --git a/net/server_select.lua b/net/server_select.lua index e319e016..bd4e59df 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -936,11 +936,11 @@ local addclient = function( address, port, listeners, pattern, sslctx ) return nil, err end client:settimeout( 0 ) - _, err = client:connect( address, port ) - if err then -- try again + local ok, err = client:connect( address, port ) + if ok or err == "timeout" then return wrapclient( client, address, port, listeners, pattern, sslctx ) else - return wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx ) + return nil, err end end -- cgit v1.2.3 From e3c0bdb6f3179edbb38371a301af4da309760246 Mon Sep 17 00:00:00 2001 From: daurnimator Date: Wed, 18 Dec 2013 17:55:03 -0500 Subject: net/server_select: addclient: Check arguments --- net/server_select.lua | 15 +++++++++++++++ 1 file changed, 15 insertions(+) (limited to 'net') diff --git a/net/server_select.lua b/net/server_select.lua index bd4e59df..c707e48f 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -931,6 +931,21 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx end local addclient = function( address, port, listeners, pattern, sslctx ) + local err + if type( listeners ) ~= "table" then + err = "invalid listener table" + elseif type ( addr ) ~= "string" then + err = "invalid address" + elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then + err = "invalid port" + elseif sslctx and not has_luasec then + err = "luasec not found" + end + if err then + out_error( "server.lua, addclient: ", err ) + return nil, err + end + local client, err = luasocket.tcp( ) if err then return nil, err -- cgit v1.2.3 From 31e5378e58926c552870a92f33d79da12af75507 Mon Sep 17 00:00:00 2001 From: daurnimator Date: Wed, 18 Dec 2013 18:06:33 -0500 Subject: net/server_select: Fix typo --- net/server_select.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'net') diff --git a/net/server_select.lua b/net/server_select.lua index c707e48f..91b8b01f 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -934,7 +934,7 @@ local addclient = function( address, port, listeners, pattern, sslctx ) local err if type( listeners ) ~= "table" then err = "invalid listener table" - elseif type ( addr ) ~= "string" then + elseif type ( address ) ~= "string" then err = "invalid address" elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then err = "invalid port" -- cgit v1.2.3 From 48f909666b964c8170ea8785a5e6b6043a25648b Mon Sep 17 00:00:00 2001 From: daurnimator Date: Wed, 18 Dec 2013 18:11:17 -0500 Subject: net/server_event: pcall require ssl rather than relying on globals --- net/server_event.lua | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) (limited to 'net') diff --git a/net/server_event.lua b/net/server_event.lua index 82accc99..502cc80a 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -44,7 +44,7 @@ local setmetatable = use "setmetatable" local t_insert = table.insert local t_concat = table.concat -local ssl = use "ssl" +local has_luasec, ssl = pcall ( require , "ssl" ) local socket = use "socket" or require "socket" local log = require ("util.logger").init("socket") @@ -136,7 +136,7 @@ do self:_close() debug( "new connection failed. id:", self.id, "error:", self.fatalerror ) else - if plainssl and ssl then -- start ssl session + if plainssl and has_luasec then -- start ssl session self:starttls(self._sslctx, true) else -- normal connection self:_start_session(true) @@ -506,7 +506,7 @@ do _sslctx = sslctx; -- parameters _usingssl = false; -- client is using ssl; } - if not ssl then interface.starttls = false; end + if not has_luasec then interface.starttls = false; end interface.id = tostring(interface):match("%x+$"); interface.writecallback = function( event ) -- called on write events --vdebug( "new client write event, id/ip/port:", interface, ip, port ) @@ -689,7 +689,7 @@ do interface._connections = interface._connections + 1 -- increase connection count local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx ) --vdebug( "client id:", clientinterface, "startssl:", startssl ) - if ssl and sslctx then + if has_luasec and sslctx then clientinterface:starttls(sslctx, true) else clientinterface:_start_session( true ) @@ -710,25 +710,17 @@ do end local addserver = ( function( ) - return function( addr, port, listener, pattern, sslcfg, startssl ) -- TODO: check arguments - --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslcfg or "nil", startssl or "nil") + return function( addr, port, listener, pattern, sslctx, startssl ) -- TODO: check arguments + --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslctx or "nil", startssl or "nil") + if sslctx and not has_luasec then + debug "fatal error: luasec not found" + return nil, "luasec not found" + end local server, err = socket.bind( addr, port, cfg.ACCEPT_QUEUE ) -- create server socket if not server then debug( "creating server socket on "..addr.." port "..port.." failed:", err ) return nil, err end - local sslctx - if sslcfg then - if not ssl then - debug "fatal error: luasec not found" - return nil, "luasec not found" - end - sslctx, err = sslcfg - if err then - debug( "error while creating new ssl context for server socket:", err ) - return nil, err - end - end local interface = handleserver( server, addr, port, pattern, listener, sslctx, startssl ) -- new server handler debug( "new server created with id:", tostring(interface)) return interface @@ -745,7 +737,7 @@ do end function addclient( addr, serverport, listener, pattern, sslctx ) - if sslctx and not ssl then + if sslctx and not has_luasec then debug "need luasec, but not available" return nil, "luasec not found" end -- cgit v1.2.3 From 248c6e05ed1c5927c6c18c0481f152993f1a03a3 Mon Sep 17 00:00:00 2001 From: daurnimator Date: Wed, 18 Dec 2013 18:11:47 -0500 Subject: net/server: addclient: wrapclient already calls startconnection for us --- net/server_event.lua | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'net') diff --git a/net/server_event.lua b/net/server_event.lua index 502cc80a..81dc4512 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -128,7 +128,7 @@ do return self:_destroy(); end - function interface_mt:_start_connection(plainssl) -- should be called from addclient + function interface_mt:_start_connection(plainssl) -- called from wrapclient local callback = function( event ) if EV_TIMEOUT == event then -- timeout during connection self.fatalerror = "connection timeout" @@ -751,7 +751,6 @@ do if res or ( err == "timeout" ) then local ip, port = client:getsockname( ) local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx ) - interface:_start_connection( startssl ) debug( "new connection id:", interface.id ) return interface, err else -- cgit v1.2.3 From 8184587aa4a1ace6268e9ceb2059f1c72bf14633 Mon Sep 17 00:00:00 2001 From: daurnimator Date: Wed, 18 Dec 2013 19:00:24 -0500 Subject: net/http: Use server.addclient --- net/http.lua | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) (limited to 'net') diff --git a/net/http.lua b/net/http.lua index ab9ec7b6..b87c9396 100644 --- a/net/http.lua +++ b/net/http.lua @@ -6,7 +6,6 @@ -- COPYING file in the source package for more information. -- -local socket = require "socket" local b64 = require "util.encodings".base64.encode; local url = require "socket.url" local httpstream_new = require "net.http.parser".new; @@ -160,21 +159,17 @@ function request(u, ex, callback) end local port_number = port and tonumber(port) or (using_https and 443 or 80); - -- Connect the socket, and wrap it with net.server - local conn = socket.tcp(); - conn:settimeout(10); - local ok, err = conn:connect(host, port_number); - if not ok and err ~= "timeout" then - callback(nil, 0, req); - return nil, err; - end - local sslctx = false; if using_https then sslctx = ex and ex.sslctx or { mode = "client", protocol = "sslv23", options = { "no_sslv2" } }; end - req.handler, req.conn = assert(server.wrapclient(conn, host, port_number, listener, "*a", sslctx)); + local handler, conn = server.addclient(host, port_number, listener, "*a", sslctx) + if not handler then + callback(nil, 0, req); + return nil, conn; + end + req.handler, req.conn = handler, conn req.write = function (...) return req.handler:write(...); end req.callback = function (content, code, request, response) log("debug", "Calling callback, status %s", code or "---"); return select(2, xpcall(function () return callback(content, code, request, response) end, handleerr)); end -- cgit v1.2.3 From 24118d8f74b097bf71f61d240c855ff5655e510d Mon Sep 17 00:00:00 2001 From: Kim Alvefur Date: Mon, 23 Dec 2013 17:55:41 +0100 Subject: net.server_{select,event}: addclient: Add argument for overriding socket type --- net/server_event.lua | 17 +++++++++++++---- net/server_select.lua | 11 +++++++++-- 2 files changed, 22 insertions(+), 6 deletions(-) (limited to 'net') diff --git a/net/server_event.lua b/net/server_event.lua index 81dc4512..ae64d50e 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -736,12 +736,19 @@ do --function handleclient( client, ip, port, server, pattern, listener, _, sslctx ) -- creates an client interface end - function addclient( addr, serverport, listener, pattern, sslctx ) + function addclient( addr, serverport, listener, pattern, sslctx, typ ) if sslctx and not has_luasec then debug "need luasec, but not available" return nil, "luasec not found" end - local client, err = socket.tcp() -- creating new socket + if not typ then + typ = "tcp" + end + local create = socket[typ] + if type( create ) ~= "function" then + return nil, "invalid socket type" + end + local client, err = create() -- creating new socket if not client then debug( "cannot create socket:", err ) return nil, err @@ -749,8 +756,10 @@ do client:settimeout( 0 ) -- set nonblocking local res, err = client:connect( addr, serverport ) -- connect if res or ( err == "timeout" ) then - local ip, port = client:getsockname( ) - local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx ) + if client.getsockname then + addr = client:getsockname( ) + end + local interface = wrapclient( client, addr, serverport, listener, pattern, sslctx ) debug( "new connection id:", interface.id ) return interface, err else diff --git a/net/server_select.lua b/net/server_select.lua index 91b8b01f..1ce3c8c7 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -930,7 +930,7 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx return handler, socket end -local addclient = function( address, port, listeners, pattern, sslctx ) +local addclient = function( address, port, listeners, pattern, sslctx, typ ) local err if type( listeners ) ~= "table" then err = "invalid listener table" @@ -941,12 +941,19 @@ local addclient = function( address, port, listeners, pattern, sslctx ) elseif sslctx and not has_luasec then err = "luasec not found" end + if not typ then + typ = "tcp" + end + local create = luasocket[typ] + if type( create ) ~= "function" then + err = "invalid socket type" + end if err then out_error( "server.lua, addclient: ", err ) return nil, err end - local client, err = luasocket.tcp( ) + local client, err = create( ) if err then return nil, err end -- cgit v1.2.3 From bcb9af93c22dba3008a79ecd44eb947c2c7cd910 Mon Sep 17 00:00:00 2001 From: Kim Alvefur Date: Mon, 23 Dec 2013 17:57:53 +0100 Subject: net.server_{select,event}: addclient: Use getaddrinfo to detect IP address type if no socket type argument given. (Argument must be given for non-TCP) --- net/server_event.lua | 9 ++++++++- net/server_select.lua | 10 +++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) (limited to 'net') diff --git a/net/server_event.lua b/net/server_event.lua index ae64d50e..1a3b8ca6 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -46,6 +46,7 @@ local t_concat = table.concat local has_luasec, ssl = pcall ( require , "ssl" ) local socket = use "socket" or require "socket" +local getaddrinfo = socket.dns.getaddrinfo local log = require ("util.logger").init("socket") @@ -742,7 +743,13 @@ do return nil, "luasec not found" end if not typ then - typ = "tcp" + local addrinfo, err = getaddrinfo(addr) + if not addrinfo then return nil, err end + if addrinfo[1] and addrinfo[1].family == "inet6" then + typ = "tcp6" + else + typ = "tcp" + end end local create = socket[typ] if type( create ) ~= "function" then diff --git a/net/server_select.lua b/net/server_select.lua index 1ce3c8c7..ee9cac7e 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -51,6 +51,7 @@ local coroutine_yield = coroutine.yield local has_luasec, luasec = pcall ( require , "ssl" ) local luasocket = use "socket" or require "socket" local luasocket_gettime = luasocket.gettime +local getaddrinfo = luasocket.dns.getaddrinfo --// extern lib methods //-- @@ -942,12 +943,19 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ ) err = "luasec not found" end if not typ then - typ = "tcp" + local addrinfo, err = getaddrinfo(address) + if not addrinfo then return nil, err end + if addrinfo[1] and addrinfo[1].family == "inet6" then + typ = "tcp6" + else + typ = "tcp" + end end local create = luasocket[typ] if type( create ) ~= "function" then err = "invalid socket type" end + if err then out_error( "server.lua, addclient: ", err ) return nil, err -- cgit v1.2.3 From 29d3c27219e71e8f2ed0c835bb83333ca5ce9fbc Mon Sep 17 00:00:00 2001 From: Kim Alvefur Date: Mon, 23 Dec 2013 23:23:59 +0100 Subject: net.server_{select,event}: addclient: Handle missing getaddrinfo --- net/server_event.lua | 6 ++---- net/server_select.lua | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) (limited to 'net') diff --git a/net/server_event.lua b/net/server_event.lua index 1a3b8ca6..ef0a27d8 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -742,16 +742,14 @@ do debug "need luasec, but not available" return nil, "luasec not found" end - if not typ then + if getaddrinfo and not typ then local addrinfo, err = getaddrinfo(addr) if not addrinfo then return nil, err end if addrinfo[1] and addrinfo[1].family == "inet6" then typ = "tcp6" - else - typ = "tcp" end end - local create = socket[typ] + local create = socket[typ or "tcp"] if type( create ) ~= "function" then return nil, "invalid socket type" end diff --git a/net/server_select.lua b/net/server_select.lua index ee9cac7e..b69b5fc7 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -942,16 +942,14 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ ) elseif sslctx and not has_luasec then err = "luasec not found" end - if not typ then + if getaddrinfo and not typ then local addrinfo, err = getaddrinfo(address) if not addrinfo then return nil, err end if addrinfo[1] and addrinfo[1].family == "inet6" then typ = "tcp6" - else - typ = "tcp" end end - local create = luasocket[typ] + local create = luasocket[typ or "tcp"] if type( create ) ~= "function" then err = "invalid socket type" end -- cgit v1.2.3 From 55a097981ec6c98f4dd3115e143c066fbe90d8e3 Mon Sep 17 00:00:00 2001 From: daurnimator Date: Wed, 25 Jun 2014 12:15:00 -0400 Subject: net/server_*: Fix addclient: LuaSocket 3.0-rc1 sometimes returns EALREADY instead of EINPROGRESS when the dns lookup has multiple results --- net/server_event.lua | 2 +- net/server_select.lua | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'net') diff --git a/net/server_event.lua b/net/server_event.lua index a3087847..b79fc463 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -761,7 +761,7 @@ do end client:settimeout( 0 ) -- set nonblocking local res, err = client:connect( addr, serverport ) -- connect - if res or ( err == "timeout" ) then + if res or ( err == "timeout" or err == "Operation already in progress" ) then if client.getsockname then addr = client:getsockname( ) end diff --git a/net/server_select.lua b/net/server_select.lua index 4a36617c..0aaea4be 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -966,7 +966,7 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ ) end client:settimeout( 0 ) local ok, err = client:connect( address, port ) - if ok or err == "timeout" then + if ok or err == "timeout" or err == "Operation already in progress" then return wrapclient( client, address, port, listeners, pattern, sslctx ) else return nil, err -- cgit v1.2.3 From 716ad8b24a7dc699c39b8b30c484405035a078e8 Mon Sep 17 00:00:00 2001 From: daurnimator Date: Wed, 3 Sep 2014 15:28:46 -0400 Subject: net/websocket: Add new websocket client code --- net/websocket.lua | 264 +++++++++++++++++++++++++++++++++++++++++++++++ net/websocket/frames.lua | 196 +++++++++++++++++++++++++++++++++++ 2 files changed, 460 insertions(+) create mode 100644 net/websocket.lua create mode 100644 net/websocket/frames.lua (limited to 'net') diff --git a/net/websocket.lua b/net/websocket.lua new file mode 100644 index 00000000..32bb1a6e --- /dev/null +++ b/net/websocket.lua @@ -0,0 +1,264 @@ +-- Prosody IM +-- Copyright (C) 2012 Florian Zeitz +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local http = require "net.http"; +local frames = require "net.websocket.frames"; +local base64 = require "util.encodings".base64; +local sha1 = require "util.hashes".sha1; +local random_bytes = require "util.random".bytes; +local timer = require "util.timer"; +local log = require "util.logger".init "websocket"; + +local close_timeout = 3; -- Seconds to wait after sending close frame until closing connection. + +local websockets = {}; + +local websocket_listeners = {}; +function websocket_listeners.ondisconnect(handler, err) + local s = websockets[handler]; + websockets[handler] = nil; + if s.close_timer then + timer.stop(s.close_timer); + s.close_timer = nil; + end + s.readyState = 3; + if s.close_code == nil and s.onerror then s:onerror(err); end + if s.onclose then s:onclose(s.close_code, s.close_message or err); end +end + +function websocket_listeners.ondetach(handler) + websockets[handler] = nil; +end + +local function fail(s, code, reason) + module:log("warn", "WebSocket connection failed, closing. %d %s", code, reason); + s:close(code, reason); + s.handler:close(); + return false +end + +function websocket_listeners.onincoming(handler, buffer, err) + local s = websockets[handler]; + s.readbuffer = s.readbuffer..buffer; + while true do + local frame, len = frames.parse(s.readbuffer); + if frame == nil then break end + s.readbuffer = s.readbuffer:sub(len+1); + + log("debug", "Websocket received frame: opcode=%0x, %i bytes", frame.opcode, #frame.data); + + -- Error cases + if frame.RSV1 or frame.RSV2 or frame.RSV3 then -- Reserved bits non zero + return fail(s, 1002, "Reserved bits not zero"); + end + + if frame.opcode < 0x8 then + local databuffer = s.databuffer; + if frame.opcode == 0x0 then -- Continuation frames + if not databuffer then + return fail(s, 1002, "Unexpected continuation frame"); + end + databuffer[#databuffer+1] = frame.data; + elseif frame.opcode == 0x1 or frame.opcode == 0x2 then -- Text or Binary frame + if databuffer then + return fail(s, 1002, "Continuation frame expected"); + end + databuffer = {type=frame.opcode, frame.data}; + s.databuffer = databuffer; + else + return fail(s, 1002, "Reserved opcode"); + end + if frame.FIN then + s.databuffer = nil; + if s.onmessage then + s:onmessage(table.concat(databuffer), databuffer.type); + end + end + else -- Control frame + if frame.length > 125 then -- Control frame with too much payload + return fail(s, 1002, "Payload too large"); + elseif not frame.FIN then -- Fragmented control frame + return fail(s, 1002, "Fragmented control frame"); + end + if frame.opcode == 0x8 then -- Close request + if frame.length == 1 then + return fail(s, 1002, "Close frame with payload, but too short for status code"); + end + local status_code, message = frames.parse_close(frame.data); + if status_code == nil then + --[[ RFC 6455 7.4.1 + 1005 is a reserved value and MUST NOT be set as a status code in a + Close control frame by an endpoint. It is designated for use in + applications expecting a status code to indicate that no status + code was actually present. + ]] + status_code = 1005 + elseif status_code < 1000 then + return fail(s, 1002, "Closed with invalid status code"); + elseif ((status_code > 1003 and status_code < 1007) or status_code > 1011) and status_code < 3000 then + return fail(s, 1002, "Closed with reserved status code"); + end + s.close_code, s.close_message = status_code, message; + s:close(1000); + return true; + elseif frame.opcode == 0x9 then -- Ping frame + frame.opcode = 0xA; + frame.MASK = true; -- RFC 6455 6.1.5: If the data is being sent by the client, the frame(s) MUST be masked + handler:write(frames.build(frame)); + elseif frame.opcode == 0xA then -- Pong frame + log("debug", "Received unexpected pong frame: " .. tostring(frame.data)); + else + return fail(s, 1002, "Reserved opcode"); + end + end + end + return true; +end + +local websocket_methods = {}; +local function close_timeout_cb(now, timerid, s) + s.close_timer = nil; + log("warn", "Close timeout waiting for server to close, closing manually."); + s.handler:close(); +end +function websocket_methods:close(code, reason) + if self.readyState < 2 then + code = code or 1000; + log("debug", "closing WebSocket with code %i: %s" , code , tostring(reason)); + self.readyState = 2; + local handler = self.handler; + handler:write(frames.build_close(code, reason)); + -- Do not close socket straight away, wait for acknowledgement from server. + self.close_timer = timer.add_task(close_timeout, close_timeout_cb, self); + elseif self.readyState == 2 then + log("debug", "tried to close a closing WebSocket, closing the raw socket."); + -- Stop timer + if self.close_timer then + timer.stop(self.close_timer); + self.close_timer = nil; + end + local handler = self.handler; + handler:close(); + else + log("debug", "tried to close a closed WebSocket, ignoring."); + end +end +function websocket_methods:send(data, opcode) + if self.readyState < 1 then + return nil, "WebSocket not open yet, unable to send data."; + elseif self.readyState >= 2 then + return nil, "WebSocket closed, unable to send data."; + end + if opcode == "text" or opcode == nil then + opcode = 0x1; + elseif opcode == "binary" then + opcode = 0x2; + end + local frame = { + FIN = true; + MASK = true; -- RFC 6455 6.1.5: If the data is being sent by the client, the frame(s) MUST be masked + opcode = opcode; + data = tostring(data); + }; + log("debug", "WebSocket sending frame: opcode=%0x, %i bytes", frame.opcode, #frame.data); + return self.handler:write(frames.build(frame)); +end + +local websocket_metatable = { + __index = websocket_methods; +}; + +local function connect(url, ex, listeners) + ex = ex or {}; + + --[[ RFC 6455 4.1.7: + The request MUST include a header field with the name + |Sec-WebSocket-Key|. The value of this header field MUST be a + nonce consisting of a randomly selected 16-byte value that has + been base64-encoded (see Section 4 of [RFC4648]). The nonce + MUST be selected randomly for each connection. + ]] + local key = base64.encode(random_bytes(16)); + + -- Either a single protocol string or an array of protocol strings. + local protocol = ex.protocol; + if type(protocol) == "table" then + protocol = table.concat(protocol, ", "); + end + + local headers = { + ["Upgrade"] = "websocket"; + ["Connection"] = "Upgrade"; + ["Sec-WebSocket-Key"] = key; + ["Sec-WebSocket-Protocol"] = protocol; + ["Sec-WebSocket-Version"] = "13"; + ["Sec-WebSocket-Extensions"] = ex.extensions; + } + if ex.headers then + for k,v in pairs(ex.headers) do + headers[k] = v; + end + end + + local s = setmetatable({ + readbuffer = ""; + databuffer = nil; + handler = nil; + close_code = nil; + close_message = nil; + close_timer = nil; + readyState = 0; + protocol = nil; + + url = url; + + onopen = listeners.onopen; + onclose = listeners.onclose; + onmessage = listeners.onmessage; + onerror = listeners.onerror; + }, websocket_metatable); + + local http_url = url:gsub("^(ws)", "http"); + local http_req = http.request(http_url, { + method = "GET"; + headers = headers; + sslctx = ex.sslctx; + }, function(b, c, r, http_req) + if c ~= 101 + or r.headers["connection"]:lower() ~= "upgrade" + or r.headers["upgrade"] ~= "websocket" + or r.headers["sec-websocket-accept"] ~= base64.encode(sha1(key .. "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) + -- TODO: check "Sec-WebSocket-Protocol" + then + s.readyState = 3; + log("warn", "WebSocket connection to %s failed: %s", url, tostring(b)); + if s.onerror then s:onerror("connecting-failed"); end + return + end + + s.protocol = r.headers["sec-websocket-protocol"]; + + -- Take possession of socket from http + http_req.conn = nil; + local handler = http_req.handler; + s.handler = handler; + websockets[handler] = s; + handler:setlistener(websocket_listeners); + + log("debug", "WebSocket connected successfully to %s", url); + s.readyState = 1; + if s.onopen then s:onopen(); end + websocket_listeners.onincoming(handler, b); + end); + + return s; +end + +return { + connect = connect; +}; diff --git a/net/websocket/frames.lua b/net/websocket/frames.lua new file mode 100644 index 00000000..a5fcdad9 --- /dev/null +++ b/net/websocket/frames.lua @@ -0,0 +1,196 @@ +-- Prosody IM +-- Copyright (C) 2012 Florian Zeitz +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local softreq = require "util.dependencies".softreq; +local log = require "util.logger".init "websocket.frames"; +local random_bytes = require "util.random".bytes; + +local bit; +pcall(function() bit = require"bit"; end); +bit = bit or softreq"bit32" +if not bit then log("error", "No bit module found. Either LuaJIT 2, lua-bitop or Lua 5.2 is required"); end +local band = bit.band; +local bor = bit.bor; +local bxor = bit.bxor; +local lshift = bit.lshift; +local rshift = bit.rshift; + +local t_concat = table.concat; +local s_byte = string.byte; +local s_char= string.char; +local s_sub = string.sub; + +local function read_uint16be(str, pos) + local l1, l2 = s_byte(str, pos, pos+1); + return l1*256 + l2; +end +-- TODO: this may lose precision +local function read_uint64be(str, pos) + local l1, l2, l3, l4, l5, l6, l7, l8 = s_byte(str, pos, pos+7); + return lshift(l1, 56) + lshift(l2, 48) + lshift(l3, 40) + lshift(l4, 32) + + lshift(l5, 24) + lshift(l6, 16) + lshift(l7, 8) + l8; +end +local function pack_uint16be(x) + return s_char(rshift(x, 8), band(x, 0xFF)); +end +local function sm(x, n) + return band(rshift(x, n), 0xFF); +end +local function pack_uint64be(x) + return s_char(rshift(x, 56), sm(x, 48), sm(x, 40), sm(x, 32), + sm(x, 24), sm(x, 16), sm(x, 8), band(x, 0xFF)); +end + +local function parse_frame_header(frame) + if #frame < 2 then return; end + + local byte1, byte2 = s_byte(frame, 1, 2); + local result = { + FIN = band(byte1, 0x80) > 0; + RSV1 = band(byte1, 0x40) > 0; + RSV2 = band(byte1, 0x20) > 0; + RSV3 = band(byte1, 0x10) > 0; + opcode = band(byte1, 0x0F); + + MASK = band(byte2, 0x80) > 0; + length = band(byte2, 0x7F); + }; + + local length_bytes = 0; + if result.length == 126 then + length_bytes = 2; + elseif result.length == 127 then + length_bytes = 8; + end + + local header_length = 2 + length_bytes + (result.MASK and 4 or 0); + if #frame < header_length then return; end + + if length_bytes == 2 then + result.length = read_uint16be(frame, 3); + elseif length_bytes == 8 then + result.length = read_uint64be(frame, 3); + end + + if result.MASK then + result.key = { s_byte(frame, pos+1, pos+4) }; + end + + return result, header_length; +end + +-- XORs the string `str` with the array of bytes `key` +-- TODO: optimize +local function apply_mask(str, key, from, to) + from = from or 1 + if from < 0 then from = #str + from + 1 end -- negative indicies + to = to or #str + if to < 0 then to = #str + to + 1 end -- negative indicies + local key_len = #key + local counter = 0; + local data = {}; + for i = from, to do + local key_index = counter%key_len + 1; + counter = counter + 1; + data[counter] = s_char(bxor(key[key_index], s_byte(str, i))); + end + return t_concat(data); +end + +local function parse_frame_body(frame, header, pos) + if header.MASK then + return apply_mask(frame, header.key, pos, pos + header.length - 1); + else + return frame:sub(pos, pos + header.length - 1); + end +end + +local function parse_frame(frame) + local result, pos = parse_frame_header(frame); + if result == nil or #frame < (pos + result.length) then return; end + result.data = parse_frame_body(frame, result, pos+1); + return result, pos + result.length; +end + +local function build_frame(desc) + local data = desc.data or ""; + + assert(desc.opcode and desc.opcode >= 0 and desc.opcode <= 0xF, "Invalid WebSocket opcode"); + if desc.opcode >= 0x8 then + -- RFC 6455 5.5 + assert(#data <= 125, "WebSocket control frames MUST have a payload length of 125 bytes or less."); + end + + local b1 = bor(desc.opcode, + desc.FIN and 0x80 or 0, + desc.RSV1 and 0x40 or 0, + desc.RSV2 and 0x20 or 0, + desc.RSV3 and 0x10 or 0); + local b2; + + local length_extra + if #data <= 125 then -- 7-bit length + b2 = #data; + length_extra = ""; + elseif #data <= 0xFFFF then -- 2-byte length + b2 = 126; + length_extra = pack_uint16be(#data); + else -- 8-byte length + b2 = 127; + length_extra = pack_uint64be(#data); + end + + local key = "" + if desc.MASK then + local key_a = desc.key + if key_a then + key = s_char(unpack(key_a, 1, 4)); + else + key = random_bytes(4); + key_a = {key:byte(1,4)}; + end + b2 = bor(b2, 0x80); + data = apply_mask(data, key_a); + end + + return s_char(b1, b2) .. length_extra .. key .. data +end + +local function parse_close(data) + local code, message + if #data >= 2 then + code = read_uint16be(data, 1); + if #data > 2 then + message = s_sub(data, 3); + end + end + return code, message +end + +local function build_close(code, message) + local data = pack_uint16be(code); + if message then + assert(#message<=123, "Close reason must be <=123 bytes"); + data = data .. message; + end + return build_frame({ + opcode = 0x8; + FIN = true; + MASK = true; + data = data; + }); +end + +return { + parse_header = parse_frame_header; + parse_body = parse_frame_body; + parse = parse_frame; + build = build_frame; + parse_close = parse_close; + build_close = build_close; +}; -- cgit v1.2.3 From f0f0c0393ca2c0f6645fc2055c8e4aeffdac4225 Mon Sep 17 00:00:00 2001 From: daurnimator Date: Fri, 17 Oct 2014 17:30:21 -0400 Subject: net/server: Split up different backends in a nicer way. Add global config option 'server' --- net/server.lua | 89 +++++++++++++++++++++++++++++++++------------------------- 1 file changed, 51 insertions(+), 38 deletions(-) (limited to 'net') diff --git a/net/server.lua b/net/server.lua index 2a0b89ae..449632ca 100644 --- a/net/server.lua +++ b/net/server.lua @@ -6,18 +6,22 @@ -- COPYING file in the source package for more information. -- -local use_luaevent = prosody and require "core.configmanager".get("*", "use_libevent"); +local server_type = prosody and require "core.configmanager".get("*", "server") or "select"; +if prosody and require "core.configmanager".get("*", "use_libevent") then + server_type = "event"; +end -if use_luaevent then - use_luaevent = pcall(require, "luaevent.core"); - if not use_luaevent then +if server_type == "event" then + if not pcall(require, "luaevent.core") then + print(log) log("error", "libevent not found, falling back to select()"); + server_type = "select" end end local server; - -if use_luaevent then +local set_config; +if server_type == "event" then server = require "net.server_event"; -- Overwrite signal.signal() because we need to ask libevent to @@ -35,45 +39,54 @@ if use_luaevent then return server.hook_signal(signal_id, handler); end end -else - use_luaevent = false; + + local defaults = {}; + for k,v in pairs(server.cfg) do + defaults[k] = v; + end + function set_config(settings) + local event_settings = { + ACCEPT_DELAY = settings.event_accept_retry_interval; + ACCEPT_QUEUE = settings.tcp_backlog; + CLEAR_DELAY = settings.event_clear_interval; + CONNECT_TIMEOUT = settings.connect_timeout; + DEBUG = settings.debug; + HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout; + MAX_CONNECTIONS = settings.max_connections; + MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips; + MAX_READ_LENGTH = settings.max_receive_buffer_size; + MAX_SEND_LENGTH = settings.max_send_buffer_size; + READ_TIMEOUT = settings.read_timeout; + WRITE_TIMEOUT = settings.send_timeout; + }; + + for k,default in pairs(defaults) do + server.cfg[k] = event_settings[k] or default; + end + end +elseif server_type == "select" then server = require "net.server_select"; -end -if prosody then - local config_get = require "core.configmanager".get; local defaults = {}; - for k,v in pairs(server.cfg or server.getsettings()) do + for k,v in pairs(server.getsettings()) do defaults[k] = v; end + function set_config(settings) + local select_settings = {}; + for k,default in pairs(defaults) do + select_settings[k] = settings[k] or default; + end + server.changesettings(select_settings); + end +else + error("Unsupported server type") +end + +if prosody then + local config_get = require "core.configmanager".get; local function load_config() local settings = config_get("*", "network_settings") or {}; - if use_luaevent then - local event_settings = { - ACCEPT_DELAY = settings.event_accept_retry_interval; - ACCEPT_QUEUE = settings.tcp_backlog; - CLEAR_DELAY = settings.event_clear_interval; - CONNECT_TIMEOUT = settings.connect_timeout; - DEBUG = settings.debug; - HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout; - MAX_CONNECTIONS = settings.max_connections; - MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips; - MAX_READ_LENGTH = settings.max_receive_buffer_size; - MAX_SEND_LENGTH = settings.max_send_buffer_size; - READ_TIMEOUT = settings.read_timeout; - WRITE_TIMEOUT = settings.send_timeout; - }; - - for k,default in pairs(defaults) do - server.cfg[k] = event_settings[k] or default; - end - else - local select_settings = {}; - for k,default in pairs(defaults) do - select_settings[k] = settings[k] or default; - end - server.changesettings(select_settings); - end + return set_config(settings); end load_config(); prosody.events.add_handler("config-reloaded", load_config); -- cgit v1.2.3 From dcd855afaa62797fc8285ff4a7a8e1a8f6279a1f Mon Sep 17 00:00:00 2001 From: daurnimator Date: Mon, 20 Oct 2014 16:13:24 -0400 Subject: Move timer code out of util.timer and into relevant net.server backends --- net/server_event.lua | 18 ++++++++++++++++++ net/server_select.lua | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) (limited to 'net') diff --git a/net/server_event.lua b/net/server_event.lua index 480d876d..fa6dda19 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -848,6 +848,23 @@ local function link(sender, receiver, buffersize) sender:set_mode("*a"); end +local add_task do + local EVENT_LEAVE = (event.core and event.core.LEAVE) or -1; + local socket_gettime = socket.gettime + function add_task(delay, callback) + local event_handle; + event_handle = base:addevent(nil, 0, function () + local ret = callback(socket_gettime()); + if ret then + return 0, ret; + elseif event_handle then + return EVENT_LEAVE; + end + end + , delay); + end +end + return { cfg = cfg, @@ -864,6 +881,7 @@ return { closeall = closeallservers, get_backend = get_backend, hook_signal = hook_signal, + add_task = add_task, __NAME = SCRIPT_NAME, __DATE = LAST_MODIFIED, diff --git a/net/server_select.lua b/net/server_select.lua index 51449bdf..d8404001 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -42,6 +42,7 @@ local os_difftime = os.difftime local math_min = math.min local math_huge = math.huge local table_concat = table.concat +local table_insert = table.insert local string_sub = string.sub local coroutine_wrap = coroutine.wrap local coroutine_yield = coroutine.yield @@ -832,6 +833,50 @@ addtimer = function( listener ) return true end +local add_task do + local data = {}; + local new_data = {}; + + function add_task(delay, callback) + local current_time = luasocket_gettime(); + delay = delay + current_time; + if delay >= current_time then + table_insert(new_data, {delay, callback}); + else + local r = callback(current_time); + if r and type(r) == "number" then + return add_task(r, callback); + end + end + end + + addtimer(function() + local current_time = luasocket_gettime(); + if #new_data > 0 then + for _, d in pairs(new_data) do + table_insert(data, d); + end + new_data = {}; + end + + local next_time = math_huge; + for i, d in pairs(data) do + local t, callback = d[1], d[2]; + if t <= current_time then + data[i] = nil; + local r = callback(current_time); + if type(r) == "number" then + add_task(r, callback); + next_time = math_min(next_time, r); + end + else + next_time = math_min(next_time, t - current_time); + end + end + return next_time; + end); +end + stats = function( ) return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen end @@ -1007,6 +1052,7 @@ end return { _addtimer = addtimer, + add_task = add_task; addclient = addclient, wrapclient = wrapclient, -- cgit v1.2.3 From 8dd15926f8b11511330217d5ddef32a9b75e6bdc Mon Sep 17 00:00:00 2001 From: daurnimator Date: Tue, 21 Oct 2014 17:26:48 -0400 Subject: net/server: If server.hook_signal exists, overwrite signal.signal; else make server.hook_signal == signal.signal No longer server_event specific server.hook_signal will always exist --- net/server.lua | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) (limited to 'net') diff --git a/net/server.lua b/net/server.lua index 449632ca..5c0d0d24 100644 --- a/net/server.lua +++ b/net/server.lua @@ -24,22 +24,6 @@ local set_config; if server_type == "event" then server = require "net.server_event"; - -- Overwrite signal.signal() because we need to ask libevent to - -- handle them instead - local ok, signal = pcall(require, "util.signal"); - if ok and signal then - local _signal_signal = signal.signal; - function signal.signal(signal_id, handler) - if type(signal_id) == "string" then - signal_id = signal[signal_id:upper()]; - end - if type(signal_id) ~= "number" then - return false, "invalid-signal"; - end - return server.hook_signal(signal_id, handler); - end - end - local defaults = {}; for k,v in pairs(server.cfg) do defaults[k] = v; @@ -82,6 +66,24 @@ else error("Unsupported server type") end +-- If server.hook_signal exists, replace signal.signal() +local ok, signal = pcall(require, "util.signal"); +if server.hook_signal then + if ok then + function signal.signal(signal_id, handler) + if type(signal_id) == "string" then + signal_id = signal[signal_id:upper()]; + end + if type(signal_id) ~= "number" then + return false, "invalid-signal"; + end + return server.hook_signal(signal_id, handler); + end + end +else + server.hook_signal = signal.signal; +end + if prosody then local config_get = require "core.configmanager".get; local function load_config() -- cgit v1.2.3 From 36ac759ec5e9c9a0a666840854f5284d0d7b6bb5 Mon Sep 17 00:00:00 2001 From: Matthew Wild Date: Wed, 22 Oct 2014 12:56:41 +0100 Subject: net.server: Rename 'server' config option to 'network_backend' (to select which net.server implementation to use) --- net/server.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'net') diff --git a/net/server.lua b/net/server.lua index 5c0d0d24..163b4476 100644 --- a/net/server.lua +++ b/net/server.lua @@ -6,7 +6,7 @@ -- COPYING file in the source package for more information. -- -local server_type = prosody and require "core.configmanager".get("*", "server") or "select"; +local server_type = prosody and require "core.configmanager".get("*", "network_backend") or "select"; if prosody and require "core.configmanager".get("*", "use_libevent") then server_type = "event"; end -- cgit v1.2.3 From 7c48f41f8fb2850bf5b38f3a17e8bc52fd94ce0f Mon Sep 17 00:00:00 2001 From: daurnimator Date: Wed, 22 Oct 2014 15:59:51 -0400 Subject: net/server: Remove print --- net/server.lua | 1 - 1 file changed, 1 deletion(-) (limited to 'net') diff --git a/net/server.lua b/net/server.lua index 163b4476..9f24b0a6 100644 --- a/net/server.lua +++ b/net/server.lua @@ -13,7 +13,6 @@ end if server_type == "event" then if not pcall(require, "luaevent.core") then - print(log) log("error", "libevent not found, falling back to select()"); server_type = "select" end -- cgit v1.2.3 From f08fe049ef978fdaa00132f238d6f038605ed3a1 Mon Sep 17 00:00:00 2001 From: daurnimator Date: Wed, 22 Oct 2014 16:00:40 -0400 Subject: net/server: Handle lack of util.signal correctly --- net/server.lua | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) (limited to 'net') diff --git a/net/server.lua b/net/server.lua index 9f24b0a6..a753a19c 100644 --- a/net/server.lua +++ b/net/server.lua @@ -66,9 +66,9 @@ else end -- If server.hook_signal exists, replace signal.signal() -local ok, signal = pcall(require, "util.signal"); -if server.hook_signal then - if ok then +local has_signal, signal = pcall(require, "util.signal"); +if has_signal then + if server.hook_signal then function signal.signal(signal_id, handler) if type(signal_id) == "string" then signal_id = signal[signal_id:upper()]; @@ -78,9 +78,15 @@ if server.hook_signal then end return server.hook_signal(signal_id, handler); end + else + server.hook_signal = signal.signal; end else - server.hook_signal = signal.signal; + if not server.hook_signal then + server.hook_signal = function() + return false, "signal hooking not supported" + end + end end if prosody then -- cgit v1.2.3 From 726d063ff7eac971074b7e6a7ac397a605721175 Mon Sep 17 00:00:00 2001 From: daurnimator Date: Tue, 18 Nov 2014 14:14:41 -0500 Subject: net.cqueues: Add module that allows use of cqueues while still using net.server as main loop --- net/cqueues.lua | 65 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 net/cqueues.lua (limited to 'net') diff --git a/net/cqueues.lua b/net/cqueues.lua new file mode 100644 index 00000000..e82fe4ad --- /dev/null +++ b/net/cqueues.lua @@ -0,0 +1,65 @@ +-- Prosody IM +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- +-- This module allows you to use cqueues with a net.server mainloop +-- + +local server = require "net.server"; +local cqueues = require "cqueues"; + +-- Create a single top level cqueue +local cq; + +if server.cq then -- server provides cqueues object + cq = server.cq; +elseif server.get_backend() == "select" and server._addtimer then -- server_select + cq = cqueues.new(); + local function step() + assert(cq:loop(0)); + end + + -- Use wrapclient (as wrapconnection isn't exported) to get server_select to watch cq fd + local handler = server.wrapclient({ + getfd = function() return cq:pollfd(); end; + settimeout = function() end; -- Method just needs to exist + close = function() end; -- Need close method for 'closeall' + }, nil, nil, {}); + + -- Only need to listen for readable; cqueues handles everything under the hood + -- readbuffer is called when `select` notes an fd as readable + handler.readbuffer = step; + + -- Use server_select low lever timer facility, + -- this callback gets called *every* time there is a timeout in the main loop + server._addtimer(function(current_time) + -- This may end up in extra step()'s, but cqueues handles it for us. + step(); + return cq:timeout(); + end); +elseif server.event and server.base then -- server_event + cq = cqueues.new(); + -- Only need to listen for readable; cqueues handles everything under the hood + local EV_READ = server.event.EV_READ; + server.base:addevent(cq:pollfd(), EV_READ, function(e) + assert(cq:loop(0)); + -- Convert a cq timeout to an acceptable timeout for luaevent + local t = cq:timeout(); + if t == 0 then -- if you give luaevent 0, it won't call this callback again + t = 0.000001; -- 1 microsecond is the smallest that works (goes into a `struct timeval`) + elseif t == nil then -- you always need to give a timeout, pick something big if we don't have one + t = 0x7FFFFFFF; -- largest 32bit int + end + return EV_READ, t; + end, + -- Schedule the callback to fire on first tick to ensure any cq:wrap calls that happen during start-up are serviced. + 0.000001); +else + error "NYI" +end + +return { + cq = cq; +} -- cgit v1.2.3 From f777a56e6634fc54ac9d2f8d9599861af8bb18d5 Mon Sep 17 00:00:00 2001 From: daurnimator Date: Tue, 6 Jan 2015 20:01:59 -0500 Subject: net.cqueues: Add workaround for luaevent callback getting collected --- net/cqueues.lua | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'net') diff --git a/net/cqueues.lua b/net/cqueues.lua index e82fe4ad..a67e405a 100644 --- a/net/cqueues.lua +++ b/net/cqueues.lua @@ -43,7 +43,11 @@ elseif server.event and server.base then -- server_event cq = cqueues.new(); -- Only need to listen for readable; cqueues handles everything under the hood local EV_READ = server.event.EV_READ; - server.base:addevent(cq:pollfd(), EV_READ, function(e) + local event_handle; + event_handle = server.base:addevent(cq:pollfd(), EV_READ, function(e) + -- Need to reference event_handle or this callback will get collected + -- This creates a circular reference that can only be broken if event_handle is manually :close()'d + local _ = event_handle; assert(cq:loop(0)); -- Convert a cq timeout to an acceptable timeout for luaevent local t = cq:timeout(); -- cgit v1.2.3 From 11d0ce1c4a12c2b2570c910d09c69afdbeef87f0 Mon Sep 17 00:00:00 2001 From: daurnimator Date: Tue, 13 Jan 2015 18:36:00 -0500 Subject: net.cqueues: Fixes hardcoded timeout for first iteration This was originally put in place as a fix for what ended up a cqueues bug: https://github.com/wahern/cqueues/issues/40 A check for a cqueues version with the bug fix is included. --- net/cqueues.lua | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) (limited to 'net') diff --git a/net/cqueues.lua b/net/cqueues.lua index a67e405a..f6bfd949 100644 --- a/net/cqueues.lua +++ b/net/cqueues.lua @@ -9,6 +9,7 @@ local server = require "net.server"; local cqueues = require "cqueues"; +assert(cqueues.VERSION >= 20150112, "cqueues newer than 20151013 required") -- Create a single top level cqueue local cq; @@ -43,23 +44,27 @@ elseif server.event and server.base then -- server_event cq = cqueues.new(); -- Only need to listen for readable; cqueues handles everything under the hood local EV_READ = server.event.EV_READ; + -- Convert a cqueues timeout to an acceptable timeout for luaevent + local function luaevent_safe_timeout(cq) + local t = cq:timeout(); + -- if you give luaevent 0 or nil, it re-uses the previous timeout. + if t == 0 then + t = 0.000001; -- 1 microsecond is the smallest that works (goes into a `struct timeval`) + elseif t == nil then -- pick something big if we don't have one + t = 0x7FFFFFFF; -- largest 32bit int + end + return t + end local event_handle; event_handle = server.base:addevent(cq:pollfd(), EV_READ, function(e) -- Need to reference event_handle or this callback will get collected -- This creates a circular reference that can only be broken if event_handle is manually :close()'d local _ = event_handle; + -- Run as many cqueues things as possible (with a timeout of 0) + -- If an error is thrown, it will break the libevent loop; but prosody resumes after logging a top level error assert(cq:loop(0)); - -- Convert a cq timeout to an acceptable timeout for luaevent - local t = cq:timeout(); - if t == 0 then -- if you give luaevent 0, it won't call this callback again - t = 0.000001; -- 1 microsecond is the smallest that works (goes into a `struct timeval`) - elseif t == nil then -- you always need to give a timeout, pick something big if we don't have one - t = 0x7FFFFFFF; -- largest 32bit int - end - return EV_READ, t; - end, - -- Schedule the callback to fire on first tick to ensure any cq:wrap calls that happen during start-up are serviced. - 0.000001); + return EV_READ, luaevent_safe_timeout(cq); + end, luaevent_safe_timeout(cq)); else error "NYI" end -- cgit v1.2.3 From 9b719f1f477341c979e14d1561a24cb8553c65d3 Mon Sep 17 00:00:00 2001 From: daurnimator Date: Fri, 16 Jan 2015 12:06:42 -0500 Subject: net.cqueues: Fix incorrect version check --- net/cqueues.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'net') diff --git a/net/cqueues.lua b/net/cqueues.lua index f6bfd949..8c4c756f 100644 --- a/net/cqueues.lua +++ b/net/cqueues.lua @@ -9,7 +9,7 @@ local server = require "net.server"; local cqueues = require "cqueues"; -assert(cqueues.VERSION >= 20150112, "cqueues newer than 20151013 required") +assert(cqueues.VERSION >= 20150113, "cqueues newer than 20150113 required") -- Create a single top level cqueue local cq; -- cgit v1.2.3 From 7a1a86aba7a1c1bacb7b082f8c94ec5ee20c405d Mon Sep 17 00:00:00 2001 From: daurnimator Date: Thu, 15 Jan 2015 09:03:00 -0500 Subject: net.server_select: Fix timers not being fired until another timer fixes (or 1 second passes) --- net/server_select.lua | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) (limited to 'net') diff --git a/net/server_select.lua b/net/server_select.lua index d8404001..6d98ccac 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -115,8 +115,6 @@ local _checkinterval local _sendtimeout local _readtimeout -local _timer - local _maxselectlen local _maxfd @@ -890,8 +888,15 @@ end loop = function(once) -- this is the main loop of the program if quitting then return "quitting"; end if once then quitting = "once"; end - local next_timer_time = math_huge; + _currenttime = luasocket_gettime( ) repeat + -- Fire timers + local next_timer_time = math_huge; + for i = 1, _timerlistlen do + local t = _timerlist[ i ]( _currenttime ) -- fire timers + if t then next_timer_time = math_min(next_timer_time, t); end + end + local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) ) for i, socket in ipairs( write ) do -- send data waiting in writequeues local handler = _socketlist[ socket ] @@ -940,18 +945,6 @@ loop = function(once) -- this is the main loop of the program end end - -- Fire timers - if _currenttime - _timer >= math_min(next_timer_time, 1) then - next_timer_time = math_huge; - for i = 1, _timerlistlen do - local t = _timerlist[ i ]( _currenttime ) -- fire timers - if t then next_timer_time = math_min(next_timer_time, t); end - end - _timer = _currenttime - else - next_timer_time = next_timer_time - (_currenttime - _timer); - end - -- wait some time (0 by default) socket_sleep( _sleeptime ) until quitting; @@ -1037,7 +1030,6 @@ use "setmetatable" ( _socketlist, { __mode = "k" } ) use "setmetatable" ( _readtimes, { __mode = "k" } ) use "setmetatable" ( _writetimes, { __mode = "k" } ) -_timer = luasocket_gettime( ) _starttime = luasocket_gettime( ) local function setlogger(new_logger) -- cgit v1.2.3 From c6bc543f5e75b3f589000de7382071ce263b1969 Mon Sep 17 00:00:00 2001 From: daurnimator Date: Thu, 15 Jan 2015 09:05:08 -0500 Subject: net.server_select: In add_task timer callback, use passed in time rather than re-fetching --- net/server_select.lua | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'net') diff --git a/net/server_select.lua b/net/server_select.lua index 6d98ccac..b9e72342 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -848,8 +848,7 @@ local add_task do end end - addtimer(function() - local current_time = luasocket_gettime(); + addtimer(function(current_time) if #new_data > 0 then for _, d in pairs(new_data) do table_insert(data, d); -- cgit v1.2.3 From ea84b7e27817d24e462bac02d64959a9c3483a9d Mon Sep 17 00:00:00 2001 From: daurnimator Date: Mon, 19 Jan 2015 14:01:11 -0500 Subject: net.server_select: Remove do-nothing os_difftime calls --- net/server_select.lua | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) (limited to 'net') diff --git a/net/server_select.lua b/net/server_select.lua index b9e72342..cf55be5d 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -38,7 +38,6 @@ local coroutine = use "coroutine" --// lua lib methods //-- -local os_difftime = os.difftime local math_min = math.min local math_huge = math.huge local table_concat = table.concat @@ -923,17 +922,16 @@ loop = function(once) -- this is the main loop of the program _currenttime = luasocket_gettime( ) -- Check for socket timeouts - local difftime = os_difftime( _currenttime - _starttime ) - if difftime > _checkinterval then + if _currenttime - _starttime > _checkinterval then _starttime = _currenttime for handler, timestamp in pairs( _writetimes ) do - if os_difftime( _currenttime - timestamp ) > _sendtimeout then + if _currenttime - timestamp > _sendtimeout then handler.disconnect( )( handler, "send timeout" ) handler:force_close() -- forced disconnect end end for handler, timestamp in pairs( _readtimes ) do - if os_difftime( _currenttime - timestamp ) > _readtimeout then + if _currenttime - timestamp > _readtimeout then if not(handler.onreadtimeout) or handler:onreadtimeout() ~= true then handler.disconnect( )( handler, "read timeout" ) handler:close( ) -- forced disconnect? -- cgit v1.2.3 From 8843844d0a481c43a63268f7f942a88037e4c422 Mon Sep 17 00:00:00 2001 From: daurnimator Date: Mon, 19 Jan 2015 14:05:37 -0500 Subject: net.server_select: Remove socket.sleep call from main loop It's been there since the start; but should really not be required. People can remember an issue with FreeBSD that this solved, but this was a hack solution anyway. If that issue rears it's head again, we will solve it properly. --- net/server_select.lua | 8 -------- 1 file changed, 8 deletions(-) (limited to 'net') diff --git a/net/server_select.lua b/net/server_select.lua index cf55be5d..a0574f33 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -57,7 +57,6 @@ local getaddrinfo = luasocket.dns.getaddrinfo local ssl_wrap = ( has_luasec and luasec.wrap ) local socket_bind = luasocket.bind -local socket_sleep = luasocket.sleep local socket_select = luasocket.select --// functions //-- @@ -101,7 +100,6 @@ local _sendtraffic local _readtraffic local _selecttimeout -local _sleeptime local _tcpbacklog local _starttime @@ -138,7 +136,6 @@ _sendtraffic = 0 -- some stats _readtraffic = 0 _selecttimeout = 1 -- timeout of socket.select -_sleeptime = 0 -- time to wait at the end of every loop _tcpbacklog = 128 -- some kind of hint to the OS _maxsendlen = 51000 * 1024 -- max len of send buffer @@ -790,7 +787,6 @@ end getsettings = function( ) return { select_timeout = _selecttimeout; - select_sleep_time = _sleeptime; tcp_backlog = _tcpbacklog; max_send_buffer_size = _maxsendlen; max_receive_buffer_size = _maxreadlen; @@ -808,7 +804,6 @@ changesettings = function( new ) return nil, "invalid settings table" end _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout - _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval @@ -941,9 +936,6 @@ loop = function(once) -- this is the main loop of the program end end end - - -- wait some time (0 by default) - socket_sleep( _sleeptime ) until quitting; if once and quitting == "once" then quitting = nil; return; end return "quitting" -- cgit v1.2.3 From 6591bd9db584bf6c5b8920c4438cd087bbfefce3 Mon Sep 17 00:00:00 2001 From: daurnimator Date: Mon, 19 Jan 2015 14:09:13 -0500 Subject: net.server_select: Remove unused code --- net/server_select.lua | 6 ------ 1 file changed, 6 deletions(-) (limited to 'net') diff --git a/net/server_select.lua b/net/server_select.lua index a0574f33..35dcb5a7 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -31,7 +31,6 @@ local tostring = use "tostring" --// lua libs //-- -local os = use "os" local table = use "table" local string = use "string" local coroutine = use "coroutine" @@ -287,7 +286,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport local bufferqueuelen = 0 -- end of buffer array local toclose - local fatalerror local needtls local bufferlen = 0 @@ -499,7 +497,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport return dispatch( handler, buffer, err ) else -- connections was closed or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) ) - fatalerror = true _ = handler and handler:force_close( err ) return false end @@ -539,7 +536,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport return true else -- connection was closed during sending or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) ) - fatalerror = true _ = handler and handler:force_close( err ) return false end @@ -1011,8 +1007,6 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ ) end end ---// EXPERIMENTAL //-- - ----------------------------------// BEGIN //-- use "setmetatable" ( _socketlist, { __mode = "k" } ) -- cgit v1.2.3