aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/adns.lua44
-rw-r--r--net/connect.lua8
-rw-r--r--net/connlisteners.lua18
-rw-r--r--net/cqueues.lua56
-rw-r--r--net/dns.lua72
-rw-r--r--net/http.lua45
-rw-r--r--net/http/codes.lua2
-rw-r--r--net/http/errors.lua119
-rw-r--r--net/http/files.lua149
-rw-r--r--net/http/parser.lua147
-rw-r--r--net/http/server.lua172
-rw-r--r--net/resolvers/basic.lua64
-rw-r--r--net/resolvers/manual.lua1
-rw-r--r--net/resolvers/service.lua29
-rw-r--r--net/server.lua6
-rw-r--r--net/server_epoll.lua376
-rw-r--r--net/server_event.lua55
-rw-r--r--net/server_select.lua144
-rw-r--r--net/unbound.lua220
-rw-r--r--net/websocket.lua7
-rw-r--r--net/websocket/frames.lua7
21 files changed, 1320 insertions, 421 deletions
diff --git a/net/adns.lua b/net/adns.lua
index 0bdf6ee3..ae168b9c 100644
--- a/net/adns.lua
+++ b/net/adns.lua
@@ -8,13 +8,17 @@
local server = require "net.server";
local new_resolver = require "net.dns".resolver;
+local promise = require "util.promise";
local log = require "util.logger".init("adns");
-local coroutine, tostring, pcall = coroutine, tostring, pcall;
+log("debug", "Using legacy DNS API (missing lua-unbound?)"); -- TODO write docs about luaunbound
+-- TODO Raise log level once packages are available
+
+local coroutine, pcall = coroutine, pcall;
local setmetatable = setmetatable;
-local function dummy_send(sock, data, i, j) return (j-i)+1; end
+local function dummy_send(sock, data, i, j) return (j-i)+1; end -- luacheck: ignore 212
local _ENV = nil;
-- luacheck: std none
@@ -29,8 +33,7 @@ local function new_async_socket(sock, resolver)
local peername = "<unknown>";
local listener = {};
local handler = {};
- local err;
- function listener.onincoming(conn, data)
+ function listener.onincoming(conn, data) -- luacheck: ignore 212/conn
if data then
resolver:feed(handler, data);
end
@@ -40,15 +43,18 @@ local function new_async_socket(sock, resolver)
log("warn", "DNS socket for %s disconnected: %s", peername, err);
local servers = resolver.server;
if resolver.socketset[conn] == resolver.best_server and resolver.best_server == #servers then
- log("error", "Exhausted all %d configured DNS servers, next lookup will try %s again", #servers, servers[1]);
+ log("warn", "Exhausted all %d configured DNS servers, next lookup will try %s again", #servers, servers[1]);
end
resolver:servfail(conn); -- Let the magic commence
end
end
- handler, err = server.wrapclient(sock, "dns", 53, listener);
- if not handler then
- return nil, err;
+ do
+ local err;
+ handler, err = server.wrapclient(sock, "dns", 53, listener);
+ if not handler then
+ return nil, err;
+ end
end
if handler.set then
-- server_epoll: only watch for incoming data
@@ -76,11 +82,11 @@ function async_resolver_methods:lookup(handler, qname, qtype, qclass)
handler(peek);
return;
end
- log("debug", "Records for %s not in cache, sending query (%s)...", qname, tostring(coroutine.running()));
+ log("debug", "Records for %s not in cache, sending query (%s)...", qname, coroutine.running());
local ok, err = resolver:query(qname, qtype, qclass);
if ok then
coroutine.yield(setmetatable({ resolver, qclass or "IN", qtype or "A", qname, coroutine.running()}, query_mt)); -- Wait for reply
- log("debug", "Reply for %s (%s)", qname, tostring(coroutine.running()));
+ log("debug", "Reply for %s (%s)", qname, coroutine.running());
end
if ok then
ok, err = pcall(handler, resolver:peek(qname, qtype, qclass));
@@ -89,13 +95,25 @@ function async_resolver_methods:lookup(handler, qname, qtype, qclass)
ok, err = pcall(handler, nil, err);
end
if not ok then
- log("error", "Error in DNS response handler: %s", tostring(err));
+ log("error", "Error in DNS response handler: %s", err);
end
end)(resolver:peek(qname, qtype, qclass));
end
-function query_methods:cancel(call_handler, reason)
- log("warn", "Cancelling DNS lookup for %s", tostring(self[4]));
+function async_resolver_methods:lookup_promise(qname, qtype, qclass)
+ return promise.new(function (resolve, reject)
+ local function handler(answer)
+ if not answer then
+ return reject();
+ end
+ resolve(answer);
+ end
+ self:lookup(handler, qname, qtype, qclass);
+ end);
+end
+
+function query_methods:cancel(call_handler, reason) -- luacheck: ignore 212/reason
+ log("warn", "Cancelling DNS lookup for %s", self[4]);
self[1].cancel(self[2], self[3], self[4], self[5], call_handler);
end
diff --git a/net/connect.lua b/net/connect.lua
index b812ffcd..d52d3901 100644
--- a/net/connect.lua
+++ b/net/connect.lua
@@ -2,6 +2,12 @@ local server = require "net.server";
local log = require "util.logger".init("net.connect");
local new_id = require "util.id".short;
+-- TODO #1246 Happy Eyeballs
+-- FIXME RFC 6724
+-- FIXME Error propagation from resolvers doesn't work
+-- FIXME #1428 Reuse DNS resolver object between service and basic resolver
+-- FIXME #1429 Close DNS resolver object when done
+
local pending_connection_methods = {};
local pending_connection_mt = {
__name = "pending_connection";
@@ -38,7 +44,7 @@ local function attempt_connection(p)
p:log("debug", "Next target to try is %s:%d", ip, port);
local conn, err = server.addclient(ip, port, pending_connection_listeners, p.options.pattern or "*a", p.options.sslctx, conn_type, extra);
if not conn then
- log("debug", "Connection attempt failed immediately: %s", tostring(err));
+ log("debug", "Connection attempt failed immediately: %s", err);
p.last_error = err or "unknown reason";
return attempt_connection(p);
end
diff --git a/net/connlisteners.lua b/net/connlisteners.lua
deleted file mode 100644
index 9b8f88c3..00000000
--- a/net/connlisteners.lua
+++ /dev/null
@@ -1,18 +0,0 @@
--- COMPAT w/pre-0.9
-local log = require "util.logger".init("net.connlisteners");
-local traceback = debug.traceback;
-
-local _ENV = nil;
--- luacheck: std none
-
-local function fail()
- log("error", "Attempt to use legacy connlisteners API. For more info see https://prosody.im/doc/developers/network");
- log("error", "Legacy connlisteners API usage, %s", traceback("", 2));
-end
-
-return {
- register = fail;
- get = fail;
- start = fail;
- -- epic fail
-};
diff --git a/net/cqueues.lua b/net/cqueues.lua
index 8c4c756f..65d2a019 100644
--- a/net/cqueues.lua
+++ b/net/cqueues.lua
@@ -9,6 +9,7 @@
local server = require "net.server";
local cqueues = require "cqueues";
+local timer = require "util.timer";
assert(cqueues.VERSION >= 20150113, "cqueues newer than 20150113 required")
-- Create a single top level cqueue
@@ -16,55 +17,24 @@ local cq;
if server.cq then -- server provides cqueues object
cq = server.cq;
-elseif server.get_backend() == "select" and server._addtimer then -- server_select
+elseif server.watchfd then
cq = cqueues.new();
- local function step()
+ local timeout = timer.add_task(cq:timeout() or 0, function ()
+ -- FIXME It should be enough to reschedule this timeout instead of replacing it, but this does not work. See https://issues.prosody.im/1572
assert(cq:loop(0));
- end
-
- -- Use wrapclient (as wrapconnection isn't exported) to get server_select to watch cq fd
- local handler = server.wrapclient({
- getfd = function() return cq:pollfd(); end;
- settimeout = function() end; -- Method just needs to exist
- close = function() end; -- Need close method for 'closeall'
- }, nil, nil, {});
-
- -- Only need to listen for readable; cqueues handles everything under the hood
- -- readbuffer is called when `select` notes an fd as readable
- handler.readbuffer = step;
-
- -- Use server_select low lever timer facility,
- -- this callback gets called *every* time there is a timeout in the main loop
- server._addtimer(function(current_time)
- -- This may end up in extra step()'s, but cqueues handles it for us.
- step();
return cq:timeout();
end);
-elseif server.event and server.base then -- server_event
- cq = cqueues.new();
- -- Only need to listen for readable; cqueues handles everything under the hood
- local EV_READ = server.event.EV_READ;
- -- Convert a cqueues timeout to an acceptable timeout for luaevent
- local function luaevent_safe_timeout(cq)
+ server.watchfd(cq:pollfd(), function ()
+ assert(cq:loop(0));
local t = cq:timeout();
- -- if you give luaevent 0 or nil, it re-uses the previous timeout.
- if t == 0 then
- t = 0.000001; -- 1 microsecond is the smallest that works (goes into a `struct timeval`)
- elseif t == nil then -- pick something big if we don't have one
- t = 0x7FFFFFFF; -- largest 32bit int
+ if t then
+ timer.stop(timeout);
+ timeout = timer.add_task(cq:timeout(), function ()
+ assert(cq:loop(0));
+ return cq:timeout();
+ end);
end
- return t
- end
- local event_handle;
- event_handle = server.base:addevent(cq:pollfd(), EV_READ, function(e)
- -- Need to reference event_handle or this callback will get collected
- -- This creates a circular reference that can only be broken if event_handle is manually :close()'d
- local _ = event_handle;
- -- Run as many cqueues things as possible (with a timeout of 0)
- -- If an error is thrown, it will break the libevent loop; but prosody resumes after logging a top level error
- assert(cq:loop(0));
- return EV_READ, luaevent_safe_timeout(cq);
- end, luaevent_safe_timeout(cq));
+ end);
else
error "NYI"
end
diff --git a/net/dns.lua b/net/dns.lua
index 3902f95c..17119152 100644
--- a/net/dns.lua
+++ b/net/dns.lua
@@ -13,10 +13,12 @@
local socket = require "socket";
-local timer = require "util.timer";
+local have_timer, timer = pcall(require, "util.timer");
local new_ip = require "util.ip".new_ip;
local have_util_net, util_net = pcall(require, "util.net");
+local log = require "util.logger".init("dns");
+
local _, windows = pcall(require, "util.windows");
local is_windows = (_ and windows) or os.getenv("WINDIR");
@@ -69,7 +71,9 @@ local ztact = { -- public domain 20080404 lua@ztact.com
};
local get, set = ztact.get, ztact.set;
-local default_timeout = 15;
+local default_timeout = 5;
+local default_jitter = 1;
+local default_retry_jitter = 2;
-------------------------------------------------- module dns
local _ENV = nil;
@@ -664,8 +668,10 @@ end
-- socket layer -------------------------------------------------- socket layer
-resolver.delays = { 1, 3 };
+resolver.delays = { 1, 2, 3, 5 };
+resolver.jitter = have_timer and default_jitter or nil;
+resolver.retry_jitter = have_timer and default_retry_jitter or nil;
function resolver:addnameserver(address) -- - - - - - - - - - addnameserver
self.server = self.server or {};
@@ -853,7 +859,10 @@ function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query
packet = header..question,
server = self.best_server,
delay = 1,
- retry = socket.gettime() + self.delays[1]
+ retry = socket.gettime() + self.delays[1];
+ qclass = qclass;
+ qtype = qtype;
+ qname = qname;
};
-- remember the query
@@ -864,30 +873,32 @@ function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query
if not conn then
return nil, err;
end
- conn:send (o.packet)
+ if self.jitter then
+ timer.add_task(math.random()*self.jitter, function ()
+ conn:send(o.packet);
+ end);
+ else
+ conn:send(o.packet);
+ end
-- remember which coroutine wants the answer
if co then
set(self.wanted, qclass, qtype, qname, co, true);
end
- if timer and self.timeout then
+ if have_timer and self.timeout then
local num_servers = #self.server;
local i = 1;
timer.add_task(self.timeout, function ()
if get(self.wanted, qclass, qtype, qname, co) then
- if i < num_servers then
+ log("debug", "DNS request timeout %d/%d", i, num_servers)
i = i + 1;
- self:servfail(conn);
- o.server = self.best_server;
- conn, err = self:getsocket(o.server);
- if conn then
- conn:send(o.packet);
- return self.timeout;
- end
- end
- -- Tried everything, failed
- self:cancel(qclass, qtype, qname);
+ self:servfail(self.socket[o.server]);
+-- end
+ end
+ -- Still outstanding? (i.e. retried)
+ if get(self.wanted, qclass, qtype, qname, co) then
+ return self.timeout; -- Then wait
end
end)
end
@@ -904,6 +915,7 @@ function resolver:servfail(sock, err)
-- Find all requests to the down server, and retry on the next server
self.time = socket.gettime();
+ log("debug", "servfail %d (of %d)", num, #self.server);
for id,queries in pairs(self.active) do
for question,o in pairs(queries) do
if o.server == num then -- This request was to the broken server
@@ -913,12 +925,27 @@ function resolver:servfail(sock, err)
end
o.retries = (o.retries or 0) + 1;
- if o.retries >= #self.server then
- --print('timeout');
- queries[question] = nil;
- else
+ local retried;
+ if o.retries < #self.server then
sock, err = self:getsocket(o.server);
- if sock then sock:send(o.packet); end
+ if sock then
+ retried = true;
+ if self.retry_jitter then
+ local delay = self.delays[((o.retries-1)%#self.delays)+1] + (math.random()*self.retry_jitter);
+ log("debug", "retry %d in %0.2fs", o.retries, delay);
+ timer.add_task(delay, function ()
+ sock:send(o.packet);
+ end);
+ else
+ log("debug", "retry %d (immediate)", o.retries);
+ sock:send(o.packet);
+ end
+ end
+ end
+ if not retried then
+ log("debug", 'tried all servers, giving up');
+ self:cancel(o.qclass, o.qtype, o.qname);
+ queries[question] = nil;
end
end
end
@@ -1164,6 +1191,7 @@ end
local _resolver = dns.resolver();
dns._resolver = _resolver;
+_resolver.jitter, _resolver.retry_jitter = false, false;
function dns.lookup(...) -- - - - - - - - - - - - - - - - - - - - - lookup
return _resolver:lookup(...);
diff --git a/net/http.lua b/net/http.lua
index 7335d210..3481389a 100644
--- a/net/http.lua
+++ b/net/http.lua
@@ -12,6 +12,8 @@ local httpstream_new = require "net.http.parser".new;
local util_http = require "util.http";
local events = require "util.events";
local verify_identity = require"util.x509".verify_identity;
+local promise = require "util.promise";
+local http_errors = require "net.http.errors";
local basic_resolver = require "net.resolvers.basic";
local connect = require "net.connect".connect;
@@ -22,6 +24,7 @@ local t_insert, t_concat = table.insert, table.concat;
local pairs = pairs;
local tonumber, tostring, traceback =
tonumber, tostring, debug.traceback;
+local os_time = os.time;
local xpcall = require "util.xpcall".xpcall;
local error = error
@@ -40,7 +43,7 @@ local listener = { default_port = 80, default_mode = "*a" };
local function handleerr(err) log("error", "Traceback[http]: %s", traceback(tostring(err), 2)); return err; end
local function log_if_failed(req, ret, ...)
if not ret then
- log("error", "Request '%s': error in callback: %s", req.id, tostring((...)));
+ log("error", "Request '%s': error in callback: %s", req.id, (...));
if not req.suppress_errors then
error(...);
end
@@ -81,7 +84,24 @@ local function request_reader(request, data, err)
return;
end
+ local finalize_sink;
local function success_cb(r)
+ if r.partial then
+ -- Request should be streamed
+ log("debug", "Request '%s': partial response (%s%s)",
+ request.id,
+ r.chunked and "chunked, " or "",
+ r.body_length and ("%d bytes"):format(r.body_length) or "unknown length"
+ );
+ if request.streaming_handler then
+ log("debug", "Request '%s': Streaming via handler");
+ r.body_sink, finalize_sink = request.streaming_handler(r);
+ end
+ return;
+ elseif finalize_sink then
+ log("debug", "Request '%s': Finalizing response stream");
+ finalize_sink(r);
+ end
if request.callback then
request.callback(r.body, r.code, r, request);
request.callback = nil;
@@ -161,7 +181,7 @@ function listener.onincoming(conn, data)
local request = requests[conn];
if not request then
- log("warn", "Received response from connection %s with no request attached!", tostring(conn));
+ log("warn", "Received response from connection %s with no request attached!", conn);
return;
end
@@ -202,6 +222,7 @@ local function request(self, u, ex, callback)
req.url = u;
req.http = self;
+ req.time = os_time();
if not req.path then
req.path = "/";
@@ -254,6 +275,7 @@ local function request(self, u, ex, callback)
end
req.insecure = ex.insecure;
req.suppress_errors = ex.suppress_errors;
+ req.streaming_handler = ex.streaming_handler;
end
log("debug", "Making %s %s request '%s' to %s", req.scheme:upper(), method or "GET", req.id, (ex and ex.suppress_url and host_header) or u);
@@ -282,7 +304,22 @@ end
local function new(options)
local http = {
options = options;
- request = request;
+ request = function (self, u, ex, callback)
+ if callback ~= nil then
+ return request(self, u, ex, callback);
+ else
+ return promise.new(function (resolve, reject)
+ request(self, u, ex, function (body, code, a, b)
+ if code == 0 then
+ reject(http_errors.new(body, { request = a }));
+ else
+ a.request = b;
+ resolve(a);
+ end
+ end);
+ end);
+ end
+ end;
new = options and function (new_options)
local final_options = {};
for k, v in pairs(options) do final_options[k] = v; end
@@ -297,7 +334,7 @@ local function new(options)
end
local default_http = new({
- sslctx = { mode = "client", protocol = "sslv23", options = { "no_sslv2", "no_sslv3" } };
+ sslctx = { mode = "client", protocol = "sslv23", options = { "no_sslv2", "no_sslv3" }, alpn = "http/1.1" };
suppress_errors = true;
});
diff --git a/net/http/codes.lua b/net/http/codes.lua
index 8098b5c3..4327f151 100644
--- a/net/http/codes.lua
+++ b/net/http/codes.lua
@@ -82,5 +82,5 @@ local response_codes = {
-- [512-599] = "Unassigned";
};
-for k,v in pairs(response_codes) do response_codes[k] = k.." "..v; end
+for k,v in pairs(response_codes) do response_codes[k] = ("%03d %s"):format(k, v); end
return setmetatable(response_codes, { __index = function(_, k) return k.." Unassigned"; end })
diff --git a/net/http/errors.lua b/net/http/errors.lua
new file mode 100644
index 00000000..1691e426
--- /dev/null
+++ b/net/http/errors.lua
@@ -0,0 +1,119 @@
+-- This module returns a table that is suitable for use as a util.error registry,
+-- and a function to return a util.error object given callback 'code' and 'body'
+-- parameters.
+
+local codes = require "net.http.codes";
+local util_error = require "util.error";
+
+local error_templates = {
+ -- This code is used by us to report a client-side or connection error.
+ -- Instead of using the code, use the supplied body text to get one of
+ -- the more detailed errors below.
+ [0] = {
+ code = 0, type = "cancel", condition = "internal-server-error";
+ text = "Connection or internal error";
+ };
+
+ -- These are net.http built-in errors, they are returned in
+ -- the body parameter when code == 0
+ ["cancelled"] = {
+ code = 0, type = "cancel", condition = "remote-server-timeout";
+ text = "Request cancelled";
+ };
+ ["connection-closed"] = {
+ code = 0, type = "wait", condition = "remote-server-timeout";
+ text = "Connection closed";
+ };
+ ["certificate-chain-invalid"] = {
+ code = 0, type = "cancel", condition = "remote-server-timeout";
+ text = "Server certificate not trusted";
+ };
+ ["certificate-verify-failed"] = {
+ code = 0, type = "cancel", condition = "remote-server-timeout";
+ text = "Server certificate invalid";
+ };
+ ["connection failed"] = {
+ code = 0, type = "cancel", condition = "remote-server-not-found";
+ text = "Connection failed";
+ };
+ ["invalid-url"] = {
+ code = 0, type = "modify", condition = "bad-request";
+ text = "Invalid URL";
+ };
+ ["unable to resolve service"] = {
+ code = 0, type = "cancel", condition = "remote-server-not-found";
+ text = "DNS resolution failed";
+ };
+
+ -- This doesn't attempt to map every single HTTP code (not all have sane mappings),
+ -- but all the common ones should be covered. XEP-0086 was used as reference for
+ -- most of these.
+ [400] = { type = "modify", condition = "bad-request" };
+ [401] = { type = "auth", condition = "not-authorized" };
+ [402] = { type = "auth", condition = "payment-required" };
+ [403] = { type = "auth", condition = "forbidden" };
+ [404] = { type = "cancel", condition = "item-not-found" };
+ [405] = { type = "cancel", condition = "not-allowed" };
+ [406] = { type = "modify", condition = "not-acceptable" };
+ [407] = { type = "auth", condition = "registration-required" };
+ [408] = { type = "wait", condition = "remote-server-timeout" };
+ [409] = { type = "cancel", condition = "conflict" };
+ [410] = { type = "cancel", condition = "gone" };
+ [411] = { type = "modify", condition = "bad-request" };
+ [412] = { type = "cancel", condition = "conflict" };
+ [413] = { type = "modify", condition = "resource-constraint" };
+ [414] = { type = "modify", condition = "resource-constraint" };
+ [415] = { type = "cancel", condition = "feature-not-implemented" };
+ [416] = { type = "modify", condition = "bad-request" };
+
+ [422] = { type = "modify", condition = "bad-request" };
+ [423] = { type = "wait", condition = "resource-constraint" };
+
+ [429] = { type = "wait", condition = "resource-constraint" };
+ [431] = { type = "modify", condition = "resource-constraint" };
+ [451] = { type = "auth", condition = "forbidden" };
+
+ [500] = { type = "wait", condition = "internal-server-error" };
+ [501] = { type = "cancel", condition = "feature-not-implemented" };
+ [502] = { type = "wait", condition = "remote-server-timeout" };
+ [503] = { type = "cancel", condition = "service-unavailable" };
+ [504] = { type = "wait", condition = "remote-server-timeout" };
+ [507] = { type = "wait", condition = "resource-constraint" };
+ [511] = { type = "auth", condition = "not-authorized" };
+};
+
+for k, v in pairs(codes) do
+ if error_templates[k] then
+ error_templates[k].code = k;
+ error_templates[k].text = v;
+ else
+ error_templates[k] = { type = "cancel", condition = "undefined-condition", text = v, code = k };
+ end
+end
+
+setmetatable(error_templates, {
+ __index = function(_, k)
+ if type(k) ~= "number" then
+ return nil;
+ end
+ return {
+ type = "cancel";
+ condition = "undefined-condition";
+ text = codes[k] or (k.." Unassigned");
+ code = k;
+ };
+ end
+});
+
+local function new(code, body, context)
+ if code == 0 then
+ return util_error.new(body, context, error_templates);
+ else
+ return util_error.new(code, context, error_templates);
+ end
+end
+
+return {
+ registry = error_templates;
+ new = new;
+};
diff --git a/net/http/files.lua b/net/http/files.lua
new file mode 100644
index 00000000..583f7514
--- /dev/null
+++ b/net/http/files.lua
@@ -0,0 +1,149 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local server = require"net.http.server";
+local lfs = require "lfs";
+local new_cache = require "util.cache".new;
+local log = require "util.logger".init("net.http.files");
+
+local os_date = os.date;
+local open = io.open;
+local stat = lfs.attributes;
+local build_path = require"socket.url".build_path;
+local path_sep = package.config:sub(1,1);
+
+
+local forbidden_chars_pattern = "[/%z]";
+if package.config:sub(1,1) == "\\" then
+ forbidden_chars_pattern = "[/%z\001-\031\127\"*:<>?|]"
+end
+
+local urldecode = require "util.http".urldecode;
+local function sanitize_path(path) --> util.paths or util.http?
+ if not path then return end
+ local out = {};
+
+ local c = 0;
+ for component in path:gmatch("([^/]+)") do
+ component = urldecode(component);
+ if component:find(forbidden_chars_pattern) then
+ return nil;
+ elseif component == ".." then
+ if c <= 0 then
+ return nil;
+ end
+ out[c] = nil;
+ c = c - 1;
+ elseif component ~= "." then
+ c = c + 1;
+ out[c] = component;
+ end
+ end
+ if path:sub(-1,-1) == "/" then
+ out[c+1] = "";
+ end
+ return "/"..table.concat(out, "/");
+end
+
+local function serve(opts)
+ if type(opts) ~= "table" then -- assume path string
+ opts = { path = opts };
+ end
+ local mime_map = opts.mime_map or { html = "text/html" };
+ local cache = new_cache(opts.cache_size or 256);
+ local cache_max_file_size = tonumber(opts.cache_max_file_size) or 1024
+ -- luacheck: ignore 431
+ local base_path = opts.path;
+ local dir_indices = opts.index_files or { "index.html", "index.htm" };
+ local directory_index = opts.directory_index;
+ local function serve_file(event, path)
+ local request, response = event.request, event.response;
+ local sanitized_path = sanitize_path(path);
+ if path and not sanitized_path then
+ return 400;
+ end
+ path = sanitized_path;
+ local orig_path = sanitize_path(request.path);
+ local full_path = base_path .. (path or ""):gsub("/", path_sep);
+ local attr = stat(full_path:match("^.*[^\\/]")); -- Strip trailing path separator because Windows
+ if not attr then
+ return 404;
+ end
+
+ local request_headers, response_headers = request.headers, response.headers;
+
+ local last_modified = os_date('!%a, %d %b %Y %H:%M:%S GMT', attr.modification);
+ response_headers.last_modified = last_modified;
+
+ local etag = ('"%x-%x-%x"'):format(attr.change or 0, attr.size or 0, attr.modification or 0);
+ response_headers.etag = etag;
+
+ local if_none_match = request_headers.if_none_match
+ local if_modified_since = request_headers.if_modified_since;
+ if etag == if_none_match
+ or (not if_none_match and last_modified == if_modified_since) then
+ return 304;
+ end
+
+ local data = cache:get(orig_path);
+ if data and data.etag == etag then
+ response_headers.content_type = data.content_type;
+ data = data.data;
+ cache:set(orig_path, data);
+ elseif attr.mode == "directory" and path then
+ if full_path:sub(-1) ~= "/" then
+ local dir_path = { is_absolute = true, is_directory = true };
+ for dir in orig_path:gmatch("[^/]+") do dir_path[#dir_path+1]=dir; end
+ response_headers.location = build_path(dir_path);
+ return 301;
+ end
+ for i=1,#dir_indices do
+ if stat(full_path..dir_indices[i], "mode") == "file" then
+ return serve_file(event, path..dir_indices[i]);
+ end
+ end
+
+ if directory_index then
+ data = server._events.fire_event("directory-index", { path = request.path, full_path = full_path });
+ end
+ if not data then
+ return 403;
+ end
+ cache:set(orig_path, { data = data, content_type = mime_map.html; etag = etag; });
+ response_headers.content_type = mime_map.html;
+
+ else
+ local f, err = open(full_path, "rb");
+ if not f then
+ log("debug", "Could not open %s. Error was %s", full_path, err);
+ return 403;
+ end
+ local ext = full_path:match("%.([^./]+)$");
+ local content_type = ext and mime_map[ext];
+ response_headers.content_type = content_type;
+ if attr.size > cache_max_file_size then
+ response_headers.content_length = ("%d"):format(attr.size);
+ log("debug", "%d > cache_max_file_size", attr.size);
+ return response:send_file(f);
+ else
+ data = f:read("*a");
+ f:close();
+ end
+ cache:set(orig_path, { data = data; content_type = content_type; etag = etag });
+ end
+
+ return response:send(data);
+ end
+
+ return serve_file;
+end
+
+return {
+ serve = serve;
+}
+
diff --git a/net/http/parser.lua b/net/http/parser.lua
index 4e4ae9fb..96f17fdb 100644
--- a/net/http/parser.lua
+++ b/net/http/parser.lua
@@ -1,8 +1,8 @@
local tonumber = tonumber;
local assert = assert;
-local t_insert, t_concat = table.insert, table.concat;
local url_parse = require "socket.url".parse;
local urldecode = require "util.http".urldecode;
+local dbuffer = require "util.dbuffer";
local function preprocess_path(path)
path = urldecode((path:gsub("//+", "/")));
@@ -28,10 +28,13 @@ local httpstream = {};
function httpstream.new(success_cb, error_cb, parser_type, options_cb)
local client = true;
if not parser_type or parser_type == "server" then client = false; else assert(parser_type == "client", "Invalid parser type"); end
- local buf, buflen, buftable = {}, 0, true;
local bodylimit = tonumber(options_cb and options_cb().body_size_limit) or 10*1024*1024;
+ -- https://stackoverflow.com/a/686243
+ -- Indiviual headers can be up to 16k? What madness?
+ local headlimit = tonumber(options_cb and options_cb().head_size_limit) or 10*1024;
local buflimit = tonumber(options_cb and options_cb().buffer_size_limit) or bodylimit * 2;
- local chunked, chunk_size, chunk_start;
+ local buffer = dbuffer.new(buflimit);
+ local chunked;
local state = nil;
local packet;
local len;
@@ -41,32 +44,27 @@ function httpstream.new(success_cb, error_cb, parser_type, options_cb)
feed = function(_, data)
if error then return nil, "parse has failed"; end
if not data then -- EOF
- if buftable then buf, buftable = t_concat(buf), false; end
if state and client and not len then -- reading client body until EOF
- packet.body = buf;
+ buffer:collapse();
+ packet.body = buffer:read_chunk() or "";
+ packet.partial = nil;
success_cb(packet);
- elseif buf ~= "" then -- unexpected EOF
+ state = nil;
+ elseif buffer:length() ~= 0 then -- unexpected EOF
error = true; return error_cb("unexpected-eof");
end
return;
end
- if buftable then
- t_insert(buf, data);
- else
- buf = { buf, data };
- buftable = true;
- end
- buflen = buflen + #data;
- if buflen > buflimit then error = true; return error_cb("max-buffer-size-exceeded"); end
- while buflen > 0 do
+ if not buffer:write(data) then error = true; return error_cb("max-buffer-size-exceeded"); end
+ while buffer:length() > 0 do
if state == nil then -- read request
- if buftable then buf, buftable = t_concat(buf), false; end
- local index = buf:find("\r\n\r\n", nil, true);
+ local index = buffer:sub(1, headlimit):find("\r\n\r\n", nil, true);
if not index then return; end -- not enough data
- local method, path, httpversion, status_code, reason_phrase;
+ -- FIXME was reason_phrase meant to be passed on somewhere?
+ local method, path, httpversion, status_code, reason_phrase; -- luacheck: ignore reason_phrase
local first_line;
local headers = {};
- for line in buf:sub(1,index+1):gmatch("([^\r\n]+)\r\n") do -- parse request
+ for line in buffer:read(index+3):gmatch("([^\r\n]+)\r\n") do -- parse request
if first_line then
local key, val = line:match("^([^%s:]+): *(.*)$");
if not key then error = true; return error_cb("invalid-header-line"); end -- TODO handle multi-line and invalid headers
@@ -91,7 +89,6 @@ function httpstream.new(success_cb, error_cb, parser_type, options_cb)
if not first_line then error = true; return error_cb("invalid-status-line"); end
chunked = have_body and headers["transfer-encoding"] == "chunked";
len = tonumber(headers["content-length"]); -- TODO check for invalid len
- if len and len > bodylimit then error = true; return error_cb("content-length-limit-exceeded"); end
if client then
-- FIXME handle '100 Continue' response (by skipping it)
if not have_body then len = 0; end
@@ -99,7 +96,10 @@ function httpstream.new(success_cb, error_cb, parser_type, options_cb)
code = status_code;
httpversion = httpversion;
headers = headers;
- body = have_body and "" or nil;
+ body = false;
+ body_length = len;
+ chunked = chunked;
+ partial = true;
-- COMPAT the properties below are deprecated
responseversion = httpversion;
responseheaders = headers;
@@ -124,60 +124,81 @@ function httpstream.new(success_cb, error_cb, parser_type, options_cb)
path = path;
httpversion = httpversion;
headers = headers;
- body = nil;
+ body = false;
+ body_sink = nil;
+ chunked = chunked;
+ partial = true;
};
end
- buf = buf:sub(index + 4);
- buflen = #buf;
+ if len and len > bodylimit then
+ -- Early notification, for redirection
+ success_cb(packet);
+ if not packet.body_sink then error = true; return error_cb("content-length-limit-exceeded"); end
+ end
+ if chunked and not packet.body_sink then
+ success_cb(packet);
+ if not packet.body_sink then
+ packet.body_buffer = dbuffer.new(buflimit);
+ end
+ end
state = true;
end
if state then -- read body
- if client then
- if chunked then
- if chunk_start and buflen - chunk_start - 2 < chunk_size then
- return;
- end -- not enough data
- if buftable then buf, buftable = t_concat(buf), false; end
- if not buf:find("\r\n", nil, true) then
- return;
- end -- not enough data
- if not chunk_size then
- chunk_size, chunk_start = buf:match("^(%x+)[^\r\n]*\r\n()");
- chunk_size = chunk_size and tonumber(chunk_size, 16);
- if not chunk_size then error = true; return error_cb("invalid-chunk-size"); end
- end
- if chunk_size == 0 and buf:find("\r\n\r\n", chunk_start-2, true) then
- state, chunk_size = nil, nil;
- buf = buf:gsub("^.-\r\n\r\n", ""); -- This ensure extensions and trailers are stripped
- success_cb(packet);
- elseif buflen - chunk_start - 2 >= chunk_size then -- we have a chunk
- packet.body = packet.body..buf:sub(chunk_start, chunk_start + (chunk_size-1));
- buf = buf:sub(chunk_start + chunk_size + 2);
- buflen = buflen - (chunk_start + chunk_size + 2 - 1);
- chunk_size, chunk_start = nil, nil;
- else -- Partial chunk remaining
- break;
+ if chunked then
+ local chunk_header = buffer:sub(1, 512); -- XXX How large do chunk headers grow?
+ local chunk_size, chunk_start = chunk_header:match("^(%x+)[^\r\n]*\r\n()");
+ if not chunk_size then return; end
+ chunk_size = chunk_size and tonumber(chunk_size, 16);
+ if not chunk_size then error = true; return error_cb("invalid-chunk-size"); end
+ if chunk_size == 0 and chunk_header:find("\r\n\r\n", chunk_start-2, true) then
+ local body_buffer = packet.body_buffer;
+ if body_buffer then
+ packet.body_buffer = nil;
+ body_buffer:collapse();
+ packet.body = body_buffer:read_chunk() or "";
end
- elseif len and buflen >= len then
- if buftable then buf, buftable = t_concat(buf), false; end
- if packet.code == 101 then
- packet.body, buf, buflen, buftable = buf, {}, 0, true;
+
+ buffer:collapse();
+ local buf = buffer:read_chunk();
+ buf = buf:gsub("^.-\r\n\r\n", ""); -- This ensure extensions and trailers are stripped
+ buffer:write(buf);
+ state, chunked = nil, nil;
+ packet.partial = nil;
+ success_cb(packet);
+ elseif buffer:length() - chunk_start - 2 >= chunk_size then -- we have a chunk
+ buffer:discard(chunk_start - 1); -- TODO verify that it's not off-by-one
+ (packet.body_sink or packet.body_buffer):write(buffer:read(chunk_size));
+ buffer:discard(2); -- CRLF
+ else -- Partial chunk remaining
+ break;
+ end
+ elseif packet.body_sink then
+ local chunk = buffer:read_chunk(len);
+ while chunk and len > 0 do
+ if packet.body_sink:write(chunk) then
+ len = len - #chunk;
+ chunk = buffer:read_chunk(len);
else
- packet.body, buf = buf:sub(1, len), buf:sub(len + 1);
- buflen = #buf;
+ error = true;
+ return error_cb("body-sink-write-failure");
end
- state = nil; success_cb(packet);
- else
- break;
end
- elseif buflen >= len then
- if buftable then buf, buftable = t_concat(buf), false; end
- packet.body, buf = buf:sub(1, len), buf:sub(len + 1);
- buflen = #buf;
- state = nil; success_cb(packet);
+ if len == 0 then
+ state = nil;
+ packet.partial = nil;
+ success_cb(packet);
+ end
+ elseif buffer:length() >= len then
+ assert(not chunked)
+ packet.body = buffer:read(len) or "";
+ state = nil;
+ packet.partial = nil;
+ success_cb(packet);
else
break;
end
+ else
+ break;
end
end
end;
diff --git a/net/http/server.lua b/net/http/server.lua
index 3873bbe0..ab71dbc9 100644
--- a/net/http/server.lua
+++ b/net/http/server.lua
@@ -1,5 +1,5 @@
-local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat;
+local t_insert, t_concat = table.insert, table.concat;
local parser_new = require "net.http.parser".new;
local events = require "util.events".new();
local addserver = require "net.server".addserver;
@@ -8,12 +8,12 @@ local os_date = os.date;
local pairs = pairs;
local s_upper = string.upper;
local setmetatable = setmetatable;
-local xpcall = require "util.xpcall".xpcall;
-local traceback = debug.traceback;
-local tostring = tostring;
local cache = require "util.cache";
local codes = require "net.http.codes";
+local promise = require "util.promise";
+local errors = require "util.error";
local blocksize = 2^16;
+local async = require "util.async";
local _M = {};
@@ -89,51 +89,60 @@ setmetatable(events._handlers, {
local handle_request;
-local last_err;
-local function _traceback_handler(err) last_err = err; log("error", "Traceback[httpserver]: %s", traceback(tostring(err), 2)); end
events.add_handler("http-error", function (error)
return "Error processing request: "..codes[error.code]..". Check your error log for more information.";
end, -1);
+local runner_callbacks = {};
+
+function runner_callbacks:ready()
+ self.data.conn:resume();
+end
+
+function runner_callbacks:waiting()
+ self.data.conn:pause();
+end
+
+function runner_callbacks:error(err)
+ log("error", "Traceback[httpserver]: %s", err);
+ self.data.conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = err }));
+ self.data.conn:close();
+end
+
+local function noop() end
function listener.onconnect(conn)
+ local session = { conn = conn };
local secure = conn:ssl() and true or nil;
- local pending = {};
- local waiting = false;
- local function process_next()
- if waiting then return; end -- log("debug", "can't process_next, waiting");
- waiting = true;
- while sessions[conn] and #pending > 0 do
- local request = t_remove(pending);
- --log("debug", "process_next: %s", request.path);
- if not xpcall(handle_request, _traceback_handler, conn, request, process_next) then
- conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = last_err }));
- conn:close();
- end
+ local ip = conn:ip();
+ session.thread = async.runner(function (request)
+ local wait, done;
+ if request.partial == true then
+ -- Have the header for a request, we want to receive the rest
+ -- when we've decided where the data should go.
+ wait, done = noop, noop;
+ else -- Got the entire request
+ -- Hold off on receiving more incoming requests until this one has been handled.
+ wait, done = async.waiter();
end
- --log("debug", "ready for more");
- waiting = false;
- end
+ handle_request(conn, request, done); wait();
+ end, runner_callbacks, session);
local function success_cb(request)
--log("debug", "success_cb: %s", request.path);
- if waiting then
- log("error", "http connection handler is not reentrant: %s", request.path);
- assert(false, "http connection handler is not reentrant");
- end
+ request.ip = ip;
request.secure = secure;
- t_insert(pending, request);
- process_next();
+ session.thread:run(request);
end
local function error_cb(err)
log("debug", "error_cb: %s", err or "<nil>");
-- FIXME don't close immediately, wait until we process current stuff
-- FIXME if err, send off a bad-request response
- sessions[conn] = nil;
conn:close();
end
local function options_cb()
return options;
end
- sessions[conn] = parser_new(success_cb, error_cb, "server", options_cb);
+ session.parser = parser_new(success_cb, error_cb, "server", options_cb);
+ sessions[conn] = session;
end
function listener.ondisconnect(conn)
@@ -152,7 +161,7 @@ function listener.ondetach(conn)
end
function listener.onincoming(conn, data)
- sessions[conn]:feed(data);
+ sessions[conn].parser:feed(data);
end
function listener.ondrain(conn)
@@ -170,6 +179,49 @@ local headerfix = setmetatable({}, {
end
});
+local function handle_result(request, response, result)
+ if result == nil then
+ result = 404;
+ end
+
+ if result == true then
+ return;
+ end
+
+ local body;
+ local result_type = type(result);
+ if result_type == "number" then
+ response.status_code = result;
+ if result >= 400 then
+ body = events.fire_event("http-error", { request = request, response = response, code = result });
+ end
+ elseif result_type == "string" then
+ body = result;
+ elseif errors.is_err(result) then
+ response.status_code = result.code or 500;
+ body = events.fire_event("http-error", { request = request, response = response, code = result.code or 500, error = result });
+ elseif promise.is_promise(result) then
+ result:next(function (ret)
+ handle_result(request, response, ret);
+ end, function (err)
+ response.status_code = 500;
+ handle_result(request, response, err or 500);
+ end);
+ return true;
+ elseif result_type == "table" then
+ for k, v in pairs(result) do
+ if k ~= "headers" then
+ response[k] = v;
+ else
+ for header_name, header_value in pairs(v) do
+ response.headers[header_name] = header_value;
+ end
+ end
+ end
+ end
+ return response:send(body);
+end
+
function _M.hijack_response(response, listener) -- luacheck: ignore
error("TODO");
end
@@ -194,8 +246,11 @@ function handle_request(conn, request, finish_cb)
response_conn_header = httpversion == "1.1" and "close" or nil
end
+ local is_head_request = request.method == "HEAD";
+
local response = {
request = request;
+ is_head_request = is_head_request;
status_code = 200;
headers = { date = date_header, connection = response_conn_header };
persistent = persistent;
@@ -227,6 +282,11 @@ function handle_request(conn, request, finish_cb)
local payload = { request = request, response = response };
log("debug", "Firing event: %s", global_event);
local result = events.fire_event(global_event, payload);
+ if result == nil and is_head_request then
+ local global_head_event = "GET "..request.path:match("[^?]*");
+ log("debug", "Firing event: %s", global_head_event);
+ result = events.fire_event(global_head_event, payload);
+ end
if result == nil then
if not hosts[host] then
if hosts[default_host] then
@@ -247,40 +307,17 @@ function handle_request(conn, request, finish_cb)
local host_event = request.method.." "..host..request.path:match("[^?]*");
log("debug", "Firing event: %s", host_event);
result = events.fire_event(host_event, payload);
- end
- if result ~= nil then
- if result ~= true then
- local body;
- local result_type = type(result);
- if result_type == "number" then
- response.status_code = result;
- if result >= 400 then
- payload.code = result;
- body = events.fire_event("http-error", payload);
- end
- elseif result_type == "string" then
- body = result;
- elseif result_type == "table" then
- for k, v in pairs(result) do
- if k ~= "headers" then
- response[k] = v;
- else
- for header_name, header_value in pairs(v) do
- response.headers[header_name] = header_value;
- end
- end
- end
- end
- response:send(body);
+
+ if result == nil and is_head_request then
+ local host_head_event = "GET "..host..request.path:match("[^?]*");
+ log("debug", "Firing event: %s", host_head_event);
+ result = events.fire_event(host_head_event, payload);
end
- return;
end
- -- if handler not called, return 404
- response.status_code = 404;
- payload.code = 404;
- response:send(events.fire_event("http-error", payload));
+ return handle_result(request, response, result);
end
+
local function prepare_header(response)
local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]);
local headers = response.headers;
@@ -292,12 +329,21 @@ local function prepare_header(response)
return output;
end
_M.prepare_header = prepare_header;
+function _M.send_head_response(response)
+ if response.finished then return; end
+ local output = prepare_header(response);
+ response.conn:write(t_concat(output));
+ response:done();
+end
function _M.send_response(response, body)
if response.finished then return; end
body = body or response.body or "";
-- Per RFC 7230, informational (1xx) and 204 (no content) should have no c-l header
if response.status_code > 199 and response.status_code ~= 204 then
- response.headers.content_length = #body;
+ response.headers.content_length = ("%d"):format(#body);
+ end
+ if response.is_head_request then
+ return _M.send_head_response(response)
end
local output = prepare_header(response);
t_insert(output, body);
@@ -305,6 +351,10 @@ function _M.send_response(response, body)
response:done();
end
function _M.send_file(response, f)
+ if response.is_head_request then
+ if f.close then f:close(); end
+ return _M.send_head_response(response);
+ end
if response.finished then return; end
local chunked = not response.headers.content_length;
if chunked then response.headers.transfer_encoding = "chunked"; end
diff --git a/net/resolvers/basic.lua b/net/resolvers/basic.lua
index 867ccf60..3c0e69f5 100644
--- a/net/resolvers/basic.lua
+++ b/net/resolvers/basic.lua
@@ -2,10 +2,13 @@ local adns = require "net.adns";
local inet_pton = require "util.net".pton;
local inet_ntop = require "util.net".ntop;
local idna_to_ascii = require "util.encodings".idna.to_ascii;
+local unpack = table.unpack or unpack; -- luacheck: ignore 113
local methods = {};
local resolver_mt = { __index = methods };
+-- FIXME RFC 6724
+
-- Find the next target to connect to, and
-- pass it to cb()
function methods:next(cb)
@@ -25,34 +28,69 @@ function methods:next(cb)
return;
end
+ local secure = true;
+ local tlsa = {};
local targets = {};
- local n = 2;
+ local n = 3;
local function ready()
n = n - 1;
if n > 0 then return; end
self.targets = targets;
+ --[[
+ -- TODO stash tlsa somewhere per connection
+ -- FIXME 'extra' here is not per connection
+ if self.extra and self.extra.use_dane then
+ if secure and tlsa[1] then
+ end
+ end
+ --]]
self:next(cb);
end
-- Resolve DNS to target list
local dns_resolver = adns.resolver();
- dns_resolver:lookup(function (answer)
- if answer then
- for _, record in ipairs(answer) do
- table.insert(targets, { self.conn_type.."4", record.a, self.port, self.extra });
+
+ if not self.extra or self.extra.use_ipv4 ~= false then
+ dns_resolver:lookup(function (answer)
+ if answer then
+ secure = secure and answer.secure;
+ for _, record in ipairs(answer) do
+ table.insert(targets, { self.conn_type.."4", record.a, self.port, self.extra });
+ end
end
- end
+ ready();
+ end, self.hostname, "A", "IN");
+ else
+ ready();
+ end
+
+ if not self.extra or self.extra.use_ipv6 ~= false then
+ dns_resolver:lookup(function (answer)
+ if answer then
+ secure = secure and answer.secure;
+ for _, record in ipairs(answer) do
+ table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra });
+ end
+ end
+ ready();
+ end, self.hostname, "AAAA", "IN");
+ else
ready();
- end, self.hostname, "A", "IN");
+ end
- dns_resolver:lookup(function (answer)
- if answer then
- for _, record in ipairs(answer) do
- table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra });
+ if self.extra and self.extra.use_dane == true then
+ dns_resolver:lookup(function (answer)
+ if answer then
+ secure = secure and answer.secure;
+ for _, record in ipairs(answer) do
+ table.insert(tlsa, record.tlsa);
+ end
end
- end
+ ready();
+ end, ("_%d._tcp.%s"):format(self.port, self.hostname), "TLSA", "IN");
+ else
ready();
- end, self.hostname, "AAAA", "IN");
+ end
end
local function new(hostname, port, conn_type, extra)
diff --git a/net/resolvers/manual.lua b/net/resolvers/manual.lua
index c0d4e5d5..dbc40256 100644
--- a/net/resolvers/manual.lua
+++ b/net/resolvers/manual.lua
@@ -1,5 +1,6 @@
local methods = {};
local resolver_mt = { __index = methods };
+local unpack = table.unpack or unpack; -- luacheck: ignore 113
-- Find the next target to connect to, and
-- pass it to cb()
diff --git a/net/resolvers/service.lua b/net/resolvers/service.lua
index 34f14cba..d74adf06 100644
--- a/net/resolvers/service.lua
+++ b/net/resolvers/service.lua
@@ -1,6 +1,8 @@
local adns = require "net.adns";
local basic = require "net.resolvers.basic";
+local inet_pton = require "util.net".pton;
local idna_to_ascii = require "util.encodings".idna.to_ascii;
+local unpack = table.unpack or unpack; -- luacheck: ignore 113
local methods = {};
local resolver_mt = { __index = methods };
@@ -9,14 +11,17 @@ local resolver_mt = { __index = methods };
-- pass it to cb()
function methods:next(cb)
if self.targets then
- if #self.targets == 0 then
- cb(nil);
- return;
+ if not self.resolver then
+ if #self.targets == 0 then
+ cb(nil);
+ return;
+ end
+ local next_target = table.remove(self.targets, 1);
+ self.resolver = basic.new(unpack(next_target, 1, 4));
end
- local next_target = table.remove(self.targets, 1);
- self.resolver = basic.new(unpack(next_target, 1, 4));
self.resolver:next(function (...)
if ... == nil then
+ self.resolver = nil;
self:next(cb);
else
cb(...);
@@ -39,7 +44,11 @@ function methods:next(cb)
-- Resolve DNS to target list
local dns_resolver = adns.resolver();
- dns_resolver:lookup(function (answer)
+ dns_resolver:lookup(function (answer, err)
+ if not answer and not err then
+ -- net.adns returns nil if there are zero records or nxdomain
+ answer = {};
+ end
if answer then
if #answer == 0 then
if self.extra and self.extra.default_port then
@@ -64,6 +73,14 @@ function methods:next(cb)
end
local function new(hostname, service, conn_type, extra)
+ local is_ip = inet_pton(hostname);
+ if not is_ip and hostname:sub(1,1) == '[' then
+ is_ip = inet_pton(hostname:sub(2,-2));
+ end
+ if is_ip and extra and extra.default_port then
+ return basic.new(hostname, extra.default_port, conn_type, extra);
+ end
+
return setmetatable({
hostname = idna_to_ascii(hostname);
service = service;
diff --git a/net/server.lua b/net/server.lua
index abbb421d..f5666594 100644
--- a/net/server.lua
+++ b/net/server.lua
@@ -13,7 +13,11 @@ if not (prosody and prosody.config_loaded) then
end
local log = require "util.logger".init("net.server");
-local server_type = require "core.configmanager".get("*", "network_backend") or "select";
+
+local have_util_poll = pcall(require, "util.poll");
+local default_backend = have_util_poll and "epoll" or "select";
+
+local server_type = require "core.configmanager".get("*", "network_backend") or default_backend;
if require "core.configmanager".get("*", "use_libevent") then
server_type = "event";
diff --git a/net/server_epoll.lua b/net/server_epoll.lua
index 53a67dd5..34a11c03 100644
--- a/net/server_epoll.lua
+++ b/net/server_epoll.lua
@@ -9,20 +9,25 @@
local t_insert = table.insert;
local t_concat = table.concat;
local setmetatable = setmetatable;
-local tostring = tostring;
local pcall = pcall;
local type = type;
local next = next;
local pairs = pairs;
-local log = require "util.logger".init("server_epoll");
+local ipairs = ipairs;
+local traceback = debug.traceback;
+local logger = require "util.logger";
+local log = logger.init("server_epoll");
local socket = require "socket";
local luasec = require "ssl";
-local gettime = require "util.time".now;
+local realtime = require "util.time".now;
+local monotonic = require "util.time".monotonic;
local indexedbheap = require "util.indexedbheap";
local createtable = require "util.table".create;
local inet = require "util.net";
local inet_pton = inet.pton;
local _SOCKETINVALID = socket._SOCKETINVALID or -1;
+local new_id = require "util.id".medium;
+local xpcall = require "util.xpcall".xpcall;
local poller = require "util.poll"
local EEXIST = poller.EEXIST;
@@ -38,7 +43,10 @@ local default_config = { __index = {
read_timeout = 14 * 60;
-- How long to wait for a socket to become writable after queuing data to send
- send_timeout = 60;
+ send_timeout = 180;
+
+ -- How long to wait for a socket to become writable after creation
+ connect_timeout = 20;
-- Some number possibly influencing how many pending connections can be accepted
tcp_backlog = 128;
@@ -58,6 +66,20 @@ local default_config = { __index = {
-- Maximum and minimum amount of time to sleep waiting for events (adjusted for pending timers)
max_wait = 86400;
min_wait = 1e-06;
+
+ -- Enable extra noisy debug logging
+ -- TODO disable once considered stable
+ verbose = true;
+
+ -- EXPERIMENTAL
+ -- Whether to kill connections in case of callback errors.
+ fatal_errors = false;
+
+ -- Or disable protection (like server_select) for potential performance gains
+ protect_listeners = true;
+
+ -- Attempt writes instantly
+ opportunistic_writes = false;
}};
local cfg = default_config.__index;
@@ -68,54 +90,56 @@ local fds = createtable(10, 0); -- FD -> conn
local timers = indexedbheap.create();
local function noop() end
-local function closetimer(t)
- t[1] = 0;
- t[2] = noop;
- timers:remove(t.id);
+local function closetimer(id)
+ timers:remove(id);
end
-local function reschedule(t, time)
- t[1] = time;
- timers:reprioritize(t.id, time);
-end
-
--- Add absolute timer
-local function at(time, f)
- local timer = { time, f, close = closetimer, reschedule = reschedule, id = nil };
- timer.id = timers:insert(timer, time);
- return timer;
+local function reschedule(id, time)
+ time = monotonic() + time;
+ timers:reprioritize(id, time);
end
-- Add relative timer
-local function addtimer(timeout, f)
- return at(gettime() + timeout, f);
+local function addtimer(timeout, f, param)
+ local time = monotonic() + timeout;
+ if param ~= nil then
+ local timer_callback = f
+ function f(current_time, timer_id)
+ local t = timer_callback(current_time, timer_id, param)
+ return t;
+ end
+ end
+ local id = timers:insert(f, time);
+ return id;
end
-- Run callbacks of expired timers
-- Return time until next timeout
local function runtimers(next_delay, min_wait)
-- Any timers at all?
- local now = gettime();
+ local elapsed = monotonic();
+ local now = realtime();
local peek = timers:peek();
local readd;
while peek do
- if peek > now then
+ if peek > elapsed then
break;
end
local _, timer, id = timers:pop();
- local ok, ret = pcall(timer[2], now);
+ local ok, ret = xpcall(timer, traceback, now, id);
if ok and type(ret) == "number" then
- local next_time = now+ret;
- timer[1] = next_time;
+ local next_time = elapsed+ret;
-- Delay insertion of timers to be re-added
-- so they don't get called again this tick
if readd then
- readd[id] = timer;
+ readd[id] = { timer, next_time };
else
- readd = { [id] = timer };
+ readd = { [id] = { timer, next_time } };
end
+ elseif not ok then
+ log("error", "Error in timer: %s", ret);
end
peek = timers:peek();
@@ -123,7 +147,7 @@ local function runtimers(next_delay, min_wait)
if readd then
for _, timer in pairs(readd) do
- timers:insert(timer, timer[1]);
+ timers:insert(timer[1], timer[2]);
end
peek = timers:peek();
end
@@ -131,7 +155,7 @@ local function runtimers(next_delay, min_wait)
if peek == nil then
return next_delay;
else
- next_delay = peek - now;
+ next_delay = peek - elapsed;
end
if next_delay < min_wait then
@@ -154,6 +178,22 @@ function interface_mt:__tostring()
return ("FD %d"):format(self:getfd());
end
+interface.log = log;
+function interface:debug(msg, ...)
+ self.log("debug", msg, ...);
+end
+
+interface.noise = interface.debug;
+function interface:noise(msg, ...)
+ if cfg.verbose then
+ return self:debug(msg, ...);
+ end
+end
+
+function interface:error(msg, ...)
+ self.log("error", msg, ...);
+end
+
-- Replace the listener and tell the old one
function interface:setlistener(listeners, data)
self:on("detach");
@@ -164,21 +204,36 @@ end
-- Call a listener callback
function interface:on(what, ...)
if not self.listeners then
- log("error", "%s has no listeners", self);
+ self:error("Interface is missing listener callbacks");
return;
end
local listener = self.listeners["on"..what];
if not listener then
- -- log("debug", "Missing listener 'on%s'", what); -- uncomment for development and debugging
+ self:noise("Missing listener 'on%s'", what); -- uncomment for development and debugging
return;
end
- local ok, err = pcall(listener, self, ...);
+ if not cfg.protect_listeners then
+ return listener(self, ...);
+ end
+ local onerror = self.listeners.onerror or traceback;
+ local ok, err = xpcall(listener, onerror, self, ...);
if not ok then
- log("error", "Error calling on%s: %s", what, err);
+ if cfg.fatal_errors then
+ self:error("Closing due to error calling on%s: %s", what, err);
+ self:destroy();
+ else
+ self:error("Error calling on%s: %s", what, err);
+ end
+ return nil, err;
end
return err;
end
+-- Allow this one to be overridden
+function interface:onincoming(...)
+ return self:on("incoming", ...);
+end
+
-- Return the file descriptor number
function interface:getfd()
if self.conn then
@@ -235,19 +290,21 @@ end
function interface:setreadtimeout(t)
if t == false then
if self._readtimeout then
- self._readtimeout:close();
+ closetimer(self._readtimeout);
self._readtimeout = nil;
end
return
end
t = t or cfg.read_timeout;
if self._readtimeout then
- self._readtimeout:reschedule(gettime() + t);
+ reschedule(self._readtimeout, t);
else
self._readtimeout = addtimer(t, function ()
if self:on("readtimeout") then
+ self:noise("Read timeout handled");
return cfg.read_timeout;
else
+ self:debug("Read timeout not handled, disconnecting");
self:on("disconnect", "read timeout");
self:destroy();
end
@@ -259,17 +316,18 @@ end
function interface:setwritetimeout(t)
if t == false then
if self._writetimeout then
- self._writetimeout:close();
+ closetimer(self._writetimeout);
self._writetimeout = nil;
end
return
end
t = t or cfg.send_timeout;
if self._writetimeout then
- self._writetimeout:reschedule(gettime() + t);
+ reschedule(self._writetimeout, t);
else
self._writetimeout = addtimer(t, function ()
- self:on("disconnect", "write timeout");
+ self:noise("Write timeout");
+ self:on("disconnect", self._connected and "write timeout" or "connection timeout");
self:destroy();
end);
end
@@ -285,15 +343,15 @@ function interface:add(r, w)
local ok, err, errno = poll:add(fd, r, w);
if not ok then
if errno == EEXIST then
- log("debug", "%s already registered!", self);
+ self:debug("FD already registered in poller! (EEXIST)");
return self:set(r, w); -- So try to change its flags
end
- log("error", "Could not register %s: %s(%d)", self, err, errno);
+ self:debug("Could not register in poller: %s(%d)", err, errno);
return ok, err;
end
self._wantread, self._wantwrite = r, w;
fds[fd] = self;
- log("debug", "Watching %s", self);
+ self:noise("Registered in poller");
return true;
end
@@ -306,7 +364,7 @@ function interface:set(r, w)
if w == nil then w = self._wantwrite; end
local ok, err, errno = poll:set(fd, r, w);
if not ok then
- log("error", "Could not update poller state %s: %s(%d)", self, err, errno);
+ self:debug("Could not update poller state: %s(%d)", err, errno);
return ok, err;
end
self._wantread, self._wantwrite = r, w;
@@ -323,12 +381,12 @@ function interface:del()
end
local ok, err, errno = poll:del(fd);
if not ok and errno ~= ENOENT then
- log("error", "Could not unregister %s: %s(%d)", self, err, errno);
+ self:debug("Could not unregister: %s(%d)", err, errno);
return ok, err;
end
self._wantread, self._wantwrite = nil, nil;
fds[fd] = nil;
- log("debug", "Unwatched %s", self);
+ self:noise("Unregistered from poller");
return true;
end
@@ -350,7 +408,7 @@ function interface:onreadable()
local data, err, partial = self.conn:receive(self.read_size or cfg.read_size);
if data then
self:onconnect();
- self:on("incoming", data);
+ self:onincoming(data);
else
if err == "wantread" then
self:set(true, nil);
@@ -361,15 +419,28 @@ function interface:onreadable()
end
if partial and partial ~= "" then
self:onconnect();
- self:on("incoming", partial, err);
+ self:onincoming(partial, err);
end
if err ~= "timeout" then
+ if err == "closed" then
+ self:debug("Connection closed by remote");
+ else
+ self:debug("Read error, closing (%s)", err);
+ end
self:on("disconnect", err);
self:destroy()
return;
end
end
if not self.conn then return; end
+ if self._limit and (data or partial) then
+ local cost = self._limit * #(data or partial);
+ if cost > cfg.min_wait then
+ self:setreadtimeout(false);
+ self:pausefor(cost);
+ return;
+ end
+ end
if self._wantread and self.conn:dirty() then
self:setreadtimeout(false);
self:pausefor(cfg.read_retry_delay);
@@ -383,7 +454,7 @@ function interface:onwritable()
self:onconnect();
if not self.conn then return; end -- could have been closed in onconnect
local buffer = self.writebuffer;
- local data = t_concat(buffer);
+ local data = #buffer == 1 and buffer[1] or t_concat(buffer);
local ok, err, partial = self.conn:send(data);
if ok then
self:set(nil, false);
@@ -394,10 +465,12 @@ function interface:onwritable()
self:ondrain(); -- Be aware of writes in ondrain
return;
elseif partial then
+ self:debug("Sent %d out of %d buffered bytes", partial, #data);
buffer[1] = data:sub(partial+1);
for i = #buffer, 2, -1 do
buffer[i] = nil;
end
+ self:set(nil, true);
self:setwritetimeout();
end
if err == "wantwrite" or err == "timeout" then
@@ -423,8 +496,14 @@ function interface:write(data)
else
self.writebuffer = { data };
end
- self:setwritetimeout();
- self:set(nil, true);
+ if not self._write_lock then
+ if cfg.opportunistic_writes then
+ self:onwritable();
+ return #data;
+ end
+ self:setwritetimeout();
+ self:set(nil, true);
+ end
return #data;
end
interface.send = interface.write;
@@ -434,10 +513,10 @@ function interface:close()
if self.writebuffer and self.writebuffer[1] then
self:set(false, true); -- Flush final buffer contents
self.write, self.send = noop, noop; -- No more writing
- log("debug", "Close %s after writing", self);
+ self:debug("Close after writing remaining buffered data");
self.ondrain = interface.close;
else
- log("debug", "Close %s now", self);
+ self:debug("Closing now");
self.write, self.send = noop, noop;
self.close = noop;
self:on("disconnect");
@@ -466,31 +545,32 @@ function interface:starttls(tls_ctx)
if tls_ctx then self.tls_ctx = tls_ctx; end
self.starttls = false;
if self.writebuffer and self.writebuffer[1] then
- log("debug", "Start TLS on %s after write", self);
+ self:debug("Start TLS after write");
self.ondrain = interface.starttls;
self:set(nil, true); -- make sure wantwrite is set
else
if self.ondrain == interface.starttls then
self.ondrain = nil;
end
- self.onwritable = interface.tlshandskake;
- self.onreadable = interface.tlshandskake;
+ self.onwritable = interface.tlshandshake;
+ self.onreadable = interface.tlshandshake;
self:set(true, true);
- log("debug", "Prepare to start TLS on %s", self);
+ self:debug("Prepared to start TLS");
end
end
-function interface:tlshandskake()
+function interface:tlshandshake()
self:setwritetimeout(false);
self:setreadtimeout(false);
if not self._tls then
self._tls = true;
- log("debug", "Start TLS on %s now", self);
+ self:debug("Starting TLS now");
self:del();
+ self:updatenames(); -- Can't getpeer/sockname after wrap()
local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx);
if not ok then
conn, err = ok, conn;
- log("error", "Failed to initialize TLS: %s", err);
+ self:debug("Failed to initialize TLS: %s", err);
end
if not conn then
self:on("disconnect", err);
@@ -499,33 +579,56 @@ function interface:tlshandskake()
end
conn:settimeout(0);
self.conn = conn;
- if conn.sni and self.servername then
- conn:sni(self.servername);
+ if conn.sni then
+ if self.servername then
+ conn:sni(self.servername);
+ elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
+ conn:sni(self._server.hosts, true);
+ end
+ end
+ if self.extra and self.extra.tlsa and conn.settlsa then
+ -- TODO Error handling
+ if not conn:setdane(self.servername or self.extra.dane_hostname) then
+ self:debug("Could not enable DANE on connection");
+ else
+ self:debug("Enabling DANE with %d TLSA records", #self.extra.tlsa);
+ self:noise("DANE hostname is %q", self.servername or self.extra.dane_hostname);
+ for _, tlsa in ipairs(self.extra.tlsa) do
+ self:noise("TLSA: %q", tlsa);
+ conn:settlsa(tlsa.use, tlsa.select, tlsa.match, tlsa.data);
+ end
+ end
end
self:on("starttls");
self.ondrain = nil;
- self.onwritable = interface.tlshandskake;
- self.onreadable = interface.tlshandskake;
+ self.onwritable = interface.tlshandshake;
+ self.onreadable = interface.tlshandshake;
return self:init();
end
+ self:noise("Continuing TLS handshake");
local ok, err = self.conn:dohandshake();
if ok then
- log("debug", "TLS handshake on %s complete", self);
+ local info = self.conn.info and self.conn:info();
+ if type(info) == "table" then
+ self:debug("TLS handshake complete (%s with %s)", info.protocol, info.cipher);
+ else
+ self:debug("TLS handshake complete");
+ end
self.onwritable = nil;
self.onreadable = nil;
self:on("status", "ssl-handshake-complete");
self:setwritetimeout();
self:set(true, true);
elseif err == "wantread" then
- log("debug", "TLS handshake on %s to wait until readable", self);
+ self:noise("TLS handshake to wait until readable");
self:set(true, false);
self:setreadtimeout(cfg.ssl_handshake_timeout);
elseif err == "wantwrite" then
- log("debug", "TLS handshake on %s to wait until writable", self);
+ self:noise("TLS handshake to wait until writable");
self:set(false, true);
self:setwritetimeout(cfg.ssl_handshake_timeout);
else
- log("debug", "TLS handshake error on %s: %s", self, err);
+ self:debug("TLS handshake error: %s", err);
self:on("disconnect", err);
self:destroy();
end
@@ -533,15 +636,18 @@ end
local function wrapsocket(client, server, read_size, listeners, tls_ctx, extra) -- luasocket object -> interface object
client:settimeout(0);
+ local conn_id = ("conn%s"):format(new_id());
local conn = setmetatable({
conn = client;
_server = server;
- created = gettime();
+ created = realtime();
listeners = listeners;
read_size = read_size or (server and server.read_size);
writebuffer = {};
tls_ctx = tls_ctx or (server and server.tls_ctx);
tls_direct = server and server.tls_direct;
+ id = conn_id;
+ log = logger.init(conn_id);
extra = extra;
}, interface_mt);
@@ -558,12 +664,12 @@ end
function interface:updatenames()
local conn = self.conn;
local ok, peername, peerport = pcall(conn.getpeername, conn);
- if ok then
- self.peername, self.peerport = peername, peerport;
+ if ok and peername then
+ self.peername, self.peerport = peername, peerport or 0;
end
local ok, sockname, sockport = pcall(conn.getsockname, conn);
- if ok then
- self.sockname, self.sockport = sockname, sockport;
+ if ok and sockname then
+ self.sockname, self.sockport = sockname, sockport or 0;
end
end
@@ -572,76 +678,129 @@ end
function interface:onacceptable()
local conn, err = self.conn:accept();
if not conn then
- log("debug", "Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval);
+ self:debug("Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval);
self:pausefor(cfg.accept_retry_interval);
return;
end
local client = wrapsocket(conn, self, nil, self.listeners);
- log("debug", "New connection %s", tostring(client));
+ client:debug("New connection %s on server %s", client, self);
client:init();
if self.tls_direct then
client:starttls(self.tls_ctx);
+ else
+ client:onconnect();
end
end
-- Initialization
function interface:init()
- self:setwritetimeout();
+ self:setwritetimeout(cfg.connect_timeout);
return self:add(true, true);
end
function interface:pause()
+ self:noise("Pause reading");
return self:set(false);
end
function interface:resume()
+ self:noise("Resume reading");
return self:set(true);
end
-- Pause connection for some time
function interface:pausefor(t)
+ self:noise("Pause for %fs", t);
if self._pausefor then
- self._pausefor:close();
+ closetimer(self._pausefor);
+ self._pausefor = nil;
end
if t == false then return; end
self:set(false);
self._pausefor = addtimer(t, function ()
self._pausefor = nil;
self:set(true);
+ self:noise("Resuming after pause, connection is %s", not self.conn and "missing" or self.conn:dirty() and "dirty" or "clean");
if self.conn and self.conn:dirty() then
self:onreadable();
end
end);
end
+function interface:setlimit(Bps)
+ if Bps > 0 then
+ self._limit = 1/Bps;
+ else
+ self._limit = nil;
+ end
+end
+
+function interface:pause_writes()
+ if self._write_lock then
+ return
+ end
+ self:noise("Pause writes");
+ self._write_lock = true;
+ self:setwritetimeout(false);
+ self:set(nil, false);
+end
+
+function interface:resume_writes()
+ if not self._write_lock then
+ return
+ end
+ self:noise("Resume writes");
+ self._write_lock = nil;
+ if self.writebuffer[1] then
+ self:setwritetimeout();
+ self:set(nil, true);
+ end
+end
+
-- Connected!
function interface:onconnect()
- if self.conn and not self.peername and self.conn.getpeername then
- self.peername, self.peerport = self.conn:getpeername();
- end
+ self._connected = true;
+ self:updatenames();
+ self:debug("Connected (%s)", self);
self.onconnect = noop;
self:on("connect");
end
-local function addserver(addr, port, listeners, read_size, tls_ctx)
- local conn, err = socket.bind(addr, port, cfg.tcp_backlog);
- if not conn then return conn, err; end
- conn:settimeout(0);
+local function wrapserver(conn, addr, port, listeners, config)
local server = setmetatable({
conn = conn;
- created = gettime();
+ created = realtime();
listeners = listeners;
- read_size = read_size;
+ read_size = config and config.read_size;
onreadable = interface.onacceptable;
- tls_ctx = tls_ctx;
- tls_direct = tls_ctx and true or false;
+ tls_ctx = config and config.tls_ctx;
+ tls_direct = config and config.tls_direct;
+ hosts = config and config.sni_hosts;
sockname = addr;
sockport = port;
+ log = logger.init(("serv%s"):format(new_id()));
}, interface_mt);
+ server:debug("Server %s created", server);
server:add(true, false);
return server;
end
+local function listen(addr, port, listeners, config)
+ local conn, err = socket.bind(addr, port, cfg.tcp_backlog);
+ if not conn then return conn, err; end
+ conn:settimeout(0);
+ return wrapserver(conn, addr, port, listeners, config);
+end
+
+-- COMPAT
+local function addserver(addr, port, listeners, read_size, tls_ctx)
+ return listen(addr, port, listeners, {
+ read_size = read_size;
+ tls_ctx = tls_ctx;
+ tls_direct = tls_ctx and true or false;
+ });
+end
+
-- COMPAT
local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx, extra)
local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra);
@@ -675,13 +834,19 @@ local function addclient(addr, port, listeners, read_size, tls_ctx, typ, extra)
return nil, "invalid socket type";
end
local conn, err = create();
+ if not conn then return conn, err; end
local ok, err = conn:settimeout(0);
if not ok then return ok, err; end
local ok, err = conn:setpeername(addr, port);
if not ok and err ~= "timeout" then return ok, err; end
local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra)
local ok, err = client:init();
+ if not client.peername then
+ -- otherwise not set until connected
+ client.peername, client.peerport = addr, port;
+ end
if not ok then return ok, err; end
+ client:debug("Client %s created", client);
if tls_ctx then
client:starttls(tls_ctx);
end
@@ -703,23 +868,23 @@ local function watchfd(fd, onreadable, onwritable)
end;
-- Otherwise it'll need to be something LuaSocket-compatible
end
+ conn.id = new_id();
+ conn.log = logger.init(("fdwatch%s"):format(conn.id));
conn:add(onreadable, onwritable);
return conn;
end;
-- Dump all data from one connection into another
-local function link(from, to)
- from.listeners = setmetatable({
- onincoming = function (_, data)
- from:pause();
- to:write(data);
- end,
- }, {__index=from.listeners});
- to.listeners = setmetatable({
- ondrain = function ()
- from:resume();
- end,
- }, {__index=to.listeners});
+local function link(from, to, read_size)
+ from:debug("Linking to %s", to.id);
+ function from:onincoming(data)
+ self:pause();
+ to:write(data);
+ end
+ function to:ondrain() -- luacheck: ignore 212/self
+ from:resume();
+ end
+ from:set_mode(read_size);
from:set(true, nil);
to:set(nil, true);
end
@@ -778,11 +943,21 @@ return {
addserver = addserver;
addclient = addclient;
add_task = addtimer;
- at = at;
+ timer = {
+ -- API-compatible with util.timer
+ add_task = addtimer;
+ stop = closetimer;
+ reschedule = reschedule;
+ to_absolute_time = function (t)
+ return t-monotonic()+realtime();
+ end;
+ };
+ listen = listen;
loop = loop;
closeall = closeall;
setquitting = setquitting;
wrapclient = wrapclient;
+ wrapserver = wrapserver;
watchfd = watchfd;
link = link;
set_config = function (newconfig)
@@ -792,6 +967,7 @@ return {
-- libevent emulation
event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 };
addevent = function (fd, mode, callback)
+ log("warn", "Using deprecated libevent emulation, please update code to use watchfd API instead");
local function onevent(self)
local ret = self:callback();
if ret == -1 then
@@ -811,6 +987,8 @@ return {
fds[fd] = nil;
end;
}, interface_mt);
+ conn.id = conn:getfd();
+ conn.log = logger.init(("fdwatch%d"):format(conn.id));
local ok, err = conn:add(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
if not ok then return ok, err; end
return conn;
diff --git a/net/server_event.lua b/net/server_event.lua
index 746526ce..f7e1f448 100644
--- a/net/server_event.lua
+++ b/net/server_event.lua
@@ -165,8 +165,12 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed
return false
end
- if self.conn.sni and self.servername then
- self.conn:sni(self.servername);
+ if self.conn.sni then
+ if self.servername then
+ self.conn:sni(self.servername);
+ elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
+ self.conn:sni(self._server.hosts, true);
+ end
end
self.conn:settimeout( 0 ) -- set non blocking
@@ -258,6 +262,7 @@ end
--TODO: Deprecate
function interface_mt:lock_read(switch)
+ log("warn", ":lock_read is deprecated, use :pasue() and :resume()");
if switch then
return self:pause();
else
@@ -277,6 +282,19 @@ function interface_mt:resume()
end
end
+function interface_mt:pause_writes()
+ return self:_lock(self.nointerface, self.noreading, true);
+end
+
+function interface_mt:resume_writes()
+ self:_lock(self.nointerface, self.noreading, false);
+ if self.writecallback and not self.eventwrite then
+ self.eventwrite = addevent( base, self.conn, EV_WRITE, self.writecallback, cfg.WRITE_TIMEOUT ); -- register callback
+ return true;
+ end
+end
+
+
function interface_mt:counter(c)
if c then
self._connections = self._connections + c
@@ -286,7 +304,7 @@ end
-- Public methods
function interface_mt:write(data)
- if self.nowriting then return nil, "locked" end
+ if self.nointerface then return nil, "locked"; end
--vdebug( "try to send data to client, id/data:", self.id, data )
data = tostring( data )
local len = #data
@@ -298,7 +316,7 @@ function interface_mt:write(data)
end
t_insert(self.writebuffer, data) -- new buffer
self.writebufferlen = total
- if not self.eventwrite then -- register new write event
+ if not self.eventwrite and not self.nowriting then -- register new write event
--vdebug( "register new write event" )
self.eventwrite = addevent( base, self.conn, EV_WRITE, self.writecallback, cfg.WRITE_TIMEOUT )
end
@@ -445,10 +463,6 @@ end
function interface_mt:ontimeout()
end
function interface_mt:onreadtimeout()
- self.fatalerror = "timeout during receiving"
- debug( "connection failed:", self.fatalerror )
- self:_close()
- self.eventread = nil
end
function interface_mt:ondrain()
end
@@ -642,7 +656,7 @@ local function handleclient( client, ip, port, server, pattern, listener, sslctx
return interface
end
-local function handleserver( server, addr, port, pattern, listener, sslctx ) -- creates an server interface
+local function handleserver( server, addr, port, pattern, listener, sslctx, startssl ) -- creates a server interface
debug "creating server interface..."
local interface = {
_connections = 0;
@@ -658,6 +672,7 @@ local function handleserver( server, addr, port, pattern, listener, sslctx ) --
_ip = addr, _port = port, _pattern = pattern,
_sslctx = sslctx;
+ hosts = {};
}
interface.id = tostring(interface):match("%x+$");
interface.readcallback = function( event ) -- server handler, called on incoming connections
@@ -677,6 +692,7 @@ local function handleserver( server, addr, port, pattern, listener, sslctx ) --
end
end
--vdebug("max connection check ok, accepting...")
+ -- luacheck: ignore 231/err
local client, err = server:accept() -- try to accept; TODO: check err
while client do
if interface._connections >= cfg.MAX_CONNECTIONS then
@@ -688,7 +704,7 @@ local function handleserver( server, addr, port, pattern, listener, sslctx ) --
interface._connections = interface._connections + 1 -- increase connection count
local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx )
--vdebug( "client id:", clientinterface, "startssl:", startssl )
- if has_luasec and sslctx then
+ if has_luasec and startssl then
clientinterface:starttls(sslctx, true)
else
clientinterface:_start_session( true )
@@ -707,9 +723,9 @@ local function handleserver( server, addr, port, pattern, listener, sslctx ) --
return interface
end
-local function addserver( addr, port, listener, pattern, sslctx, startssl ) -- TODO: check arguments
- --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslctx or "nil", startssl or "nil")
- if sslctx and not has_luasec then
+local function listen(addr, port, listener, config)
+ config = config or {}
+ if config.sslctx and not has_luasec then
debug "fatal error: luasec not found"
return nil, "luasec not found"
end
@@ -718,11 +734,20 @@ local function addserver( addr, port, listener, pattern, sslctx, startssl ) --
debug( "creating server socket on "..addr.." port "..port.." failed:", err )
return nil, err
end
- local interface = handleserver( server, addr, port, pattern, listener, sslctx, startssl ) -- new server handler
+ local interface = handleserver( server, addr, port, config.read_size, listener, config.tls_ctx, config.tls_direct) -- new server handler
debug( "new server created with id:", tostring(interface))
return interface
end
+local function addserver( addr, port, listener, pattern, sslctx ) -- TODO: check arguments
+ --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslctx or "nil", startssl or "nil")
+ return listen( addr, port, listener, {
+ read_size = pattern,
+ tls_ctx = sslctx,
+ tls_direct = not not sslctx,
+ });
+end
+
local function wrapclient( client, ip, port, listeners, pattern, sslctx, extra )
local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx, extra )
interface:_start_connection(sslctx)
@@ -756,6 +781,7 @@ local function addclient( addr, serverport, listener, pattern, sslctx, typ, extr
client:settimeout( 0 ) -- set nonblocking
local res, err = client:setpeername( addr, serverport ) -- connect
if res or ( err == "timeout" ) then
+ -- luacheck: ignore 211/port
local ip, port = client:getsockname( )
local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx, extra )
debug( "new connection id:", interface.id )
@@ -883,6 +909,7 @@ return {
event_base = base,
addevent = newevent,
addserver = addserver,
+ listen = listen,
addclient = addclient,
wrapclient = wrapclient,
setquitting = setquitting,
diff --git a/net/server_select.lua b/net/server_select.lua
index deb8fe48..09c1c027 100644
--- a/net/server_select.lua
+++ b/net/server_select.lua
@@ -68,6 +68,7 @@ local idfalse
local closeall
local addsocket
local addserver
+local listen
local addtimer
local getserver
local wrapserver
@@ -123,7 +124,7 @@ local _maxsslhandshake
_server = { } -- key = port, value = table; list of listening servers
_readlist = { } -- array with sockets to read from
-_sendlist = { } -- arrary with sockets to write to
+_sendlist = { } -- array with sockets to write to
_timerlist = { } -- array of timer functions
_socketlist = { } -- key = socket, value = wrapped socket (handlers)
_readtimes = { } -- key = handler, value = timestamp of last data reading
@@ -149,7 +150,7 @@ _checkinterval = 30 -- interval in secs to check idle clients
_sendtimeout = 60000 -- allowed send idle time in secs
_readtimeout = 14 * 60 -- allowed read idle time in secs
-local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows
+local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to determine whether this is Windows
_maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows
_maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows
@@ -157,7 +158,7 @@ _maxsslhandshake = 30 -- max handshake round-trips
----------------------------------// PRIVATE //--
-wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd
+wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, ssldirect ) -- this function wraps a server -- FIXME Make sure FD < _maxfd
if socket:getfd() >= _maxfd then
out_error("server.lua: Disallowed FD number: "..socket:getfd())
@@ -183,6 +184,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t
handler.sslctx = function( )
return sslctx
end
+ handler.hosts = {} -- sni
handler.remove = function( )
connections = connections - 1
if handler then
@@ -244,13 +246,13 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t
local client, err = accept( socket ) -- try to accept
if client then
local ip, clientport = client:getpeername( )
- local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket
+ local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx, ssldirect ) -- wrap new client socket
if err then -- error while wrapping ssl socket
return false
end
connections = connections + 1
out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))
- if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes
+ if dispatch and not ssldirect then -- SSL connections will notify onconnect when handshake completes
return dispatch( handler );
end
return;
@@ -264,7 +266,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t
return handler
end
-wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, extra ) -- this function wraps a client to a handler object
+wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, ssldirect, extra ) -- this function wraps a client to a handler object
if socket:getfd() >= _maxfd then
out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent
@@ -287,6 +289,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
local ssl
+ local pending
+
local dispatch = listeners.onincoming
local status = listeners.onstatus
local disconnect = listeners.ondisconnect
@@ -341,6 +345,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
listeners.onattach(self, data)
end
end
+ handler._setpending = function( )
+ pending = true
+ end
handler.getstats = function( )
return readtraffic, sendtraffic
end
@@ -377,7 +384,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
_readlistlen = removesocket( _readlist, socket, _readlistlen )
_readtimes[ handler ] = nil
if bufferqueuelen ~= 0 then
- handler.sendbuffer() -- Try now to send any outstanding data
+ handler:sendbuffer() -- Try now to send any outstanding data
if bufferqueuelen ~= 0 then -- Still not empty, so we'll try again later
if handler then
handler.write = nil -- ... but no further writing allowed
@@ -429,9 +436,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
bufferlen = bufferlen + #data
if bufferlen > maxsendlen then
_closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle
- handler.write = idfalse -- don't write anymore
return false
- elseif socket and not _sendlist[ socket ] then
+ elseif not nosend and socket and not _sendlist[ socket ] then
_sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
end
bufferqueuelen = bufferqueuelen + 1
@@ -461,49 +467,55 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
maxreadlen = readlen or maxreadlen
return bufferlen, maxreadlen, maxsendlen
end
- --TODO: Deprecate
handler.lock_read = function (self, switch)
+ out_error( "server.lua, lock_read() is deprecated, use pause() and resume()" )
if switch == true then
- local tmp = _readlistlen
- _readlistlen = removesocket( _readlist, socket, _readlistlen )
- _readtimes[ handler ] = nil
- if _readlistlen ~= tmp then
- noread = true
- end
+ return self:pause()
elseif switch == false then
- if noread then
- noread = false
- _readlistlen = addsocket(_readlist, socket, _readlistlen)
- _readtimes[ handler ] = _currenttime
- end
+ return self:resume()
end
return noread
end
handler.pause = function (self)
- return self:lock_read(true);
+ local tmp = _readlistlen
+ _readlistlen = removesocket( _readlist, socket, _readlistlen )
+ _readtimes[ handler ] = nil
+ if _readlistlen ~= tmp then
+ noread = true
+ end
+ return noread;
end
handler.resume = function (self)
- return self:lock_read(false);
+ if noread then
+ noread = false
+ _readlistlen = addsocket(_readlist, socket, _readlistlen)
+ _readtimes[ handler ] = _currenttime
+ end
+ return noread;
end
handler.lock = function( self, switch )
- handler.lock_read (switch)
+ out_error( "server.lua, lock() is deprecated" )
+ handler.lock_read (self, switch)
if switch == true then
- handler.write = idfalse
- local tmp = _sendlistlen
- _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
- _writetimes[ handler ] = nil
- if _sendlistlen ~= tmp then
- nosend = true
- end
+ handler.pause_writes (self)
elseif switch == false then
- handler.write = write
- if nosend then
- nosend = false
- write( "" )
- end
+ handler.resume_writes (self)
end
return noread, nosend
end
+ handler.pause_writes = function (self)
+ local tmp = _sendlistlen
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
+ _writetimes[ handler ] = nil
+ nosend = true
+ end
+ handler.resume_writes = function (self)
+ nosend = false
+ if bufferlen > 0 and socket then
+ _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
+ end
+ end
+
local _readbuffer = function( ) -- this function reads data
local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern"
if not err or (err == "wantread" or err == "timeout") then -- received something
@@ -518,6 +530,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
_readtraffic = _readtraffic + count
_readtimes[ handler ] = _currenttime
--out_put( "server.lua: read data '", buffer:gsub("[^%w%p ]", "."), "', error: ", err )
+ if pending then -- connection established
+ pending = nil
+ if listeners.onconnect then
+ listeners.onconnect(handler)
+ end
+ end
return dispatch( handler, buffer, err )
else -- connections was closed or fatal error
out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
@@ -528,6 +546,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
local _sendbuffer = function( ) -- this function sends data
local succ, err, byte, buffer, count;
if socket then
+ if pending then
+ pending = nil
+ if listeners.onconnect then
+ listeners.onconnect(handler);
+ end
+ end
buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )
succ, err, byte = send( socket, buffer, 1, bufferlen )
count = ( succ or byte or 0 ) * STAT_UNIT
@@ -604,7 +628,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
coroutine_yield( ) -- handshake not finished
end
end
- err = "ssl handshake error: " .. ( err or "handshake too long" );
+ err = ( err or "handshake too long" );
out_put( "server.lua: ", err );
_ = handler and handler:force_close(err)
return false, err -- handshake failed
@@ -624,13 +648,18 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
local oldsocket, err = socket
socket, err = ssl_wrap( socket, sslctx ) -- wrap socket
+
if not socket then
out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
return nil, err -- fatal error
end
- if socket.sni and self.servername then
- socket:sni(self.servername);
+ if socket.sni then
+ if self.servername then
+ socket:sni(self.servername);
+ elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
+ socket:sni(self.server().hosts, true);
+ end
end
socket:settimeout( 0 )
@@ -668,7 +697,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
_socketlist[ socket ] = handler
_readlistlen = addsocket(_readlist, socket, _readlistlen)
- if sslctx and has_luasec then
+ if sslctx and ssldirect and has_luasec then
out_put "server.lua: auto-starting ssl negotiation..."
handler.autostart_ssl = true;
local ok, err = handler:starttls(sslctx);
@@ -723,7 +752,7 @@ local function link(sender, receiver, buffersize)
local sender_locked;
local _sendbuffer = receiver.sendbuffer;
function receiver.sendbuffer()
- _sendbuffer();
+ _sendbuffer(receiver);
if sender_locked and receiver.bufferlen() < buffersize then
sender:lock_read(false); -- Unlock now
sender_locked = nil;
@@ -743,9 +772,13 @@ end
----------------------------------// PUBLIC //--
-addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
+listen = function ( addr, port, listeners, config )
addr = addr or "*"
+ config = config or {}
local err
+ local sslctx = config.tls_ctx;
+ local ssldirect = config.tls_direct;
+ local pattern = config.read_size;
if type( listeners ) ~= "table" then
err = "invalid listener table"
elseif type ( addr ) ~= "string" then
@@ -766,7 +799,7 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function
out_error( "server.lua, [", addr, "]:", port, ": ", err )
return nil, err
end
- local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket
+ local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, ssldirect ) -- wrap new server socket
if not handler then
server:close( )
return nil, err
@@ -779,6 +812,14 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function
return handler
end
+addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
+ return listen(addr, port, listeners, {
+ read_size = pattern;
+ tls_ctx = sslctx;
+ tls_direct = sslctx and true or false;
+ });
+end
+
getserver = function ( addr, port )
return _server[ addr..":"..port ];
end
@@ -921,7 +962,7 @@ loop = function(once) -- this is the main loop of the program
for _, socket in ipairs( read ) do -- receive data
local handler = _socketlist[ socket ]
if handler then
- handler.readbuffer( )
+ handler:readbuffer( )
else
closesocket( socket )
out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen
@@ -930,7 +971,7 @@ loop = function(once) -- this is the main loop of the program
for _, socket in ipairs( write ) do -- send data waiting in writequeues
local handler = _socketlist[ socket ]
if handler then
- handler.sendbuffer( )
+ handler:sendbuffer( )
else
closesocket( socket )
out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen
@@ -987,21 +1028,13 @@ end
--// EXPERIMENTAL //--
local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx, extra )
- local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, extra)
+ local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, sslctx, extra)
if not handler then return nil, err end
_socketlist[ socket ] = handler
if not sslctx then
+ handler._setpending()
_readlistlen = addsocket(_readlist, socket, _readlistlen)
_sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
- if listeners.onconnect then
- -- When socket is writeable, call onconnect
- local _sendbuffer = handler.sendbuffer;
- handler.sendbuffer = function ()
- handler.sendbuffer = _sendbuffer;
- listeners.onconnect(handler);
- return _sendbuffer(); -- Send any queued outgoing data
- end
- end
end
return handler, socket
end
@@ -1123,6 +1156,7 @@ return {
stats = stats,
closeall = closeall,
addserver = addserver,
+ listen = listen,
getserver = getserver,
setlogger = setlogger,
getsettings = getsettings,
diff --git a/net/unbound.lua b/net/unbound.lua
new file mode 100644
index 00000000..22b0614b
--- /dev/null
+++ b/net/unbound.lua
@@ -0,0 +1,220 @@
+-- libunbound based net.adns replacement for Prosody IM
+-- Copyright (C) 2013-2015 Kim Alvefur
+--
+-- This file is MIT licensed.
+--
+-- luacheck: ignore prosody
+
+local setmetatable = setmetatable;
+local tostring = tostring;
+local t_concat = table.concat;
+local s_format = string.format;
+local s_lower = string.lower;
+local s_upper = string.upper;
+local noop = function() end;
+
+local logger = require "util.logger";
+local log = logger.init("unbound");
+local net_server = require "net.server";
+local libunbound = require"lunbound";
+local promise = require"util.promise";
+local new_id = require "util.id".medium;
+
+local gettime = require"socket".gettime;
+local dns_utils = require"util.dns";
+local classes, types, errors = dns_utils.classes, dns_utils.types, dns_utils.errors;
+local parsers = dns_utils.parsers;
+
+local function add_defaults(conf)
+ if conf then
+ for option, default in pairs(libunbound.config) do
+ if conf[option] == nil then
+ conf[option] = default;
+ end
+ end
+ end
+ return conf;
+end
+
+local unbound_config;
+if prosody then
+ local config = require"core.configmanager";
+ unbound_config = add_defaults(config.get("*", "unbound"));
+ prosody.events.add_handler("config-reloaded", function()
+ unbound_config = add_defaults(config.get("*", "unbound"));
+ end);
+end
+-- Note: libunbound will default to using root hints if resolvconf is unset
+
+local function connect_server(unbound, server)
+ log("debug", "Setting up net.server event handling for %s", unbound);
+ return server.watchfd(unbound, function ()
+ log("debug", "Processing queries for %s", unbound);
+ unbound:process()
+ end);
+end
+
+local unbound, server_conn;
+
+local function initialize()
+ unbound = libunbound.new(unbound_config);
+ server_conn = connect_server(unbound, net_server);
+end
+if prosody then
+ prosody.events.add_handler("server-started", initialize);
+end
+
+local answer_mt = {
+ __tostring = function(self)
+ if self._string then return self._string end
+ local h = s_format("Status: %s", errors[self.status]);
+ if self.secure then
+ h = h .. ", Secure";
+ elseif self.bogus then
+ h = h .. s_format(", Bogus: %s", self.bogus);
+ end
+ local t = { h };
+ for i = 1, #self do
+ t[i+1]=self.qname.."\t"..classes[self.qclass].."\t"..types[self.qtype].."\t"..tostring(self[i]);
+ end
+ local _string = t_concat(t, "\n");
+ self._string = _string;
+ return _string;
+ end;
+};
+
+local waiting_queries = {};
+
+local function prep_answer(a)
+ if not a then return end
+ local status = errors[a.rcode];
+ local qclass = classes[a.qclass];
+ local qtype = types[a.qtype];
+ a.status, a.class, a.type = status, qclass, qtype;
+
+ local t = s_lower(qtype);
+ local rr_mt = { __index = a, __tostring = function(self) return tostring(self[t]) end };
+ local parser = parsers[qtype];
+ for i = 1, #a do
+ if a.bogus then
+ -- Discard bogus data
+ a[i] = nil;
+ else
+ a[i] = setmetatable({[t] = parser(a[i])}, rr_mt);
+ end
+ end
+ return setmetatable(a, answer_mt);
+end
+
+local function lookup(callback, qname, qtype, qclass)
+ if not unbound then initialize(); end
+ qtype = qtype and s_upper(qtype) or "A";
+ qclass = qclass and s_upper(qclass) or "IN";
+ local ntype, nclass = types[qtype], classes[qclass];
+ local startedat = gettime();
+ local ret;
+ local log_query = logger.init("unbound.query"..new_id());
+ local function callback_wrapper(a, err)
+ local gotdataat = gettime();
+ waiting_queries[ret] = nil;
+ if a then
+ prep_answer(a);
+ log_query("debug", "Results for %s %s %s: %s (%s, %f sec)", qname, qclass, qtype, a.rcode == 0 and (#a .. " items") or a.status,
+ a.secure and "Secure" or a.bogus or "Insecure", gotdataat - startedat); -- Insecure as in unsigned
+ else
+ log_query("error", "Results for %s %s %s: %s", qname, qclass, qtype, tostring(err));
+ end
+ local ok, cerr = pcall(callback, a, err);
+ if not ok then log_query("error", "Error in callback: %s", cerr); end
+ end
+ log_query("debug", "Resolve %s %s %s", qname, qclass, qtype);
+ local err;
+ ret, err = unbound:resolve_async(callback_wrapper, qname, ntype, nclass);
+ if ret then
+ waiting_queries[ret] = callback;
+ else
+ log_query("warn", "Resolver error: %s", err);
+ end
+ return ret, err;
+end
+
+local function lookup_sync(qname, qtype, qclass)
+ if not unbound then initialize(); end
+ qtype = qtype and s_upper(qtype) or "A";
+ qclass = qclass and s_upper(qclass) or "IN";
+ local ntype, nclass = types[qtype], classes[qclass];
+ local a, err = unbound:resolve(qname, ntype, nclass);
+ if not a then return a, err; end
+ return prep_answer(a);
+end
+
+local function cancel(id)
+ local cb = waiting_queries[id];
+ unbound:cancel(id);
+ if cb then
+ cb(nil, "canceled");
+ waiting_queries[id] = nil;
+ end
+ return true;
+end
+
+-- Reinitiate libunbound context, drops cache
+local function purge()
+ for id in pairs(waiting_queries) do cancel(id); end
+ if server_conn then server_conn:close(); end
+ initialize();
+ return true;
+end
+
+local function not_implemented()
+ error "not implemented";
+end
+-- Public API
+local _M = {
+ lookup = lookup;
+ cancel = cancel;
+ new_async_socket = not_implemented;
+ dns = {
+ lookup = lookup_sync;
+ cancel = cancel;
+ cache = noop;
+ socket_wrapper_set = noop;
+ settimeout = noop;
+ query = noop;
+ purge = purge;
+ random = noop;
+ peek = noop;
+
+ types = types;
+ classes = classes;
+ };
+};
+
+local function lookup_promise(_, qname, qtype, qclass)
+ return promise.new(function (resolve, reject)
+ local function callback(answer, err)
+ if err then
+ return reject(err);
+ else
+ return resolve(answer);
+ end
+ end
+ local ret, err = lookup(callback, qname, qtype, qclass)
+ if not ret then reject(err); end
+ end);
+end
+
+local wrapper = {
+ lookup = function (_, callback, qname, qtype, qclass)
+ return lookup(callback, qname, qtype, qclass)
+ end;
+ lookup_promise = lookup_promise;
+ _resolver = {
+ settimeout = function () end;
+ closeall = function () end;
+ };
+}
+
+function _M.resolver() return wrapper; end
+
+return _M;
diff --git a/net/websocket.lua b/net/websocket.lua
index 469c6a58..193cd556 100644
--- a/net/websocket.lua
+++ b/net/websocket.lua
@@ -23,6 +23,7 @@ local websockets = {};
local websocket_listeners = {};
function websocket_listeners.ondisconnect(conn, err)
local s = websockets[conn];
+ if not s then return; end
websockets[conn] = nil;
if s.close_timer then
timer.stop(s.close_timer);
@@ -113,7 +114,7 @@ function websocket_listeners.onincoming(conn, buffer, err) -- luacheck: ignore 2
frame.MASK = true; -- RFC 6455 6.1.5: If the data is being sent by the client, the frame(s) MUST be masked
conn:write(frames.build(frame));
elseif frame.opcode == 0xA then -- Pong frame
- log("debug", "Received unexpected pong frame: " .. tostring(frame.data));
+ log("debug", "Received unexpected pong frame: %s", frame.data);
else
return fail(s, 1002, "Reserved opcode");
end
@@ -131,7 +132,7 @@ end
function websocket_methods:close(code, reason)
if self.readyState < 2 then
code = code or 1000;
- log("debug", "closing WebSocket with code %i: %s" , code , tostring(reason));
+ log("debug", "closing WebSocket with code %i: %s" , code , reason);
self.readyState = 2;
local conn = self.conn;
conn:write(frames.build_close(code, reason, true));
@@ -245,7 +246,7 @@ local function connect(url, ex, listeners)
or (protocol and not protocol[r.headers["sec-websocket-protocol"]])
then
s.readyState = 3;
- log("warn", "WebSocket connection to %s failed: %s", url, tostring(b));
+ log("warn", "WebSocket connection to %s failed: %s", url, b);
if s.onerror then s:onerror("connecting-failed"); end
return;
end
diff --git a/net/websocket/frames.lua b/net/websocket/frames.lua
index 1d0ac06f..03ce21a8 100644
--- a/net/websocket/frames.lua
+++ b/net/websocket/frames.lua
@@ -9,8 +9,7 @@
local softreq = require "util.dependencies".softreq;
local random_bytes = require "util.random".bytes;
-local bit = assert(softreq"bit" or softreq"bit32",
- "No bit module found. See https://prosody.im/doc/depends#bitop");
+local bit = require "util.bitcompat";
local band = bit.band;
local bor = bit.bor;
local lshift = bit.lshift;
@@ -19,8 +18,8 @@ local sbit = require "util.strbitop";
local sxor = sbit.sxor;
local s_char= string.char;
-local s_pack = string.pack; -- luacheck: ignore 143
-local s_unpack = string.unpack; -- luacheck: ignore 143
+local s_pack = string.pack;
+local s_unpack = string.unpack;
if not s_pack and softreq"struct" then
s_pack = softreq"struct".pack;