diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/adns.lua | 44 | ||||
-rw-r--r-- | net/connect.lua | 8 | ||||
-rw-r--r-- | net/connlisteners.lua | 18 | ||||
-rw-r--r-- | net/cqueues.lua | 56 | ||||
-rw-r--r-- | net/dns.lua | 72 | ||||
-rw-r--r-- | net/http.lua | 53 | ||||
-rw-r--r-- | net/http/codes.lua | 2 | ||||
-rw-r--r-- | net/http/errors.lua | 119 | ||||
-rw-r--r-- | net/http/files.lua | 149 | ||||
-rw-r--r-- | net/http/parser.lua | 147 | ||||
-rw-r--r-- | net/http/server.lua | 179 | ||||
-rw-r--r-- | net/resolvers/basic.lua | 65 | ||||
-rw-r--r-- | net/resolvers/manual.lua | 1 | ||||
-rw-r--r-- | net/resolvers/service.lua | 33 | ||||
-rw-r--r-- | net/server.lua | 6 | ||||
-rw-r--r-- | net/server_epoll.lua | 538 | ||||
-rw-r--r-- | net/server_event.lua | 60 | ||||
-rw-r--r-- | net/server_select.lua | 149 | ||||
-rw-r--r-- | net/unbound.lua | 220 | ||||
-rw-r--r-- | net/websocket.lua | 7 | ||||
-rw-r--r-- | net/websocket/frames.lua | 7 |
21 files changed, 1466 insertions, 467 deletions
diff --git a/net/adns.lua b/net/adns.lua index 0bdf6ee3..ae168b9c 100644 --- a/net/adns.lua +++ b/net/adns.lua @@ -8,13 +8,17 @@ local server = require "net.server"; local new_resolver = require "net.dns".resolver; +local promise = require "util.promise"; local log = require "util.logger".init("adns"); -local coroutine, tostring, pcall = coroutine, tostring, pcall; +log("debug", "Using legacy DNS API (missing lua-unbound?)"); -- TODO write docs about luaunbound +-- TODO Raise log level once packages are available + +local coroutine, pcall = coroutine, pcall; local setmetatable = setmetatable; -local function dummy_send(sock, data, i, j) return (j-i)+1; end +local function dummy_send(sock, data, i, j) return (j-i)+1; end -- luacheck: ignore 212 local _ENV = nil; -- luacheck: std none @@ -29,8 +33,7 @@ local function new_async_socket(sock, resolver) local peername = "<unknown>"; local listener = {}; local handler = {}; - local err; - function listener.onincoming(conn, data) + function listener.onincoming(conn, data) -- luacheck: ignore 212/conn if data then resolver:feed(handler, data); end @@ -40,15 +43,18 @@ local function new_async_socket(sock, resolver) log("warn", "DNS socket for %s disconnected: %s", peername, err); local servers = resolver.server; if resolver.socketset[conn] == resolver.best_server and resolver.best_server == #servers then - log("error", "Exhausted all %d configured DNS servers, next lookup will try %s again", #servers, servers[1]); + log("warn", "Exhausted all %d configured DNS servers, next lookup will try %s again", #servers, servers[1]); end resolver:servfail(conn); -- Let the magic commence end end - handler, err = server.wrapclient(sock, "dns", 53, listener); - if not handler then - return nil, err; + do + local err; + handler, err = server.wrapclient(sock, "dns", 53, listener); + if not handler then + return nil, err; + end end if handler.set then -- server_epoll: only watch for incoming data @@ -76,11 +82,11 @@ function async_resolver_methods:lookup(handler, qname, qtype, qclass) handler(peek); return; end - log("debug", "Records for %s not in cache, sending query (%s)...", qname, tostring(coroutine.running())); + log("debug", "Records for %s not in cache, sending query (%s)...", qname, coroutine.running()); local ok, err = resolver:query(qname, qtype, qclass); if ok then coroutine.yield(setmetatable({ resolver, qclass or "IN", qtype or "A", qname, coroutine.running()}, query_mt)); -- Wait for reply - log("debug", "Reply for %s (%s)", qname, tostring(coroutine.running())); + log("debug", "Reply for %s (%s)", qname, coroutine.running()); end if ok then ok, err = pcall(handler, resolver:peek(qname, qtype, qclass)); @@ -89,13 +95,25 @@ function async_resolver_methods:lookup(handler, qname, qtype, qclass) ok, err = pcall(handler, nil, err); end if not ok then - log("error", "Error in DNS response handler: %s", tostring(err)); + log("error", "Error in DNS response handler: %s", err); end end)(resolver:peek(qname, qtype, qclass)); end -function query_methods:cancel(call_handler, reason) - log("warn", "Cancelling DNS lookup for %s", tostring(self[4])); +function async_resolver_methods:lookup_promise(qname, qtype, qclass) + return promise.new(function (resolve, reject) + local function handler(answer) + if not answer then + return reject(); + end + resolve(answer); + end + self:lookup(handler, qname, qtype, qclass); + end); +end + +function query_methods:cancel(call_handler, reason) -- luacheck: ignore 212/reason + log("warn", "Cancelling DNS lookup for %s", self[4]); self[1].cancel(self[2], self[3], self[4], self[5], call_handler); end diff --git a/net/connect.lua b/net/connect.lua index b812ffcd..d52d3901 100644 --- a/net/connect.lua +++ b/net/connect.lua @@ -2,6 +2,12 @@ local server = require "net.server"; local log = require "util.logger".init("net.connect"); local new_id = require "util.id".short; +-- TODO #1246 Happy Eyeballs +-- FIXME RFC 6724 +-- FIXME Error propagation from resolvers doesn't work +-- FIXME #1428 Reuse DNS resolver object between service and basic resolver +-- FIXME #1429 Close DNS resolver object when done + local pending_connection_methods = {}; local pending_connection_mt = { __name = "pending_connection"; @@ -38,7 +44,7 @@ local function attempt_connection(p) p:log("debug", "Next target to try is %s:%d", ip, port); local conn, err = server.addclient(ip, port, pending_connection_listeners, p.options.pattern or "*a", p.options.sslctx, conn_type, extra); if not conn then - log("debug", "Connection attempt failed immediately: %s", tostring(err)); + log("debug", "Connection attempt failed immediately: %s", err); p.last_error = err or "unknown reason"; return attempt_connection(p); end diff --git a/net/connlisteners.lua b/net/connlisteners.lua deleted file mode 100644 index 9b8f88c3..00000000 --- a/net/connlisteners.lua +++ /dev/null @@ -1,18 +0,0 @@ --- COMPAT w/pre-0.9 -local log = require "util.logger".init("net.connlisteners"); -local traceback = debug.traceback; - -local _ENV = nil; --- luacheck: std none - -local function fail() - log("error", "Attempt to use legacy connlisteners API. For more info see https://prosody.im/doc/developers/network"); - log("error", "Legacy connlisteners API usage, %s", traceback("", 2)); -end - -return { - register = fail; - get = fail; - start = fail; - -- epic fail -}; diff --git a/net/cqueues.lua b/net/cqueues.lua index 8c4c756f..65d2a019 100644 --- a/net/cqueues.lua +++ b/net/cqueues.lua @@ -9,6 +9,7 @@ local server = require "net.server"; local cqueues = require "cqueues"; +local timer = require "util.timer"; assert(cqueues.VERSION >= 20150113, "cqueues newer than 20150113 required") -- Create a single top level cqueue @@ -16,55 +17,24 @@ local cq; if server.cq then -- server provides cqueues object cq = server.cq; -elseif server.get_backend() == "select" and server._addtimer then -- server_select +elseif server.watchfd then cq = cqueues.new(); - local function step() + local timeout = timer.add_task(cq:timeout() or 0, function () + -- FIXME It should be enough to reschedule this timeout instead of replacing it, but this does not work. See https://issues.prosody.im/1572 assert(cq:loop(0)); - end - - -- Use wrapclient (as wrapconnection isn't exported) to get server_select to watch cq fd - local handler = server.wrapclient({ - getfd = function() return cq:pollfd(); end; - settimeout = function() end; -- Method just needs to exist - close = function() end; -- Need close method for 'closeall' - }, nil, nil, {}); - - -- Only need to listen for readable; cqueues handles everything under the hood - -- readbuffer is called when `select` notes an fd as readable - handler.readbuffer = step; - - -- Use server_select low lever timer facility, - -- this callback gets called *every* time there is a timeout in the main loop - server._addtimer(function(current_time) - -- This may end up in extra step()'s, but cqueues handles it for us. - step(); return cq:timeout(); end); -elseif server.event and server.base then -- server_event - cq = cqueues.new(); - -- Only need to listen for readable; cqueues handles everything under the hood - local EV_READ = server.event.EV_READ; - -- Convert a cqueues timeout to an acceptable timeout for luaevent - local function luaevent_safe_timeout(cq) + server.watchfd(cq:pollfd(), function () + assert(cq:loop(0)); local t = cq:timeout(); - -- if you give luaevent 0 or nil, it re-uses the previous timeout. - if t == 0 then - t = 0.000001; -- 1 microsecond is the smallest that works (goes into a `struct timeval`) - elseif t == nil then -- pick something big if we don't have one - t = 0x7FFFFFFF; -- largest 32bit int + if t then + timer.stop(timeout); + timeout = timer.add_task(cq:timeout(), function () + assert(cq:loop(0)); + return cq:timeout(); + end); end - return t - end - local event_handle; - event_handle = server.base:addevent(cq:pollfd(), EV_READ, function(e) - -- Need to reference event_handle or this callback will get collected - -- This creates a circular reference that can only be broken if event_handle is manually :close()'d - local _ = event_handle; - -- Run as many cqueues things as possible (with a timeout of 0) - -- If an error is thrown, it will break the libevent loop; but prosody resumes after logging a top level error - assert(cq:loop(0)); - return EV_READ, luaevent_safe_timeout(cq); - end, luaevent_safe_timeout(cq)); + end); else error "NYI" end diff --git a/net/dns.lua b/net/dns.lua index 3902f95c..17119152 100644 --- a/net/dns.lua +++ b/net/dns.lua @@ -13,10 +13,12 @@ local socket = require "socket"; -local timer = require "util.timer"; +local have_timer, timer = pcall(require, "util.timer"); local new_ip = require "util.ip".new_ip; local have_util_net, util_net = pcall(require, "util.net"); +local log = require "util.logger".init("dns"); + local _, windows = pcall(require, "util.windows"); local is_windows = (_ and windows) or os.getenv("WINDIR"); @@ -69,7 +71,9 @@ local ztact = { -- public domain 20080404 lua@ztact.com }; local get, set = ztact.get, ztact.set; -local default_timeout = 15; +local default_timeout = 5; +local default_jitter = 1; +local default_retry_jitter = 2; -------------------------------------------------- module dns local _ENV = nil; @@ -664,8 +668,10 @@ end -- socket layer -------------------------------------------------- socket layer -resolver.delays = { 1, 3 }; +resolver.delays = { 1, 2, 3, 5 }; +resolver.jitter = have_timer and default_jitter or nil; +resolver.retry_jitter = have_timer and default_retry_jitter or nil; function resolver:addnameserver(address) -- - - - - - - - - - addnameserver self.server = self.server or {}; @@ -853,7 +859,10 @@ function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query packet = header..question, server = self.best_server, delay = 1, - retry = socket.gettime() + self.delays[1] + retry = socket.gettime() + self.delays[1]; + qclass = qclass; + qtype = qtype; + qname = qname; }; -- remember the query @@ -864,30 +873,32 @@ function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query if not conn then return nil, err; end - conn:send (o.packet) + if self.jitter then + timer.add_task(math.random()*self.jitter, function () + conn:send(o.packet); + end); + else + conn:send(o.packet); + end -- remember which coroutine wants the answer if co then set(self.wanted, qclass, qtype, qname, co, true); end - if timer and self.timeout then + if have_timer and self.timeout then local num_servers = #self.server; local i = 1; timer.add_task(self.timeout, function () if get(self.wanted, qclass, qtype, qname, co) then - if i < num_servers then + log("debug", "DNS request timeout %d/%d", i, num_servers) i = i + 1; - self:servfail(conn); - o.server = self.best_server; - conn, err = self:getsocket(o.server); - if conn then - conn:send(o.packet); - return self.timeout; - end - end - -- Tried everything, failed - self:cancel(qclass, qtype, qname); + self:servfail(self.socket[o.server]); +-- end + end + -- Still outstanding? (i.e. retried) + if get(self.wanted, qclass, qtype, qname, co) then + return self.timeout; -- Then wait end end) end @@ -904,6 +915,7 @@ function resolver:servfail(sock, err) -- Find all requests to the down server, and retry on the next server self.time = socket.gettime(); + log("debug", "servfail %d (of %d)", num, #self.server); for id,queries in pairs(self.active) do for question,o in pairs(queries) do if o.server == num then -- This request was to the broken server @@ -913,12 +925,27 @@ function resolver:servfail(sock, err) end o.retries = (o.retries or 0) + 1; - if o.retries >= #self.server then - --print('timeout'); - queries[question] = nil; - else + local retried; + if o.retries < #self.server then sock, err = self:getsocket(o.server); - if sock then sock:send(o.packet); end + if sock then + retried = true; + if self.retry_jitter then + local delay = self.delays[((o.retries-1)%#self.delays)+1] + (math.random()*self.retry_jitter); + log("debug", "retry %d in %0.2fs", o.retries, delay); + timer.add_task(delay, function () + sock:send(o.packet); + end); + else + log("debug", "retry %d (immediate)", o.retries); + sock:send(o.packet); + end + end + end + if not retried then + log("debug", 'tried all servers, giving up'); + self:cancel(o.qclass, o.qtype, o.qname); + queries[question] = nil; end end end @@ -1164,6 +1191,7 @@ end local _resolver = dns.resolver(); dns._resolver = _resolver; +_resolver.jitter, _resolver.retry_jitter = false, false; function dns.lookup(...) -- - - - - - - - - - - - - - - - - - - - - lookup return _resolver:lookup(...); diff --git a/net/http.lua b/net/http.lua index 7335d210..a13c3942 100644 --- a/net/http.lua +++ b/net/http.lua @@ -12,6 +12,8 @@ local httpstream_new = require "net.http.parser".new; local util_http = require "util.http"; local events = require "util.events"; local verify_identity = require"util.x509".verify_identity; +local promise = require "util.promise"; +local http_errors = require "net.http.errors"; local basic_resolver = require "net.resolvers.basic"; local connect = require "net.connect".connect; @@ -22,6 +24,7 @@ local t_insert, t_concat = table.insert, table.concat; local pairs = pairs; local tonumber, tostring, traceback = tonumber, tostring, debug.traceback; +local os_time = os.time; local xpcall = require "util.xpcall".xpcall; local error = error @@ -40,7 +43,7 @@ local listener = { default_port = 80, default_mode = "*a" }; local function handleerr(err) log("error", "Traceback[http]: %s", traceback(tostring(err), 2)); return err; end local function log_if_failed(req, ret, ...) if not ret then - log("error", "Request '%s': error in callback: %s", req.id, tostring((...))); + log("error", "Request '%s': error in callback: %s", req.id, (...)); if not req.suppress_errors then error(...); end @@ -81,7 +84,24 @@ local function request_reader(request, data, err) return; end + local finalize_sink; local function success_cb(r) + if r.partial then + -- Request should be streamed + log("debug", "Request '%s': partial response (%s%s)", + request.id, + r.chunked and "chunked, " or "", + r.body_length and ("%d bytes"):format(r.body_length) or "unknown length" + ); + if request.streaming_handler then + log("debug", "Request '%s': Streaming via handler"); + r.body_sink, finalize_sink = request.streaming_handler(r); + end + return; + elseif finalize_sink then + log("debug", "Request '%s': Finalizing response stream"); + finalize_sink(r); + end if request.callback then request.callback(r.body, r.code, r, request); request.callback = nil; @@ -144,13 +164,11 @@ function listener.onconnect(conn) t_insert(request_line, 4, "?"..req.query); end - conn:write(t_concat(request_line)); - local t = { [2] = ": ", [4] = "\r\n" }; for k, v in pairs(req.headers) do - t[1], t[3] = k, v; - conn:write(t_concat(t)); + t_insert(request_line, k .. ": " .. v .. "\r\n"); end - conn:write("\r\n"); + t_insert(request_line, "\r\n") + conn:write(t_concat(request_line)); if req.body then conn:write(req.body); @@ -161,7 +179,7 @@ function listener.onincoming(conn, data) local request = requests[conn]; if not request then - log("warn", "Received response from connection %s with no request attached!", tostring(conn)); + log("warn", "Received response from connection %s with no request attached!", conn); return; end @@ -202,6 +220,7 @@ local function request(self, u, ex, callback) req.url = u; req.http = self; + req.time = os_time(); if not req.path then req.path = "/"; @@ -254,6 +273,7 @@ local function request(self, u, ex, callback) end req.insecure = ex.insecure; req.suppress_errors = ex.suppress_errors; + req.streaming_handler = ex.streaming_handler; end log("debug", "Making %s %s request '%s' to %s", req.scheme:upper(), method or "GET", req.id, (ex and ex.suppress_url and host_header) or u); @@ -282,7 +302,22 @@ end local function new(options) local http = { options = options; - request = request; + request = function (self, u, ex, callback) + if callback ~= nil then + return request(self, u, ex, callback); + else + return promise.new(function (resolve, reject) + request(self, u, ex, function (body, code, a, b) + if code == 0 then + reject(http_errors.new(body, { request = a })); + else + a.request = b; + resolve(a); + end + end); + end); + end + end; new = options and function (new_options) local final_options = {}; for k, v in pairs(options) do final_options[k] = v; end @@ -297,7 +332,7 @@ local function new(options) end local default_http = new({ - sslctx = { mode = "client", protocol = "sslv23", options = { "no_sslv2", "no_sslv3" } }; + sslctx = { mode = "client", protocol = "sslv23", options = { "no_sslv2", "no_sslv3" }, alpn = "http/1.1" }; suppress_errors = true; }); diff --git a/net/http/codes.lua b/net/http/codes.lua index 8098b5c3..4327f151 100644 --- a/net/http/codes.lua +++ b/net/http/codes.lua @@ -82,5 +82,5 @@ local response_codes = { -- [512-599] = "Unassigned"; }; -for k,v in pairs(response_codes) do response_codes[k] = k.." "..v; end +for k,v in pairs(response_codes) do response_codes[k] = ("%03d %s"):format(k, v); end return setmetatable(response_codes, { __index = function(_, k) return k.." Unassigned"; end }) diff --git a/net/http/errors.lua b/net/http/errors.lua new file mode 100644 index 00000000..1691e426 --- /dev/null +++ b/net/http/errors.lua @@ -0,0 +1,119 @@ +-- This module returns a table that is suitable for use as a util.error registry, +-- and a function to return a util.error object given callback 'code' and 'body' +-- parameters. + +local codes = require "net.http.codes"; +local util_error = require "util.error"; + +local error_templates = { + -- This code is used by us to report a client-side or connection error. + -- Instead of using the code, use the supplied body text to get one of + -- the more detailed errors below. + [0] = { + code = 0, type = "cancel", condition = "internal-server-error"; + text = "Connection or internal error"; + }; + + -- These are net.http built-in errors, they are returned in + -- the body parameter when code == 0 + ["cancelled"] = { + code = 0, type = "cancel", condition = "remote-server-timeout"; + text = "Request cancelled"; + }; + ["connection-closed"] = { + code = 0, type = "wait", condition = "remote-server-timeout"; + text = "Connection closed"; + }; + ["certificate-chain-invalid"] = { + code = 0, type = "cancel", condition = "remote-server-timeout"; + text = "Server certificate not trusted"; + }; + ["certificate-verify-failed"] = { + code = 0, type = "cancel", condition = "remote-server-timeout"; + text = "Server certificate invalid"; + }; + ["connection failed"] = { + code = 0, type = "cancel", condition = "remote-server-not-found"; + text = "Connection failed"; + }; + ["invalid-url"] = { + code = 0, type = "modify", condition = "bad-request"; + text = "Invalid URL"; + }; + ["unable to resolve service"] = { + code = 0, type = "cancel", condition = "remote-server-not-found"; + text = "DNS resolution failed"; + }; + + -- This doesn't attempt to map every single HTTP code (not all have sane mappings), + -- but all the common ones should be covered. XEP-0086 was used as reference for + -- most of these. + [400] = { type = "modify", condition = "bad-request" }; + [401] = { type = "auth", condition = "not-authorized" }; + [402] = { type = "auth", condition = "payment-required" }; + [403] = { type = "auth", condition = "forbidden" }; + [404] = { type = "cancel", condition = "item-not-found" }; + [405] = { type = "cancel", condition = "not-allowed" }; + [406] = { type = "modify", condition = "not-acceptable" }; + [407] = { type = "auth", condition = "registration-required" }; + [408] = { type = "wait", condition = "remote-server-timeout" }; + [409] = { type = "cancel", condition = "conflict" }; + [410] = { type = "cancel", condition = "gone" }; + [411] = { type = "modify", condition = "bad-request" }; + [412] = { type = "cancel", condition = "conflict" }; + [413] = { type = "modify", condition = "resource-constraint" }; + [414] = { type = "modify", condition = "resource-constraint" }; + [415] = { type = "cancel", condition = "feature-not-implemented" }; + [416] = { type = "modify", condition = "bad-request" }; + + [422] = { type = "modify", condition = "bad-request" }; + [423] = { type = "wait", condition = "resource-constraint" }; + + [429] = { type = "wait", condition = "resource-constraint" }; + [431] = { type = "modify", condition = "resource-constraint" }; + [451] = { type = "auth", condition = "forbidden" }; + + [500] = { type = "wait", condition = "internal-server-error" }; + [501] = { type = "cancel", condition = "feature-not-implemented" }; + [502] = { type = "wait", condition = "remote-server-timeout" }; + [503] = { type = "cancel", condition = "service-unavailable" }; + [504] = { type = "wait", condition = "remote-server-timeout" }; + [507] = { type = "wait", condition = "resource-constraint" }; + [511] = { type = "auth", condition = "not-authorized" }; +}; + +for k, v in pairs(codes) do + if error_templates[k] then + error_templates[k].code = k; + error_templates[k].text = v; + else + error_templates[k] = { type = "cancel", condition = "undefined-condition", text = v, code = k }; + end +end + +setmetatable(error_templates, { + __index = function(_, k) + if type(k) ~= "number" then + return nil; + end + return { + type = "cancel"; + condition = "undefined-condition"; + text = codes[k] or (k.." Unassigned"); + code = k; + }; + end +}); + +local function new(code, body, context) + if code == 0 then + return util_error.new(body, context, error_templates); + else + return util_error.new(code, context, error_templates); + end +end + +return { + registry = error_templates; + new = new; +}; diff --git a/net/http/files.lua b/net/http/files.lua new file mode 100644 index 00000000..583f7514 --- /dev/null +++ b/net/http/files.lua @@ -0,0 +1,149 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local server = require"net.http.server"; +local lfs = require "lfs"; +local new_cache = require "util.cache".new; +local log = require "util.logger".init("net.http.files"); + +local os_date = os.date; +local open = io.open; +local stat = lfs.attributes; +local build_path = require"socket.url".build_path; +local path_sep = package.config:sub(1,1); + + +local forbidden_chars_pattern = "[/%z]"; +if package.config:sub(1,1) == "\\" then + forbidden_chars_pattern = "[/%z\001-\031\127\"*:<>?|]" +end + +local urldecode = require "util.http".urldecode; +local function sanitize_path(path) --> util.paths or util.http? + if not path then return end + local out = {}; + + local c = 0; + for component in path:gmatch("([^/]+)") do + component = urldecode(component); + if component:find(forbidden_chars_pattern) then + return nil; + elseif component == ".." then + if c <= 0 then + return nil; + end + out[c] = nil; + c = c - 1; + elseif component ~= "." then + c = c + 1; + out[c] = component; + end + end + if path:sub(-1,-1) == "/" then + out[c+1] = ""; + end + return "/"..table.concat(out, "/"); +end + +local function serve(opts) + if type(opts) ~= "table" then -- assume path string + opts = { path = opts }; + end + local mime_map = opts.mime_map or { html = "text/html" }; + local cache = new_cache(opts.cache_size or 256); + local cache_max_file_size = tonumber(opts.cache_max_file_size) or 1024 + -- luacheck: ignore 431 + local base_path = opts.path; + local dir_indices = opts.index_files or { "index.html", "index.htm" }; + local directory_index = opts.directory_index; + local function serve_file(event, path) + local request, response = event.request, event.response; + local sanitized_path = sanitize_path(path); + if path and not sanitized_path then + return 400; + end + path = sanitized_path; + local orig_path = sanitize_path(request.path); + local full_path = base_path .. (path or ""):gsub("/", path_sep); + local attr = stat(full_path:match("^.*[^\\/]")); -- Strip trailing path separator because Windows + if not attr then + return 404; + end + + local request_headers, response_headers = request.headers, response.headers; + + local last_modified = os_date('!%a, %d %b %Y %H:%M:%S GMT', attr.modification); + response_headers.last_modified = last_modified; + + local etag = ('"%x-%x-%x"'):format(attr.change or 0, attr.size or 0, attr.modification or 0); + response_headers.etag = etag; + + local if_none_match = request_headers.if_none_match + local if_modified_since = request_headers.if_modified_since; + if etag == if_none_match + or (not if_none_match and last_modified == if_modified_since) then + return 304; + end + + local data = cache:get(orig_path); + if data and data.etag == etag then + response_headers.content_type = data.content_type; + data = data.data; + cache:set(orig_path, data); + elseif attr.mode == "directory" and path then + if full_path:sub(-1) ~= "/" then + local dir_path = { is_absolute = true, is_directory = true }; + for dir in orig_path:gmatch("[^/]+") do dir_path[#dir_path+1]=dir; end + response_headers.location = build_path(dir_path); + return 301; + end + for i=1,#dir_indices do + if stat(full_path..dir_indices[i], "mode") == "file" then + return serve_file(event, path..dir_indices[i]); + end + end + + if directory_index then + data = server._events.fire_event("directory-index", { path = request.path, full_path = full_path }); + end + if not data then + return 403; + end + cache:set(orig_path, { data = data, content_type = mime_map.html; etag = etag; }); + response_headers.content_type = mime_map.html; + + else + local f, err = open(full_path, "rb"); + if not f then + log("debug", "Could not open %s. Error was %s", full_path, err); + return 403; + end + local ext = full_path:match("%.([^./]+)$"); + local content_type = ext and mime_map[ext]; + response_headers.content_type = content_type; + if attr.size > cache_max_file_size then + response_headers.content_length = ("%d"):format(attr.size); + log("debug", "%d > cache_max_file_size", attr.size); + return response:send_file(f); + else + data = f:read("*a"); + f:close(); + end + cache:set(orig_path, { data = data; content_type = content_type; etag = etag }); + end + + return response:send(data); + end + + return serve_file; +end + +return { + serve = serve; +} + diff --git a/net/http/parser.lua b/net/http/parser.lua index 4e4ae9fb..7fbade4c 100644 --- a/net/http/parser.lua +++ b/net/http/parser.lua @@ -1,8 +1,8 @@ local tonumber = tonumber; local assert = assert; -local t_insert, t_concat = table.insert, table.concat; local url_parse = require "socket.url".parse; local urldecode = require "util.http".urldecode; +local dbuffer = require "util.dbuffer"; local function preprocess_path(path) path = urldecode((path:gsub("//+", "/"))); @@ -28,10 +28,13 @@ local httpstream = {}; function httpstream.new(success_cb, error_cb, parser_type, options_cb) local client = true; if not parser_type or parser_type == "server" then client = false; else assert(parser_type == "client", "Invalid parser type"); end - local buf, buflen, buftable = {}, 0, true; local bodylimit = tonumber(options_cb and options_cb().body_size_limit) or 10*1024*1024; + -- https://stackoverflow.com/a/686243 + -- Individual headers can be up to 16k? What madness? + local headlimit = tonumber(options_cb and options_cb().head_size_limit) or 10*1024; local buflimit = tonumber(options_cb and options_cb().buffer_size_limit) or bodylimit * 2; - local chunked, chunk_size, chunk_start; + local buffer = dbuffer.new(buflimit); + local chunked; local state = nil; local packet; local len; @@ -41,32 +44,27 @@ function httpstream.new(success_cb, error_cb, parser_type, options_cb) feed = function(_, data) if error then return nil, "parse has failed"; end if not data then -- EOF - if buftable then buf, buftable = t_concat(buf), false; end if state and client and not len then -- reading client body until EOF - packet.body = buf; + buffer:collapse(); + packet.body = buffer:read_chunk() or ""; + packet.partial = nil; success_cb(packet); - elseif buf ~= "" then -- unexpected EOF + state = nil; + elseif buffer:length() ~= 0 then -- unexpected EOF error = true; return error_cb("unexpected-eof"); end return; end - if buftable then - t_insert(buf, data); - else - buf = { buf, data }; - buftable = true; - end - buflen = buflen + #data; - if buflen > buflimit then error = true; return error_cb("max-buffer-size-exceeded"); end - while buflen > 0 do + if not buffer:write(data) then error = true; return error_cb("max-buffer-size-exceeded"); end + while buffer:length() > 0 do if state == nil then -- read request - if buftable then buf, buftable = t_concat(buf), false; end - local index = buf:find("\r\n\r\n", nil, true); + local index = buffer:sub(1, headlimit):find("\r\n\r\n", nil, true); if not index then return; end -- not enough data - local method, path, httpversion, status_code, reason_phrase; + -- FIXME was reason_phrase meant to be passed on somewhere? + local method, path, httpversion, status_code, reason_phrase; -- luacheck: ignore reason_phrase local first_line; local headers = {}; - for line in buf:sub(1,index+1):gmatch("([^\r\n]+)\r\n") do -- parse request + for line in buffer:read(index+3):gmatch("([^\r\n]+)\r\n") do -- parse request if first_line then local key, val = line:match("^([^%s:]+): *(.*)$"); if not key then error = true; return error_cb("invalid-header-line"); end -- TODO handle multi-line and invalid headers @@ -91,7 +89,6 @@ function httpstream.new(success_cb, error_cb, parser_type, options_cb) if not first_line then error = true; return error_cb("invalid-status-line"); end chunked = have_body and headers["transfer-encoding"] == "chunked"; len = tonumber(headers["content-length"]); -- TODO check for invalid len - if len and len > bodylimit then error = true; return error_cb("content-length-limit-exceeded"); end if client then -- FIXME handle '100 Continue' response (by skipping it) if not have_body then len = 0; end @@ -99,7 +96,10 @@ function httpstream.new(success_cb, error_cb, parser_type, options_cb) code = status_code; httpversion = httpversion; headers = headers; - body = have_body and "" or nil; + body = false; + body_length = len; + chunked = chunked; + partial = true; -- COMPAT the properties below are deprecated responseversion = httpversion; responseheaders = headers; @@ -124,60 +124,81 @@ function httpstream.new(success_cb, error_cb, parser_type, options_cb) path = path; httpversion = httpversion; headers = headers; - body = nil; + body = false; + body_sink = nil; + chunked = chunked; + partial = true; }; end - buf = buf:sub(index + 4); - buflen = #buf; + if len and len > bodylimit then + -- Early notification, for redirection + success_cb(packet); + if not packet.body_sink then error = true; return error_cb("content-length-limit-exceeded"); end + end + if chunked and not packet.body_sink then + success_cb(packet); + if not packet.body_sink then + packet.body_buffer = dbuffer.new(buflimit); + end + end state = true; end if state then -- read body - if client then - if chunked then - if chunk_start and buflen - chunk_start - 2 < chunk_size then - return; - end -- not enough data - if buftable then buf, buftable = t_concat(buf), false; end - if not buf:find("\r\n", nil, true) then - return; - end -- not enough data - if not chunk_size then - chunk_size, chunk_start = buf:match("^(%x+)[^\r\n]*\r\n()"); - chunk_size = chunk_size and tonumber(chunk_size, 16); - if not chunk_size then error = true; return error_cb("invalid-chunk-size"); end - end - if chunk_size == 0 and buf:find("\r\n\r\n", chunk_start-2, true) then - state, chunk_size = nil, nil; - buf = buf:gsub("^.-\r\n\r\n", ""); -- This ensure extensions and trailers are stripped - success_cb(packet); - elseif buflen - chunk_start - 2 >= chunk_size then -- we have a chunk - packet.body = packet.body..buf:sub(chunk_start, chunk_start + (chunk_size-1)); - buf = buf:sub(chunk_start + chunk_size + 2); - buflen = buflen - (chunk_start + chunk_size + 2 - 1); - chunk_size, chunk_start = nil, nil; - else -- Partial chunk remaining - break; + if chunked then + local chunk_header = buffer:sub(1, 512); -- XXX How large do chunk headers grow? + local chunk_size, chunk_start = chunk_header:match("^(%x+)[^\r\n]*\r\n()"); + if not chunk_size then return; end + chunk_size = chunk_size and tonumber(chunk_size, 16); + if not chunk_size then error = true; return error_cb("invalid-chunk-size"); end + if chunk_size == 0 and chunk_header:find("\r\n\r\n", chunk_start-2, true) then + local body_buffer = packet.body_buffer; + if body_buffer then + packet.body_buffer = nil; + body_buffer:collapse(); + packet.body = body_buffer:read_chunk() or ""; end - elseif len and buflen >= len then - if buftable then buf, buftable = t_concat(buf), false; end - if packet.code == 101 then - packet.body, buf, buflen, buftable = buf, {}, 0, true; + + buffer:collapse(); + local buf = buffer:read_chunk(); + buf = buf:gsub("^.-\r\n\r\n", ""); -- This ensure extensions and trailers are stripped + buffer:write(buf); + state, chunked = nil, nil; + packet.partial = nil; + success_cb(packet); + elseif buffer:length() - chunk_start - 2 >= chunk_size then -- we have a chunk + buffer:discard(chunk_start - 1); -- TODO verify that it's not off-by-one + (packet.body_sink or packet.body_buffer):write(buffer:read(chunk_size)); + buffer:discard(2); -- CRLF + else -- Partial chunk remaining + break; + end + elseif packet.body_sink then + local chunk = buffer:read_chunk(len); + while chunk and len > 0 do + if packet.body_sink:write(chunk) then + len = len - #chunk; + chunk = buffer:read_chunk(len); else - packet.body, buf = buf:sub(1, len), buf:sub(len + 1); - buflen = #buf; + error = true; + return error_cb("body-sink-write-failure"); end - state = nil; success_cb(packet); - else - break; end - elseif buflen >= len then - if buftable then buf, buftable = t_concat(buf), false; end - packet.body, buf = buf:sub(1, len), buf:sub(len + 1); - buflen = #buf; - state = nil; success_cb(packet); + if len == 0 then + state = nil; + packet.partial = nil; + success_cb(packet); + end + elseif buffer:length() >= len then + assert(not chunked) + packet.body = buffer:read(len) or ""; + state = nil; + packet.partial = nil; + success_cb(packet); else break; end + else + break; end end end; diff --git a/net/http/server.lua b/net/http/server.lua index 3873bbe0..97e15e42 100644 --- a/net/http/server.lua +++ b/net/http/server.lua @@ -1,5 +1,5 @@ -local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat; +local t_insert, t_concat = table.insert, table.concat; local parser_new = require "net.http.parser".new; local events = require "util.events".new(); local addserver = require "net.server".addserver; @@ -8,12 +8,12 @@ local os_date = os.date; local pairs = pairs; local s_upper = string.upper; local setmetatable = setmetatable; -local xpcall = require "util.xpcall".xpcall; -local traceback = debug.traceback; -local tostring = tostring; local cache = require "util.cache"; local codes = require "net.http.codes"; +local promise = require "util.promise"; +local errors = require "util.error"; local blocksize = 2^16; +local async = require "util.async"; local _M = {}; @@ -89,51 +89,60 @@ setmetatable(events._handlers, { local handle_request; -local last_err; -local function _traceback_handler(err) last_err = err; log("error", "Traceback[httpserver]: %s", traceback(tostring(err), 2)); end events.add_handler("http-error", function (error) return "Error processing request: "..codes[error.code]..". Check your error log for more information."; end, -1); +local runner_callbacks = {}; + +function runner_callbacks:ready() + self.data.conn:resume(); +end + +function runner_callbacks:waiting() + self.data.conn:pause(); +end + +function runner_callbacks:error(err) + log("error", "Traceback[httpserver]: %s", err); + self.data.conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = err })); + self.data.conn:close(); +end + +local function noop() end function listener.onconnect(conn) + local session = { conn = conn }; local secure = conn:ssl() and true or nil; - local pending = {}; - local waiting = false; - local function process_next() - if waiting then return; end -- log("debug", "can't process_next, waiting"); - waiting = true; - while sessions[conn] and #pending > 0 do - local request = t_remove(pending); - --log("debug", "process_next: %s", request.path); - if not xpcall(handle_request, _traceback_handler, conn, request, process_next) then - conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = last_err })); - conn:close(); - end + local ip = conn:ip(); + session.thread = async.runner(function (request) + local wait, done; + if request.partial == true then + -- Have the header for a request, we want to receive the rest + -- when we've decided where the data should go. + wait, done = noop, noop; + else -- Got the entire request + -- Hold off on receiving more incoming requests until this one has been handled. + wait, done = async.waiter(); end - --log("debug", "ready for more"); - waiting = false; - end + handle_request(conn, request, done); wait(); + end, runner_callbacks, session); local function success_cb(request) --log("debug", "success_cb: %s", request.path); - if waiting then - log("error", "http connection handler is not reentrant: %s", request.path); - assert(false, "http connection handler is not reentrant"); - end + request.ip = ip; request.secure = secure; - t_insert(pending, request); - process_next(); + session.thread:run(request); end local function error_cb(err) log("debug", "error_cb: %s", err or "<nil>"); -- FIXME don't close immediately, wait until we process current stuff -- FIXME if err, send off a bad-request response - sessions[conn] = nil; conn:close(); end local function options_cb() return options; end - sessions[conn] = parser_new(success_cb, error_cb, "server", options_cb); + session.parser = parser_new(success_cb, error_cb, "server", options_cb); + sessions[conn] = session; end function listener.ondisconnect(conn) @@ -152,7 +161,7 @@ function listener.ondetach(conn) end function listener.onincoming(conn, data) - sessions[conn]:feed(data); + sessions[conn].parser:feed(data); end function listener.ondrain(conn) @@ -170,6 +179,49 @@ local headerfix = setmetatable({}, { end }); +local function handle_result(request, response, result) + if result == nil then + result = 404; + end + + if result == true then + return; + end + + local body; + local result_type = type(result); + if result_type == "number" then + response.status_code = result; + if result >= 400 then + body = events.fire_event("http-error", { request = request, response = response, code = result }); + end + elseif result_type == "string" then + body = result; + elseif errors.is_err(result) then + response.status_code = result.code or 500; + body = events.fire_event("http-error", { request = request, response = response, code = result.code or 500, error = result }); + elseif promise.is_promise(result) then + result:next(function (ret) + handle_result(request, response, ret); + end, function (err) + response.status_code = 500; + handle_result(request, response, err or 500); + end); + return true; + elseif result_type == "table" then + for k, v in pairs(result) do + if k ~= "headers" then + response[k] = v; + else + for header_name, header_value in pairs(v) do + response.headers[header_name] = header_value; + end + end + end + end + return response:send(body); +end + function _M.hijack_response(response, listener) -- luacheck: ignore error("TODO"); end @@ -194,13 +246,17 @@ function handle_request(conn, request, finish_cb) response_conn_header = httpversion == "1.1" and "close" or nil end + local is_head_request = request.method == "HEAD"; + local response = { request = request; + is_head_request = is_head_request; status_code = 200; headers = { date = date_header, connection = response_conn_header }; persistent = persistent; conn = conn; send = _M.send_response; + write_headers = _M.write_headers; send_file = _M.send_file; done = _M.finish_response; finish_cb = finish_cb; @@ -227,6 +283,11 @@ function handle_request(conn, request, finish_cb) local payload = { request = request, response = response }; log("debug", "Firing event: %s", global_event); local result = events.fire_event(global_event, payload); + if result == nil and is_head_request then + local global_head_event = "GET "..request.path:match("[^?]*"); + log("debug", "Firing event: %s", global_head_event); + result = events.fire_event(global_head_event, payload); + end if result == nil then if not hosts[host] then if hosts[default_host] then @@ -247,40 +308,17 @@ function handle_request(conn, request, finish_cb) local host_event = request.method.." "..host..request.path:match("[^?]*"); log("debug", "Firing event: %s", host_event); result = events.fire_event(host_event, payload); - end - if result ~= nil then - if result ~= true then - local body; - local result_type = type(result); - if result_type == "number" then - response.status_code = result; - if result >= 400 then - payload.code = result; - body = events.fire_event("http-error", payload); - end - elseif result_type == "string" then - body = result; - elseif result_type == "table" then - for k, v in pairs(result) do - if k ~= "headers" then - response[k] = v; - else - for header_name, header_value in pairs(v) do - response.headers[header_name] = header_value; - end - end - end - end - response:send(body); + + if result == nil and is_head_request then + local host_head_event = "GET "..host..request.path:match("[^?]*"); + log("debug", "Firing event: %s", host_head_event); + result = events.fire_event(host_head_event, payload); end - return; end - -- if handler not called, return 404 - response.status_code = 404; - payload.code = 404; - response:send(events.fire_event("http-error", payload)); + return handle_result(request, response, result); end + local function prepare_header(response) local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]); local headers = response.headers; @@ -292,12 +330,25 @@ local function prepare_header(response) return output; end _M.prepare_header = prepare_header; +function _M.write_headers(response) + if response.finished then return; end + local output = prepare_header(response); + response.conn:write(t_concat(output)); +end +function _M.send_head_response(response) + if response.finished then return; end + _M.write_headers(response); + response:done(); +end function _M.send_response(response, body) if response.finished then return; end body = body or response.body or ""; -- Per RFC 7230, informational (1xx) and 204 (no content) should have no c-l header if response.status_code > 199 and response.status_code ~= 204 then - response.headers.content_length = #body; + response.headers.content_length = ("%d"):format(#body); + end + if response.is_head_request then + return _M.send_head_response(response) end local output = prepare_header(response); t_insert(output, body); @@ -305,6 +356,10 @@ function _M.send_response(response, body) response:done(); end function _M.send_file(response, f) + if response.is_head_request then + if f.close then f:close(); end + return _M.send_head_response(response); + end if response.finished then return; end local chunked = not response.headers.content_length; if chunked then response.headers.transfer_encoding = "chunked"; end @@ -331,7 +386,7 @@ function _M.send_file(response, f) return response:done(); end end - response.conn:write(t_concat(prepare_header(response))); + _M.write_headers(response); return true; end function _M.finish_response(response) diff --git a/net/resolvers/basic.lua b/net/resolvers/basic.lua index 867ccf60..34f1e1c7 100644 --- a/net/resolvers/basic.lua +++ b/net/resolvers/basic.lua @@ -2,10 +2,13 @@ local adns = require "net.adns"; local inet_pton = require "util.net".pton; local inet_ntop = require "util.net".ntop; local idna_to_ascii = require "util.encodings".idna.to_ascii; +local unpack = table.unpack or unpack; -- luacheck: ignore 113 local methods = {}; local resolver_mt = { __index = methods }; +-- FIXME RFC 6724 + -- Find the next target to connect to, and -- pass it to cb() function methods:next(cb) @@ -25,34 +28,70 @@ function methods:next(cb) return; end + local secure = true; + local tlsa = {}; local targets = {}; - local n = 2; + local n = 3; local function ready() n = n - 1; if n > 0 then return; end self.targets = targets; + if self.extra and self.extra.use_dane then + if secure and tlsa[1] then + self.extra.tlsa = tlsa; + self.extra.dane_hostname = self.hostname; + else + self.extra.tlsa = nil; + self.extra.dane_hostname = nil; + end + end self:next(cb); end -- Resolve DNS to target list local dns_resolver = adns.resolver(); - dns_resolver:lookup(function (answer) - if answer then - for _, record in ipairs(answer) do - table.insert(targets, { self.conn_type.."4", record.a, self.port, self.extra }); + + if not self.extra or self.extra.use_ipv4 ~= false then + dns_resolver:lookup(function (answer) + if answer then + secure = secure and answer.secure; + for _, record in ipairs(answer) do + table.insert(targets, { self.conn_type.."4", record.a, self.port, self.extra }); + end end - end + ready(); + end, self.hostname, "A", "IN"); + else + ready(); + end + + if not self.extra or self.extra.use_ipv6 ~= false then + dns_resolver:lookup(function (answer) + if answer then + secure = secure and answer.secure; + for _, record in ipairs(answer) do + table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra }); + end + end + ready(); + end, self.hostname, "AAAA", "IN"); + else ready(); - end, self.hostname, "A", "IN"); + end - dns_resolver:lookup(function (answer) - if answer then - for _, record in ipairs(answer) do - table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra }); + if self.extra and self.extra.use_dane == true then + dns_resolver:lookup(function (answer) + if answer then + secure = secure and answer.secure; + for _, record in ipairs(answer) do + table.insert(tlsa, record.tlsa); + end end - end + ready(); + end, ("_%d._tcp.%s"):format(self.port, self.hostname), "TLSA", "IN"); + else ready(); - end, self.hostname, "AAAA", "IN"); + end end local function new(hostname, port, conn_type, extra) diff --git a/net/resolvers/manual.lua b/net/resolvers/manual.lua index c0d4e5d5..dbc40256 100644 --- a/net/resolvers/manual.lua +++ b/net/resolvers/manual.lua @@ -1,5 +1,6 @@ local methods = {}; local resolver_mt = { __index = methods }; +local unpack = table.unpack or unpack; -- luacheck: ignore 113 -- Find the next target to connect to, and -- pass it to cb() diff --git a/net/resolvers/service.lua b/net/resolvers/service.lua index 34f14cba..204c8a7f 100644 --- a/net/resolvers/service.lua +++ b/net/resolvers/service.lua @@ -1,6 +1,8 @@ local adns = require "net.adns"; local basic = require "net.resolvers.basic"; +local inet_pton = require "util.net".pton; local idna_to_ascii = require "util.encodings".idna.to_ascii; +local unpack = table.unpack or unpack; -- luacheck: ignore 113 local methods = {}; local resolver_mt = { __index = methods }; @@ -9,14 +11,17 @@ local resolver_mt = { __index = methods }; -- pass it to cb() function methods:next(cb) if self.targets then - if #self.targets == 0 then - cb(nil); - return; + if not self.resolver then + if #self.targets == 0 then + cb(nil); + return; + end + local next_target = table.remove(self.targets, 1); + self.resolver = basic.new(unpack(next_target, 1, 4)); end - local next_target = table.remove(self.targets, 1); - self.resolver = basic.new(unpack(next_target, 1, 4)); self.resolver:next(function (...) if ... == nil then + self.resolver = nil; self:next(cb); else cb(...); @@ -39,8 +44,16 @@ function methods:next(cb) -- Resolve DNS to target list local dns_resolver = adns.resolver(); - dns_resolver:lookup(function (answer) + dns_resolver:lookup(function (answer, err) + if not answer and not err then + -- net.adns returns nil if there are zero records or nxdomain + answer = {}; + end if answer then + if self.extra and not answer.secure then + self.extra.use_dane = false; + end + if #answer == 0 then if self.extra and self.extra.default_port then table.insert(targets, { self.hostname, self.extra.default_port, self.conn_type, self.extra }); @@ -64,6 +77,14 @@ function methods:next(cb) end local function new(hostname, service, conn_type, extra) + local is_ip = inet_pton(hostname); + if not is_ip and hostname:sub(1,1) == '[' then + is_ip = inet_pton(hostname:sub(2,-2)); + end + if is_ip and extra and extra.default_port then + return basic.new(hostname, extra.default_port, conn_type, extra); + end + return setmetatable({ hostname = idna_to_ascii(hostname); service = service; diff --git a/net/server.lua b/net/server.lua index abbb421d..f5666594 100644 --- a/net/server.lua +++ b/net/server.lua @@ -13,7 +13,11 @@ if not (prosody and prosody.config_loaded) then end local log = require "util.logger".init("net.server"); -local server_type = require "core.configmanager".get("*", "network_backend") or "select"; + +local have_util_poll = pcall(require, "util.poll"); +local default_backend = have_util_poll and "epoll" or "select"; + +local server_type = require "core.configmanager".get("*", "network_backend") or default_backend; if require "core.configmanager".get("*", "use_libevent") then server_type = "event"; diff --git a/net/server_epoll.lua b/net/server_epoll.lua index 53a67dd5..eb2d0d77 100644 --- a/net/server_epoll.lua +++ b/net/server_epoll.lua @@ -9,20 +9,25 @@ local t_insert = table.insert; local t_concat = table.concat; local setmetatable = setmetatable; -local tostring = tostring; local pcall = pcall; local type = type; local next = next; local pairs = pairs; -local log = require "util.logger".init("server_epoll"); +local ipairs = ipairs; +local traceback = debug.traceback; +local logger = require "util.logger"; +local log = logger.init("server_epoll"); local socket = require "socket"; local luasec = require "ssl"; -local gettime = require "util.time".now; +local realtime = require "util.time".now; +local monotonic = require "util.time".monotonic; local indexedbheap = require "util.indexedbheap"; local createtable = require "util.table".create; local inet = require "util.net"; local inet_pton = inet.pton; local _SOCKETINVALID = socket._SOCKETINVALID or -1; +local new_id = require "util.id".medium; +local xpcall = require "util.xpcall".xpcall; local poller = require "util.poll" local EEXIST = poller.EEXIST; @@ -38,7 +43,10 @@ local default_config = { __index = { read_timeout = 14 * 60; -- How long to wait for a socket to become writable after queuing data to send - send_timeout = 60; + send_timeout = 180; + + -- How long to wait for a socket to become writable after creation + connect_timeout = 20; -- Some number possibly influencing how many pending connections can be accepted tcp_backlog = 128; @@ -46,7 +54,7 @@ local default_config = { __index = { -- If accepting a new incoming connection fails, wait this long before trying again accept_retry_interval = 10; - -- If there is still more data to read from LuaSocktes buffer, wait this long and read again + -- If there is still more data to read from LuaSockets buffer, wait this long and read again read_retry_delay = 1e-06; -- Size of chunks to read from sockets @@ -57,7 +65,30 @@ local default_config = { __index = { -- Maximum and minimum amount of time to sleep waiting for events (adjusted for pending timers) max_wait = 86400; - min_wait = 1e-06; + min_wait = 0.001; + + -- Enable extra noisy debug logging + -- TODO disable once considered stable + verbose = true; + + -- EXPERIMENTAL + -- Whether to kill connections in case of callback errors. + fatal_errors = false; + + -- Or disable protection (like server_select) for potential performance gains + protect_listeners = true; + + -- Attempt writes instantly + opportunistic_writes = false; + + -- TCP Keepalives + tcp_keepalive = false; -- boolean | number + + -- Whether to let the Nagle algorithm stay enabled + nagle = true; + + -- Reuse write buffer tables + keep_buffers = true; }}; local cfg = default_config.__index; @@ -68,54 +99,56 @@ local fds = createtable(10, 0); -- FD -> conn local timers = indexedbheap.create(); local function noop() end -local function closetimer(t) - t[1] = 0; - t[2] = noop; - timers:remove(t.id); +local function closetimer(id) + timers:remove(id); end -local function reschedule(t, time) - t[1] = time; - timers:reprioritize(t.id, time); -end - --- Add absolute timer -local function at(time, f) - local timer = { time, f, close = closetimer, reschedule = reschedule, id = nil }; - timer.id = timers:insert(timer, time); - return timer; +local function reschedule(id, time) + time = monotonic() + time; + timers:reprioritize(id, time); end -- Add relative timer -local function addtimer(timeout, f) - return at(gettime() + timeout, f); +local function addtimer(timeout, f, param) + local time = monotonic() + timeout; + if param ~= nil then + local timer_callback = f + function f(current_time, timer_id) + local t = timer_callback(current_time, timer_id, param) + return t; + end + end + local id = timers:insert(f, time); + return id; end -- Run callbacks of expired timers -- Return time until next timeout local function runtimers(next_delay, min_wait) -- Any timers at all? - local now = gettime(); + local elapsed = monotonic(); + local now = realtime(); local peek = timers:peek(); local readd; while peek do - if peek > now then + if peek > elapsed then break; end local _, timer, id = timers:pop(); - local ok, ret = pcall(timer[2], now); + local ok, ret = xpcall(timer, traceback, now, id); if ok and type(ret) == "number" then - local next_time = now+ret; - timer[1] = next_time; + local next_time = elapsed+ret; -- Delay insertion of timers to be re-added -- so they don't get called again this tick if readd then - readd[id] = timer; + readd[id] = { timer, next_time }; else - readd = { [id] = timer }; + readd = { [id] = { timer, next_time } }; end + elseif not ok then + log("error", "Error in timer: %s", ret); end peek = timers:peek(); @@ -123,7 +156,7 @@ local function runtimers(next_delay, min_wait) if readd then for _, timer in pairs(readd) do - timers:insert(timer, timer[1]); + timers:insert(timer[1], timer[2]); end peek = timers:peek(); end @@ -131,7 +164,7 @@ local function runtimers(next_delay, min_wait) if peek == nil then return next_delay; else - next_delay = peek - now; + next_delay = peek - elapsed; end if next_delay < min_wait then @@ -154,6 +187,22 @@ function interface_mt:__tostring() return ("FD %d"):format(self:getfd()); end +interface.log = log; +function interface:debug(msg, ...) + self.log("debug", msg, ...); +end + +interface.noise = interface.debug; +function interface:noise(msg, ...) + if cfg.verbose then + return self:debug(msg, ...); + end +end + +function interface:error(msg, ...) + self.log("error", msg, ...); +end + -- Replace the listener and tell the old one function interface:setlistener(listeners, data) self:on("detach"); @@ -164,21 +213,36 @@ end -- Call a listener callback function interface:on(what, ...) if not self.listeners then - log("error", "%s has no listeners", self); + self:error("Interface is missing listener callbacks"); return; end local listener = self.listeners["on"..what]; if not listener then - -- log("debug", "Missing listener 'on%s'", what); -- uncomment for development and debugging + self:noise("Missing listener 'on%s'", what); -- uncomment for development and debugging return; end - local ok, err = pcall(listener, self, ...); + if not cfg.protect_listeners then + return listener(self, ...); + end + local onerror = self.listeners.onerror or traceback; + local ok, err = xpcall(listener, onerror, self, ...); if not ok then - log("error", "Error calling on%s: %s", what, err); + if cfg.fatal_errors then + self:error("Closing due to error calling on%s: %s", what, err); + self:destroy(); + else + self:error("Error calling on%s: %s", what, err); + end + return nil, err; end return err; end +-- Allow this one to be overridden +function interface:onincoming(...) + return self:on("incoming", ...); +end + -- Return the file descriptor number function interface:getfd() if self.conn then @@ -226,28 +290,36 @@ end function interface:setoption(k, v) -- LuaSec doesn't expose setoption :( - if self.conn.setoption then - self.conn:setoption(k, v); + local ok, ret, err = pcall(self.conn.setoption, self.conn, k, v); + if not ok then + self:noise("Setting option %q = %q failed: %s", k, v, ret); + return ok, ret; + elseif not ret then + self:noise("Setting option %q = %q failed: %s", k, v, err); + return ret, err; end + return ret; end -- Timeout for detecting dead or idle sockets function interface:setreadtimeout(t) if t == false then if self._readtimeout then - self._readtimeout:close(); + closetimer(self._readtimeout); self._readtimeout = nil; end return end t = t or cfg.read_timeout; if self._readtimeout then - self._readtimeout:reschedule(gettime() + t); + reschedule(self._readtimeout, t); else self._readtimeout = addtimer(t, function () if self:on("readtimeout") then + self:noise("Read timeout handled"); return cfg.read_timeout; else + self:debug("Read timeout not handled, disconnecting"); self:on("disconnect", "read timeout"); self:destroy(); end @@ -259,17 +331,18 @@ end function interface:setwritetimeout(t) if t == false then if self._writetimeout then - self._writetimeout:close(); + closetimer(self._writetimeout); self._writetimeout = nil; end return end t = t or cfg.send_timeout; if self._writetimeout then - self._writetimeout:reschedule(gettime() + t); + reschedule(self._writetimeout, t); else self._writetimeout = addtimer(t, function () - self:on("disconnect", "write timeout"); + self:noise("Write timeout"); + self:on("disconnect", self._connected and "write timeout" or "connection timeout"); self:destroy(); end); end @@ -285,15 +358,15 @@ function interface:add(r, w) local ok, err, errno = poll:add(fd, r, w); if not ok then if errno == EEXIST then - log("debug", "%s already registered!", self); + self:debug("FD already registered in poller! (EEXIST)"); return self:set(r, w); -- So try to change its flags end - log("error", "Could not register %s: %s(%d)", self, err, errno); + self:debug("Could not register in poller: %s(%d)", err, errno); return ok, err; end self._wantread, self._wantwrite = r, w; fds[fd] = self; - log("debug", "Watching %s", self); + self:noise("Registered in poller"); return true; end @@ -306,7 +379,7 @@ function interface:set(r, w) if w == nil then w = self._wantwrite; end local ok, err, errno = poll:set(fd, r, w); if not ok then - log("error", "Could not update poller state %s: %s(%d)", self, err, errno); + self:debug("Could not update poller state: %s(%d)", err, errno); return ok, err; end self._wantread, self._wantwrite = r, w; @@ -323,12 +396,12 @@ function interface:del() end local ok, err, errno = poll:del(fd); if not ok and errno ~= ENOENT then - log("error", "Could not unregister %s: %s(%d)", self, err, errno); + self:debug("Could not unregister: %s(%d)", err, errno); return ok, err; end self._wantread, self._wantwrite = nil, nil; fds[fd] = nil; - log("debug", "Unwatched %s", self); + self:noise("Unregistered from poller"); return true; end @@ -350,27 +423,44 @@ function interface:onreadable() local data, err, partial = self.conn:receive(self.read_size or cfg.read_size); if data then self:onconnect(); - self:on("incoming", data); + self:onincoming(data); else if err == "wantread" then self:set(true, nil); err = "timeout"; elseif err == "wantwrite" then self:set(nil, true); + self:setwritetimeout(); err = "timeout"; + elseif err == "timeout" and not self._connected then + err = "connection timeout"; end if partial and partial ~= "" then self:onconnect(); - self:on("incoming", partial, err); + self:onincoming(partial, err); end if err ~= "timeout" then + if err == "closed" then + self:debug("Connection closed by remote"); + else + self:debug("Read error, closing (%s)", err); + end self:on("disconnect", err); self:destroy() return; end end if not self.conn then return; end - if self._wantread and self.conn:dirty() then + if self._limit and (data or partial) then + local cost = self._limit * #(data or partial); + if cost > cfg.min_wait then + self:setreadtimeout(false); + self:pausefor(cost); + return; + end + end + if not self._wantread then return end + if self.conn:dirty() then self:setreadtimeout(false); self:pausefor(cfg.read_retry_delay); else @@ -382,31 +472,55 @@ end function interface:onwritable() self:onconnect(); if not self.conn then return; end -- could have been closed in onconnect + self:on("predrain"); local buffer = self.writebuffer; - local data = t_concat(buffer); + local data = buffer or ""; + if type(buffer) == "table" then + if buffer[3] then + data = t_concat(data); + elseif buffer[2] then + data = buffer[1] .. buffer[2]; + else + data = buffer[1] or ""; + end + end local ok, err, partial = self.conn:send(data); + self._writable = ok; if ok then self:set(nil, false); - for i = #buffer, 1, -1 do - buffer[i] = nil; + if cfg.keep_buffers and type(buffer) == "table" then + for i = #buffer, 1, -1 do + buffer[i] = nil; + end + else + self.writebuffer = nil; end self:setwritetimeout(false); self:ondrain(); -- Be aware of writes in ondrain - return; + return ok; elseif partial then - buffer[1] = data:sub(partial+1); - for i = #buffer, 2, -1 do - buffer[i] = nil; + self:debug("Sent %d out of %d buffered bytes", partial, #data); + if cfg.keep_buffers and type(buffer) == "table" then + buffer[1] = data:sub(partial+1); + for i = #buffer, 2, -1 do + buffer[i] = nil; + end + else + self.writebuffer = data:sub(partial+1); end + self:set(nil, true); self:setwritetimeout(); end if err == "wantwrite" or err == "timeout" then self:set(nil, true); + self:setwritetimeout(); elseif err == "wantread" then self:set(true, nil); + self:setreadtimeout(); elseif err ~= "timeout" then self:on("disconnect", err); self:destroy(); + return ok, err; end end @@ -418,26 +532,39 @@ end -- Add data to write buffer and set flag for wanting to write function interface:write(data) local buffer = self.writebuffer; - if buffer then + if type(buffer) == "table" then t_insert(buffer, data); - else - self.writebuffer = { data }; + elseif type(buffer) == "string" then + self:noise("Allocating buffer!") + self.writebuffer = { buffer, data }; + elseif buffer == nil then + self.writebuffer = data; + end + if not self._write_lock then + if self._writable and cfg.opportunistic_writes and not self._opportunistic_write then + self._opportunistic_write = true; + local ret, err = self:onwritable(); + self._opportunistic_write = nil; + return ret, err; + end + self:setwritetimeout(); + self:set(nil, true); end - self:setwritetimeout(); - self:set(nil, true); return #data; end interface.send = interface.write; -- Close, possibly after writing is done function interface:close() - if self.writebuffer and self.writebuffer[1] then + if self.writebuffer and (self.writebuffer[1] or type(self.writebuffer) == "string") then self:set(false, true); -- Flush final buffer contents + self:setreadtimeout(false); + self:setwritetimeout(); self.write, self.send = noop, noop; -- No more writing - log("debug", "Close %s after writing", self); + self:debug("Close after writing remaining buffered data"); self.ondrain = interface.close; else - log("debug", "Close %s now", self); + self:debug("Closing now"); self.write, self.send = noop, noop; self.close = noop; self:on("disconnect"); @@ -462,70 +589,108 @@ function interface:ssl() return self._tls; end +function interface:set_sslctx(sslctx) + self._sslctx = sslctx; +end + function interface:starttls(tls_ctx) if tls_ctx then self.tls_ctx = tls_ctx; end self.starttls = false; - if self.writebuffer and self.writebuffer[1] then - log("debug", "Start TLS on %s after write", self); + if self.writebuffer and (self.writebuffer[1] or type(self.writebuffer) == "string") then + self:debug("Start TLS after write"); self.ondrain = interface.starttls; self:set(nil, true); -- make sure wantwrite is set else if self.ondrain == interface.starttls then self.ondrain = nil; end - self.onwritable = interface.tlshandskake; - self.onreadable = interface.tlshandskake; + self.onwritable = interface.inittls; + self.onreadable = interface.inittls; self:set(true, true); - log("debug", "Prepare to start TLS on %s", self); + self:setreadtimeout(false); + self:setwritetimeout(cfg.ssl_handshake_timeout); + self:debug("Prepared to start TLS"); end end -function interface:tlshandskake() - self:setwritetimeout(false); - self:setreadtimeout(false); - if not self._tls then - self._tls = true; - log("debug", "Start TLS on %s now", self); - self:del(); - local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx); - if not ok then - conn, err = ok, conn; - log("error", "Failed to initialize TLS: %s", err); - end - if not conn then - self:on("disconnect", err); - self:destroy(); - return conn, err; - end - conn:settimeout(0); - self.conn = conn; - if conn.sni and self.servername then +function interface:inittls(tls_ctx, now) + if self._tls then return end + if tls_ctx then self.tls_ctx = tls_ctx; end + self._tls = true; + self:debug("Starting TLS now"); + self:updatenames(); -- Can't getpeer/sockname after wrap() + local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx); + if not ok then + conn, err = ok, conn; + self:debug("Failed to initialize TLS: %s", err); + end + if not conn then + self:on("disconnect", err); + self:destroy(); + return conn, err; + end + conn:settimeout(0); + self.conn = conn; + if conn.sni then + if self.servername then conn:sni(self.servername); + elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then + conn:sni(self._server.hosts, true); end - self:on("starttls"); - self.ondrain = nil; - self.onwritable = interface.tlshandskake; - self.onreadable = interface.tlshandskake; - return self:init(); end + if self.extra and self.extra.tlsa and conn.settlsa then + -- TODO Error handling + if not conn:setdane(self.servername or self.extra.dane_hostname) then + self:debug("Could not enable DANE on connection"); + else + self:debug("Enabling DANE with %d TLSA records", #self.extra.tlsa); + self:noise("DANE hostname is %q", self.servername or self.extra.dane_hostname); + for _, tlsa in ipairs(self.extra.tlsa) do + self:noise("TLSA: %q", tlsa); + conn:settlsa(tlsa.use, tlsa.select, tlsa.match, tlsa.data); + end + end + end + self:on("starttls"); + self.ondrain = nil; + self.onwritable = interface.tlshandshake; + self.onreadable = interface.tlshandshake; + if now then + return self:tlshandshake() + end + self:setreadtimeout(false); + self:setwritetimeout(cfg.ssl_handshake_timeout); + self:set(true, true); +end + +function interface:tlshandshake() + self:setreadtimeout(false); + self:noise("Continuing TLS handshake"); local ok, err = self.conn:dohandshake(); if ok then - log("debug", "TLS handshake on %s complete", self); + local info = self.conn.info and self.conn:info(); + if type(info) == "table" then + self:debug("TLS handshake complete (%s with %s)", info.protocol, info.cipher); + else + self:debug("TLS handshake complete"); + end + self:setwritetimeout(false); self.onwritable = nil; self.onreadable = nil; self:on("status", "ssl-handshake-complete"); - self:setwritetimeout(); self:set(true, true); + self:onconnect(); + self:onreadable(); elseif err == "wantread" then - log("debug", "TLS handshake on %s to wait until readable", self); + self:noise("TLS handshake to wait until readable"); self:set(true, false); - self:setreadtimeout(cfg.ssl_handshake_timeout); + self:setwritetimeout(cfg.ssl_handshake_timeout); elseif err == "wantwrite" then - log("debug", "TLS handshake on %s to wait until writable", self); + self:noise("TLS handshake to wait until writable"); self:set(false, true); self:setwritetimeout(cfg.ssl_handshake_timeout); else - log("debug", "TLS handshake error on %s: %s", self, err); + self:debug("TLS handshake error: %s", err); self:on("disconnect", err); self:destroy(); end @@ -533,15 +698,18 @@ end local function wrapsocket(client, server, read_size, listeners, tls_ctx, extra) -- luasocket object -> interface object client:settimeout(0); + local conn_id = ("conn%s"):format(new_id()); local conn = setmetatable({ conn = client; _server = server; - created = gettime(); + created = realtime(); listeners = listeners; read_size = read_size or (server and server.read_size); - writebuffer = {}; + writebuffer = nil; tls_ctx = tls_ctx or (server and server.tls_ctx); tls_direct = server and server.tls_direct; + id = conn_id; + log = logger.init(conn_id); extra = extra; }, interface_mt); @@ -558,12 +726,12 @@ end function interface:updatenames() local conn = self.conn; local ok, peername, peerport = pcall(conn.getpeername, conn); - if ok then - self.peername, self.peerport = peername, peerport; + if ok and peername then + self.peername, self.peerport = peername, peerport or 0; end local ok, sockname, sockport = pcall(conn.getsockname, conn); - if ok then - self.sockname, self.sockport = sockname, sockport; + if ok and sockname then + self.sockname, self.sockport = sockname, sockport or 0; end end @@ -572,76 +740,149 @@ end function interface:onacceptable() local conn, err = self.conn:accept(); if not conn then - log("debug", "Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval); + self:debug("Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval); self:pausefor(cfg.accept_retry_interval); return; end local client = wrapsocket(conn, self, nil, self.listeners); - log("debug", "New connection %s", tostring(client)); - client:init(); + client:debug("New connection %s on server %s", client, self); + client:defaultoptions(); + client._writable = cfg.opportunistic_writes; if self.tls_direct then - client:starttls(self.tls_ctx); + client:add(true, true); + client:inittls(self.tls_ctx, true); + else + client:add(true, false); + client:onconnect(); + client:onreadable(); end end --- Initialization +-- Initialization for outgoing connections function interface:init() - self:setwritetimeout(); + self:setwritetimeout(cfg.connect_timeout); + self:defaultoptions(); return self:add(true, true); end +function interface:defaultoptions() + if cfg.nagle == false then + self:setoption("tcp-nodelay", true); + end + if cfg.tcp_keepalive then + self:setoption("keepalive", true); + if type(cfg.tcp_keepalive) == "number" then + self:setoption("tcp-keepidle", cfg.tcp_keepalive); + end + end +end + function interface:pause() + self:noise("Pause reading"); + self:setreadtimeout(false); return self:set(false); end function interface:resume() + self:noise("Resume reading"); + self:setreadtimeout(); return self:set(true); end -- Pause connection for some time function interface:pausefor(t) + self:noise("Pause for %fs", t); if self._pausefor then - self._pausefor:close(); + closetimer(self._pausefor); + self._pausefor = nil; end if t == false then return; end self:set(false); self._pausefor = addtimer(t, function () self._pausefor = nil; self:set(true); + self:noise("Resuming after pause"); if self.conn and self.conn:dirty() then + self:noise("Have buffered incoming data to process"); self:onreadable(); end end); end +function interface:setlimit(Bps) + if Bps > 0 then + self._limit = 1/Bps; + else + self._limit = nil; + end +end + +function interface:pause_writes() + if self._write_lock then + return + end + self:noise("Pause writes"); + self._write_lock = true; + self:setwritetimeout(false); + self:set(nil, false); +end + +function interface:resume_writes() + if not self._write_lock then + return + end + self:noise("Resume writes"); + self._write_lock = nil; + if self.writebuffer and (self.writebuffer[1] or type(self.writebuffer) == "string") then + self:setwritetimeout(); + self:set(nil, true); + end +end + -- Connected! function interface:onconnect() - if self.conn and not self.peername and self.conn.getpeername then - self.peername, self.peerport = self.conn:getpeername(); - end + self._connected = true; + self:updatenames(); + self:debug("Connected (%s)", self); self.onconnect = noop; self:on("connect"); end -local function addserver(addr, port, listeners, read_size, tls_ctx) - local conn, err = socket.bind(addr, port, cfg.tcp_backlog); - if not conn then return conn, err; end - conn:settimeout(0); +local function wrapserver(conn, addr, port, listeners, config) local server = setmetatable({ conn = conn; - created = gettime(); + created = realtime(); listeners = listeners; - read_size = read_size; + read_size = config and config.read_size; onreadable = interface.onacceptable; - tls_ctx = tls_ctx; - tls_direct = tls_ctx and true or false; + tls_ctx = config and config.tls_ctx; + tls_direct = config and config.tls_direct; + hosts = config and config.sni_hosts; sockname = addr; sockport = port; + log = logger.init(("serv%s"):format(new_id())); }, interface_mt); + server:debug("Server %s created", server); server:add(true, false); return server; end +local function listen(addr, port, listeners, config) + local conn, err = socket.bind(addr, port, cfg.tcp_backlog); + if not conn then return conn, err; end + conn:settimeout(0); + return wrapserver(conn, addr, port, listeners, config); +end + +-- COMPAT +local function addserver(addr, port, listeners, read_size, tls_ctx) + return listen(addr, port, listeners, { + read_size = read_size; + tls_ctx = tls_ctx; + tls_direct = tls_ctx and true or false; + }); +end + -- COMPAT local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx, extra) local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra); @@ -675,13 +916,19 @@ local function addclient(addr, port, listeners, read_size, tls_ctx, typ, extra) return nil, "invalid socket type"; end local conn, err = create(); + if not conn then return conn, err; end local ok, err = conn:settimeout(0); if not ok then return ok, err; end local ok, err = conn:setpeername(addr, port); if not ok and err ~= "timeout" then return ok, err; end local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra) local ok, err = client:init(); + if not client.peername then + -- otherwise not set until connected + client.peername, client.peerport = addr, port; + end if not ok then return ok, err; end + client:debug("Client %s created", client); if tls_ctx then client:starttls(tls_ctx); end @@ -703,23 +950,23 @@ local function watchfd(fd, onreadable, onwritable) end; -- Otherwise it'll need to be something LuaSocket-compatible end + conn.id = new_id(); + conn.log = logger.init(("fdwatch%s"):format(conn.id)); conn:add(onreadable, onwritable); return conn; end; -- Dump all data from one connection into another -local function link(from, to) - from.listeners = setmetatable({ - onincoming = function (_, data) - from:pause(); - to:write(data); - end, - }, {__index=from.listeners}); - to.listeners = setmetatable({ - ondrain = function () - from:resume(); - end, - }, {__index=to.listeners}); +local function link(from, to, read_size) + from:debug("Linking to %s", to.id); + function from:onincoming(data) + self:pause(); + to:write(data); + end + function to:ondrain() -- luacheck: ignore 212/self + from:resume(); + end + from:set_mode(read_size); from:set(true, nil); to:set(nil, true); end @@ -778,11 +1025,21 @@ return { addserver = addserver; addclient = addclient; add_task = addtimer; - at = at; + timer = { + -- API-compatible with util.timer + add_task = addtimer; + stop = closetimer; + reschedule = reschedule; + to_absolute_time = function (t) + return t-monotonic()+realtime(); + end; + }; + listen = listen; loop = loop; closeall = closeall; setquitting = setquitting; wrapclient = wrapclient; + wrapserver = wrapserver; watchfd = watchfd; link = link; set_config = function (newconfig) @@ -792,6 +1049,7 @@ return { -- libevent emulation event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 }; addevent = function (fd, mode, callback) + log("warn", "Using deprecated libevent emulation, please update code to use watchfd API instead"); local function onevent(self) local ret = self:callback(); if ret == -1 then @@ -811,6 +1069,8 @@ return { fds[fd] = nil; end; }, interface_mt); + conn.id = conn:getfd(); + conn.log = logger.init(("fdwatch%d"):format(conn.id)); local ok, err = conn:add(mode == "r" or mode == "rw", mode == "w" or mode == "rw"); if not ok then return ok, err; end return conn; diff --git a/net/server_event.lua b/net/server_event.lua index 746526ce..139c7e5f 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -165,8 +165,12 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed return false end - if self.conn.sni and self.servername then - self.conn:sni(self.servername); + if self.conn.sni then + if self.servername then + self.conn:sni(self.servername); + elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then + self.conn:sni(self._server.hosts, true); + end end self.conn:settimeout( 0 ) -- set non blocking @@ -258,6 +262,7 @@ end --TODO: Deprecate function interface_mt:lock_read(switch) + log("warn", ":lock_read is deprecated, use :pasue() and :resume()"); if switch then return self:pause(); else @@ -277,6 +282,19 @@ function interface_mt:resume() end end +function interface_mt:pause_writes() + return self:_lock(self.nointerface, self.noreading, true); +end + +function interface_mt:resume_writes() + self:_lock(self.nointerface, self.noreading, false); + if self.writecallback and not self.eventwrite then + self.eventwrite = addevent( base, self.conn, EV_WRITE, self.writecallback, cfg.WRITE_TIMEOUT ); -- register callback + return true; + end +end + + function interface_mt:counter(c) if c then self._connections = self._connections + c @@ -286,7 +304,7 @@ end -- Public methods function interface_mt:write(data) - if self.nowriting then return nil, "locked" end + if self.nointerface then return nil, "locked"; end --vdebug( "try to send data to client, id/data:", self.id, data ) data = tostring( data ) local len = #data @@ -298,7 +316,7 @@ function interface_mt:write(data) end t_insert(self.writebuffer, data) -- new buffer self.writebufferlen = total - if not self.eventwrite then -- register new write event + if not self.eventwrite and not self.nowriting then -- register new write event --vdebug( "register new write event" ) self.eventwrite = addevent( base, self.conn, EV_WRITE, self.writecallback, cfg.WRITE_TIMEOUT ) end @@ -431,6 +449,7 @@ function interface_mt:setlistener(listener, data) self.onstatus = listener.onstatus; self.ondetach = listener.ondetach; self.onattach = listener.onattach; + self.onpredrain = listener.onpredrain; self.ondrain = listener.ondrain; self:onattach(data); end @@ -445,10 +464,8 @@ end function interface_mt:ontimeout() end function interface_mt:onreadtimeout() - self.fatalerror = "timeout during receiving" - debug( "connection failed:", self.fatalerror ) - self:_close() - self.eventread = nil +end +function interface_mt:onpredrain() end function interface_mt:ondrain() end @@ -476,6 +493,7 @@ local function handleclient( client, ip, port, server, pattern, listener, sslctx onincoming = listener.onincoming; -- will be called when client sends data ontimeout = listener.ontimeout; -- called when fatal socket timeout occurs onreadtimeout = listener.onreadtimeout; -- called when socket inactivity timeout occurs + onpredrain = listener.onpredrain; -- called before writes ondrain = listener.ondrain; -- called when writebuffer is empty ondetach = listener.ondetach; -- called when disassociating this listener from this connection onstatus = listener.onstatus; -- called for status changes (e.g. of SSL/TLS) @@ -526,6 +544,7 @@ local function handleclient( client, ip, port, server, pattern, listener, sslctx interface.eventwritetimeout = false end end + interface:onpredrain(); interface.writebuffer = { t_concat(interface.writebuffer) } local succ, err, byte = interface.conn:send( interface.writebuffer[1], 1, interface.writebufferlen ) --vdebug( "write data:", interface.writebuffer, "error:", err, "part:", byte ) @@ -642,7 +661,7 @@ local function handleclient( client, ip, port, server, pattern, listener, sslctx return interface end -local function handleserver( server, addr, port, pattern, listener, sslctx ) -- creates an server interface +local function handleserver( server, addr, port, pattern, listener, sslctx, startssl ) -- creates a server interface debug "creating server interface..." local interface = { _connections = 0; @@ -658,6 +677,7 @@ local function handleserver( server, addr, port, pattern, listener, sslctx ) -- _ip = addr, _port = port, _pattern = pattern, _sslctx = sslctx; + hosts = {}; } interface.id = tostring(interface):match("%x+$"); interface.readcallback = function( event ) -- server handler, called on incoming connections @@ -677,6 +697,7 @@ local function handleserver( server, addr, port, pattern, listener, sslctx ) -- end end --vdebug("max connection check ok, accepting...") + -- luacheck: ignore 231/err local client, err = server:accept() -- try to accept; TODO: check err while client do if interface._connections >= cfg.MAX_CONNECTIONS then @@ -688,7 +709,7 @@ local function handleserver( server, addr, port, pattern, listener, sslctx ) -- interface._connections = interface._connections + 1 -- increase connection count local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx ) --vdebug( "client id:", clientinterface, "startssl:", startssl ) - if has_luasec and sslctx then + if has_luasec and startssl then clientinterface:starttls(sslctx, true) else clientinterface:_start_session( true ) @@ -707,9 +728,9 @@ local function handleserver( server, addr, port, pattern, listener, sslctx ) -- return interface end -local function addserver( addr, port, listener, pattern, sslctx, startssl ) -- TODO: check arguments - --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslctx or "nil", startssl or "nil") - if sslctx and not has_luasec then +local function listen(addr, port, listener, config) + config = config or {} + if config.sslctx and not has_luasec then debug "fatal error: luasec not found" return nil, "luasec not found" end @@ -718,11 +739,20 @@ local function addserver( addr, port, listener, pattern, sslctx, startssl ) -- debug( "creating server socket on "..addr.." port "..port.." failed:", err ) return nil, err end - local interface = handleserver( server, addr, port, pattern, listener, sslctx, startssl ) -- new server handler + local interface = handleserver( server, addr, port, config.read_size, listener, config.tls_ctx, config.tls_direct) -- new server handler debug( "new server created with id:", tostring(interface)) return interface end +local function addserver( addr, port, listener, pattern, sslctx ) -- TODO: check arguments + --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslctx or "nil", startssl or "nil") + return listen( addr, port, listener, { + read_size = pattern, + tls_ctx = sslctx, + tls_direct = not not sslctx, + }); +end + local function wrapclient( client, ip, port, listeners, pattern, sslctx, extra ) local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx, extra ) interface:_start_connection(sslctx) @@ -756,6 +786,7 @@ local function addclient( addr, serverport, listener, pattern, sslctx, typ, extr client:settimeout( 0 ) -- set nonblocking local res, err = client:setpeername( addr, serverport ) -- connect if res or ( err == "timeout" ) then + -- luacheck: ignore 211/port local ip, port = client:getsockname( ) local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx, extra ) debug( "new connection id:", interface.id ) @@ -883,6 +914,7 @@ return { event_base = base, addevent = newevent, addserver = addserver, + listen = listen, addclient = addclient, wrapclient = wrapclient, setquitting = setquitting, diff --git a/net/server_select.lua b/net/server_select.lua index deb8fe48..eea850ce 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -68,6 +68,7 @@ local idfalse local closeall local addsocket local addserver +local listen local addtimer local getserver local wrapserver @@ -123,7 +124,7 @@ local _maxsslhandshake _server = { } -- key = port, value = table; list of listening servers _readlist = { } -- array with sockets to read from -_sendlist = { } -- arrary with sockets to write to +_sendlist = { } -- array with sockets to write to _timerlist = { } -- array of timer functions _socketlist = { } -- key = socket, value = wrapped socket (handlers) _readtimes = { } -- key = handler, value = timestamp of last data reading @@ -149,7 +150,7 @@ _checkinterval = 30 -- interval in secs to check idle clients _sendtimeout = 60000 -- allowed send idle time in secs _readtimeout = 14 * 60 -- allowed read idle time in secs -local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows +local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to determine whether this is Windows _maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows _maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows @@ -157,7 +158,7 @@ _maxsslhandshake = 30 -- max handshake round-trips ----------------------------------// PRIVATE //-- -wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd +wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, ssldirect ) -- this function wraps a server -- FIXME Make sure FD < _maxfd if socket:getfd() >= _maxfd then out_error("server.lua: Disallowed FD number: "..socket:getfd()) @@ -183,6 +184,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t handler.sslctx = function( ) return sslctx end + handler.hosts = {} -- sni handler.remove = function( ) connections = connections - 1 if handler then @@ -244,13 +246,13 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t local client, err = accept( socket ) -- try to accept if client then local ip, clientport = client:getpeername( ) - local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket + local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx, ssldirect ) -- wrap new client socket if err then -- error while wrapping ssl socket return false end connections = connections + 1 out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport)) - if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes + if dispatch and not ssldirect then -- SSL connections will notify onconnect when handshake completes return dispatch( handler ); end return; @@ -264,7 +266,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t return handler end -wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, extra ) -- this function wraps a client to a handler object +wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, ssldirect, extra ) -- this function wraps a client to a handler object if socket:getfd() >= _maxfd then out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent @@ -287,9 +289,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport local ssl + local pending + local dispatch = listeners.onincoming local status = listeners.onstatus local disconnect = listeners.ondisconnect + local predrain = listeners.onpredrain local drain = listeners.ondrain local onreadtimeout = listeners.onreadtimeout; local detach = listeners.ondetach @@ -334,6 +339,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport dispatch = listeners.onincoming disconnect = listeners.ondisconnect status = listeners.onstatus + predrain = listeners.onpredrain drain = listeners.ondrain handler.onreadtimeout = listeners.onreadtimeout detach = listeners.ondetach @@ -341,6 +347,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport listeners.onattach(self, data) end end + handler._setpending = function( ) + pending = true + end handler.getstats = function( ) return readtraffic, sendtraffic end @@ -377,7 +386,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _readlistlen = removesocket( _readlist, socket, _readlistlen ) _readtimes[ handler ] = nil if bufferqueuelen ~= 0 then - handler.sendbuffer() -- Try now to send any outstanding data + handler:sendbuffer() -- Try now to send any outstanding data if bufferqueuelen ~= 0 then -- Still not empty, so we'll try again later if handler then handler.write = nil -- ... but no further writing allowed @@ -429,9 +438,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport bufferlen = bufferlen + #data if bufferlen > maxsendlen then _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle - handler.write = idfalse -- don't write anymore return false - elseif socket and not _sendlist[ socket ] then + elseif not nosend and socket and not _sendlist[ socket ] then _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) end bufferqueuelen = bufferqueuelen + 1 @@ -461,49 +469,55 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport maxreadlen = readlen or maxreadlen return bufferlen, maxreadlen, maxsendlen end - --TODO: Deprecate handler.lock_read = function (self, switch) + out_error( "server.lua, lock_read() is deprecated, use pause() and resume()" ) if switch == true then - local tmp = _readlistlen - _readlistlen = removesocket( _readlist, socket, _readlistlen ) - _readtimes[ handler ] = nil - if _readlistlen ~= tmp then - noread = true - end + return self:pause() elseif switch == false then - if noread then - noread = false - _readlistlen = addsocket(_readlist, socket, _readlistlen) - _readtimes[ handler ] = _currenttime - end + return self:resume() end return noread end handler.pause = function (self) - return self:lock_read(true); + local tmp = _readlistlen + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _readtimes[ handler ] = nil + if _readlistlen ~= tmp then + noread = true + end + return noread; end handler.resume = function (self) - return self:lock_read(false); + if noread then + noread = false + _readlistlen = addsocket(_readlist, socket, _readlistlen) + _readtimes[ handler ] = _currenttime + end + return noread; end handler.lock = function( self, switch ) - handler.lock_read (switch) + out_error( "server.lua, lock() is deprecated" ) + handler.lock_read (self, switch) if switch == true then - handler.write = idfalse - local tmp = _sendlistlen - _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) - _writetimes[ handler ] = nil - if _sendlistlen ~= tmp then - nosend = true - end + handler.pause_writes (self) elseif switch == false then - handler.write = write - if nosend then - nosend = false - write( "" ) - end + handler.resume_writes (self) end return noread, nosend end + handler.pause_writes = function (self) + local tmp = _sendlistlen + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + _writetimes[ handler ] = nil + nosend = true + end + handler.resume_writes = function (self) + nosend = false + if bufferlen > 0 and socket then + _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) + end + end + local _readbuffer = function( ) -- this function reads data local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern" if not err or (err == "wantread" or err == "timeout") then -- received something @@ -518,6 +532,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _readtraffic = _readtraffic + count _readtimes[ handler ] = _currenttime --out_put( "server.lua: read data '", buffer:gsub("[^%w%p ]", "."), "', error: ", err ) + if pending then -- connection established + pending = nil + if listeners.onconnect then + listeners.onconnect(handler) + end + end return dispatch( handler, buffer, err ) else -- connections was closed or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) ) @@ -528,6 +548,15 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport local _sendbuffer = function( ) -- this function sends data local succ, err, byte, buffer, count; if socket then + if pending then + pending = nil + if listeners.onconnect then + listeners.onconnect(handler); + end + end + if predrain then + predrain(handler); + end buffer = table_concat( bufferqueue, "", 1, bufferqueuelen ) succ, err, byte = send( socket, buffer, 1, bufferlen ) count = ( succ or byte or 0 ) * STAT_UNIT @@ -604,7 +633,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport coroutine_yield( ) -- handshake not finished end end - err = "ssl handshake error: " .. ( err or "handshake too long" ); + err = ( err or "handshake too long" ); out_put( "server.lua: ", err ); _ = handler and handler:force_close(err) return false, err -- handshake failed @@ -624,13 +653,18 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) local oldsocket, err = socket socket, err = ssl_wrap( socket, sslctx ) -- wrap socket + if not socket then out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) return nil, err -- fatal error end - if socket.sni and self.servername then - socket:sni(self.servername); + if socket.sni then + if self.servername then + socket:sni(self.servername); + elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then + socket:sni(self.server().hosts, true); + end end socket:settimeout( 0 ) @@ -668,7 +702,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _socketlist[ socket ] = handler _readlistlen = addsocket(_readlist, socket, _readlistlen) - if sslctx and has_luasec then + if sslctx and ssldirect and has_luasec then out_put "server.lua: auto-starting ssl negotiation..." handler.autostart_ssl = true; local ok, err = handler:starttls(sslctx); @@ -723,7 +757,7 @@ local function link(sender, receiver, buffersize) local sender_locked; local _sendbuffer = receiver.sendbuffer; function receiver.sendbuffer() - _sendbuffer(); + _sendbuffer(receiver); if sender_locked and receiver.bufferlen() < buffersize then sender:lock_read(false); -- Unlock now sender_locked = nil; @@ -743,9 +777,13 @@ end ----------------------------------// PUBLIC //-- -addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server +listen = function ( addr, port, listeners, config ) addr = addr or "*" + config = config or {} local err + local sslctx = config.tls_ctx; + local ssldirect = config.tls_direct; + local pattern = config.read_size; if type( listeners ) ~= "table" then err = "invalid listener table" elseif type ( addr ) ~= "string" then @@ -766,7 +804,7 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function out_error( "server.lua, [", addr, "]:", port, ": ", err ) return nil, err end - local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket + local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, ssldirect ) -- wrap new server socket if not handler then server:close( ) return nil, err @@ -779,6 +817,14 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function return handler end +addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server + return listen(addr, port, listeners, { + read_size = pattern; + tls_ctx = sslctx; + tls_direct = sslctx and true or false; + }); +end + getserver = function ( addr, port ) return _server[ addr..":"..port ]; end @@ -921,7 +967,7 @@ loop = function(once) -- this is the main loop of the program for _, socket in ipairs( read ) do -- receive data local handler = _socketlist[ socket ] if handler then - handler.readbuffer( ) + handler:readbuffer( ) else closesocket( socket ) out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen @@ -930,7 +976,7 @@ loop = function(once) -- this is the main loop of the program for _, socket in ipairs( write ) do -- send data waiting in writequeues local handler = _socketlist[ socket ] if handler then - handler.sendbuffer( ) + handler:sendbuffer( ) else closesocket( socket ) out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen @@ -987,21 +1033,13 @@ end --// EXPERIMENTAL //-- local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx, extra ) - local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, extra) + local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, sslctx, extra) if not handler then return nil, err end _socketlist[ socket ] = handler if not sslctx then + handler._setpending() _readlistlen = addsocket(_readlist, socket, _readlistlen) _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) - if listeners.onconnect then - -- When socket is writeable, call onconnect - local _sendbuffer = handler.sendbuffer; - handler.sendbuffer = function () - handler.sendbuffer = _sendbuffer; - listeners.onconnect(handler); - return _sendbuffer(); -- Send any queued outgoing data - end - end end return handler, socket end @@ -1123,6 +1161,7 @@ return { stats = stats, closeall = closeall, addserver = addserver, + listen = listen, getserver = getserver, setlogger = setlogger, getsettings = getsettings, diff --git a/net/unbound.lua b/net/unbound.lua new file mode 100644 index 00000000..22b0614b --- /dev/null +++ b/net/unbound.lua @@ -0,0 +1,220 @@ +-- libunbound based net.adns replacement for Prosody IM +-- Copyright (C) 2013-2015 Kim Alvefur +-- +-- This file is MIT licensed. +-- +-- luacheck: ignore prosody + +local setmetatable = setmetatable; +local tostring = tostring; +local t_concat = table.concat; +local s_format = string.format; +local s_lower = string.lower; +local s_upper = string.upper; +local noop = function() end; + +local logger = require "util.logger"; +local log = logger.init("unbound"); +local net_server = require "net.server"; +local libunbound = require"lunbound"; +local promise = require"util.promise"; +local new_id = require "util.id".medium; + +local gettime = require"socket".gettime; +local dns_utils = require"util.dns"; +local classes, types, errors = dns_utils.classes, dns_utils.types, dns_utils.errors; +local parsers = dns_utils.parsers; + +local function add_defaults(conf) + if conf then + for option, default in pairs(libunbound.config) do + if conf[option] == nil then + conf[option] = default; + end + end + end + return conf; +end + +local unbound_config; +if prosody then + local config = require"core.configmanager"; + unbound_config = add_defaults(config.get("*", "unbound")); + prosody.events.add_handler("config-reloaded", function() + unbound_config = add_defaults(config.get("*", "unbound")); + end); +end +-- Note: libunbound will default to using root hints if resolvconf is unset + +local function connect_server(unbound, server) + log("debug", "Setting up net.server event handling for %s", unbound); + return server.watchfd(unbound, function () + log("debug", "Processing queries for %s", unbound); + unbound:process() + end); +end + +local unbound, server_conn; + +local function initialize() + unbound = libunbound.new(unbound_config); + server_conn = connect_server(unbound, net_server); +end +if prosody then + prosody.events.add_handler("server-started", initialize); +end + +local answer_mt = { + __tostring = function(self) + if self._string then return self._string end + local h = s_format("Status: %s", errors[self.status]); + if self.secure then + h = h .. ", Secure"; + elseif self.bogus then + h = h .. s_format(", Bogus: %s", self.bogus); + end + local t = { h }; + for i = 1, #self do + t[i+1]=self.qname.."\t"..classes[self.qclass].."\t"..types[self.qtype].."\t"..tostring(self[i]); + end + local _string = t_concat(t, "\n"); + self._string = _string; + return _string; + end; +}; + +local waiting_queries = {}; + +local function prep_answer(a) + if not a then return end + local status = errors[a.rcode]; + local qclass = classes[a.qclass]; + local qtype = types[a.qtype]; + a.status, a.class, a.type = status, qclass, qtype; + + local t = s_lower(qtype); + local rr_mt = { __index = a, __tostring = function(self) return tostring(self[t]) end }; + local parser = parsers[qtype]; + for i = 1, #a do + if a.bogus then + -- Discard bogus data + a[i] = nil; + else + a[i] = setmetatable({[t] = parser(a[i])}, rr_mt); + end + end + return setmetatable(a, answer_mt); +end + +local function lookup(callback, qname, qtype, qclass) + if not unbound then initialize(); end + qtype = qtype and s_upper(qtype) or "A"; + qclass = qclass and s_upper(qclass) or "IN"; + local ntype, nclass = types[qtype], classes[qclass]; + local startedat = gettime(); + local ret; + local log_query = logger.init("unbound.query"..new_id()); + local function callback_wrapper(a, err) + local gotdataat = gettime(); + waiting_queries[ret] = nil; + if a then + prep_answer(a); + log_query("debug", "Results for %s %s %s: %s (%s, %f sec)", qname, qclass, qtype, a.rcode == 0 and (#a .. " items") or a.status, + a.secure and "Secure" or a.bogus or "Insecure", gotdataat - startedat); -- Insecure as in unsigned + else + log_query("error", "Results for %s %s %s: %s", qname, qclass, qtype, tostring(err)); + end + local ok, cerr = pcall(callback, a, err); + if not ok then log_query("error", "Error in callback: %s", cerr); end + end + log_query("debug", "Resolve %s %s %s", qname, qclass, qtype); + local err; + ret, err = unbound:resolve_async(callback_wrapper, qname, ntype, nclass); + if ret then + waiting_queries[ret] = callback; + else + log_query("warn", "Resolver error: %s", err); + end + return ret, err; +end + +local function lookup_sync(qname, qtype, qclass) + if not unbound then initialize(); end + qtype = qtype and s_upper(qtype) or "A"; + qclass = qclass and s_upper(qclass) or "IN"; + local ntype, nclass = types[qtype], classes[qclass]; + local a, err = unbound:resolve(qname, ntype, nclass); + if not a then return a, err; end + return prep_answer(a); +end + +local function cancel(id) + local cb = waiting_queries[id]; + unbound:cancel(id); + if cb then + cb(nil, "canceled"); + waiting_queries[id] = nil; + end + return true; +end + +-- Reinitiate libunbound context, drops cache +local function purge() + for id in pairs(waiting_queries) do cancel(id); end + if server_conn then server_conn:close(); end + initialize(); + return true; +end + +local function not_implemented() + error "not implemented"; +end +-- Public API +local _M = { + lookup = lookup; + cancel = cancel; + new_async_socket = not_implemented; + dns = { + lookup = lookup_sync; + cancel = cancel; + cache = noop; + socket_wrapper_set = noop; + settimeout = noop; + query = noop; + purge = purge; + random = noop; + peek = noop; + + types = types; + classes = classes; + }; +}; + +local function lookup_promise(_, qname, qtype, qclass) + return promise.new(function (resolve, reject) + local function callback(answer, err) + if err then + return reject(err); + else + return resolve(answer); + end + end + local ret, err = lookup(callback, qname, qtype, qclass) + if not ret then reject(err); end + end); +end + +local wrapper = { + lookup = function (_, callback, qname, qtype, qclass) + return lookup(callback, qname, qtype, qclass) + end; + lookup_promise = lookup_promise; + _resolver = { + settimeout = function () end; + closeall = function () end; + }; +} + +function _M.resolver() return wrapper; end + +return _M; diff --git a/net/websocket.lua b/net/websocket.lua index 469c6a58..193cd556 100644 --- a/net/websocket.lua +++ b/net/websocket.lua @@ -23,6 +23,7 @@ local websockets = {}; local websocket_listeners = {}; function websocket_listeners.ondisconnect(conn, err) local s = websockets[conn]; + if not s then return; end websockets[conn] = nil; if s.close_timer then timer.stop(s.close_timer); @@ -113,7 +114,7 @@ function websocket_listeners.onincoming(conn, buffer, err) -- luacheck: ignore 2 frame.MASK = true; -- RFC 6455 6.1.5: If the data is being sent by the client, the frame(s) MUST be masked conn:write(frames.build(frame)); elseif frame.opcode == 0xA then -- Pong frame - log("debug", "Received unexpected pong frame: " .. tostring(frame.data)); + log("debug", "Received unexpected pong frame: %s", frame.data); else return fail(s, 1002, "Reserved opcode"); end @@ -131,7 +132,7 @@ end function websocket_methods:close(code, reason) if self.readyState < 2 then code = code or 1000; - log("debug", "closing WebSocket with code %i: %s" , code , tostring(reason)); + log("debug", "closing WebSocket with code %i: %s" , code , reason); self.readyState = 2; local conn = self.conn; conn:write(frames.build_close(code, reason, true)); @@ -245,7 +246,7 @@ local function connect(url, ex, listeners) or (protocol and not protocol[r.headers["sec-websocket-protocol"]]) then s.readyState = 3; - log("warn", "WebSocket connection to %s failed: %s", url, tostring(b)); + log("warn", "WebSocket connection to %s failed: %s", url, b); if s.onerror then s:onerror("connecting-failed"); end return; end diff --git a/net/websocket/frames.lua b/net/websocket/frames.lua index 1d0ac06f..03ce21a8 100644 --- a/net/websocket/frames.lua +++ b/net/websocket/frames.lua @@ -9,8 +9,7 @@ local softreq = require "util.dependencies".softreq; local random_bytes = require "util.random".bytes; -local bit = assert(softreq"bit" or softreq"bit32", - "No bit module found. See https://prosody.im/doc/depends#bitop"); +local bit = require "util.bitcompat"; local band = bit.band; local bor = bit.bor; local lshift = bit.lshift; @@ -19,8 +18,8 @@ local sbit = require "util.strbitop"; local sxor = sbit.sxor; local s_char= string.char; -local s_pack = string.pack; -- luacheck: ignore 143 -local s_unpack = string.unpack; -- luacheck: ignore 143 +local s_pack = string.pack; +local s_unpack = string.unpack; if not s_pack and softreq"struct" then s_pack = softreq"struct".pack; |