aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/adns.lua16
-rw-r--r--net/connlisteners.lua18
-rw-r--r--net/http/files.lua149
-rw-r--r--net/resolvers/basic.lua1
-rw-r--r--net/resolvers/manual.lua1
-rw-r--r--net/resolvers/service.lua1
-rw-r--r--net/server_epoll.lua139
-rw-r--r--net/server_event.lua50
-rw-r--r--net/server_select.lua107
-rw-r--r--net/websocket/frames.lua7
10 files changed, 377 insertions, 112 deletions
diff --git a/net/adns.lua b/net/adns.lua
index 560e4b53..4fa01f8a 100644
--- a/net/adns.lua
+++ b/net/adns.lua
@@ -14,7 +14,7 @@ local log = require "util.logger".init("adns");
local coroutine, tostring, pcall = coroutine, tostring, pcall;
local setmetatable = setmetatable;
-local function dummy_send(sock, data, i, j) return (j-i)+1; end
+local function dummy_send(sock, data, i, j) return (j-i)+1; end -- luacheck: ignore 212
local _ENV = nil;
-- luacheck: std none
@@ -29,8 +29,7 @@ local function new_async_socket(sock, resolver)
local peername = "<unknown>";
local listener = {};
local handler = {};
- local err;
- function listener.onincoming(conn, data)
+ function listener.onincoming(conn, data) -- luacheck: ignore 212/conn
if data then
resolver:feed(handler, data);
end
@@ -46,9 +45,12 @@ local function new_async_socket(sock, resolver)
resolver:servfail(conn); -- Let the magic commence
end
end
- handler, err = server.wrapclient(sock, "dns", 53, listener);
- if not handler then
- return nil, err;
+ do
+ local err;
+ handler, err = server.wrapclient(sock, "dns", 53, listener);
+ if not handler then
+ return nil, err;
+ end
end
handler.settimeout = function () end
@@ -89,7 +91,7 @@ function async_resolver_methods:lookup(handler, qname, qtype, qclass)
end)(resolver:peek(qname, qtype, qclass));
end
-function query_methods:cancel(call_handler, reason)
+function query_methods:cancel(call_handler, reason) -- luacheck: ignore 212/reason
log("warn", "Cancelling DNS lookup for %s", tostring(self[4]));
self[1].cancel(self[2], self[3], self[4], self[5], call_handler);
end
diff --git a/net/connlisteners.lua b/net/connlisteners.lua
deleted file mode 100644
index 9b8f88c3..00000000
--- a/net/connlisteners.lua
+++ /dev/null
@@ -1,18 +0,0 @@
--- COMPAT w/pre-0.9
-local log = require "util.logger".init("net.connlisteners");
-local traceback = debug.traceback;
-
-local _ENV = nil;
--- luacheck: std none
-
-local function fail()
- log("error", "Attempt to use legacy connlisteners API. For more info see https://prosody.im/doc/developers/network");
- log("error", "Legacy connlisteners API usage, %s", traceback("", 2));
-end
-
-return {
- register = fail;
- get = fail;
- start = fail;
- -- epic fail
-};
diff --git a/net/http/files.lua b/net/http/files.lua
new file mode 100644
index 00000000..7ff81fc8
--- /dev/null
+++ b/net/http/files.lua
@@ -0,0 +1,149 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local server = require"net.http.server";
+local lfs = require "lfs";
+local new_cache = require "util.cache".new;
+local log = require "util.logger".init("net.http.files");
+
+local os_date = os.date;
+local open = io.open;
+local stat = lfs.attributes;
+local build_path = require"socket.url".build_path;
+local path_sep = package.config:sub(1,1);
+
+
+local forbidden_chars_pattern = "[/%z]";
+if package.config:sub(1,1) == "\\" then
+ forbidden_chars_pattern = "[/%z\001-\031\127\"*:<>?|]"
+end
+
+local urldecode = require "util.http".urldecode;
+local function sanitize_path(path) --> util.paths or util.http?
+ if not path then return end
+ local out = {};
+
+ local c = 0;
+ for component in path:gmatch("([^/]+)") do
+ component = urldecode(component);
+ if component:find(forbidden_chars_pattern) then
+ return nil;
+ elseif component == ".." then
+ if c <= 0 then
+ return nil;
+ end
+ out[c] = nil;
+ c = c - 1;
+ elseif component ~= "." then
+ c = c + 1;
+ out[c] = component;
+ end
+ end
+ if path:sub(-1,-1) == "/" then
+ out[c+1] = "";
+ end
+ return "/"..table.concat(out, "/");
+end
+
+local function serve(opts)
+ if type(opts) ~= "table" then -- assume path string
+ opts = { path = opts };
+ end
+ local mime_map = opts.mime_map or { html = "text/html" };
+ local cache = new_cache(opts.cache_size or 256);
+ local cache_max_file_size = tonumber(opts.cache_max_file_size) or 1024
+ -- luacheck: ignore 431
+ local base_path = opts.path;
+ local dir_indices = opts.index_files or { "index.html", "index.htm" };
+ local directory_index = opts.directory_index;
+ local function serve_file(event, path)
+ local request, response = event.request, event.response;
+ local sanitized_path = sanitize_path(path);
+ if path and not sanitized_path then
+ return 400;
+ end
+ path = sanitized_path;
+ local orig_path = sanitize_path(request.path);
+ local full_path = base_path .. (path or ""):gsub("/", path_sep);
+ local attr = stat(full_path:match("^.*[^\\/]")); -- Strip trailing path separator because Windows
+ if not attr then
+ return 404;
+ end
+
+ local request_headers, response_headers = request.headers, response.headers;
+
+ local last_modified = os_date('!%a, %d %b %Y %H:%M:%S GMT', attr.modification);
+ response_headers.last_modified = last_modified;
+
+ local etag = ('"%02x-%x-%x-%x"'):format(attr.dev or 0, attr.ino or 0, attr.size or 0, attr.modification or 0);
+ response_headers.etag = etag;
+
+ local if_none_match = request_headers.if_none_match
+ local if_modified_since = request_headers.if_modified_since;
+ if etag == if_none_match
+ or (not if_none_match and last_modified == if_modified_since) then
+ return 304;
+ end
+
+ local data = cache:get(orig_path);
+ if data and data.etag == etag then
+ response_headers.content_type = data.content_type;
+ data = data.data;
+ cache:set(orig_path, data);
+ elseif attr.mode == "directory" and path then
+ if full_path:sub(-1) ~= "/" then
+ local dir_path = { is_absolute = true, is_directory = true };
+ for dir in orig_path:gmatch("[^/]+") do dir_path[#dir_path+1]=dir; end
+ response_headers.location = build_path(dir_path);
+ return 301;
+ end
+ for i=1,#dir_indices do
+ if stat(full_path..dir_indices[i], "mode") == "file" then
+ return serve_file(event, path..dir_indices[i]);
+ end
+ end
+
+ if directory_index then
+ data = server._events.fire_event("directory-index", { path = request.path, full_path = full_path });
+ end
+ if not data then
+ return 403;
+ end
+ cache:set(orig_path, { data = data, content_type = mime_map.html; etag = etag; });
+ response_headers.content_type = mime_map.html;
+
+ else
+ local f, err = open(full_path, "rb");
+ if not f then
+ log("debug", "Could not open %s. Error was %s", full_path, err);
+ return 403;
+ end
+ local ext = full_path:match("%.([^./]+)$");
+ local content_type = ext and mime_map[ext];
+ response_headers.content_type = content_type;
+ if attr.size > cache_max_file_size then
+ response_headers.content_length = attr.size;
+ log("debug", "%d > cache_max_file_size", attr.size);
+ return response:send_file(f);
+ else
+ data = f:read("*a");
+ f:close();
+ end
+ cache:set(orig_path, { data = data; content_type = content_type; etag = etag });
+ end
+
+ return response:send(data);
+ end
+
+ return serve_file;
+end
+
+return {
+ serve = serve;
+}
+
diff --git a/net/resolvers/basic.lua b/net/resolvers/basic.lua
index 9a3c9952..56f9c77d 100644
--- a/net/resolvers/basic.lua
+++ b/net/resolvers/basic.lua
@@ -1,5 +1,6 @@
local adns = require "net.adns";
local inet_pton = require "util.net".pton;
+local unpack = table.unpack or unpack; -- luacheck: ignore 113
local methods = {};
local resolver_mt = { __index = methods };
diff --git a/net/resolvers/manual.lua b/net/resolvers/manual.lua
index c0d4e5d5..dbc40256 100644
--- a/net/resolvers/manual.lua
+++ b/net/resolvers/manual.lua
@@ -1,5 +1,6 @@
local methods = {};
local resolver_mt = { __index = methods };
+local unpack = table.unpack or unpack; -- luacheck: ignore 113
-- Find the next target to connect to, and
-- pass it to cb()
diff --git a/net/resolvers/service.lua b/net/resolvers/service.lua
index b5a2d821..d1b8556c 100644
--- a/net/resolvers/service.lua
+++ b/net/resolvers/service.lua
@@ -1,5 +1,6 @@
local adns = require "net.adns";
local basic = require "net.resolvers.basic";
+local unpack = table.unpack or unpack; -- luacheck: ignore 113
local methods = {};
local resolver_mt = { __index = methods };
diff --git a/net/server_epoll.lua b/net/server_epoll.lua
index 0c03ae15..ccf46928 100644
--- a/net/server_epoll.lua
+++ b/net/server_epoll.lua
@@ -9,12 +9,12 @@
local t_insert = table.insert;
local t_concat = table.concat;
local setmetatable = setmetatable;
-local tostring = tostring;
local pcall = pcall;
local type = type;
local next = next;
local pairs = pairs;
-local log = require "util.logger".init("server_epoll");
+local logger = require "util.logger";
+local log = logger.init("server_epoll");
local socket = require "socket";
local luasec = require "ssl";
local gettime = require "util.time".now;
@@ -23,6 +23,7 @@ local createtable = require "util.table".create;
local inet = require "util.net";
local inet_pton = inet.pton;
local _SOCKETINVALID = socket._SOCKETINVALID or -1;
+local new_id = require "util.id".medium;
local poller = require "util.poll"
local EEXIST = poller.EEXIST;
@@ -38,7 +39,10 @@ local default_config = { __index = {
read_timeout = 14 * 60;
-- How long to wait for a socket to become writable after queuing data to send
- send_timeout = 60;
+ send_timeout = 180;
+
+ -- How long to wait for a socket to become writable after creation
+ connect_timeout = 20;
-- Some number possibly influencing how many pending connections can be accepted
tcp_backlog = 128;
@@ -58,6 +62,10 @@ local default_config = { __index = {
-- Maximum and minimum amount of time to sleep waiting for events (adjusted for pending timers)
max_wait = 86400;
min_wait = 1e-06;
+
+ -- EXPERIMENTAL
+ -- Whether to kill connections in case of callback errors.
+ fatal_errors = false;
}};
local cfg = default_config.__index;
@@ -102,7 +110,7 @@ local function runtimers(next_delay, min_wait)
if peek > now then
next_delay = peek - now;
break;
- end
+ end
local _, timer, id = timers:pop();
local ok, ret = pcall(timer[2], now);
@@ -110,10 +118,10 @@ local function runtimers(next_delay, min_wait)
local next_time = now+ret;
timer[1] = next_time;
timers:insert(timer, next_time);
- end
+ end
peek = timers:peek();
- end
+ end
if peek == nil then
return next_delay;
end
@@ -138,6 +146,15 @@ function interface_mt:__tostring()
return ("FD %d"):format(self:getfd());
end
+interface.log = log;
+function interface:debug(msg, ...) --luacheck: ignore 212/self
+ self.log("debug", msg, ...);
+end
+
+function interface:error(msg, ...) --luacheck: ignore 212/self
+ self.log("error", msg, ...);
+end
+
-- Replace the listener and tell the old one
function interface:setlistener(listeners, data)
self:on("detach");
@@ -148,17 +165,23 @@ end
-- Call a listener callback
function interface:on(what, ...)
if not self.listeners then
- log("error", "%s has no listeners", self);
+ self:debug("Interface is missing listener callbacks");
return;
end
local listener = self.listeners["on"..what];
if not listener then
- -- log("debug", "Missing listener 'on%s'", what); -- uncomment for development and debugging
+ -- self:debug("Missing listener 'on%s'", what); -- uncomment for development and debugging
return;
end
local ok, err = pcall(listener, self, ...);
if not ok then
- log("error", "Error calling on%s: %s", what, err);
+ if cfg.fatal_errors then
+ self:debug("Closing due to error calling on%s: %s", what, err);
+ self:destroy();
+ else
+ self:debug("Error calling on%s: %s", what, err);
+ end
+ return nil, err;
end
return err;
end
@@ -269,15 +292,15 @@ function interface:add(r, w)
local ok, err, errno = poll:add(fd, r, w);
if not ok then
if errno == EEXIST then
- log("debug", "%s already registered!", self);
+ self:debug("FD already registered in poller! (EEXIST)");
return self:set(r, w); -- So try to change its flags
end
- log("error", "Could not register %s: %s(%d)", self, err, errno);
+ self:debug("Could not register in poller: %s(%d)", err, errno);
return ok, err;
end
self._wantread, self._wantwrite = r, w;
fds[fd] = self;
- log("debug", "Watching %s", self);
+ self:debug("Registered in poller");
return true;
end
@@ -290,7 +313,7 @@ function interface:set(r, w)
if w == nil then w = self._wantwrite; end
local ok, err, errno = poll:set(fd, r, w);
if not ok then
- log("error", "Could not update poller state %s: %s(%d)", self, err, errno);
+ self:debug("Could not update poller state: %s(%d)", err, errno);
return ok, err;
end
self._wantread, self._wantwrite = r, w;
@@ -307,12 +330,12 @@ function interface:del()
end
local ok, err, errno = poll:del(fd);
if not ok and errno ~= ENOENT then
- log("error", "Could not unregister %s: %s(%d)", self, err, errno);
+ self:debug("Could not unregister: %s(%d)", err, errno);
return ok, err;
end
self._wantread, self._wantwrite = nil, nil;
fds[fd] = nil;
- log("debug", "Unwatched %s", self);
+ self:debug("Unregistered from poller");
return true;
end
@@ -407,8 +430,10 @@ function interface:write(data)
else
self.writebuffer = { data };
end
- self:setwritetimeout();
- self:set(nil, true);
+ if not self._write_lock then
+ self:setwritetimeout();
+ self:set(nil, true);
+ end
return #data;
end
interface.send = interface.write;
@@ -418,10 +443,10 @@ function interface:close()
if self.writebuffer and self.writebuffer[1] then
self:set(false, true); -- Flush final buffer contents
self.write, self.send = noop, noop; -- No more writing
- log("debug", "Close %s after writing", self);
+ self:debug("Close after writing");
self.ondrain = interface.close;
else
- log("debug", "Close %s now", self);
+ self:debug("Closing now");
self.write, self.send = noop, noop;
self.close = noop;
self:on("disconnect");
@@ -450,7 +475,7 @@ function interface:starttls(tls_ctx)
if tls_ctx then self.tls_ctx = tls_ctx; end
self.starttls = false;
if self.writebuffer and self.writebuffer[1] then
- log("debug", "Start TLS on %s after write", self);
+ self:debug("Start TLS after write");
self.ondrain = interface.starttls;
self:set(nil, true); -- make sure wantwrite is set
else
@@ -460,7 +485,7 @@ function interface:starttls(tls_ctx)
self.onwritable = interface.tlshandskake;
self.onreadable = interface.tlshandskake;
self:set(true, true);
- log("debug", "Prepare to start TLS on %s", self);
+ self:debug("Prepared to start TLS");
end
end
@@ -469,12 +494,12 @@ function interface:tlshandskake()
self:setreadtimeout(false);
if not self._tls then
self._tls = true;
- log("debug", "Start TLS on %s now", self);
+ self:debug("Starting TLS now");
self:del();
local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx);
if not ok then
conn, err = ok, conn;
- log("error", "Failed to initialize TLS: %s", err);
+ self:debug("Failed to initialize TLS: %s", err);
end
if not conn then
self:on("disconnect", err);
@@ -483,6 +508,13 @@ function interface:tlshandskake()
end
conn:settimeout(0);
self.conn = conn;
+ if conn.sni then
+ if self.servername then
+ conn:sni(self.servername);
+ elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
+ conn:sni(self._server.hosts, true);
+ end
+ end
self:on("starttls");
self.ondrain = nil;
self.onwritable = interface.tlshandskake;
@@ -491,22 +523,22 @@ function interface:tlshandskake()
end
local ok, err = self.conn:dohandshake();
if ok then
- log("debug", "TLS handshake on %s complete", self);
+ self:debug("TLS handshake complete");
self.onwritable = nil;
self.onreadable = nil;
self:on("status", "ssl-handshake-complete");
self:setwritetimeout();
self:set(true, true);
elseif err == "wantread" then
- log("debug", "TLS handshake on %s to wait until readable", self);
+ self:debug("TLS handshake to wait until readable");
self:set(true, false);
self:setreadtimeout(cfg.ssl_handshake_timeout);
elseif err == "wantwrite" then
- log("debug", "TLS handshake on %s to wait until writable", self);
+ self:debug("TLS handshake to wait until writable");
self:set(false, true);
self:setwritetimeout(cfg.ssl_handshake_timeout);
else
- log("debug", "TLS handshake error on %s: %s", self, err);
+ self:debug("TLS handshake error: %s", err);
self:on("disconnect", err);
self:destroy();
end
@@ -523,6 +555,7 @@ local function wrapsocket(client, server, read_size, listeners, tls_ctx) -- luas
writebuffer = {};
tls_ctx = tls_ctx or (server and server.tls_ctx);
tls_direct = server and server.tls_direct;
+ log = logger.init(("conn%s"):format(new_id()));
}, interface_mt);
conn:updatenames();
@@ -546,21 +579,23 @@ end
function interface:onacceptable()
local conn, err = self.conn:accept();
if not conn then
- log("debug", "Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval);
+ self:debug("Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval);
self:pausefor(cfg.accept_retry_interval);
return;
end
local client = wrapsocket(conn, self, nil, self.listeners);
- log("debug", "New connection %s", tostring(client));
+ client:debug("New connection %s on server %s", client, self);
client:init();
if self.tls_direct then
client:starttls(self.tls_ctx);
+ else
+ client:onconnect();
end
end
-- Initialization
function interface:init()
- self:setwritetimeout();
+ self:setwritetimeout(cfg.connect_timeout);
return self:add(true, true);
end
@@ -588,16 +623,28 @@ function interface:pausefor(t)
end);
end
+function interface:pause_writes()
+ self._write_lock = true;
+ self:setwritetimeout(false);
+ self:set(nil, false);
+end
+
+function interface:resume_writes()
+ self._write_lock = nil;
+ if self.writebuffer[1] then
+ self:setwritetimeout();
+ self:set(nil, true);
+ end
+end
+
-- Connected!
function interface:onconnect()
- if self.conn and not self.peername and self.conn.getpeername then
- self.peername, self.peerport = self.conn:getpeername();
- end
+ self:updatenames();
self.onconnect = noop;
self:on("connect");
end
-local function addserver(addr, port, listeners, read_size, tls_ctx)
+local function listen(addr, port, listeners, config)
local conn, err = socket.bind(addr, port, cfg.tcp_backlog);
if not conn then return conn, err; end
conn:settimeout(0);
@@ -605,18 +652,30 @@ local function addserver(addr, port, listeners, read_size, tls_ctx)
conn = conn;
created = gettime();
listeners = listeners;
- read_size = read_size;
+ read_size = config and config.read_size;
onreadable = interface.onacceptable;
- tls_ctx = tls_ctx;
- tls_direct = tls_ctx and true or false;
+ tls_ctx = config and config.tls_ctx;
+ tls_direct = config and config.tls_direct;
+ hosts = config and config.sni_hosts;
sockname = addr;
sockport = port;
+ log = logger.init(("serv%s"):format(new_id()));
}, interface_mt);
+ server:debug("Server %s created", server);
server:add(true, false);
return server;
end
-- COMPAT
+local function addserver(addr, port, listeners, read_size, tls_ctx)
+ return listen(addr, port, listeners, {
+ read_size = read_size;
+ tls_ctx = tls_ctx;
+ tls_direct = tls_ctx and true or false;
+ });
+end
+
+-- COMPAT
local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx)
local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx);
if not client.peername then
@@ -649,6 +708,7 @@ local function addclient(addr, port, listeners, read_size, tls_ctx, typ)
return nil, "invalid socket type";
end
local conn, err = create();
+ if not conn then return conn, err; end
local ok, err = conn:settimeout(0);
if not ok then return ok, err; end
local ok, err = conn:setpeername(addr, port);
@@ -659,6 +719,7 @@ local function addclient(addr, port, listeners, read_size, tls_ctx, typ)
if tls_ctx then
client:starttls(tls_ctx);
end
+ client:debug("Client %s created", client);
return client, conn;
end
@@ -677,6 +738,7 @@ local function watchfd(fd, onreadable, onwritable)
end;
-- Otherwise it'll need to be something LuaSocket-compatible
end
+ conn.log = logger.init(("fdwatch%s"):format(new_id()));
conn:add(onreadable, onwritable);
return conn;
end;
@@ -752,6 +814,7 @@ return {
addserver = addserver;
addclient = addclient;
add_task = addtimer;
+ listen = listen;
at = at;
loop = loop;
closeall = closeall;
@@ -766,6 +829,7 @@ return {
-- libevent emulation
event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 };
addevent = function (fd, mode, callback)
+ log("warn", "Using deprecated libevent emulation, please update code to use watchfd API instead");
local function onevent(self)
local ret = self:callback();
if ret == -1 then
@@ -785,6 +849,7 @@ return {
fds[fd] = nil;
end;
}, interface_mt);
+ conn.log = logger.init(("fdwatch%d"):format(conn:getfd()));
local ok, err = conn:add(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
if not ok then return ok, err; end
return conn;
diff --git a/net/server_event.lua b/net/server_event.lua
index 11bd6a29..fde79d86 100644
--- a/net/server_event.lua
+++ b/net/server_event.lua
@@ -164,6 +164,15 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed
debug( "fatal error while ssl wrapping:", err )
return false
end
+
+ if self.conn.sni then
+ if self.servername then
+ self.conn:sni(self.servername);
+ elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
+ self.conn:sni(self._server.hosts, true);
+ end
+ end
+
self.conn:settimeout( 0 ) -- set non blocking
local handshakecallback = coroutine_wrap(function( event )
local _, err
@@ -253,6 +262,7 @@ end
--TODO: Deprecate
function interface_mt:lock_read(switch)
+ log("warn", ":lock_read is deprecated, use :pasue() and :resume()");
if switch then
return self:pause();
else
@@ -272,6 +282,19 @@ function interface_mt:resume()
end
end
+function interface_mt:pause_writes()
+ return self:_lock(self.nointerface, self.noreading, true);
+end
+
+function interface_mt:resume_writes()
+ self:_lock(self.nointerface, self.noreading, false);
+ if self.writecallback and not self.eventwrite then
+ self.eventwrite = addevent( base, self.conn, EV_WRITE, self.writecallback, cfg.WRITE_TIMEOUT ); -- register callback
+ return true;
+ end
+end
+
+
function interface_mt:counter(c)
if c then
self._connections = self._connections + c
@@ -281,7 +304,7 @@ end
-- Public methods
function interface_mt:write(data)
- if self.nowriting then return nil, "locked" end
+ if self.nointerface then return nil, "locked"; end
--vdebug( "try to send data to client, id/data:", self.id, data )
data = tostring( data )
local len = #data
@@ -293,7 +316,7 @@ function interface_mt:write(data)
end
t_insert(self.writebuffer, data) -- new buffer
self.writebufferlen = total
- if not self.eventwrite then -- register new write event
+ if not self.eventwrite and not self.nowriting then -- register new write event
--vdebug( "register new write event" )
self.eventwrite = addevent( base, self.conn, EV_WRITE, self.writecallback, cfg.WRITE_TIMEOUT )
end
@@ -635,7 +658,7 @@ local function handleclient( client, ip, port, server, pattern, listener, sslctx
return interface
end
-local function handleserver( server, addr, port, pattern, listener, sslctx ) -- creates an server interface
+local function handleserver( server, addr, port, pattern, listener, sslctx, startssl ) -- creates a server interface
debug "creating server interface..."
local interface = {
_connections = 0;
@@ -651,6 +674,7 @@ local function handleserver( server, addr, port, pattern, listener, sslctx ) --
_ip = addr, _port = port, _pattern = pattern,
_sslctx = sslctx;
+ hosts = {};
}
interface.id = tostring(interface):match("%x+$");
interface.readcallback = function( event ) -- server handler, called on incoming connections
@@ -681,7 +705,7 @@ local function handleserver( server, addr, port, pattern, listener, sslctx ) --
interface._connections = interface._connections + 1 -- increase connection count
local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx )
--vdebug( "client id:", clientinterface, "startssl:", startssl )
- if has_luasec and sslctx then
+ if has_luasec and startssl then
clientinterface:starttls(sslctx, true)
else
clientinterface:_start_session( true )
@@ -700,9 +724,9 @@ local function handleserver( server, addr, port, pattern, listener, sslctx ) --
return interface
end
-local function addserver( addr, port, listener, pattern, sslctx, startssl ) -- TODO: check arguments
- --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslctx or "nil", startssl or "nil")
- if sslctx and not has_luasec then
+local function listen(addr, port, listener, config)
+ config = config or {}
+ if config.sslctx and not has_luasec then
debug "fatal error: luasec not found"
return nil, "luasec not found"
end
@@ -711,11 +735,20 @@ local function addserver( addr, port, listener, pattern, sslctx, startssl ) --
debug( "creating server socket on "..addr.." port "..port.." failed:", err )
return nil, err
end
- local interface = handleserver( server, addr, port, pattern, listener, sslctx, startssl ) -- new server handler
+ local interface = handleserver( server, addr, port, config.read_size, listener, config.tls_ctx, config.tls_direct) -- new server handler
debug( "new server created with id:", tostring(interface))
return interface
end
+local function addserver( addr, port, listener, pattern, sslctx ) -- TODO: check arguments
+ --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslctx or "nil", startssl or "nil")
+ return listen( addr, port, listener, {
+ read_size = pattern,
+ tls_ctx = sslctx,
+ tls_direct = not not sslctx,
+ });
+end
+
local function wrapclient( client, ip, port, listeners, pattern, sslctx )
local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx )
interface:_start_connection(sslctx)
@@ -876,6 +909,7 @@ return {
event_base = base,
addevent = newevent,
addserver = addserver,
+ listen = listen,
addclient = addclient,
wrapclient = wrapclient,
setquitting = setquitting,
diff --git a/net/server_select.lua b/net/server_select.lua
index 1a40a6d3..e14c126e 100644
--- a/net/server_select.lua
+++ b/net/server_select.lua
@@ -68,6 +68,7 @@ local idfalse
local closeall
local addsocket
local addserver
+local listen
local addtimer
local getserver
local wrapserver
@@ -123,7 +124,7 @@ local _maxsslhandshake
_server = { } -- key = port, value = table; list of listening servers
_readlist = { } -- array with sockets to read from
-_sendlist = { } -- arrary with sockets to write to
+_sendlist = { } -- array with sockets to write to
_timerlist = { } -- array of timer functions
_socketlist = { } -- key = socket, value = wrapped socket (handlers)
_readtimes = { } -- key = handler, value = timestamp of last data reading
@@ -149,7 +150,7 @@ _checkinterval = 30 -- interval in secs to check idle clients
_sendtimeout = 60000 -- allowed send idle time in secs
_readtimeout = 14 * 60 -- allowed read idle time in secs
-local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows
+local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to determine whether this is Windows
_maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows
_maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows
@@ -157,7 +158,7 @@ _maxsslhandshake = 30 -- max handshake round-trips
----------------------------------// PRIVATE //--
-wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd
+wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, ssldirect ) -- this function wraps a server -- FIXME Make sure FD < _maxfd
if socket:getfd() >= _maxfd then
out_error("server.lua: Disallowed FD number: "..socket:getfd())
@@ -183,6 +184,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t
handler.sslctx = function( )
return sslctx
end
+ handler.hosts = {} -- sni
handler.remove = function( )
connections = connections - 1
if handler then
@@ -244,13 +246,13 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t
local client, err = accept( socket ) -- try to accept
if client then
local ip, clientport = client:getpeername( )
- local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket
+ local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx, ssldirect ) -- wrap new client socket
if err then -- error while wrapping ssl socket
return false
end
connections = connections + 1
out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))
- if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes
+ if dispatch and not ssldirect then -- SSL connections will notify onconnect when handshake completes
return dispatch( handler );
end
return;
@@ -264,7 +266,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t
return handler
end
-wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object
+wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, ssldirect ) -- this function wraps a client to a handler object
if socket:getfd() >= _maxfd then
out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent
@@ -424,9 +426,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
bufferlen = bufferlen + #data
if bufferlen > maxsendlen then
_closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle
- handler.write = idfalse -- don't write anymore
return false
- elseif socket and not _sendlist[ socket ] then
+ elseif not nosend and socket and not _sendlist[ socket ] then
_sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
end
bufferqueuelen = bufferqueuelen + 1
@@ -456,49 +457,55 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
maxreadlen = readlen or maxreadlen
return bufferlen, maxreadlen, maxsendlen
end
- --TODO: Deprecate
handler.lock_read = function (self, switch)
+ out_error( "server.lua, lock_read() is deprecated, use pause() and resume()" )
if switch == true then
- local tmp = _readlistlen
- _readlistlen = removesocket( _readlist, socket, _readlistlen )
- _readtimes[ handler ] = nil
- if _readlistlen ~= tmp then
- noread = true
- end
+ return self:pause()
elseif switch == false then
- if noread then
- noread = false
- _readlistlen = addsocket(_readlist, socket, _readlistlen)
- _readtimes[ handler ] = _currenttime
- end
+ return self:resume()
end
return noread
end
handler.pause = function (self)
- return self:lock_read(true);
+ local tmp = _readlistlen
+ _readlistlen = removesocket( _readlist, socket, _readlistlen )
+ _readtimes[ handler ] = nil
+ if _readlistlen ~= tmp then
+ noread = true
+ end
+ return noread;
end
handler.resume = function (self)
- return self:lock_read(false);
+ if noread then
+ noread = false
+ _readlistlen = addsocket(_readlist, socket, _readlistlen)
+ _readtimes[ handler ] = _currenttime
+ end
+ return noread;
end
handler.lock = function( self, switch )
- handler.lock_read (switch)
+ out_error( "server.lua, lock() is deprecated" )
+ handler.lock_read (self, switch)
if switch == true then
- handler.write = idfalse
- local tmp = _sendlistlen
- _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
- _writetimes[ handler ] = nil
- if _sendlistlen ~= tmp then
- nosend = true
- end
+ handler.pause_writes (self)
elseif switch == false then
- handler.write = write
- if nosend then
- nosend = false
- write( "" )
- end
+ handler.resume_writes (self)
end
return noread, nosend
end
+ handler.pause_writes = function (self)
+ local tmp = _sendlistlen
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
+ _writetimes[ handler ] = nil
+ nosend = true
+ end
+ handler.resume_writes = function (self)
+ nosend = false
+ if bufferlen > 0 then
+ _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
+ end
+ end
+
local _readbuffer = function( ) -- this function reads data
local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern"
if not err or (err == "wantread" or err == "timeout") then -- received something
@@ -619,11 +626,20 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
local oldsocket, err = socket
socket, err = ssl_wrap( socket, sslctx ) -- wrap socket
+
if not socket then
out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
return nil, err -- fatal error
end
+ if socket.sni then
+ if self.servername then
+ socket:sni(self.servername);
+ elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
+ socket:sni(self.server().hosts, true);
+ end
+ end
+
socket:settimeout( 0 )
-- add the new socket to our system
@@ -659,7 +675,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
_socketlist[ socket ] = handler
_readlistlen = addsocket(_readlist, socket, _readlistlen)
- if sslctx and has_luasec then
+ if sslctx and ssldirect and has_luasec then
out_put "server.lua: auto-starting ssl negotiation..."
handler.autostart_ssl = true;
local ok, err = handler:starttls(sslctx);
@@ -734,9 +750,13 @@ end
----------------------------------// PUBLIC //--
-addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
+listen = function ( addr, port, listeners, config )
addr = addr or "*"
+ config = config or {}
local err
+ local sslctx = config.tls_ctx;
+ local ssldirect = config.tls_direct;
+ local pattern = config.read_size;
if type( listeners ) ~= "table" then
err = "invalid listener table"
elseif type ( addr ) ~= "string" then
@@ -757,7 +777,7 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function
out_error( "server.lua, [", addr, "]:", port, ": ", err )
return nil, err
end
- local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket
+ local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, ssldirect ) -- wrap new server socket
if not handler then
server:close( )
return nil, err
@@ -770,6 +790,14 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function
return handler
end
+addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
+ return listen(addr, port, listeners, {
+ read_size = pattern;
+ tls_ctx = sslctx;
+ tls_direct = sslctx and true or false;
+ });
+end
+
getserver = function ( addr, port )
return _server[ addr..":"..port ];
end
@@ -978,7 +1006,7 @@ end
--// EXPERIMENTAL //--
local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx )
- local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
+ local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, sslctx)
if not handler then return nil, err end
_socketlist[ socket ] = handler
if not sslctx then
@@ -1114,6 +1142,7 @@ return {
stats = stats,
closeall = closeall,
addserver = addserver,
+ listen = listen,
getserver = getserver,
setlogger = setlogger,
getsettings = getsettings,
diff --git a/net/websocket/frames.lua b/net/websocket/frames.lua
index ba25d261..86752109 100644
--- a/net/websocket/frames.lua
+++ b/net/websocket/frames.lua
@@ -9,20 +9,21 @@
local softreq = require "util.dependencies".softreq;
local random_bytes = require "util.random".bytes;
-local bit = assert(softreq"bit" or softreq"bit32",
+local bit = assert(softreq"bit32" or softreq"bit",
"No bit module found. See https://prosody.im/doc/depends#bitop");
local band = bit.band;
local bor = bit.bor;
local bxor = bit.bxor;
local lshift = bit.lshift;
local rshift = bit.rshift;
+local unpack = table.unpack or unpack; -- luacheck: ignore 113
local t_concat = table.concat;
local s_byte = string.byte;
local s_char= string.char;
local s_sub = string.sub;
-local s_pack = string.pack; -- luacheck: ignore 143
-local s_unpack = string.unpack; -- luacheck: ignore 143
+local s_pack = string.pack;
+local s_unpack = string.unpack;
if not s_pack and softreq"struct" then
s_pack = softreq"struct".pack;