diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/adns.lua | 16 | ||||
-rw-r--r-- | net/connlisteners.lua | 18 | ||||
-rw-r--r-- | net/http/files.lua | 149 | ||||
-rw-r--r-- | net/resolvers/basic.lua | 1 | ||||
-rw-r--r-- | net/resolvers/manual.lua | 1 | ||||
-rw-r--r-- | net/resolvers/service.lua | 1 | ||||
-rw-r--r-- | net/server_epoll.lua | 139 | ||||
-rw-r--r-- | net/server_event.lua | 50 | ||||
-rw-r--r-- | net/server_select.lua | 107 | ||||
-rw-r--r-- | net/websocket/frames.lua | 7 |
10 files changed, 354 insertions, 135 deletions
diff --git a/net/adns.lua b/net/adns.lua index 560e4b53..4fa01f8a 100644 --- a/net/adns.lua +++ b/net/adns.lua @@ -14,7 +14,7 @@ local log = require "util.logger".init("adns"); local coroutine, tostring, pcall = coroutine, tostring, pcall; local setmetatable = setmetatable; -local function dummy_send(sock, data, i, j) return (j-i)+1; end +local function dummy_send(sock, data, i, j) return (j-i)+1; end -- luacheck: ignore 212 local _ENV = nil; -- luacheck: std none @@ -29,8 +29,7 @@ local function new_async_socket(sock, resolver) local peername = "<unknown>"; local listener = {}; local handler = {}; - local err; - function listener.onincoming(conn, data) + function listener.onincoming(conn, data) -- luacheck: ignore 212/conn if data then resolver:feed(handler, data); end @@ -46,9 +45,12 @@ local function new_async_socket(sock, resolver) resolver:servfail(conn); -- Let the magic commence end end - handler, err = server.wrapclient(sock, "dns", 53, listener); - if not handler then - return nil, err; + do + local err; + handler, err = server.wrapclient(sock, "dns", 53, listener); + if not handler then + return nil, err; + end end handler.settimeout = function () end @@ -89,7 +91,7 @@ function async_resolver_methods:lookup(handler, qname, qtype, qclass) end)(resolver:peek(qname, qtype, qclass)); end -function query_methods:cancel(call_handler, reason) +function query_methods:cancel(call_handler, reason) -- luacheck: ignore 212/reason log("warn", "Cancelling DNS lookup for %s", tostring(self[4])); self[1].cancel(self[2], self[3], self[4], self[5], call_handler); end diff --git a/net/connlisteners.lua b/net/connlisteners.lua deleted file mode 100644 index 9b8f88c3..00000000 --- a/net/connlisteners.lua +++ /dev/null @@ -1,18 +0,0 @@ --- COMPAT w/pre-0.9 -local log = require "util.logger".init("net.connlisteners"); -local traceback = debug.traceback; - -local _ENV = nil; --- luacheck: std none - -local function fail() - log("error", "Attempt to use legacy connlisteners API. For more info see https://prosody.im/doc/developers/network"); - log("error", "Legacy connlisteners API usage, %s", traceback("", 2)); -end - -return { - register = fail; - get = fail; - start = fail; - -- epic fail -}; diff --git a/net/http/files.lua b/net/http/files.lua new file mode 100644 index 00000000..090b15c8 --- /dev/null +++ b/net/http/files.lua @@ -0,0 +1,149 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local server = require"net.http.server"; +local lfs = require "lfs"; +local new_cache = require "util.cache".new; +local log = require "util.logger".init("net.http.files"); + +local os_date = os.date; +local open = io.open; +local stat = lfs.attributes; +local build_path = require"socket.url".build_path; +local path_sep = package.config:sub(1,1); + + +local forbidden_chars_pattern = "[/%z]"; +if package.config:sub(1,1) == "\\" then + forbidden_chars_pattern = "[/%z\001-\031\127\"*:<>?|]" +end + +local urldecode = require "util.http".urldecode; +local function sanitize_path(path) --> util.paths or util.http? + if not path then return end + local out = {}; + + local c = 0; + for component in path:gmatch("([^/]+)") do + component = urldecode(component); + if component:find(forbidden_chars_pattern) then + return nil; + elseif component == ".." then + if c <= 0 then + return nil; + end + out[c] = nil; + c = c - 1; + elseif component ~= "." then + c = c + 1; + out[c] = component; + end + end + if path:sub(-1,-1) == "/" then + out[c+1] = ""; + end + return "/"..table.concat(out, "/"); +end + +local function serve(opts) + if type(opts) ~= "table" then -- assume path string + opts = { path = opts }; + end + local mime_map = opts.mime_map or { html = "text/html" }; + local cache = new_cache(opts.cache_size or 256); + local cache_max_file_size = tonumber(opts.cache_max_file_size) or 1024 + -- luacheck: ignore 431 + local base_path = opts.path; + local dir_indices = opts.index_files or { "index.html", "index.htm" }; + local directory_index = opts.directory_index; + local function serve_file(event, path) + local request, response = event.request, event.response; + local sanitized_path = sanitize_path(path); + if path and not sanitized_path then + return 400; + end + path = sanitized_path; + local orig_path = sanitize_path(request.path); + local full_path = base_path .. (path or ""):gsub("/", path_sep); + local attr = stat(full_path:match("^.*[^\\/]")); -- Strip trailing path separator because Windows + if not attr then + return 404; + end + + local request_headers, response_headers = request.headers, response.headers; + + local last_modified = os_date('!%a, %d %b %Y %H:%M:%S GMT', attr.modification); + response_headers.last_modified = last_modified; + + local etag = ('"%02x-%x-%x-%x"'):format(attr.dev or 0, attr.ino or 0, attr.size or 0, attr.modification or 0); + response_headers.etag = etag; + + local if_none_match = request_headers.if_none_match + local if_modified_since = request_headers.if_modified_since; + if etag == if_none_match + or (not if_none_match and last_modified == if_modified_since) then + return 304; + end + + local data = cache:get(orig_path); + if data and data.etag == etag then + response_headers.content_type = data.content_type; + data = data.data; + cache:get(orig_path, data); + elseif attr.mode == "directory" and path then + if full_path:sub(-1) ~= "/" then + local dir_path = { is_absolute = true, is_directory = true }; + for dir in orig_path:gmatch("[^/]+") do dir_path[#dir_path+1]=dir; end + response_headers.location = build_path(dir_path); + return 301; + end + for i=1,#dir_indices do + if stat(full_path..dir_indices[i], "mode") == "file" then + return serve_file(event, path..dir_indices[i]); + end + end + + if directory_index then + data = server._events.fire_event("directory-index", { path = request.path, full_path = full_path }); + end + if not data then + return 403; + end + cache:set(orig_path, { data = data, content_type = mime_map.html; etag = etag; }); + response_headers.content_type = mime_map.html; + + else + local f, err = open(full_path, "rb"); + if not f then + log("debug", "Could not open %s. Error was %s", full_path, err); + return 403; + end + local ext = full_path:match("%.([^./]+)$"); + local content_type = ext and mime_map[ext]; + response_headers.content_type = content_type; + if attr.size > cache_max_file_size then + response_headers.content_length = attr.size; + log("debug", "%d > cache_max_file_size", attr.size); + return response:send_file(f); + else + data = f:read("*a"); + f:close(); + end + cache:set(orig_path, { data = data; content_type = content_type; etag = etag }); + end + + return response:send(data); + end + + return serve_file; +end + +return { + serve = serve; +} + diff --git a/net/resolvers/basic.lua b/net/resolvers/basic.lua index 9a3c9952..56f9c77d 100644 --- a/net/resolvers/basic.lua +++ b/net/resolvers/basic.lua @@ -1,5 +1,6 @@ local adns = require "net.adns"; local inet_pton = require "util.net".pton; +local unpack = table.unpack or unpack; -- luacheck: ignore 113 local methods = {}; local resolver_mt = { __index = methods }; diff --git a/net/resolvers/manual.lua b/net/resolvers/manual.lua index c0d4e5d5..dbc40256 100644 --- a/net/resolvers/manual.lua +++ b/net/resolvers/manual.lua @@ -1,5 +1,6 @@ local methods = {}; local resolver_mt = { __index = methods }; +local unpack = table.unpack or unpack; -- luacheck: ignore 113 -- Find the next target to connect to, and -- pass it to cb() diff --git a/net/resolvers/service.lua b/net/resolvers/service.lua index b5a2d821..d1b8556c 100644 --- a/net/resolvers/service.lua +++ b/net/resolvers/service.lua @@ -1,5 +1,6 @@ local adns = require "net.adns"; local basic = require "net.resolvers.basic"; +local unpack = table.unpack or unpack; -- luacheck: ignore 113 local methods = {}; local resolver_mt = { __index = methods }; diff --git a/net/server_epoll.lua b/net/server_epoll.lua index c279c579..5f62d931 100644 --- a/net/server_epoll.lua +++ b/net/server_epoll.lua @@ -6,9 +6,7 @@ -- -local t_sort = table.sort; local t_insert = table.insert; -local t_remove = table.remove; local t_concat = table.concat; local setmetatable = setmetatable; local tostring = tostring; @@ -20,6 +18,7 @@ local log = require "util.logger".init("server_epoll"); local socket = require "socket"; local luasec = require "ssl"; local gettime = require "util.time".now; +local indexedbheap = require "util.indexedbheap"; local createtable = require "util.table".create; local inet = require "util.net"; local inet_pton = inet.pton; @@ -39,7 +38,10 @@ local default_config = { __index = { read_timeout = 14 * 60; -- How long to wait for a socket to become writable after queuing data to send - send_timeout = 60; + send_timeout = 180; + + -- How long to wait for a socket to become writable after creation + connect_timeout = 20; -- Some number possibly influencing how many pending connections can be accepted tcp_backlog = 128; @@ -66,22 +68,24 @@ local fds = createtable(10, 0); -- FD -> conn -- Timer and scheduling -- -local timers = {}; +local timers = indexedbheap.create(); local function noop() end local function closetimer(t) t[1] = 0; t[2] = noop; + timers:remove(t.id); end --- Set to true when timers have changed -local resort_timers = false; +local function reschedule(t, time) + t[1] = time; + timers:reprioritize(t.id, time); +end -- Add absolute timer local function at(time, f) - local timer = { time, f, close = closetimer }; - t_insert(timers, timer); - resort_timers = true; + local timer = { time, f, close = closetimer, reschedule = reschedule, id = nil }; + timer.id = timers:insert(timer, time); return timer; end @@ -94,50 +98,32 @@ end -- Return time until next timeout local function runtimers(next_delay, min_wait) -- Any timers at all? - if not timers[1] then - return next_delay; - end - - if resort_timers then - -- Sort earliest timers to the end - t_sort(timers, function (a, b) return a[1] > b[1]; end); - resort_timers = false; - end + local now = gettime(); + local peek = timers:peek(); + while peek do - -- Iterate from the end and remove completed timers - for i = #timers, 1, -1 do - local timer = timers[i]; - local t, f = timer[1], timer[2]; - -- Get time for every iteration to increase accuracy - local now = gettime(); - if t > now then - -- This timer should not fire yet - local diff = t - now; - if diff < next_delay then - next_delay = diff; - end + if peek > now then + next_delay = peek - now; break; end - local new_timeout = f(now); - if new_timeout then - -- Schedule for 'delay' from the time actually scheduled, - -- not from now, in order to prevent timer drift. - timer[1] = t + new_timeout; - resort_timers = true; - else - t_remove(timers, i); + + local _, timer, id = timers:pop(); + local ok, ret = pcall(timer[2], now); + if ok and type(ret) == "number" then + local next_time = now+ret; + timer[1] = next_time; + timers:insert(timer, next_time); end - end - if resort_timers or next_delay < min_wait then - -- Timers may be added from within a timer callback. - -- Those would not be considered for next_delay, - -- and we might sleep for too long, so instead - -- we return a shorter timeout so we can - -- properly sort all new timers. - next_delay = min_wait; + peek = timers:peek(); + end + if peek == nil then + return next_delay; end + if next_delay < min_wait then + return min_wait; + end return next_delay; end @@ -176,6 +162,7 @@ function interface:on(what, ...) local ok, err = pcall(listener, self, ...); if not ok then log("error", "Error calling on%s: %s", what, err); + return; end return err; end @@ -243,8 +230,7 @@ function interface:setreadtimeout(t) end t = t or cfg.read_timeout; if self._readtimeout then - self._readtimeout[1] = gettime() + t; - resort_timers = true; + self._readtimeout:reschedule(gettime() + t); else self._readtimeout = addtimer(t, function () if self:on("readtimeout") then @@ -268,8 +254,7 @@ function interface:setwritetimeout(t) end t = t or cfg.send_timeout; if self._writetimeout then - self._writetimeout[1] = gettime() + t; - resort_timers = true; + self._writetimeout:reschedule(gettime() + t); else self._writetimeout = addtimer(t, function () self:on("disconnect", "write timeout"); @@ -426,8 +411,10 @@ function interface:write(data) else self.writebuffer = { data }; end - self:setwritetimeout(); - self:set(nil, true); + if not self._write_lock then + self:setwritetimeout(); + self:set(nil, true); + end return #data; end interface.send = interface.write; @@ -502,6 +489,13 @@ function interface:tlshandskake() end conn:settimeout(0); self.conn = conn; + if conn.sni then + if self.servername then + conn:sni(self.servername); + elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then + conn:sni(self._server.hosts, true); + end + end self:on("starttls"); self.ondrain = nil; self.onwritable = interface.tlshandskake; @@ -574,12 +568,14 @@ function interface:onacceptable() client:init(); if self.tls_direct then client:starttls(self.tls_ctx); + else + client:onconnect(); end end -- Initialization function interface:init() - self:setwritetimeout(); + self:setwritetimeout(cfg.connect_timeout); return self:add(true, true); end @@ -607,16 +603,28 @@ function interface:pausefor(t) end); end +function interface:pause_writes() + self._write_lock = true; + self:setwritetimeout(false); + self:set(nil, false); +end + +function interface:resume_writes() + self._write_lock = nil; + if self.writebuffer[1] then + self:setwritetimeout(); + self:set(nil, true); + end +end + -- Connected! function interface:onconnect() - if self.conn and not self.peername and self.conn.getpeername then - self.peername, self.peerport = self.conn:getpeername(); - end + self:updatenames(); self.onconnect = noop; self:on("connect"); end -local function addserver(addr, port, listeners, read_size, tls_ctx) +local function listen(addr, port, listeners, config) local conn, err = socket.bind(addr, port, cfg.tcp_backlog); if not conn then return conn, err; end conn:settimeout(0); @@ -624,10 +632,11 @@ local function addserver(addr, port, listeners, read_size, tls_ctx) conn = conn; created = gettime(); listeners = listeners; - read_size = read_size; + read_size = config and config.read_size; onreadable = interface.onacceptable; - tls_ctx = tls_ctx; - tls_direct = tls_ctx and true or false; + tls_ctx = config and config.tls_ctx; + tls_direct = config and config.tls_direct; + hosts = config and config.sni_hosts; sockname = addr; sockport = port; }, interface_mt); @@ -636,6 +645,15 @@ local function addserver(addr, port, listeners, read_size, tls_ctx) end -- COMPAT +local function addserver(addr, port, listeners, read_size, tls_ctx) + return listen(addr, port, listeners, { + read_size = read_size; + tls_ctx = tls_ctx; + tls_direct = tls_ctx and true or false; + }); +end + +-- COMPAT local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx) local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx); if not client.peername then @@ -771,6 +789,7 @@ return { addserver = addserver; addclient = addclient; add_task = addtimer; + listen = listen; at = at; loop = loop; closeall = closeall; diff --git a/net/server_event.lua b/net/server_event.lua index 11bd6a29..fde79d86 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -164,6 +164,15 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed debug( "fatal error while ssl wrapping:", err ) return false end + + if self.conn.sni then + if self.servername then + self.conn:sni(self.servername); + elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then + self.conn:sni(self._server.hosts, true); + end + end + self.conn:settimeout( 0 ) -- set non blocking local handshakecallback = coroutine_wrap(function( event ) local _, err @@ -253,6 +262,7 @@ end --TODO: Deprecate function interface_mt:lock_read(switch) + log("warn", ":lock_read is deprecated, use :pasue() and :resume()"); if switch then return self:pause(); else @@ -272,6 +282,19 @@ function interface_mt:resume() end end +function interface_mt:pause_writes() + return self:_lock(self.nointerface, self.noreading, true); +end + +function interface_mt:resume_writes() + self:_lock(self.nointerface, self.noreading, false); + if self.writecallback and not self.eventwrite then + self.eventwrite = addevent( base, self.conn, EV_WRITE, self.writecallback, cfg.WRITE_TIMEOUT ); -- register callback + return true; + end +end + + function interface_mt:counter(c) if c then self._connections = self._connections + c @@ -281,7 +304,7 @@ end -- Public methods function interface_mt:write(data) - if self.nowriting then return nil, "locked" end + if self.nointerface then return nil, "locked"; end --vdebug( "try to send data to client, id/data:", self.id, data ) data = tostring( data ) local len = #data @@ -293,7 +316,7 @@ function interface_mt:write(data) end t_insert(self.writebuffer, data) -- new buffer self.writebufferlen = total - if not self.eventwrite then -- register new write event + if not self.eventwrite and not self.nowriting then -- register new write event --vdebug( "register new write event" ) self.eventwrite = addevent( base, self.conn, EV_WRITE, self.writecallback, cfg.WRITE_TIMEOUT ) end @@ -635,7 +658,7 @@ local function handleclient( client, ip, port, server, pattern, listener, sslctx return interface end -local function handleserver( server, addr, port, pattern, listener, sslctx ) -- creates an server interface +local function handleserver( server, addr, port, pattern, listener, sslctx, startssl ) -- creates a server interface debug "creating server interface..." local interface = { _connections = 0; @@ -651,6 +674,7 @@ local function handleserver( server, addr, port, pattern, listener, sslctx ) -- _ip = addr, _port = port, _pattern = pattern, _sslctx = sslctx; + hosts = {}; } interface.id = tostring(interface):match("%x+$"); interface.readcallback = function( event ) -- server handler, called on incoming connections @@ -681,7 +705,7 @@ local function handleserver( server, addr, port, pattern, listener, sslctx ) -- interface._connections = interface._connections + 1 -- increase connection count local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx ) --vdebug( "client id:", clientinterface, "startssl:", startssl ) - if has_luasec and sslctx then + if has_luasec and startssl then clientinterface:starttls(sslctx, true) else clientinterface:_start_session( true ) @@ -700,9 +724,9 @@ local function handleserver( server, addr, port, pattern, listener, sslctx ) -- return interface end -local function addserver( addr, port, listener, pattern, sslctx, startssl ) -- TODO: check arguments - --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslctx or "nil", startssl or "nil") - if sslctx and not has_luasec then +local function listen(addr, port, listener, config) + config = config or {} + if config.sslctx and not has_luasec then debug "fatal error: luasec not found" return nil, "luasec not found" end @@ -711,11 +735,20 @@ local function addserver( addr, port, listener, pattern, sslctx, startssl ) -- debug( "creating server socket on "..addr.." port "..port.." failed:", err ) return nil, err end - local interface = handleserver( server, addr, port, pattern, listener, sslctx, startssl ) -- new server handler + local interface = handleserver( server, addr, port, config.read_size, listener, config.tls_ctx, config.tls_direct) -- new server handler debug( "new server created with id:", tostring(interface)) return interface end +local function addserver( addr, port, listener, pattern, sslctx ) -- TODO: check arguments + --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslctx or "nil", startssl or "nil") + return listen( addr, port, listener, { + read_size = pattern, + tls_ctx = sslctx, + tls_direct = not not sslctx, + }); +end + local function wrapclient( client, ip, port, listeners, pattern, sslctx ) local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx ) interface:_start_connection(sslctx) @@ -876,6 +909,7 @@ return { event_base = base, addevent = newevent, addserver = addserver, + listen = listen, addclient = addclient, wrapclient = wrapclient, setquitting = setquitting, diff --git a/net/server_select.lua b/net/server_select.lua index 1a40a6d3..e14c126e 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -68,6 +68,7 @@ local idfalse local closeall local addsocket local addserver +local listen local addtimer local getserver local wrapserver @@ -123,7 +124,7 @@ local _maxsslhandshake _server = { } -- key = port, value = table; list of listening servers _readlist = { } -- array with sockets to read from -_sendlist = { } -- arrary with sockets to write to +_sendlist = { } -- array with sockets to write to _timerlist = { } -- array of timer functions _socketlist = { } -- key = socket, value = wrapped socket (handlers) _readtimes = { } -- key = handler, value = timestamp of last data reading @@ -149,7 +150,7 @@ _checkinterval = 30 -- interval in secs to check idle clients _sendtimeout = 60000 -- allowed send idle time in secs _readtimeout = 14 * 60 -- allowed read idle time in secs -local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows +local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to determine whether this is Windows _maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows _maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows @@ -157,7 +158,7 @@ _maxsslhandshake = 30 -- max handshake round-trips ----------------------------------// PRIVATE //-- -wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd +wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, ssldirect ) -- this function wraps a server -- FIXME Make sure FD < _maxfd if socket:getfd() >= _maxfd then out_error("server.lua: Disallowed FD number: "..socket:getfd()) @@ -183,6 +184,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t handler.sslctx = function( ) return sslctx end + handler.hosts = {} -- sni handler.remove = function( ) connections = connections - 1 if handler then @@ -244,13 +246,13 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t local client, err = accept( socket ) -- try to accept if client then local ip, clientport = client:getpeername( ) - local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket + local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx, ssldirect ) -- wrap new client socket if err then -- error while wrapping ssl socket return false end connections = connections + 1 out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport)) - if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes + if dispatch and not ssldirect then -- SSL connections will notify onconnect when handshake completes return dispatch( handler ); end return; @@ -264,7 +266,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t return handler end -wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object +wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, ssldirect ) -- this function wraps a client to a handler object if socket:getfd() >= _maxfd then out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent @@ -424,9 +426,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport bufferlen = bufferlen + #data if bufferlen > maxsendlen then _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle - handler.write = idfalse -- don't write anymore return false - elseif socket and not _sendlist[ socket ] then + elseif not nosend and socket and not _sendlist[ socket ] then _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) end bufferqueuelen = bufferqueuelen + 1 @@ -456,49 +457,55 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport maxreadlen = readlen or maxreadlen return bufferlen, maxreadlen, maxsendlen end - --TODO: Deprecate handler.lock_read = function (self, switch) + out_error( "server.lua, lock_read() is deprecated, use pause() and resume()" ) if switch == true then - local tmp = _readlistlen - _readlistlen = removesocket( _readlist, socket, _readlistlen ) - _readtimes[ handler ] = nil - if _readlistlen ~= tmp then - noread = true - end + return self:pause() elseif switch == false then - if noread then - noread = false - _readlistlen = addsocket(_readlist, socket, _readlistlen) - _readtimes[ handler ] = _currenttime - end + return self:resume() end return noread end handler.pause = function (self) - return self:lock_read(true); + local tmp = _readlistlen + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _readtimes[ handler ] = nil + if _readlistlen ~= tmp then + noread = true + end + return noread; end handler.resume = function (self) - return self:lock_read(false); + if noread then + noread = false + _readlistlen = addsocket(_readlist, socket, _readlistlen) + _readtimes[ handler ] = _currenttime + end + return noread; end handler.lock = function( self, switch ) - handler.lock_read (switch) + out_error( "server.lua, lock() is deprecated" ) + handler.lock_read (self, switch) if switch == true then - handler.write = idfalse - local tmp = _sendlistlen - _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) - _writetimes[ handler ] = nil - if _sendlistlen ~= tmp then - nosend = true - end + handler.pause_writes (self) elseif switch == false then - handler.write = write - if nosend then - nosend = false - write( "" ) - end + handler.resume_writes (self) end return noread, nosend end + handler.pause_writes = function (self) + local tmp = _sendlistlen + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + _writetimes[ handler ] = nil + nosend = true + end + handler.resume_writes = function (self) + nosend = false + if bufferlen > 0 then + _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) + end + end + local _readbuffer = function( ) -- this function reads data local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern" if not err or (err == "wantread" or err == "timeout") then -- received something @@ -619,11 +626,20 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) local oldsocket, err = socket socket, err = ssl_wrap( socket, sslctx ) -- wrap socket + if not socket then out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) return nil, err -- fatal error end + if socket.sni then + if self.servername then + socket:sni(self.servername); + elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then + socket:sni(self.server().hosts, true); + end + end + socket:settimeout( 0 ) -- add the new socket to our system @@ -659,7 +675,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _socketlist[ socket ] = handler _readlistlen = addsocket(_readlist, socket, _readlistlen) - if sslctx and has_luasec then + if sslctx and ssldirect and has_luasec then out_put "server.lua: auto-starting ssl negotiation..." handler.autostart_ssl = true; local ok, err = handler:starttls(sslctx); @@ -734,9 +750,13 @@ end ----------------------------------// PUBLIC //-- -addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server +listen = function ( addr, port, listeners, config ) addr = addr or "*" + config = config or {} local err + local sslctx = config.tls_ctx; + local ssldirect = config.tls_direct; + local pattern = config.read_size; if type( listeners ) ~= "table" then err = "invalid listener table" elseif type ( addr ) ~= "string" then @@ -757,7 +777,7 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function out_error( "server.lua, [", addr, "]:", port, ": ", err ) return nil, err end - local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket + local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, ssldirect ) -- wrap new server socket if not handler then server:close( ) return nil, err @@ -770,6 +790,14 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function return handler end +addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server + return listen(addr, port, listeners, { + read_size = pattern; + tls_ctx = sslctx; + tls_direct = sslctx and true or false; + }); +end + getserver = function ( addr, port ) return _server[ addr..":"..port ]; end @@ -978,7 +1006,7 @@ end --// EXPERIMENTAL //-- local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx ) - local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) + local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, sslctx) if not handler then return nil, err end _socketlist[ socket ] = handler if not sslctx then @@ -1114,6 +1142,7 @@ return { stats = stats, closeall = closeall, addserver = addserver, + listen = listen, getserver = getserver, setlogger = setlogger, getsettings = getsettings, diff --git a/net/websocket/frames.lua b/net/websocket/frames.lua index ba25d261..86752109 100644 --- a/net/websocket/frames.lua +++ b/net/websocket/frames.lua @@ -9,20 +9,21 @@ local softreq = require "util.dependencies".softreq; local random_bytes = require "util.random".bytes; -local bit = assert(softreq"bit" or softreq"bit32", +local bit = assert(softreq"bit32" or softreq"bit", "No bit module found. See https://prosody.im/doc/depends#bitop"); local band = bit.band; local bor = bit.bor; local bxor = bit.bxor; local lshift = bit.lshift; local rshift = bit.rshift; +local unpack = table.unpack or unpack; -- luacheck: ignore 113 local t_concat = table.concat; local s_byte = string.byte; local s_char= string.char; local s_sub = string.sub; -local s_pack = string.pack; -- luacheck: ignore 143 -local s_unpack = string.unpack; -- luacheck: ignore 143 +local s_pack = string.pack; +local s_unpack = string.unpack; if not s_pack and softreq"struct" then s_pack = softreq"struct".pack; |