aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/adns.lua42
-rw-r--r--net/connect.lua8
-rw-r--r--net/connlisteners.lua18
-rw-r--r--net/dns.lua72
-rw-r--r--net/http.lua26
-rw-r--r--net/http/codes.lua2
-rw-r--r--net/http/files.lua149
-rw-r--r--net/http/parser.lua4
-rw-r--r--net/http/server.lua104
-rw-r--r--net/resolvers/basic.lua36
-rw-r--r--net/resolvers/manual.lua1
-rw-r--r--net/resolvers/service.lua29
-rw-r--r--net/server.lua6
-rw-r--r--net/server_epoll.lua367
-rw-r--r--net/server_event.lua68
-rw-r--r--net/server_select.lua155
-rw-r--r--net/unbound.lua208
-rw-r--r--net/websocket.lua7
-rw-r--r--net/websocket/frames.lua8
19 files changed, 1021 insertions, 289 deletions
diff --git a/net/adns.lua b/net/adns.lua
index 560e4b53..8fe05653 100644
--- a/net/adns.lua
+++ b/net/adns.lua
@@ -8,13 +8,17 @@
local server = require "net.server";
local new_resolver = require "net.dns".resolver;
+local promise = require "util.promise";
local log = require "util.logger".init("adns");
-local coroutine, tostring, pcall = coroutine, tostring, pcall;
+log("debug", "Using legacy DNS API (missing lua-unbound?)"); -- TODO write docs about luaunbound
+-- TODO Raise log level once packages are available
+
+local coroutine, pcall = coroutine, pcall;
local setmetatable = setmetatable;
-local function dummy_send(sock, data, i, j) return (j-i)+1; end
+local function dummy_send(sock, data, i, j) return (j-i)+1; end -- luacheck: ignore 212
local _ENV = nil;
-- luacheck: std none
@@ -29,8 +33,7 @@ local function new_async_socket(sock, resolver)
local peername = "<unknown>";
local listener = {};
local handler = {};
- local err;
- function listener.onincoming(conn, data)
+ function listener.onincoming(conn, data) -- luacheck: ignore 212/conn
if data then
resolver:feed(handler, data);
end
@@ -46,9 +49,12 @@ local function new_async_socket(sock, resolver)
resolver:servfail(conn); -- Let the magic commence
end
end
- handler, err = server.wrapclient(sock, "dns", 53, listener);
- if not handler then
- return nil, err;
+ do
+ local err;
+ handler, err = server.wrapclient(sock, "dns", 53, listener);
+ if not handler then
+ return nil, err;
+ end
end
handler.settimeout = function () end
@@ -71,11 +77,11 @@ function async_resolver_methods:lookup(handler, qname, qtype, qclass)
handler(peek);
return;
end
- log("debug", "Records for %s not in cache, sending query (%s)...", qname, tostring(coroutine.running()));
+ log("debug", "Records for %s not in cache, sending query (%s)...", qname, coroutine.running());
local ok, err = resolver:query(qname, qtype, qclass);
if ok then
coroutine.yield(setmetatable({ resolver, qclass or "IN", qtype or "A", qname, coroutine.running()}, query_mt)); -- Wait for reply
- log("debug", "Reply for %s (%s)", qname, tostring(coroutine.running()));
+ log("debug", "Reply for %s (%s)", qname, coroutine.running());
end
if ok then
ok, err = pcall(handler, resolver:peek(qname, qtype, qclass));
@@ -84,13 +90,25 @@ function async_resolver_methods:lookup(handler, qname, qtype, qclass)
ok, err = pcall(handler, nil, err);
end
if not ok then
- log("error", "Error in DNS response handler: %s", tostring(err));
+ log("error", "Error in DNS response handler: %s", err);
end
end)(resolver:peek(qname, qtype, qclass));
end
-function query_methods:cancel(call_handler, reason)
- log("warn", "Cancelling DNS lookup for %s", tostring(self[4]));
+function async_resolver_methods:lookup_promise(qname, qtype, qclass)
+ return promise.new(function (resolve, reject)
+ local function handler(answer)
+ if not answer then
+ return reject();
+ end
+ resolve(answer);
+ end
+ self:lookup(handler, qname, qtype, qclass);
+ end);
+end
+
+function query_methods:cancel(call_handler, reason) -- luacheck: ignore 212/reason
+ log("warn", "Cancelling DNS lookup for %s", self[4]);
self[1].cancel(self[2], self[3], self[4], self[5], call_handler);
end
diff --git a/net/connect.lua b/net/connect.lua
index b812ffcd..d52d3901 100644
--- a/net/connect.lua
+++ b/net/connect.lua
@@ -2,6 +2,12 @@ local server = require "net.server";
local log = require "util.logger".init("net.connect");
local new_id = require "util.id".short;
+-- TODO #1246 Happy Eyeballs
+-- FIXME RFC 6724
+-- FIXME Error propagation from resolvers doesn't work
+-- FIXME #1428 Reuse DNS resolver object between service and basic resolver
+-- FIXME #1429 Close DNS resolver object when done
+
local pending_connection_methods = {};
local pending_connection_mt = {
__name = "pending_connection";
@@ -38,7 +44,7 @@ local function attempt_connection(p)
p:log("debug", "Next target to try is %s:%d", ip, port);
local conn, err = server.addclient(ip, port, pending_connection_listeners, p.options.pattern or "*a", p.options.sslctx, conn_type, extra);
if not conn then
- log("debug", "Connection attempt failed immediately: %s", tostring(err));
+ log("debug", "Connection attempt failed immediately: %s", err);
p.last_error = err or "unknown reason";
return attempt_connection(p);
end
diff --git a/net/connlisteners.lua b/net/connlisteners.lua
deleted file mode 100644
index 9b8f88c3..00000000
--- a/net/connlisteners.lua
+++ /dev/null
@@ -1,18 +0,0 @@
--- COMPAT w/pre-0.9
-local log = require "util.logger".init("net.connlisteners");
-local traceback = debug.traceback;
-
-local _ENV = nil;
--- luacheck: std none
-
-local function fail()
- log("error", "Attempt to use legacy connlisteners API. For more info see https://prosody.im/doc/developers/network");
- log("error", "Legacy connlisteners API usage, %s", traceback("", 2));
-end
-
-return {
- register = fail;
- get = fail;
- start = fail;
- -- epic fail
-};
diff --git a/net/dns.lua b/net/dns.lua
index 3902f95c..17119152 100644
--- a/net/dns.lua
+++ b/net/dns.lua
@@ -13,10 +13,12 @@
local socket = require "socket";
-local timer = require "util.timer";
+local have_timer, timer = pcall(require, "util.timer");
local new_ip = require "util.ip".new_ip;
local have_util_net, util_net = pcall(require, "util.net");
+local log = require "util.logger".init("dns");
+
local _, windows = pcall(require, "util.windows");
local is_windows = (_ and windows) or os.getenv("WINDIR");
@@ -69,7 +71,9 @@ local ztact = { -- public domain 20080404 lua@ztact.com
};
local get, set = ztact.get, ztact.set;
-local default_timeout = 15;
+local default_timeout = 5;
+local default_jitter = 1;
+local default_retry_jitter = 2;
-------------------------------------------------- module dns
local _ENV = nil;
@@ -664,8 +668,10 @@ end
-- socket layer -------------------------------------------------- socket layer
-resolver.delays = { 1, 3 };
+resolver.delays = { 1, 2, 3, 5 };
+resolver.jitter = have_timer and default_jitter or nil;
+resolver.retry_jitter = have_timer and default_retry_jitter or nil;
function resolver:addnameserver(address) -- - - - - - - - - - addnameserver
self.server = self.server or {};
@@ -853,7 +859,10 @@ function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query
packet = header..question,
server = self.best_server,
delay = 1,
- retry = socket.gettime() + self.delays[1]
+ retry = socket.gettime() + self.delays[1];
+ qclass = qclass;
+ qtype = qtype;
+ qname = qname;
};
-- remember the query
@@ -864,30 +873,32 @@ function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query
if not conn then
return nil, err;
end
- conn:send (o.packet)
+ if self.jitter then
+ timer.add_task(math.random()*self.jitter, function ()
+ conn:send(o.packet);
+ end);
+ else
+ conn:send(o.packet);
+ end
-- remember which coroutine wants the answer
if co then
set(self.wanted, qclass, qtype, qname, co, true);
end
- if timer and self.timeout then
+ if have_timer and self.timeout then
local num_servers = #self.server;
local i = 1;
timer.add_task(self.timeout, function ()
if get(self.wanted, qclass, qtype, qname, co) then
- if i < num_servers then
+ log("debug", "DNS request timeout %d/%d", i, num_servers)
i = i + 1;
- self:servfail(conn);
- o.server = self.best_server;
- conn, err = self:getsocket(o.server);
- if conn then
- conn:send(o.packet);
- return self.timeout;
- end
- end
- -- Tried everything, failed
- self:cancel(qclass, qtype, qname);
+ self:servfail(self.socket[o.server]);
+-- end
+ end
+ -- Still outstanding? (i.e. retried)
+ if get(self.wanted, qclass, qtype, qname, co) then
+ return self.timeout; -- Then wait
end
end)
end
@@ -904,6 +915,7 @@ function resolver:servfail(sock, err)
-- Find all requests to the down server, and retry on the next server
self.time = socket.gettime();
+ log("debug", "servfail %d (of %d)", num, #self.server);
for id,queries in pairs(self.active) do
for question,o in pairs(queries) do
if o.server == num then -- This request was to the broken server
@@ -913,12 +925,27 @@ function resolver:servfail(sock, err)
end
o.retries = (o.retries or 0) + 1;
- if o.retries >= #self.server then
- --print('timeout');
- queries[question] = nil;
- else
+ local retried;
+ if o.retries < #self.server then
sock, err = self:getsocket(o.server);
- if sock then sock:send(o.packet); end
+ if sock then
+ retried = true;
+ if self.retry_jitter then
+ local delay = self.delays[((o.retries-1)%#self.delays)+1] + (math.random()*self.retry_jitter);
+ log("debug", "retry %d in %0.2fs", o.retries, delay);
+ timer.add_task(delay, function ()
+ sock:send(o.packet);
+ end);
+ else
+ log("debug", "retry %d (immediate)", o.retries);
+ sock:send(o.packet);
+ end
+ end
+ end
+ if not retried then
+ log("debug", 'tried all servers, giving up');
+ self:cancel(o.qclass, o.qtype, o.qname);
+ queries[question] = nil;
end
end
end
@@ -1164,6 +1191,7 @@ end
local _resolver = dns.resolver();
dns._resolver = _resolver;
+_resolver.jitter, _resolver.retry_jitter = false, false;
function dns.lookup(...) -- - - - - - - - - - - - - - - - - - - - - lookup
return _resolver:lookup(...);
diff --git a/net/http.lua b/net/http.lua
index 0adac26c..17c755cf 100644
--- a/net/http.lua
+++ b/net/http.lua
@@ -12,6 +12,8 @@ local httpstream_new = require "net.http.parser".new;
local util_http = require "util.http";
local events = require "util.events";
local verify_identity = require"util.x509".verify_identity;
+local promise = require "util.promise";
+local errors = require "util.error";
local basic_resolver = require "net.resolvers.basic";
local connect = require "net.connect".connect;
@@ -40,7 +42,7 @@ local listener = { default_port = 80, default_mode = "*a" };
local function handleerr(err) log("error", "Traceback[http]: %s", traceback(tostring(err), 2)); return err; end
local function log_if_failed(req, ret, ...)
if not ret then
- log("error", "Request '%s': error in callback: %s", req.id, tostring((...)));
+ log("error", "Request '%s': error in callback: %s", req.id, (...));
if not req.suppress_errors then
error(...);
end
@@ -150,7 +152,7 @@ function listener.onincoming(conn, data)
local request = requests[conn];
if not request then
- log("warn", "Received response from connection %s with no request attached!", tostring(conn));
+ log("warn", "Received response from connection %s with no request attached!", conn);
return;
end
@@ -261,7 +263,7 @@ local function request(self, u, ex, callback)
sslctx = ex and ex.sslctx or self.options and self.options.sslctx;
end
- local http_service = basic_resolver.new(host, port_number);
+ local http_service = basic_resolver.new(host, port_number, "tcp", { servername = req.host });
connect(http_service, listener, { sslctx = sslctx }, req);
self.events.fire_event("request", { http = self, request = req, url = u });
@@ -271,7 +273,21 @@ end
local function new(options)
local http = {
options = options;
- request = request;
+ request = function (self, u, ex, callback)
+ if callback ~= nil then
+ return request(self, u, ex, callback);
+ else
+ return promise.new(function (resolve, reject)
+ request(self, u, ex, function (body, code, a, b)
+ if code == 0 then
+ reject(errors.new(body, { request = a }));
+ else
+ resolve({ request = b, response = a });
+ end
+ end);
+ end);
+ end
+ end;
new = options and function (new_options)
local final_options = {};
for k, v in pairs(options) do final_options[k] = v; end
@@ -286,7 +302,7 @@ local function new(options)
end
local default_http = new({
- sslctx = { mode = "client", protocol = "sslv23", options = { "no_sslv2", "no_sslv3" } };
+ sslctx = { mode = "client", protocol = "sslv23", options = { "no_sslv2", "no_sslv3" }, alpn = "http/1.1" };
suppress_errors = true;
});
diff --git a/net/http/codes.lua b/net/http/codes.lua
index 8098b5c3..4327f151 100644
--- a/net/http/codes.lua
+++ b/net/http/codes.lua
@@ -82,5 +82,5 @@ local response_codes = {
-- [512-599] = "Unassigned";
};
-for k,v in pairs(response_codes) do response_codes[k] = k.." "..v; end
+for k,v in pairs(response_codes) do response_codes[k] = ("%03d %s"):format(k, v); end
return setmetatable(response_codes, { __index = function(_, k) return k.." Unassigned"; end })
diff --git a/net/http/files.lua b/net/http/files.lua
new file mode 100644
index 00000000..583f7514
--- /dev/null
+++ b/net/http/files.lua
@@ -0,0 +1,149 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local server = require"net.http.server";
+local lfs = require "lfs";
+local new_cache = require "util.cache".new;
+local log = require "util.logger".init("net.http.files");
+
+local os_date = os.date;
+local open = io.open;
+local stat = lfs.attributes;
+local build_path = require"socket.url".build_path;
+local path_sep = package.config:sub(1,1);
+
+
+local forbidden_chars_pattern = "[/%z]";
+if package.config:sub(1,1) == "\\" then
+ forbidden_chars_pattern = "[/%z\001-\031\127\"*:<>?|]"
+end
+
+local urldecode = require "util.http".urldecode;
+local function sanitize_path(path) --> util.paths or util.http?
+ if not path then return end
+ local out = {};
+
+ local c = 0;
+ for component in path:gmatch("([^/]+)") do
+ component = urldecode(component);
+ if component:find(forbidden_chars_pattern) then
+ return nil;
+ elseif component == ".." then
+ if c <= 0 then
+ return nil;
+ end
+ out[c] = nil;
+ c = c - 1;
+ elseif component ~= "." then
+ c = c + 1;
+ out[c] = component;
+ end
+ end
+ if path:sub(-1,-1) == "/" then
+ out[c+1] = "";
+ end
+ return "/"..table.concat(out, "/");
+end
+
+local function serve(opts)
+ if type(opts) ~= "table" then -- assume path string
+ opts = { path = opts };
+ end
+ local mime_map = opts.mime_map or { html = "text/html" };
+ local cache = new_cache(opts.cache_size or 256);
+ local cache_max_file_size = tonumber(opts.cache_max_file_size) or 1024
+ -- luacheck: ignore 431
+ local base_path = opts.path;
+ local dir_indices = opts.index_files or { "index.html", "index.htm" };
+ local directory_index = opts.directory_index;
+ local function serve_file(event, path)
+ local request, response = event.request, event.response;
+ local sanitized_path = sanitize_path(path);
+ if path and not sanitized_path then
+ return 400;
+ end
+ path = sanitized_path;
+ local orig_path = sanitize_path(request.path);
+ local full_path = base_path .. (path or ""):gsub("/", path_sep);
+ local attr = stat(full_path:match("^.*[^\\/]")); -- Strip trailing path separator because Windows
+ if not attr then
+ return 404;
+ end
+
+ local request_headers, response_headers = request.headers, response.headers;
+
+ local last_modified = os_date('!%a, %d %b %Y %H:%M:%S GMT', attr.modification);
+ response_headers.last_modified = last_modified;
+
+ local etag = ('"%x-%x-%x"'):format(attr.change or 0, attr.size or 0, attr.modification or 0);
+ response_headers.etag = etag;
+
+ local if_none_match = request_headers.if_none_match
+ local if_modified_since = request_headers.if_modified_since;
+ if etag == if_none_match
+ or (not if_none_match and last_modified == if_modified_since) then
+ return 304;
+ end
+
+ local data = cache:get(orig_path);
+ if data and data.etag == etag then
+ response_headers.content_type = data.content_type;
+ data = data.data;
+ cache:set(orig_path, data);
+ elseif attr.mode == "directory" and path then
+ if full_path:sub(-1) ~= "/" then
+ local dir_path = { is_absolute = true, is_directory = true };
+ for dir in orig_path:gmatch("[^/]+") do dir_path[#dir_path+1]=dir; end
+ response_headers.location = build_path(dir_path);
+ return 301;
+ end
+ for i=1,#dir_indices do
+ if stat(full_path..dir_indices[i], "mode") == "file" then
+ return serve_file(event, path..dir_indices[i]);
+ end
+ end
+
+ if directory_index then
+ data = server._events.fire_event("directory-index", { path = request.path, full_path = full_path });
+ end
+ if not data then
+ return 403;
+ end
+ cache:set(orig_path, { data = data, content_type = mime_map.html; etag = etag; });
+ response_headers.content_type = mime_map.html;
+
+ else
+ local f, err = open(full_path, "rb");
+ if not f then
+ log("debug", "Could not open %s. Error was %s", full_path, err);
+ return 403;
+ end
+ local ext = full_path:match("%.([^./]+)$");
+ local content_type = ext and mime_map[ext];
+ response_headers.content_type = content_type;
+ if attr.size > cache_max_file_size then
+ response_headers.content_length = ("%d"):format(attr.size);
+ log("debug", "%d > cache_max_file_size", attr.size);
+ return response:send_file(f);
+ else
+ data = f:read("*a");
+ f:close();
+ end
+ cache:set(orig_path, { data = data; content_type = content_type; etag = etag });
+ end
+
+ return response:send(data);
+ end
+
+ return serve_file;
+end
+
+return {
+ serve = serve;
+}
+
diff --git a/net/http/parser.lua b/net/http/parser.lua
index 4e4ae9fb..b328bdfe 100644
--- a/net/http/parser.lua
+++ b/net/http/parser.lua
@@ -63,7 +63,8 @@ function httpstream.new(success_cb, error_cb, parser_type, options_cb)
if buftable then buf, buftable = t_concat(buf), false; end
local index = buf:find("\r\n\r\n", nil, true);
if not index then return; end -- not enough data
- local method, path, httpversion, status_code, reason_phrase;
+ -- FIXME was reason_phrase meant to be passed on somewhere?
+ local method, path, httpversion, status_code, reason_phrase; -- luacheck: ignore reason_phrase
local first_line;
local headers = {};
for line in buf:sub(1,index+1):gmatch("([^\r\n]+)\r\n") do -- parse request
@@ -92,6 +93,7 @@ function httpstream.new(success_cb, error_cb, parser_type, options_cb)
chunked = have_body and headers["transfer-encoding"] == "chunked";
len = tonumber(headers["content-length"]); -- TODO check for invalid len
if len and len > bodylimit then error = true; return error_cb("content-length-limit-exceeded"); end
+ -- TODO ask a callback whether to proceed in case of large requests or Expect: 100-continue
if client then
-- FIXME handle '100 Continue' response (by skipping it)
if not have_body then len = 0; end
diff --git a/net/http/server.lua b/net/http/server.lua
index 18704962..f563d330 100644
--- a/net/http/server.lua
+++ b/net/http/server.lua
@@ -13,6 +13,8 @@ local traceback = debug.traceback;
local tostring = tostring;
local cache = require "util.cache";
local codes = require "net.http.codes";
+local promise = require "util.promise";
+local errors = require "util.error";
local blocksize = 2^16;
local _M = {};
@@ -170,6 +172,48 @@ local headerfix = setmetatable({}, {
end
});
+local function handle_result(request, response, result)
+ if result == nil then
+ result = 404;
+ end
+
+ if result == true then
+ return;
+ end
+
+ local body;
+ local result_type = type(result);
+ if result_type == "number" then
+ response.status_code = result;
+ if result >= 400 then
+ body = events.fire_event("http-error", { request = request, response = response, code = result });
+ end
+ elseif result_type == "string" then
+ body = result;
+ elseif errors.is_err(result) then
+ response.status_code = result.code or 500;
+ body = events.fire_event("http-error", { request = request, response = response, code = result.code or 500, error = result });
+ elseif promise.is_promise(result) then
+ result:next(function (ret)
+ handle_result(request, response, ret);
+ end, function (err)
+ handle_result(request, response, err or 500);
+ end);
+ return true;
+ elseif result_type == "table" then
+ for k, v in pairs(result) do
+ if k ~= "headers" then
+ response[k] = v;
+ else
+ for header_name, header_value in pairs(v) do
+ response.headers[header_name] = header_value;
+ end
+ end
+ end
+ end
+ return response:send(body);
+end
+
function _M.hijack_response(response, listener) -- luacheck: ignore
error("TODO");
end
@@ -194,8 +238,11 @@ function handle_request(conn, request, finish_cb)
response_conn_header = httpversion == "1.1" and "close" or nil
end
+ local is_head_request = request.method == "HEAD";
+
local response = {
request = request;
+ is_head_request = is_head_request;
status_code = 200;
headers = { date = date_header, connection = response_conn_header };
persistent = persistent;
@@ -227,6 +274,11 @@ function handle_request(conn, request, finish_cb)
local payload = { request = request, response = response };
log("debug", "Firing event: %s", global_event);
local result = events.fire_event(global_event, payload);
+ if result == nil and is_head_request then
+ local global_head_event = "GET "..request.path:match("[^?]*");
+ log("debug", "Firing event: %s", global_head_event);
+ result = events.fire_event(global_head_event, payload);
+ end
if result == nil then
if not hosts[host] then
if hosts[default_host] then
@@ -247,40 +299,17 @@ function handle_request(conn, request, finish_cb)
local host_event = request.method.." "..host..request.path:match("[^?]*");
log("debug", "Firing event: %s", host_event);
result = events.fire_event(host_event, payload);
- end
- if result ~= nil then
- if result ~= true then
- local body;
- local result_type = type(result);
- if result_type == "number" then
- response.status_code = result;
- if result >= 400 then
- payload.code = result;
- body = events.fire_event("http-error", payload);
- end
- elseif result_type == "string" then
- body = result;
- elseif result_type == "table" then
- for k, v in pairs(result) do
- if k ~= "headers" then
- response[k] = v;
- else
- for header_name, header_value in pairs(v) do
- response.headers[header_name] = header_value;
- end
- end
- end
- end
- response:send(body);
+
+ if result == nil and is_head_request then
+ local host_head_event = "GET "..host..request.path:match("[^?]*");
+ log("debug", "Firing event: %s", host_head_event);
+ result = events.fire_event(host_head_event, payload);
end
- return;
end
- -- if handler not called, return 404
- response.status_code = 404;
- payload.code = 404;
- response:send(events.fire_event("http-error", payload));
+ return handle_result(request, response, result);
end
+
local function prepare_header(response)
local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]);
local headers = response.headers;
@@ -292,16 +321,29 @@ local function prepare_header(response)
return output;
end
_M.prepare_header = prepare_header;
+function _M.send_head_response(response)
+ if response.finished then return; end
+ local output = prepare_header(response);
+ response.conn:write(t_concat(output));
+ response:done();
+end
function _M.send_response(response, body)
if response.finished then return; end
body = body or response.body or "";
- response.headers.content_length = #body;
+ response.headers.content_length = ("%d"):format(#body);
+ if response.is_head_request then
+ return _M.send_head_response(response)
+ end
local output = prepare_header(response);
t_insert(output, body);
response.conn:write(t_concat(output));
response:done();
end
function _M.send_file(response, f)
+ if response.is_head_request then
+ if f.close then f:close(); end
+ return _M.send_head_response(response);
+ end
if response.finished then return; end
local chunked = not response.headers.content_length;
if chunked then response.headers.transfer_encoding = "chunked"; end
diff --git a/net/resolvers/basic.lua b/net/resolvers/basic.lua
index 08c71ef5..6458c638 100644
--- a/net/resolvers/basic.lua
+++ b/net/resolvers/basic.lua
@@ -2,10 +2,13 @@ local adns = require "net.adns";
local inet_pton = require "util.net".pton;
local inet_ntop = require "util.net".ntop;
local idna_to_ascii = require "util.encodings".idna.to_ascii;
+local unpack = table.unpack or unpack; -- luacheck: ignore 113
local methods = {};
local resolver_mt = { __index = methods };
+-- FIXME RFC 6724
+
-- Find the next target to connect to, and
-- pass it to cb()
function methods:next(cb)
@@ -36,23 +39,32 @@ function methods:next(cb)
-- Resolve DNS to target list
local dns_resolver = adns.resolver();
- dns_resolver:lookup(function (answer)
- if answer then
- for _, record in ipairs(answer) do
- table.insert(targets, { self.conn_type.."4", record.a, self.port, self.extra });
+
+ if not self.extra or self.extra.use_ipv4 ~= false then
+ dns_resolver:lookup(function (answer)
+ if answer then
+ for _, record in ipairs(answer) do
+ table.insert(targets, { self.conn_type.."4", record.a, self.port, self.extra });
+ end
end
- end
+ ready();
+ end, self.hostname, "A", "IN");
+ else
ready();
- end, self.hostname, "A", "IN");
+ end
- dns_resolver:lookup(function (answer)
- if answer then
- for _, record in ipairs(answer) do
- table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra });
+ if not self.extra or self.extra.use_ipv6 ~= false then
+ dns_resolver:lookup(function (answer)
+ if answer then
+ for _, record in ipairs(answer) do
+ table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra });
+ end
end
- end
+ ready();
+ end, self.hostname, "AAAA", "IN");
+ else
ready();
- end, self.hostname, "AAAA", "IN");
+ end
end
local function new(hostname, port, conn_type, extra)
diff --git a/net/resolvers/manual.lua b/net/resolvers/manual.lua
index c0d4e5d5..dbc40256 100644
--- a/net/resolvers/manual.lua
+++ b/net/resolvers/manual.lua
@@ -1,5 +1,6 @@
local methods = {};
local resolver_mt = { __index = methods };
+local unpack = table.unpack or unpack; -- luacheck: ignore 113
-- Find the next target to connect to, and
-- pass it to cb()
diff --git a/net/resolvers/service.lua b/net/resolvers/service.lua
index 34f14cba..d74adf06 100644
--- a/net/resolvers/service.lua
+++ b/net/resolvers/service.lua
@@ -1,6 +1,8 @@
local adns = require "net.adns";
local basic = require "net.resolvers.basic";
+local inet_pton = require "util.net".pton;
local idna_to_ascii = require "util.encodings".idna.to_ascii;
+local unpack = table.unpack or unpack; -- luacheck: ignore 113
local methods = {};
local resolver_mt = { __index = methods };
@@ -9,14 +11,17 @@ local resolver_mt = { __index = methods };
-- pass it to cb()
function methods:next(cb)
if self.targets then
- if #self.targets == 0 then
- cb(nil);
- return;
+ if not self.resolver then
+ if #self.targets == 0 then
+ cb(nil);
+ return;
+ end
+ local next_target = table.remove(self.targets, 1);
+ self.resolver = basic.new(unpack(next_target, 1, 4));
end
- local next_target = table.remove(self.targets, 1);
- self.resolver = basic.new(unpack(next_target, 1, 4));
self.resolver:next(function (...)
if ... == nil then
+ self.resolver = nil;
self:next(cb);
else
cb(...);
@@ -39,7 +44,11 @@ function methods:next(cb)
-- Resolve DNS to target list
local dns_resolver = adns.resolver();
- dns_resolver:lookup(function (answer)
+ dns_resolver:lookup(function (answer, err)
+ if not answer and not err then
+ -- net.adns returns nil if there are zero records or nxdomain
+ answer = {};
+ end
if answer then
if #answer == 0 then
if self.extra and self.extra.default_port then
@@ -64,6 +73,14 @@ function methods:next(cb)
end
local function new(hostname, service, conn_type, extra)
+ local is_ip = inet_pton(hostname);
+ if not is_ip and hostname:sub(1,1) == '[' then
+ is_ip = inet_pton(hostname:sub(2,-2));
+ end
+ if is_ip and extra and extra.default_port then
+ return basic.new(hostname, extra.default_port, conn_type, extra);
+ end
+
return setmetatable({
hostname = idna_to_ascii(hostname);
service = service;
diff --git a/net/server.lua b/net/server.lua
index abbb421d..f5666594 100644
--- a/net/server.lua
+++ b/net/server.lua
@@ -13,7 +13,11 @@ if not (prosody and prosody.config_loaded) then
end
local log = require "util.logger".init("net.server");
-local server_type = require "core.configmanager".get("*", "network_backend") or "select";
+
+local have_util_poll = pcall(require, "util.poll");
+local default_backend = have_util_poll and "epoll" or "select";
+
+local server_type = require "core.configmanager".get("*", "network_backend") or default_backend;
if require "core.configmanager".get("*", "use_libevent") then
server_type = "event";
diff --git a/net/server_epoll.lua b/net/server_epoll.lua
index 2182d56a..f767db78 100644
--- a/net/server_epoll.lua
+++ b/net/server_epoll.lua
@@ -9,20 +9,24 @@
local t_insert = table.insert;
local t_concat = table.concat;
local setmetatable = setmetatable;
-local tostring = tostring;
local pcall = pcall;
local type = type;
local next = next;
local pairs = pairs;
-local log = require "util.logger".init("server_epoll");
+local traceback = debug.traceback;
+local logger = require "util.logger";
+local log = logger.init("server_epoll");
local socket = require "socket";
local luasec = require "ssl";
-local gettime = require "util.time".now;
+local realtime = require "util.time".now;
+local monotonic = require "util.time".monotonic;
local indexedbheap = require "util.indexedbheap";
local createtable = require "util.table".create;
local inet = require "util.net";
local inet_pton = inet.pton;
local _SOCKETINVALID = socket._SOCKETINVALID or -1;
+local new_id = require "util.id".medium;
+local xpcall = require "util.xpcall".xpcall;
local poller = require "util.poll"
local EEXIST = poller.EEXIST;
@@ -38,7 +42,10 @@ local default_config = { __index = {
read_timeout = 14 * 60;
-- How long to wait for a socket to become writable after queuing data to send
- send_timeout = 60;
+ send_timeout = 180;
+
+ -- How long to wait for a socket to become writable after creation
+ connect_timeout = 20;
-- Some number possibly influencing how many pending connections can be accepted
tcp_backlog = 128;
@@ -58,6 +65,20 @@ local default_config = { __index = {
-- Maximum and minimum amount of time to sleep waiting for events (adjusted for pending timers)
max_wait = 86400;
min_wait = 1e-06;
+
+ -- Enable extra noisy debug logging
+ -- TODO disable once considered stable
+ verbose = true;
+
+ -- EXPERIMENTAL
+ -- Whether to kill connections in case of callback errors.
+ fatal_errors = false;
+
+ -- Or disable protection (like server_select) for potential performance gains
+ protect_listeners = true;
+
+ -- Attempt writes instantly
+ opportunistic_writes = false;
}};
local cfg = default_config.__index;
@@ -68,48 +89,50 @@ local fds = createtable(10, 0); -- FD -> conn
local timers = indexedbheap.create();
local function noop() end
-local function closetimer(t)
- t[1] = 0;
- t[2] = noop;
- timers:remove(t.id);
+local function closetimer(id)
+ timers:remove(id);
end
-local function reschedule(t, time)
- t[1] = time;
- timers:reprioritize(t.id, time);
-end
-
--- Add absolute timer
-local function at(time, f)
- local timer = { time, f, close = closetimer, reschedule = reschedule, id = nil };
- timer.id = timers:insert(timer, time);
- return timer;
+local function reschedule(id, time)
+ time = monotonic() + time;
+ timers:reprioritize(id, time);
end
-- Add relative timer
-local function addtimer(timeout, f)
- return at(gettime() + timeout, f);
+local function addtimer(timeout, f, param)
+ local time = monotonic() + timeout;
+ if param ~= nil then
+ local timer_callback = f
+ function f(current_time, timer_id)
+ local t = timer_callback(current_time, timer_id, param)
+ return t;
+ end
+ end
+ local id = timers:insert(f, time);
+ return id;
end
-- Run callbacks of expired timers
-- Return time until next timeout
local function runtimers(next_delay, min_wait)
-- Any timers at all?
- local now = gettime();
+ local elapsed = monotonic();
+ local now = realtime();
local peek = timers:peek();
while peek do
- if peek > now then
- next_delay = peek - now;
+ if peek > elapsed then
+ next_delay = peek - elapsed;
break;
end
local _, timer, id = timers:pop();
- local ok, ret = pcall(timer[2], now);
+ local ok, ret = xpcall(timer, traceback, now, id);
if ok and type(ret) == "number" then
- local next_time = now+ret;
- timer[1] = next_time;
+ local next_time = elapsed+ret;
timers:insert(timer, next_time);
+ elseif not ok then
+ log("error", "Error in timer: %s", ret);
end
peek = timers:peek();
@@ -138,6 +161,22 @@ function interface_mt:__tostring()
return ("FD %d"):format(self:getfd());
end
+interface.log = log;
+function interface:debug(msg, ...) --luacheck: ignore 212/self
+ self.log("debug", msg, ...);
+end
+
+interface.noise = interface.debug;
+function interface:noise(msg, ...) --luacheck: ignore 212/self
+ if cfg.verbose then
+ return self:debug(msg, ...);
+ end
+end
+
+function interface:error(msg, ...) --luacheck: ignore 212/self
+ self.log("error", msg, ...);
+end
+
-- Replace the listener and tell the old one
function interface:setlistener(listeners, data)
self:on("detach");
@@ -148,21 +187,36 @@ end
-- Call a listener callback
function interface:on(what, ...)
if not self.listeners then
- log("error", "%s has no listeners", self);
+ self:error("Interface is missing listener callbacks");
return;
end
local listener = self.listeners["on"..what];
if not listener then
- -- log("debug", "Missing listener 'on%s'", what); -- uncomment for development and debugging
+ self:noise("Missing listener 'on%s'", what); -- uncomment for development and debugging
return;
end
- local ok, err = pcall(listener, self, ...);
+ if not cfg.protect_listeners then
+ return listener(self, ...);
+ end
+ local onerror = self.listeners.onerror or traceback;
+ local ok, err = xpcall(listener, onerror, self, ...);
if not ok then
- log("error", "Error calling on%s: %s", what, err);
+ if cfg.fatal_errors then
+ self:error("Closing due to error calling on%s: %s", what, err);
+ self:destroy();
+ else
+ self:debug("Error calling on%s: %s", what, err);
+ end
+ return nil, err;
end
return err;
end
+-- Allow this one to be overridden
+function interface:onincoming(...)
+ return self:on("incoming", ...);
+end
+
-- Return the file descriptor number
function interface:getfd()
if self.conn then
@@ -219,19 +273,21 @@ end
function interface:setreadtimeout(t)
if t == false then
if self._readtimeout then
- self._readtimeout:close();
+ closetimer(self._readtimeout);
self._readtimeout = nil;
end
return
end
t = t or cfg.read_timeout;
if self._readtimeout then
- self._readtimeout:reschedule(gettime() + t);
+ reschedule(self._readtimeout, t);
else
self._readtimeout = addtimer(t, function ()
if self:on("readtimeout") then
+ self:noise("Read timeout handled");
return cfg.read_timeout;
else
+ self:debug("Read timeout not handled, disconnecting");
self:on("disconnect", "read timeout");
self:destroy();
end
@@ -243,17 +299,18 @@ end
function interface:setwritetimeout(t)
if t == false then
if self._writetimeout then
- self._writetimeout:close();
+ closetimer(self._writetimeout);
self._writetimeout = nil;
end
return
end
t = t or cfg.send_timeout;
if self._writetimeout then
- self._writetimeout:reschedule(gettime() + t);
+ reschedule(self._writetimeout, t);
else
self._writetimeout = addtimer(t, function ()
- self:on("disconnect", "write timeout");
+ self:noise("Write timeout");
+ self:on("disconnect", self._connected and "write timeout" or "connection timeout");
self:destroy();
end);
end
@@ -269,15 +326,15 @@ function interface:add(r, w)
local ok, err, errno = poll:add(fd, r, w);
if not ok then
if errno == EEXIST then
- log("debug", "%s already registered!", self);
+ self:debug("FD already registered in poller! (EEXIST)");
return self:set(r, w); -- So try to change its flags
end
- log("error", "Could not register %s: %s(%d)", self, err, errno);
+ self:debug("Could not register in poller: %s(%d)", err, errno);
return ok, err;
end
self._wantread, self._wantwrite = r, w;
fds[fd] = self;
- log("debug", "Watching %s", self);
+ self:noise("Registered in poller");
return true;
end
@@ -290,7 +347,7 @@ function interface:set(r, w)
if w == nil then w = self._wantwrite; end
local ok, err, errno = poll:set(fd, r, w);
if not ok then
- log("error", "Could not update poller state %s: %s(%d)", self, err, errno);
+ self:debug("Could not update poller state: %s(%d)", err, errno);
return ok, err;
end
self._wantread, self._wantwrite = r, w;
@@ -307,12 +364,12 @@ function interface:del()
end
local ok, err, errno = poll:del(fd);
if not ok and errno ~= ENOENT then
- log("error", "Could not unregister %s: %s(%d)", self, err, errno);
+ self:debug("Could not unregister: %s(%d)", err, errno);
return ok, err;
end
self._wantread, self._wantwrite = nil, nil;
fds[fd] = nil;
- log("debug", "Unwatched %s", self);
+ self:noise("Unregistered from poller");
return true;
end
@@ -334,7 +391,7 @@ function interface:onreadable()
local data, err, partial = self.conn:receive(self.read_size or cfg.read_size);
if data then
self:onconnect();
- self:on("incoming", data);
+ self:onincoming(data);
else
if err == "wantread" then
self:set(true, nil);
@@ -345,7 +402,7 @@ function interface:onreadable()
end
if partial and partial ~= "" then
self:onconnect();
- self:on("incoming", partial, err);
+ self:onincoming(partial, err);
end
if err ~= "timeout" then
self:on("disconnect", err);
@@ -354,6 +411,14 @@ function interface:onreadable()
end
end
if not self.conn then return; end
+ if self._limit and (data or partial) then
+ local cost = self._limit * #(data or partial);
+ if cost > cfg.min_wait then
+ self:setreadtimeout(false);
+ self:pausefor(cost);
+ return;
+ end
+ end
if self._wantread and self.conn:dirty() then
self:setreadtimeout(false);
self:pausefor(cfg.read_retry_delay);
@@ -367,7 +432,7 @@ function interface:onwritable()
self:onconnect();
if not self.conn then return; end -- could have been closed in onconnect
local buffer = self.writebuffer;
- local data = t_concat(buffer);
+ local data = #buffer == 1 and buffer[1] or t_concat(buffer);
local ok, err, partial = self.conn:send(data);
if ok then
self:set(nil, false);
@@ -378,10 +443,12 @@ function interface:onwritable()
self:ondrain(); -- Be aware of writes in ondrain
return;
elseif partial then
+ self:debug("Sent %d out of %d buffered bytes", partial, #data);
buffer[1] = data:sub(partial+1);
for i = #buffer, 2, -1 do
buffer[i] = nil;
end
+ self:set(nil, true);
self:setwritetimeout();
end
if err == "wantwrite" or err == "timeout" then
@@ -407,8 +474,14 @@ function interface:write(data)
else
self.writebuffer = { data };
end
- self:setwritetimeout();
- self:set(nil, true);
+ if not self._write_lock then
+ if cfg.opportunistic_writes then
+ self:onwritable();
+ return #data;
+ end
+ self:setwritetimeout();
+ self:set(nil, true);
+ end
return #data;
end
interface.send = interface.write;
@@ -418,10 +491,10 @@ function interface:close()
if self.writebuffer and self.writebuffer[1] then
self:set(false, true); -- Flush final buffer contents
self.write, self.send = noop, noop; -- No more writing
- log("debug", "Close %s after writing", self);
+ self:debug("Close after writing remaining buffered data");
self.ondrain = interface.close;
else
- log("debug", "Close %s now", self);
+ self:debug("Closing now");
self.write, self.send = noop, noop;
self.close = noop;
self:on("disconnect");
@@ -450,31 +523,32 @@ function interface:starttls(tls_ctx)
if tls_ctx then self.tls_ctx = tls_ctx; end
self.starttls = false;
if self.writebuffer and self.writebuffer[1] then
- log("debug", "Start TLS on %s after write", self);
+ self:debug("Start TLS after write");
self.ondrain = interface.starttls;
self:set(nil, true); -- make sure wantwrite is set
else
if self.ondrain == interface.starttls then
self.ondrain = nil;
end
- self.onwritable = interface.tlshandskake;
- self.onreadable = interface.tlshandskake;
+ self.onwritable = interface.tlshandshake;
+ self.onreadable = interface.tlshandshake;
self:set(true, true);
- log("debug", "Prepare to start TLS on %s", self);
+ self:debug("Prepared to start TLS");
end
end
-function interface:tlshandskake()
+function interface:tlshandshake()
self:setwritetimeout(false);
self:setreadtimeout(false);
if not self._tls then
self._tls = true;
- log("debug", "Start TLS on %s now", self);
+ self:debug("Starting TLS now");
self:del();
+ self:updatenames(); -- Can't getpeer/sockname after wrap()
local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx);
if not ok then
conn, err = ok, conn;
- log("error", "Failed to initialize TLS: %s", err);
+ self:debug("Failed to initialize TLS: %s", err);
end
if not conn then
self:on("disconnect", err);
@@ -483,48 +557,71 @@ function interface:tlshandskake()
end
conn:settimeout(0);
self.conn = conn;
+ if conn.sni then
+ if self.servername then
+ conn:sni(self.servername);
+ elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
+ conn:sni(self._server.hosts, true);
+ end
+ end
self:on("starttls");
self.ondrain = nil;
- self.onwritable = interface.tlshandskake;
- self.onreadable = interface.tlshandskake;
+ self.onwritable = interface.tlshandshake;
+ self.onreadable = interface.tlshandshake;
return self:init();
end
+ self:noise("Continuing TLS handshake");
local ok, err = self.conn:dohandshake();
if ok then
- log("debug", "TLS handshake on %s complete", self);
+ local info = self.conn.info and self.conn:info();
+ if type(info) == "table" then
+ self:debug("TLS handshake complete (%s with %s)", info.protocol, info.cipher);
+ else
+ self:debug("TLS handshake complete");
+ end
self.onwritable = nil;
self.onreadable = nil;
self:on("status", "ssl-handshake-complete");
self:setwritetimeout();
self:set(true, true);
elseif err == "wantread" then
- log("debug", "TLS handshake on %s to wait until readable", self);
+ self:noise("TLS handshake to wait until readable");
self:set(true, false);
self:setreadtimeout(cfg.ssl_handshake_timeout);
elseif err == "wantwrite" then
- log("debug", "TLS handshake on %s to wait until writable", self);
+ self:noise("TLS handshake to wait until writable");
self:set(false, true);
self:setwritetimeout(cfg.ssl_handshake_timeout);
else
- log("debug", "TLS handshake error on %s: %s", self, err);
+ self:debug("TLS handshake error: %s", err);
self:on("disconnect", err);
self:destroy();
end
end
-local function wrapsocket(client, server, read_size, listeners, tls_ctx) -- luasocket object -> interface object
+local function wrapsocket(client, server, read_size, listeners, tls_ctx, extra) -- luasocket object -> interface object
client:settimeout(0);
+ local conn_id = ("conn%s"):format(new_id());
local conn = setmetatable({
conn = client;
_server = server;
- created = gettime();
+ created = realtime();
listeners = listeners;
read_size = read_size or (server and server.read_size);
writebuffer = {};
tls_ctx = tls_ctx or (server and server.tls_ctx);
tls_direct = server and server.tls_direct;
+ id = conn_id;
+ log = logger.init(conn_id);
+ extra = extra;
}, interface_mt);
+ if extra then
+ if extra.servername then
+ conn.servername = extra.servername;
+ end
+ end
+
conn:updatenames();
return conn;
end
@@ -532,12 +629,12 @@ end
function interface:updatenames()
local conn = self.conn;
local ok, peername, peerport = pcall(conn.getpeername, conn);
- if ok then
- self.peername, self.peerport = peername, peerport;
+ if ok and peername then
+ self.peername, self.peerport = peername, peerport or 0;
end
local ok, sockname, sockport = pcall(conn.getsockname, conn);
- if ok then
- self.sockname, self.sockport = sockname, sockport;
+ if ok and sockname then
+ self.sockname, self.sockport = sockname, sockport or 0;
end
end
@@ -546,79 +643,132 @@ end
function interface:onacceptable()
local conn, err = self.conn:accept();
if not conn then
- log("debug", "Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval);
+ self:debug("Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval);
self:pausefor(cfg.accept_retry_interval);
return;
end
local client = wrapsocket(conn, self, nil, self.listeners);
- log("debug", "New connection %s", tostring(client));
+ client:debug("New connection %s on server %s", client, self);
client:init();
if self.tls_direct then
client:starttls(self.tls_ctx);
+ else
+ client:onconnect();
end
end
-- Initialization
function interface:init()
- self:setwritetimeout();
+ self:setwritetimeout(cfg.connect_timeout);
return self:add(true, true);
end
function interface:pause()
+ self:noise("Pause reading");
return self:set(false);
end
function interface:resume()
+ self:noise("Resume reading");
return self:set(true);
end
-- Pause connection for some time
function interface:pausefor(t)
+ self:noise("Pause for %fs", t);
if self._pausefor then
- self._pausefor:close();
+ closetimer(self._pausefor);
+ self._pausefor = nil;
end
if t == false then return; end
self:set(false);
self._pausefor = addtimer(t, function ()
self._pausefor = nil;
self:set(true);
+ self:noise("Resuming after pause, connection is %s", not self.conn and "missing" or self.conn:dirty() and "dirty" or "clean");
if self.conn and self.conn:dirty() then
self:onreadable();
end
end);
end
+function interface:setlimit(Bps)
+ if Bps > 0 then
+ self._limit = 1/Bps;
+ else
+ self._limit = nil;
+ end
+end
+
+function interface:pause_writes()
+ if self._write_lock then
+ return
+ end
+ self:noise("Pause writes");
+ self._write_lock = true;
+ self:setwritetimeout(false);
+ self:set(nil, false);
+end
+
+function interface:resume_writes()
+ if not self._write_lock then
+ return
+ end
+ self:noise("Resume writes");
+ self._write_lock = nil;
+ if self.writebuffer[1] then
+ self:setwritetimeout();
+ self:set(nil, true);
+ end
+end
+
-- Connected!
function interface:onconnect()
- if self.conn and not self.peername and self.conn.getpeername then
- self.peername, self.peerport = self.conn:getpeername();
- end
+ self._connected = true;
+ self:updatenames();
+ self:debug("Connected (%s)", self);
self.onconnect = noop;
self:on("connect");
end
-local function addserver(addr, port, listeners, read_size, tls_ctx)
- local conn, err = socket.bind(addr, port, cfg.tcp_backlog);
- if not conn then return conn, err; end
- conn:settimeout(0);
+local function wrapserver(conn, addr, port, listeners, config)
local server = setmetatable({
conn = conn;
- created = gettime();
+ created = realtime();
listeners = listeners;
- read_size = read_size;
+ read_size = config and config.read_size;
onreadable = interface.onacceptable;
- tls_ctx = tls_ctx;
- tls_direct = tls_ctx and true or false;
+ tls_ctx = config and config.tls_ctx;
+ tls_direct = config and config.tls_direct;
+ hosts = config and config.sni_hosts;
sockname = addr;
sockport = port;
+ log = logger.init(("serv%s"):format(new_id()));
}, interface_mt);
+ server:debug("Server %s created", server);
server:add(true, false);
return server;
end
+local function listen(addr, port, listeners, config)
+ local conn, err = socket.bind(addr, port, cfg.tcp_backlog);
+ if not conn then return conn, err; end
+ conn:settimeout(0);
+ return wrapserver(conn, addr, port, listeners, config);
+end
+
-- COMPAT
-local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx)
- local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx);
+local function addserver(addr, port, listeners, read_size, tls_ctx)
+ return listen(addr, port, listeners, {
+ read_size = read_size;
+ tls_ctx = tls_ctx;
+ tls_direct = tls_ctx and true or false;
+ });
+end
+
+-- COMPAT
+local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx, extra)
+ local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra);
if not client.peername then
client.peername, client.peerport = addr, port;
end
@@ -631,7 +781,7 @@ local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx)
end
-- New outgoing TCP connection
-local function addclient(addr, port, listeners, read_size, tls_ctx, typ)
+local function addclient(addr, port, listeners, read_size, tls_ctx, typ, extra)
local create;
if not typ then
local n = inet_pton(addr);
@@ -649,13 +799,19 @@ local function addclient(addr, port, listeners, read_size, tls_ctx, typ)
return nil, "invalid socket type";
end
local conn, err = create();
+ if not conn then return conn, err; end
local ok, err = conn:settimeout(0);
if not ok then return ok, err; end
local ok, err = conn:setpeername(addr, port);
if not ok and err ~= "timeout" then return ok, err; end
- local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx)
+ local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra)
local ok, err = client:init();
+ if not client.peername then
+ -- otherwise not set until connected
+ client.peername, client.peerport = addr, port;
+ end
if not ok then return ok, err; end
+ client:debug("Client %s created", client);
if tls_ctx then
client:starttls(tls_ctx);
end
@@ -677,23 +833,23 @@ local function watchfd(fd, onreadable, onwritable)
end;
-- Otherwise it'll need to be something LuaSocket-compatible
end
+ conn.id = new_id();
+ conn.log = logger.init(("fdwatch%s"):format(conn.id));
conn:add(onreadable, onwritable);
return conn;
end;
-- Dump all data from one connection into another
-local function link(from, to)
- from.listeners = setmetatable({
- onincoming = function (_, data)
- from:pause();
- to:write(data);
- end,
- }, {__index=from.listeners});
- to.listeners = setmetatable({
- ondrain = function ()
- from:resume();
- end,
- }, {__index=to.listeners});
+local function link(from, to, read_size)
+ from:debug("Linking to %s", to.id);
+ function from:onincoming(data)
+ self:pause();
+ to:write(data);
+ end
+ function to:ondrain() -- luacheck: ignore 212/self
+ from:resume();
+ end
+ from:set_mode(read_size);
from:set(true, nil);
to:set(nil, true);
end
@@ -752,11 +908,21 @@ return {
addserver = addserver;
addclient = addclient;
add_task = addtimer;
- at = at;
+ timer = {
+ -- API-compatible with util.timer
+ add_task = addtimer;
+ stop = closetimer;
+ reschedule = reschedule;
+ to_absolute_time = function (t)
+ return t-monotonic()+realtime();
+ end;
+ };
+ listen = listen;
loop = loop;
closeall = closeall;
setquitting = setquitting;
wrapclient = wrapclient;
+ wrapserver = wrapserver;
watchfd = watchfd;
link = link;
set_config = function (newconfig)
@@ -766,6 +932,7 @@ return {
-- libevent emulation
event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 };
addevent = function (fd, mode, callback)
+ log("warn", "Using deprecated libevent emulation, please update code to use watchfd API instead");
local function onevent(self)
local ret = self:callback();
if ret == -1 then
@@ -785,6 +952,8 @@ return {
fds[fd] = nil;
end;
}, interface_mt);
+ conn.id = conn:getfd();
+ conn.log = logger.init(("fdwatch%d"):format(conn.id));
local ok, err = conn:add(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
if not ok then return ok, err; end
return conn;
diff --git a/net/server_event.lua b/net/server_event.lua
index 11bd6a29..f7e1f448 100644
--- a/net/server_event.lua
+++ b/net/server_event.lua
@@ -164,6 +164,15 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed
debug( "fatal error while ssl wrapping:", err )
return false
end
+
+ if self.conn.sni then
+ if self.servername then
+ self.conn:sni(self.servername);
+ elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
+ self.conn:sni(self._server.hosts, true);
+ end
+ end
+
self.conn:settimeout( 0 ) -- set non blocking
local handshakecallback = coroutine_wrap(function( event )
local _, err
@@ -253,6 +262,7 @@ end
--TODO: Deprecate
function interface_mt:lock_read(switch)
+ log("warn", ":lock_read is deprecated, use :pasue() and :resume()");
if switch then
return self:pause();
else
@@ -272,6 +282,19 @@ function interface_mt:resume()
end
end
+function interface_mt:pause_writes()
+ return self:_lock(self.nointerface, self.noreading, true);
+end
+
+function interface_mt:resume_writes()
+ self:_lock(self.nointerface, self.noreading, false);
+ if self.writecallback and not self.eventwrite then
+ self.eventwrite = addevent( base, self.conn, EV_WRITE, self.writecallback, cfg.WRITE_TIMEOUT ); -- register callback
+ return true;
+ end
+end
+
+
function interface_mt:counter(c)
if c then
self._connections = self._connections + c
@@ -281,7 +304,7 @@ end
-- Public methods
function interface_mt:write(data)
- if self.nowriting then return nil, "locked" end
+ if self.nointerface then return nil, "locked"; end
--vdebug( "try to send data to client, id/data:", self.id, data )
data = tostring( data )
local len = #data
@@ -293,7 +316,7 @@ function interface_mt:write(data)
end
t_insert(self.writebuffer, data) -- new buffer
self.writebufferlen = total
- if not self.eventwrite then -- register new write event
+ if not self.eventwrite and not self.nowriting then -- register new write event
--vdebug( "register new write event" )
self.eventwrite = addevent( base, self.conn, EV_WRITE, self.writecallback, cfg.WRITE_TIMEOUT )
end
@@ -440,10 +463,6 @@ end
function interface_mt:ontimeout()
end
function interface_mt:onreadtimeout()
- self.fatalerror = "timeout during receiving"
- debug( "connection failed:", self.fatalerror )
- self:_close()
- self.eventread = nil
end
function interface_mt:ondrain()
end
@@ -456,7 +475,7 @@ end
-- End of client interface methods
-local function handleclient( client, ip, port, server, pattern, listener, sslctx ) -- creates an client interface
+local function handleclient( client, ip, port, server, pattern, listener, sslctx, extra ) -- creates an client interface
--vdebug("creating client interfacce...")
local interface = {
type = "client";
@@ -492,6 +511,8 @@ local function handleclient( client, ip, port, server, pattern, listener, sslctx
_serverport = (server and server:port() or nil),
_sslctx = sslctx; -- parameters
_usingssl = false; -- client is using ssl;
+ extra = extra;
+ servername = extra and extra.servername;
}
if not has_luasec then interface.starttls = false; end
interface.id = tostring(interface):match("%x+$");
@@ -635,7 +656,7 @@ local function handleclient( client, ip, port, server, pattern, listener, sslctx
return interface
end
-local function handleserver( server, addr, port, pattern, listener, sslctx ) -- creates an server interface
+local function handleserver( server, addr, port, pattern, listener, sslctx, startssl ) -- creates a server interface
debug "creating server interface..."
local interface = {
_connections = 0;
@@ -651,6 +672,7 @@ local function handleserver( server, addr, port, pattern, listener, sslctx ) --
_ip = addr, _port = port, _pattern = pattern,
_sslctx = sslctx;
+ hosts = {};
}
interface.id = tostring(interface):match("%x+$");
interface.readcallback = function( event ) -- server handler, called on incoming connections
@@ -670,6 +692,7 @@ local function handleserver( server, addr, port, pattern, listener, sslctx ) --
end
end
--vdebug("max connection check ok, accepting...")
+ -- luacheck: ignore 231/err
local client, err = server:accept() -- try to accept; TODO: check err
while client do
if interface._connections >= cfg.MAX_CONNECTIONS then
@@ -681,7 +704,7 @@ local function handleserver( server, addr, port, pattern, listener, sslctx ) --
interface._connections = interface._connections + 1 -- increase connection count
local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx )
--vdebug( "client id:", clientinterface, "startssl:", startssl )
- if has_luasec and sslctx then
+ if has_luasec and startssl then
clientinterface:starttls(sslctx, true)
else
clientinterface:_start_session( true )
@@ -700,9 +723,9 @@ local function handleserver( server, addr, port, pattern, listener, sslctx ) --
return interface
end
-local function addserver( addr, port, listener, pattern, sslctx, startssl ) -- TODO: check arguments
- --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslctx or "nil", startssl or "nil")
- if sslctx and not has_luasec then
+local function listen(addr, port, listener, config)
+ config = config or {}
+ if config.sslctx and not has_luasec then
debug "fatal error: luasec not found"
return nil, "luasec not found"
end
@@ -711,19 +734,28 @@ local function addserver( addr, port, listener, pattern, sslctx, startssl ) --
debug( "creating server socket on "..addr.." port "..port.." failed:", err )
return nil, err
end
- local interface = handleserver( server, addr, port, pattern, listener, sslctx, startssl ) -- new server handler
+ local interface = handleserver( server, addr, port, config.read_size, listener, config.tls_ctx, config.tls_direct) -- new server handler
debug( "new server created with id:", tostring(interface))
return interface
end
-local function wrapclient( client, ip, port, listeners, pattern, sslctx )
- local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx )
+local function addserver( addr, port, listener, pattern, sslctx ) -- TODO: check arguments
+ --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslctx or "nil", startssl or "nil")
+ return listen( addr, port, listener, {
+ read_size = pattern,
+ tls_ctx = sslctx,
+ tls_direct = not not sslctx,
+ });
+end
+
+local function wrapclient( client, ip, port, listeners, pattern, sslctx, extra )
+ local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx, extra )
interface:_start_connection(sslctx)
return interface, client
--function handleclient( client, ip, port, server, pattern, listener, _, sslctx ) -- creates an client interface
end
-local function addclient( addr, serverport, listener, pattern, sslctx, typ )
+local function addclient( addr, serverport, listener, pattern, sslctx, typ, extra )
if sslctx and not has_luasec then
debug "need luasec, but not available"
return nil, "luasec not found"
@@ -749,8 +781,9 @@ local function addclient( addr, serverport, listener, pattern, sslctx, typ )
client:settimeout( 0 ) -- set nonblocking
local res, err = client:setpeername( addr, serverport ) -- connect
if res or ( err == "timeout" ) then
+ -- luacheck: ignore 211/port
local ip, port = client:getsockname( )
- local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx )
+ local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx, extra )
debug( "new connection id:", interface.id )
return interface, err
else
@@ -876,6 +909,7 @@ return {
event_base = base,
addevent = newevent,
addserver = addserver,
+ listen = listen,
addclient = addclient,
wrapclient = wrapclient,
setquitting = setquitting,
diff --git a/net/server_select.lua b/net/server_select.lua
index 1a40a6d3..a2515d59 100644
--- a/net/server_select.lua
+++ b/net/server_select.lua
@@ -68,6 +68,7 @@ local idfalse
local closeall
local addsocket
local addserver
+local listen
local addtimer
local getserver
local wrapserver
@@ -123,7 +124,7 @@ local _maxsslhandshake
_server = { } -- key = port, value = table; list of listening servers
_readlist = { } -- array with sockets to read from
-_sendlist = { } -- arrary with sockets to write to
+_sendlist = { } -- array with sockets to write to
_timerlist = { } -- array of timer functions
_socketlist = { } -- key = socket, value = wrapped socket (handlers)
_readtimes = { } -- key = handler, value = timestamp of last data reading
@@ -149,7 +150,7 @@ _checkinterval = 30 -- interval in secs to check idle clients
_sendtimeout = 60000 -- allowed send idle time in secs
_readtimeout = 14 * 60 -- allowed read idle time in secs
-local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows
+local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to determine whether this is Windows
_maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows
_maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows
@@ -157,7 +158,7 @@ _maxsslhandshake = 30 -- max handshake round-trips
----------------------------------// PRIVATE //--
-wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd
+wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, ssldirect ) -- this function wraps a server -- FIXME Make sure FD < _maxfd
if socket:getfd() >= _maxfd then
out_error("server.lua: Disallowed FD number: "..socket:getfd())
@@ -183,6 +184,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t
handler.sslctx = function( )
return sslctx
end
+ handler.hosts = {} -- sni
handler.remove = function( )
connections = connections - 1
if handler then
@@ -244,13 +246,13 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t
local client, err = accept( socket ) -- try to accept
if client then
local ip, clientport = client:getpeername( )
- local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket
+ local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx, ssldirect ) -- wrap new client socket
if err then -- error while wrapping ssl socket
return false
end
connections = connections + 1
out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))
- if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes
+ if dispatch and not ssldirect then -- SSL connections will notify onconnect when handshake completes
return dispatch( handler );
end
return;
@@ -264,7 +266,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t
return handler
end
-wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object
+wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, ssldirect, extra ) -- this function wraps a client to a handler object
if socket:getfd() >= _maxfd then
out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent
@@ -287,6 +289,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
local ssl
+ local pending
+
local dispatch = listeners.onincoming
local status = listeners.onstatus
local disconnect = listeners.ondisconnect
@@ -314,6 +318,11 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
local handler = bufferqueue -- saves a table ^_^
+ handler.extra = extra
+ if extra then
+ handler.servername = extra.servername
+ end
+
handler.dispatch = function( )
return dispatch
end
@@ -336,6 +345,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
listeners.onattach(self, data)
end
end
+ handler._setpending = function( )
+ pending = true
+ end
handler.getstats = function( )
return readtraffic, sendtraffic
end
@@ -372,7 +384,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
_readlistlen = removesocket( _readlist, socket, _readlistlen )
_readtimes[ handler ] = nil
if bufferqueuelen ~= 0 then
- handler.sendbuffer() -- Try now to send any outstanding data
+ handler:sendbuffer() -- Try now to send any outstanding data
if bufferqueuelen ~= 0 then -- Still not empty, so we'll try again later
if handler then
handler.write = nil -- ... but no further writing allowed
@@ -424,9 +436,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
bufferlen = bufferlen + #data
if bufferlen > maxsendlen then
_closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle
- handler.write = idfalse -- don't write anymore
return false
- elseif socket and not _sendlist[ socket ] then
+ elseif not nosend and socket and not _sendlist[ socket ] then
_sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
end
bufferqueuelen = bufferqueuelen + 1
@@ -456,49 +467,55 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
maxreadlen = readlen or maxreadlen
return bufferlen, maxreadlen, maxsendlen
end
- --TODO: Deprecate
handler.lock_read = function (self, switch)
+ out_error( "server.lua, lock_read() is deprecated, use pause() and resume()" )
if switch == true then
- local tmp = _readlistlen
- _readlistlen = removesocket( _readlist, socket, _readlistlen )
- _readtimes[ handler ] = nil
- if _readlistlen ~= tmp then
- noread = true
- end
+ return self:pause()
elseif switch == false then
- if noread then
- noread = false
- _readlistlen = addsocket(_readlist, socket, _readlistlen)
- _readtimes[ handler ] = _currenttime
- end
+ return self:resume()
end
return noread
end
handler.pause = function (self)
- return self:lock_read(true);
+ local tmp = _readlistlen
+ _readlistlen = removesocket( _readlist, socket, _readlistlen )
+ _readtimes[ handler ] = nil
+ if _readlistlen ~= tmp then
+ noread = true
+ end
+ return noread;
end
handler.resume = function (self)
- return self:lock_read(false);
+ if noread then
+ noread = false
+ _readlistlen = addsocket(_readlist, socket, _readlistlen)
+ _readtimes[ handler ] = _currenttime
+ end
+ return noread;
end
handler.lock = function( self, switch )
- handler.lock_read (switch)
+ out_error( "server.lua, lock() is deprecated" )
+ handler.lock_read (self, switch)
if switch == true then
- handler.write = idfalse
- local tmp = _sendlistlen
- _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
- _writetimes[ handler ] = nil
- if _sendlistlen ~= tmp then
- nosend = true
- end
+ handler.pause_writes (self)
elseif switch == false then
- handler.write = write
- if nosend then
- nosend = false
- write( "" )
- end
+ handler.resume_writes (self)
end
return noread, nosend
end
+ handler.pause_writes = function (self)
+ local tmp = _sendlistlen
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
+ _writetimes[ handler ] = nil
+ nosend = true
+ end
+ handler.resume_writes = function (self)
+ nosend = false
+ if bufferlen > 0 then
+ _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
+ end
+ end
+
local _readbuffer = function( ) -- this function reads data
local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern"
if not err or (err == "wantread" or err == "timeout") then -- received something
@@ -513,6 +530,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
_readtraffic = _readtraffic + count
_readtimes[ handler ] = _currenttime
--out_put( "server.lua: read data '", buffer:gsub("[^%w%p ]", "."), "', error: ", err )
+ if pending then -- connection established
+ pending = nil
+ if listeners.onconnect then
+ listeners.onconnect(handler)
+ end
+ end
return dispatch( handler, buffer, err )
else -- connections was closed or fatal error
out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
@@ -523,6 +546,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
local _sendbuffer = function( ) -- this function sends data
local succ, err, byte, buffer, count;
if socket then
+ if pending then
+ pending = nil
+ if listeners.onconnect then
+ listeners.onconnect(handler);
+ end
+ end
buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )
succ, err, byte = send( socket, buffer, 1, bufferlen )
count = ( succ or byte or 0 ) * STAT_UNIT
@@ -599,7 +628,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
coroutine_yield( ) -- handshake not finished
end
end
- err = "ssl handshake error: " .. ( err or "handshake too long" );
+ err = ( err or "handshake too long" );
out_put( "server.lua: ", err );
_ = handler and handler:force_close(err)
return false, err -- handshake failed
@@ -619,11 +648,20 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
local oldsocket, err = socket
socket, err = ssl_wrap( socket, sslctx ) -- wrap socket
+
if not socket then
out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
return nil, err -- fatal error
end
+ if socket.sni then
+ if self.servername then
+ socket:sni(self.servername);
+ elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
+ socket:sni(self.server().hosts, true);
+ end
+ end
+
socket:settimeout( 0 )
-- add the new socket to our system
@@ -659,7 +697,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
_socketlist[ socket ] = handler
_readlistlen = addsocket(_readlist, socket, _readlistlen)
- if sslctx and has_luasec then
+ if sslctx and ssldirect and has_luasec then
out_put "server.lua: auto-starting ssl negotiation..."
handler.autostart_ssl = true;
local ok, err = handler:starttls(sslctx);
@@ -714,7 +752,7 @@ local function link(sender, receiver, buffersize)
local sender_locked;
local _sendbuffer = receiver.sendbuffer;
function receiver.sendbuffer()
- _sendbuffer();
+ _sendbuffer(receiver);
if sender_locked and receiver.bufferlen() < buffersize then
sender:lock_read(false); -- Unlock now
sender_locked = nil;
@@ -734,9 +772,13 @@ end
----------------------------------// PUBLIC //--
-addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
+listen = function ( addr, port, listeners, config )
addr = addr or "*"
+ config = config or {}
local err
+ local sslctx = config.tls_ctx;
+ local ssldirect = config.tls_direct;
+ local pattern = config.read_size;
if type( listeners ) ~= "table" then
err = "invalid listener table"
elseif type ( addr ) ~= "string" then
@@ -757,7 +799,7 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function
out_error( "server.lua, [", addr, "]:", port, ": ", err )
return nil, err
end
- local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket
+ local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, ssldirect ) -- wrap new server socket
if not handler then
server:close( )
return nil, err
@@ -770,6 +812,14 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function
return handler
end
+addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
+ return listen(addr, port, listeners, {
+ read_size = pattern;
+ tls_ctx = sslctx;
+ tls_direct = sslctx and true or false;
+ });
+end
+
getserver = function ( addr, port )
return _server[ addr..":"..port ];
end
@@ -912,7 +962,7 @@ loop = function(once) -- this is the main loop of the program
for _, socket in ipairs( read ) do -- receive data
local handler = _socketlist[ socket ]
if handler then
- handler.readbuffer( )
+ handler:readbuffer( )
else
closesocket( socket )
out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen
@@ -921,7 +971,7 @@ loop = function(once) -- this is the main loop of the program
for _, socket in ipairs( write ) do -- send data waiting in writequeues
local handler = _socketlist[ socket ]
if handler then
- handler.sendbuffer( )
+ handler:sendbuffer( )
else
closesocket( socket )
out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen
@@ -977,27 +1027,19 @@ end
--// EXPERIMENTAL //--
-local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx )
- local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
+local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx, extra )
+ local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, sslctx, extra)
if not handler then return nil, err end
_socketlist[ socket ] = handler
if not sslctx then
+ handler._setpending()
_readlistlen = addsocket(_readlist, socket, _readlistlen)
_sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
- if listeners.onconnect then
- -- When socket is writeable, call onconnect
- local _sendbuffer = handler.sendbuffer;
- handler.sendbuffer = function ()
- handler.sendbuffer = _sendbuffer;
- listeners.onconnect(handler);
- return _sendbuffer(); -- Send any queued outgoing data
- end
- end
end
return handler, socket
end
-local addclient = function( address, port, listeners, pattern, sslctx, typ )
+local addclient = function( address, port, listeners, pattern, sslctx, typ, extra )
local err
if type( listeners ) ~= "table" then
err = "invalid listener table"
@@ -1034,7 +1076,7 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ )
client:settimeout( 0 )
local ok, err = client:setpeername( address, port )
if ok or err == "timeout" or err == "Operation already in progress" then
- return wrapclient( client, address, port, listeners, pattern, sslctx )
+ return wrapclient( client, address, port, listeners, pattern, sslctx, extra )
else
return nil, err
end
@@ -1114,6 +1156,7 @@ return {
stats = stats,
closeall = closeall,
addserver = addserver,
+ listen = listen,
getserver = getserver,
setlogger = setlogger,
getsettings = getsettings,
diff --git a/net/unbound.lua b/net/unbound.lua
new file mode 100644
index 00000000..49f220e1
--- /dev/null
+++ b/net/unbound.lua
@@ -0,0 +1,208 @@
+-- libunbound based net.adns replacement for Prosody IM
+-- Copyright (C) 2013-2015 Kim Alvefur
+--
+-- This file is MIT licensed.
+--
+-- luacheck: ignore prosody
+
+local setmetatable = setmetatable;
+local tostring = tostring;
+local t_concat = table.concat;
+local s_format = string.format;
+local s_lower = string.lower;
+local s_upper = string.upper;
+local noop = function() end;
+
+local log = require "util.logger".init("unbound");
+local net_server = require "net.server";
+local libunbound = require"lunbound";
+local promise = require"util.promise";
+
+local gettime = require"socket".gettime;
+local dns_utils = require"util.dns";
+local classes, types, errors = dns_utils.classes, dns_utils.types, dns_utils.errors;
+local parsers = dns_utils.parsers;
+
+local function add_defaults(conf)
+ if conf then
+ for option, default in pairs(libunbound.config) do
+ if conf[option] == nil then
+ conf[option] = default;
+ end
+ end
+ end
+ return conf;
+end
+
+local unbound_config;
+if prosody then
+ local config = require"core.configmanager";
+ unbound_config = add_defaults(config.get("*", "unbound"));
+ prosody.events.add_handler("config-reloaded", function()
+ unbound_config = add_defaults(config.get("*", "unbound"));
+ end);
+end
+-- Note: libunbound will default to using root hints if resolvconf is unset
+
+local function connect_server(unbound, server)
+ return server.watchfd(unbound, function ()
+ unbound:process()
+ end);
+end
+
+local unbound = libunbound.new(unbound_config);
+
+local server_conn = connect_server(unbound, net_server);
+
+local answer_mt = {
+ __tostring = function(self)
+ if self._string then return self._string end
+ local h = s_format("Status: %s", errors[self.status]);
+ if self.secure then
+ h = h .. ", Secure";
+ elseif self.bogus then
+ h = h .. s_format(", Bogus: %s", self.bogus);
+ end
+ local t = { h };
+ for i = 1, #self do
+ t[i+1]=self.qname.."\t"..classes[self.qclass].."\t"..types[self.qtype].."\t"..tostring(self[i]);
+ end
+ local _string = t_concat(t, "\n");
+ self._string = _string;
+ return _string;
+ end;
+};
+
+local waiting_queries = {};
+
+local function prep_answer(a)
+ if not a then return end
+ local status = errors[a.rcode];
+ local qclass = classes[a.qclass];
+ local qtype = types[a.qtype];
+ a.status, a.class, a.type = status, qclass, qtype;
+
+ local t = s_lower(qtype);
+ local rr_mt = { __index = a, __tostring = function(self) return tostring(self[t]) end };
+ local parser = parsers[qtype];
+ for i = 1, #a do
+ if a.bogus then
+ -- Discard bogus data
+ a[i] = nil;
+ else
+ a[i] = setmetatable({[t] = parser(a[i])}, rr_mt);
+ end
+ end
+ return setmetatable(a, answer_mt);
+end
+
+local function lookup(callback, qname, qtype, qclass)
+ qtype = qtype and s_upper(qtype) or "A";
+ qclass = qclass and s_upper(qclass) or "IN";
+ local ntype, nclass = types[qtype], classes[qclass];
+ local startedat = gettime();
+ local ret;
+ local function callback_wrapper(a, err)
+ local gotdataat = gettime();
+ waiting_queries[ret] = nil;
+ if a then
+ prep_answer(a);
+ log("debug", "Results for %s %s %s: %s (%s, %f sec)", qname, qclass, qtype, a.rcode == 0 and (#a .. " items") or a.status,
+ a.secure and "Secure" or a.bogus or "Insecure", gotdataat - startedat); -- Insecure as in unsigned
+ else
+ log("error", "Results for %s %s %s: %s", qname, qclass, qtype, tostring(err));
+ end
+ local ok, cerr = pcall(callback, a, err);
+ if not ok then log("error", "Error in callback: %s", cerr); end
+ end
+ log("debug", "Resolve %s %s %s", qname, qclass, qtype);
+ local err;
+ ret, err = unbound:resolve_async(callback_wrapper, qname, ntype, nclass);
+ if ret then
+ waiting_queries[ret] = callback;
+ else
+ log("warn", err);
+ end
+ return ret, err;
+end
+
+local function lookup_sync(qname, qtype, qclass)
+ qtype = qtype and s_upper(qtype) or "A";
+ qclass = qclass and s_upper(qclass) or "IN";
+ local ntype, nclass = types[qtype], classes[qclass];
+ local a, err = unbound:resolve(qname, ntype, nclass);
+ if not a then return a, err; end
+ return prep_answer(a);
+end
+
+local function cancel(id)
+ local cb = waiting_queries[id];
+ unbound:cancel(id);
+ if cb then
+ cb(nil, "canceled");
+ waiting_queries[id] = nil;
+ end
+ return true;
+end
+
+-- Reinitiate libunbound context, drops cache
+local function purge()
+ for id in pairs(waiting_queries) do cancel(id); end
+ if server_conn then server_conn:close(); end
+ unbound = libunbound.new(unbound_config);
+ server_conn = connect_server(unbound, net_server);
+ return true;
+end
+
+local function not_implemented()
+ error "not implemented";
+end
+-- Public API
+local _M = {
+ lookup = lookup;
+ cancel = cancel;
+ new_async_socket = not_implemented;
+ dns = {
+ lookup = lookup_sync;
+ cancel = cancel;
+ cache = noop;
+ socket_wrapper_set = noop;
+ settimeout = noop;
+ query = noop;
+ purge = purge;
+ random = noop;
+ peek = noop;
+
+ types = types;
+ classes = classes;
+ };
+};
+
+local function lookup_promise(_, qname, qtype, qclass)
+ return promise.new(function (resolve, reject)
+ local function callback(answer, err)
+ if err then
+ return reject(err);
+ else
+ return resolve(answer);
+ end
+ end
+ local ret, err = lookup(callback, qname, qtype, qclass)
+ if not ret then reject(err); end
+ end);
+end
+
+local wrapper = {
+ lookup = function (_, callback, qname, qtype, qclass)
+ return lookup(callback, qname, qtype, qclass)
+ end;
+ lookup_promise = lookup_promise;
+ _resolver = {
+ settimeout = function () end;
+ closeall = function () end;
+ };
+}
+
+function _M.resolver() return wrapper; end
+
+return _M;
diff --git a/net/websocket.lua b/net/websocket.lua
index 469c6a58..193cd556 100644
--- a/net/websocket.lua
+++ b/net/websocket.lua
@@ -23,6 +23,7 @@ local websockets = {};
local websocket_listeners = {};
function websocket_listeners.ondisconnect(conn, err)
local s = websockets[conn];
+ if not s then return; end
websockets[conn] = nil;
if s.close_timer then
timer.stop(s.close_timer);
@@ -113,7 +114,7 @@ function websocket_listeners.onincoming(conn, buffer, err) -- luacheck: ignore 2
frame.MASK = true; -- RFC 6455 6.1.5: If the data is being sent by the client, the frame(s) MUST be masked
conn:write(frames.build(frame));
elseif frame.opcode == 0xA then -- Pong frame
- log("debug", "Received unexpected pong frame: " .. tostring(frame.data));
+ log("debug", "Received unexpected pong frame: %s", frame.data);
else
return fail(s, 1002, "Reserved opcode");
end
@@ -131,7 +132,7 @@ end
function websocket_methods:close(code, reason)
if self.readyState < 2 then
code = code or 1000;
- log("debug", "closing WebSocket with code %i: %s" , code , tostring(reason));
+ log("debug", "closing WebSocket with code %i: %s" , code , reason);
self.readyState = 2;
local conn = self.conn;
conn:write(frames.build_close(code, reason, true));
@@ -245,7 +246,7 @@ local function connect(url, ex, listeners)
or (protocol and not protocol[r.headers["sec-websocket-protocol"]])
then
s.readyState = 3;
- log("warn", "WebSocket connection to %s failed: %s", url, tostring(b));
+ log("warn", "WebSocket connection to %s failed: %s", url, b);
if s.onerror then s:onerror("connecting-failed"); end
return;
end
diff --git a/net/websocket/frames.lua b/net/websocket/frames.lua
index ba25d261..5e17df07 100644
--- a/net/websocket/frames.lua
+++ b/net/websocket/frames.lua
@@ -9,20 +9,20 @@
local softreq = require "util.dependencies".softreq;
local random_bytes = require "util.random".bytes;
-local bit = assert(softreq"bit" or softreq"bit32",
- "No bit module found. See https://prosody.im/doc/depends#bitop");
+local bit = require "util.bitcompat";
local band = bit.band;
local bor = bit.bor;
local bxor = bit.bxor;
local lshift = bit.lshift;
local rshift = bit.rshift;
+local unpack = table.unpack or unpack; -- luacheck: ignore 113
local t_concat = table.concat;
local s_byte = string.byte;
local s_char= string.char;
local s_sub = string.sub;
-local s_pack = string.pack; -- luacheck: ignore 143
-local s_unpack = string.unpack; -- luacheck: ignore 143
+local s_pack = string.pack;
+local s_unpack = string.unpack;
if not s_pack and softreq"struct" then
s_pack = softreq"struct".pack;