diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/adns.lua | 8 | ||||
-rw-r--r-- | net/connect.lua | 52 | ||||
-rw-r--r-- | net/cqueues.lua | 4 | ||||
-rw-r--r-- | net/dns.lua | 14 | ||||
-rw-r--r-- | net/http.lua | 22 | ||||
-rw-r--r-- | net/http/codes.lua | 92 | ||||
-rw-r--r-- | net/http/errors.lua | 4 | ||||
-rw-r--r-- | net/http/files.lua | 8 | ||||
-rw-r--r-- | net/http/parser.lua | 4 | ||||
-rw-r--r-- | net/http/server.lua | 42 | ||||
-rw-r--r-- | net/httpserver.lua | 2 | ||||
-rw-r--r-- | net/resolvers/basic.lua | 172 | ||||
-rw-r--r-- | net/resolvers/manual.lua | 2 | ||||
-rw-r--r-- | net/resolvers/service.lua | 101 | ||||
-rw-r--r-- | net/server.lua | 44 | ||||
-rw-r--r-- | net/server_epoll.lua | 132 | ||||
-rw-r--r-- | net/server_event.lua | 46 | ||||
-rw-r--r-- | net/server_select.lua | 33 | ||||
-rw-r--r-- | net/stun.lua | 16 | ||||
-rw-r--r-- | net/tls_luasec.lua | 114 | ||||
-rw-r--r-- | net/unbound.lua | 12 | ||||
-rw-r--r-- | net/websocket.lua | 14 | ||||
-rw-r--r-- | net/websocket/frames.lua | 11 |
23 files changed, 654 insertions, 295 deletions
diff --git a/net/adns.lua b/net/adns.lua index ae168b9c..2ecd7f53 100644 --- a/net/adns.lua +++ b/net/adns.lua @@ -6,11 +6,11 @@ -- COPYING file in the source package for more information. -- -local server = require "net.server"; -local new_resolver = require "net.dns".resolver; -local promise = require "util.promise"; +local server = require "prosody.net.server"; +local new_resolver = require "prosody.net.dns".resolver; +local promise = require "prosody.util.promise"; -local log = require "util.logger".init("adns"); +local log = require "prosody.util.logger".init("adns"); log("debug", "Using legacy DNS API (missing lua-unbound?)"); -- TODO write docs about luaunbound -- TODO Raise log level once packages are available diff --git a/net/connect.lua b/net/connect.lua index 4b602be4..0e39cf36 100644 --- a/net/connect.lua +++ b/net/connect.lua @@ -1,8 +1,8 @@ -local server = require "net.server"; -local log = require "util.logger".init("net.connect"); -local new_id = require "util.id".short; +local server = require "prosody.net.server"; +local log = require "prosody.util.logger".init("net.connect"); +local new_id = require "prosody.util.id".short; +local timer = require "prosody.util.timer"; --- TODO #1246 Happy Eyeballs -- FIXME RFC 6724 -- FIXME Error propagation from resolvers doesn't work -- FIXME #1428 Reuse DNS resolver object between service and basic resolver @@ -28,16 +28,17 @@ local pending_connection_listeners = {}; local function attempt_connection(p) p:log("debug", "Checking for targets..."); - if p.conn then - pending_connections_map[p.conn] = nil; - p.conn = nil; - end - p.target_resolver:next(function (conn_type, ip, port, extra) + p.target_resolver:next(function (conn_type, ip, port, extra, more_targets_available) if not conn_type then -- No more targets to try p:log("debug", "No more connection targets to try", p.target_resolver.last_error); - if p.listeners.onfail then - p.listeners.onfail(p.data, p.last_error or p.target_resolver.last_error or "unable to resolve service"); + if next(p.conns) == nil then + p:log("debug", "No more targets, no pending connections. Connection failed."); + if p.listeners.onfail then + p.listeners.onfail(p.data, p.last_error or p.target_resolver.last_error or "unable to resolve service"); + end + else + p:log("debug", "One or more connection attempts are still pending. Waiting for now."); end return; end @@ -49,8 +50,16 @@ local function attempt_connection(p) p.last_error = err or "unknown reason"; return attempt_connection(p); end - p.conn = conn; + p.conns[conn] = true; pending_connections_map[conn] = p; + if more_targets_available then + timer.add_task(0.250, function () + if not p.connected then + p:log("debug", "Still not connected, making parallel connection attempt..."); + attempt_connection(p); + end + end); + end end); end @@ -62,6 +71,13 @@ function pending_connection_listeners.onconnect(conn) return; end pending_connections_map[conn] = nil; + if p.connected then + -- We already succeeded in connecting + p.conns[conn] = nil; + conn:close(); + return; + end + p.connected = true; p:log("debug", "Successfully connected"); conn:setlistener(p.listeners, p.data); return p.listeners.onconnect(conn); @@ -73,9 +89,18 @@ function pending_connection_listeners.ondisconnect(conn, reason) log("warn", "Failed connection, but unexpected!"); return; end + p.conns[conn] = nil; + pending_connections_map[conn] = nil; p.last_error = reason or "unknown reason"; p:log("debug", "Connection attempt failed: %s", p.last_error); - attempt_connection(p); + if p.connected then + p:log("debug", "Connection already established, ignoring failure"); + elseif next(p.conns) == nil then + p:log("debug", "No pending connection attempts, and not yet connected"); + attempt_connection(p); + else + p:log("debug", "Other attempts are still pending, ignoring failure"); + end end local function connect(target_resolver, listeners, options, data) @@ -85,6 +110,7 @@ local function connect(target_resolver, listeners, options, data) listeners = assert(listeners); options = options or {}; data = data; + conns = {}; }, pending_connection_mt); p:log("debug", "Starting connection process"); diff --git a/net/cqueues.lua b/net/cqueues.lua index 65d2a019..72d8759f 100644 --- a/net/cqueues.lua +++ b/net/cqueues.lua @@ -7,9 +7,9 @@ -- This module allows you to use cqueues with a net.server mainloop -- -local server = require "net.server"; +local server = require "prosody.net.server"; local cqueues = require "cqueues"; -local timer = require "util.timer"; +local timer = require "prosody.util.timer"; assert(cqueues.VERSION >= 20150113, "cqueues newer than 20150113 required") -- Create a single top level cqueue diff --git a/net/dns.lua b/net/dns.lua index a9846e86..ba432c14 100644 --- a/net/dns.lua +++ b/net/dns.lua @@ -8,18 +8,18 @@ -- todo: cache results of encodeName --- reference: http://tools.ietf.org/html/rfc1035 --- reference: http://tools.ietf.org/html/rfc1876 (LOC) +-- reference: https://www.rfc-editor.org/rfc/rfc1035.html +-- reference: https://www.rfc-editor.org/rfc/rfc1876.html (LOC) local socket = require "socket"; -local have_timer, timer = pcall(require, "util.timer"); -local new_ip = require "util.ip".new_ip; -local have_util_net, util_net = pcall(require, "util.net"); +local have_timer, timer = pcall(require, "prosody.util.timer"); +local new_ip = require "prosody.util.ip".new_ip; +local have_util_net, util_net = pcall(require, "prosody.util.net"); -local log = require "util.logger".init("dns"); +local log = require "prosody.util.logger".init("dns"); -local _, windows = pcall(require, "util.windows"); +local _, windows = pcall(require, "prosody.util.windows"); local is_windows = (_ and windows) or os.getenv("WINDIR"); local coroutine, io, math, string, table = diff --git a/net/http.lua b/net/http.lua index aa435719..07c7d41f 100644 --- a/net/http.lua +++ b/net/http.lua @@ -6,17 +6,17 @@ -- COPYING file in the source package for more information. -- -local b64 = require "util.encodings".base64.encode; +local b64 = require "prosody.util.encodings".base64.encode; local url = require "socket.url" -local httpstream_new = require "net.http.parser".new; -local util_http = require "util.http"; -local events = require "util.events"; -local verify_identity = require"util.x509".verify_identity; -local promise = require "util.promise"; -local http_errors = require "net.http.errors"; +local httpstream_new = require "prosody.net.http.parser".new; +local util_http = require "prosody.util.http"; +local events = require "prosody.util.events"; +local verify_identity = require"prosody.util.x509".verify_identity; +local promise = require "prosody.util.promise"; +local http_errors = require "prosody.net.http.errors"; -local basic_resolver = require "net.resolvers.basic"; -local connect = require "net.connect".connect; +local basic_resolver = require "prosody.net.resolvers.basic"; +local connect = require "prosody.net.connect".connect; local ssl_available = pcall(require, "ssl"); @@ -25,10 +25,10 @@ local pairs = pairs; local tonumber, tostring, traceback = tonumber, tostring, debug.traceback; local os_time = os.time; -local xpcall = require "util.xpcall".xpcall; +local xpcall = require "prosody.util.xpcall".xpcall; local error = error -local log = require "util.logger".init("http"); +local log = require "prosody.util.logger".init("http"); local _ENV = nil; -- luacheck: std none diff --git a/net/http/codes.lua b/net/http/codes.lua index 4327f151..b2949286 100644 --- a/net/http/codes.lua +++ b/net/http/codes.lua @@ -2,62 +2,62 @@ local response_codes = { -- Source: http://www.iana.org/assignments/http-status-codes - [100] = "Continue"; -- RFC7231, Section 6.2.1 - [101] = "Switching Protocols"; -- RFC7231, Section 6.2.2 + [100] = "Continue"; -- RFC9110, Section 15.2.1 + [101] = "Switching Protocols"; -- RFC9110, Section 15.2.2 [102] = "Processing"; [103] = "Early Hints"; -- [104-199] = "Unassigned"; - [200] = "OK"; -- RFC7231, Section 6.3.1 - [201] = "Created"; -- RFC7231, Section 6.3.2 - [202] = "Accepted"; -- RFC7231, Section 6.3.3 - [203] = "Non-Authoritative Information"; -- RFC7231, Section 6.3.4 - [204] = "No Content"; -- RFC7231, Section 6.3.5 - [205] = "Reset Content"; -- RFC7231, Section 6.3.6 - [206] = "Partial Content"; -- RFC7233, Section 4.1 + [200] = "OK"; -- RFC9110, Section 15.3.1 + [201] = "Created"; -- RFC9110, Section 15.3.2 + [202] = "Accepted"; -- RFC9110, Section 15.3.3 + [203] = "Non-Authoritative Information"; -- RFC9110, Section 15.3.4 + [204] = "No Content"; -- RFC9110, Section 15.3.5 + [205] = "Reset Content"; -- RFC9110, Section 15.3.6 + [206] = "Partial Content"; -- RFC9110, Section 15.3.7 [207] = "Multi-Status"; [208] = "Already Reported"; -- [209-225] = "Unassigned"; [226] = "IM Used"; -- [227-299] = "Unassigned"; - [300] = "Multiple Choices"; -- RFC7231, Section 6.4.1 - [301] = "Moved Permanently"; -- RFC7231, Section 6.4.2 - [302] = "Found"; -- RFC7231, Section 6.4.3 - [303] = "See Other"; -- RFC7231, Section 6.4.4 - [304] = "Not Modified"; -- RFC7232, Section 4.1 - [305] = "Use Proxy"; -- RFC7231, Section 6.4.5 - -- [306] = "(Unused)"; -- RFC7231, Section 6.4.6 - [307] = "Temporary Redirect"; -- RFC7231, Section 6.4.7 - [308] = "Permanent Redirect"; + [300] = "Multiple Choices"; -- RFC9110, Section 15.4.1 + [301] = "Moved Permanently"; -- RFC9110, Section 15.4.2 + [302] = "Found"; -- RFC9110, Section 15.4.3 + [303] = "See Other"; -- RFC9110, Section 15.4.4 + [304] = "Not Modified"; -- RFC9110, Section 15.4.5 + [305] = "Use Proxy"; -- RFC9110, Section 15.4.6 + -- [306] = "(Unused)"; -- RFC9110, Section 15.4.7 + [307] = "Temporary Redirect"; -- RFC9110, Section 15.4.8 + [308] = "Permanent Redirect"; -- RFC9110, Section 15.4.9 -- [309-399] = "Unassigned"; - [400] = "Bad Request"; -- RFC7231, Section 6.5.1 - [401] = "Unauthorized"; -- RFC7235, Section 3.1 - [402] = "Payment Required"; -- RFC7231, Section 6.5.2 - [403] = "Forbidden"; -- RFC7231, Section 6.5.3 - [404] = "Not Found"; -- RFC7231, Section 6.5.4 - [405] = "Method Not Allowed"; -- RFC7231, Section 6.5.5 - [406] = "Not Acceptable"; -- RFC7231, Section 6.5.6 - [407] = "Proxy Authentication Required"; -- RFC7235, Section 3.2 - [408] = "Request Timeout"; -- RFC7231, Section 6.5.7 - [409] = "Conflict"; -- RFC7231, Section 6.5.8 - [410] = "Gone"; -- RFC7231, Section 6.5.9 - [411] = "Length Required"; -- RFC7231, Section 6.5.10 - [412] = "Precondition Failed"; -- RFC7232, Section 4.2 - [413] = "Payload Too Large"; -- RFC7231, Section 6.5.11 - [414] = "URI Too Long"; -- RFC7231, Section 6.5.12 - [415] = "Unsupported Media Type"; -- RFC7231, Section 6.5.13 - [416] = "Range Not Satisfiable"; -- RFC7233, Section 4.4 - [417] = "Expectation Failed"; -- RFC7231, Section 6.5.14 + [400] = "Bad Request"; -- RFC9110, Section 15.5.1 + [401] = "Unauthorized"; -- RFC9110, Section 15.5.2 + [402] = "Payment Required"; -- RFC9110, Section 15.5.3 + [403] = "Forbidden"; -- RFC9110, Section 15.5.4 + [404] = "Not Found"; -- RFC9110, Section 15.5.5 + [405] = "Method Not Allowed"; -- RFC9110, Section 15.5.6 + [406] = "Not Acceptable"; -- RFC9110, Section 15.5.7 + [407] = "Proxy Authentication Required"; -- RFC9110, Section 15.5.8 + [408] = "Request Timeout"; -- RFC9110, Section 15.5.9 + [409] = "Conflict"; -- RFC9110, Section 15.5.10 + [410] = "Gone"; -- RFC9110, Section 15.5.11 + [411] = "Length Required"; -- RFC9110, Section 15.5.12 + [412] = "Precondition Failed"; -- RFC9110, Section 15.5.13 + [413] = "Content Too Large"; -- RFC9110, Section 15.5.14 + [414] = "URI Too Long"; -- RFC9110, Section 15.5.15 + [415] = "Unsupported Media Type"; -- RFC9110, Section 15.5.16 + [416] = "Range Not Satisfiable"; -- RFC9110, Section 15.5.17 + [417] = "Expectation Failed"; -- RFC9110, Section 15.5.18 [418] = "I'm a teapot"; -- RFC2324, Section 2.3.2 -- [419-420] = "Unassigned"; - [421] = "Misdirected Request"; -- RFC7540, Section 9.1.2 - [422] = "Unprocessable Entity"; + [421] = "Misdirected Request"; -- RFC9110, Section 15.5.20 + [422] = "Unprocessable Content"; -- RFC9110, Section 15.5.21 [423] = "Locked"; [424] = "Failed Dependency"; [425] = "Too Early"; - [426] = "Upgrade Required"; -- RFC7231, Section 6.5.15 + [426] = "Upgrade Required"; -- RFC9110, Section 15.5.22 -- [427] = "Unassigned"; [428] = "Precondition Required"; [429] = "Too Many Requests"; @@ -67,17 +67,17 @@ local response_codes = { [451] = "Unavailable For Legal Reasons"; -- [452-499] = "Unassigned"; - [500] = "Internal Server Error"; -- RFC7231, Section 6.6.1 - [501] = "Not Implemented"; -- RFC7231, Section 6.6.2 - [502] = "Bad Gateway"; -- RFC7231, Section 6.6.3 - [503] = "Service Unavailable"; -- RFC7231, Section 6.6.4 - [504] = "Gateway Timeout"; -- RFC7231, Section 6.6.5 - [505] = "HTTP Version Not Supported"; -- RFC7231, Section 6.6.6 + [500] = "Internal Server Error"; -- RFC9110, Section 15.6.1 + [501] = "Not Implemented"; -- RFC9110, Section 15.6.2 + [502] = "Bad Gateway"; -- RFC9110, Section 15.6.3 + [503] = "Service Unavailable"; -- RFC9110, Section 15.6.4 + [504] = "Gateway Timeout"; -- RFC9110, Section 15.6.5 + [505] = "HTTP Version Not Supported"; -- RFC9110, Section 15.6.6 [506] = "Variant Also Negotiates"; [507] = "Insufficient Storage"; [508] = "Loop Detected"; -- [509] = "Unassigned"; - [510] = "Not Extended"; + [510] = "Not Extended"; -- (OBSOLETED) [511] = "Network Authentication Required"; -- [512-599] = "Unassigned"; }; diff --git a/net/http/errors.lua b/net/http/errors.lua index 1691e426..ca5f0ccc 100644 --- a/net/http/errors.lua +++ b/net/http/errors.lua @@ -2,8 +2,8 @@ -- and a function to return a util.error object given callback 'code' and 'body' -- parameters. -local codes = require "net.http.codes"; -local util_error = require "util.error"; +local codes = require "prosody.net.http.codes"; +local util_error = require "prosody.util.error"; local error_templates = { -- This code is used by us to report a client-side or connection error. diff --git a/net/http/files.lua b/net/http/files.lua index 583f7514..24d9f204 100644 --- a/net/http/files.lua +++ b/net/http/files.lua @@ -6,10 +6,10 @@ -- COPYING file in the source package for more information. -- -local server = require"net.http.server"; +local server = require"prosody.net.http.server"; local lfs = require "lfs"; -local new_cache = require "util.cache".new; -local log = require "util.logger".init("net.http.files"); +local new_cache = require "prosody.util.cache".new; +local log = require "prosody.util.logger".init("net.http.files"); local os_date = os.date; local open = io.open; @@ -23,7 +23,7 @@ if package.config:sub(1,1) == "\\" then forbidden_chars_pattern = "[/%z\001-\031\127\"*:<>?|]" end -local urldecode = require "util.http".urldecode; +local urldecode = require "prosody.util.http".urldecode; local function sanitize_path(path) --> util.paths or util.http? if not path then return end local out = {}; diff --git a/net/http/parser.lua b/net/http/parser.lua index a6624662..196c3faa 100644 --- a/net/http/parser.lua +++ b/net/http/parser.lua @@ -1,8 +1,8 @@ local tonumber = tonumber; local assert = assert; local url_parse = require "socket.url".parse; -local urldecode = require "util.http".urldecode; -local dbuffer = require "util.dbuffer"; +local urldecode = require "prosody.util.http".urldecode; +local dbuffer = require "prosody.util.dbuffer"; local function preprocess_path(path) path = urldecode((path:gsub("//+", "/"))); diff --git a/net/http/server.lua b/net/http/server.lua index 43e6bc9f..3d7d52b7 100644 --- a/net/http/server.lua +++ b/net/http/server.lua @@ -1,19 +1,21 @@ local t_insert, t_concat = table.insert, table.concat; -local parser_new = require "net.http.parser".new; -local events = require "util.events".new(); -local addserver = require "net.server".addserver; -local log = require "util.logger".init("http.server"); +local parser_new = require "prosody.net.http.parser".new; +local events = require "prosody.util.events".new(); +local addserver = require "prosody.net.server".addserver; +local logger = require "prosody.util.logger"; +local log = logger.init("http.server"); local os_date = os.date; local pairs = pairs; local s_upper = string.upper; local setmetatable = setmetatable; -local cache = require "util.cache"; -local codes = require "net.http.codes"; -local promise = require "util.promise"; -local errors = require "util.error"; +local cache = require "prosody.util.cache"; +local codes = require "prosody.net.http.codes"; +local promise = require "prosody.util.promise"; +local errors = require "prosody.util.error"; local blocksize = 2^16; -local async = require "util.async"; +local async = require "prosody.util.async"; +local id = require"prosody.util.id"; local _M = {}; @@ -105,7 +107,12 @@ end function runner_callbacks:error(err) log("error", "Traceback[httpserver]: %s", err); - self.data.conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = err })); + local response = { headers = { content_type = "text/plain" }; body = "" }; + response.body = events.fire_event("http-error", { code = 500; private_message = err; response = response }); + self.data.conn:write("HTTP/1.0 500 Internal Server Error\r\n\z\ + X-Content-Type-Options: nosniff\r\n\z\ + Content-Type: " .. response.headers.content_type .. "\r\n\r\n"); + self.data.conn:write(response.body); self.data.conn:close(); end @@ -128,6 +135,8 @@ function listener.onconnect(conn) end, runner_callbacks, session); local function success_cb(request) --log("debug", "success_cb: %s", request.path); + request.id = id.short(); + request.log = logger.init("http." .. request.method .. "-" .. request.id); request.ip = ip; request.secure = secure; session.thread:run(request); @@ -232,6 +241,8 @@ function handle_request(conn, request, finish_cb) request.headers = headers; request.conn = conn; + request.log("debug", "%s %s HTTP/%s", request.method, request.path, request.httpversion); + local date_header = os_date('!%a, %d %b %Y %H:%M:%S GMT'); -- FIXME use local conn_header = request.headers.connection; conn_header = conn_header and ","..conn_header:gsub("[ \t]", ""):lower().."," or "" @@ -249,10 +260,12 @@ function handle_request(conn, request, finish_cb) local is_head_request = request.method == "HEAD"; local response = { + id = request.id; + log = request.log; request = request; is_head_request = is_head_request; status_code = 200; - headers = { date = date_header, connection = response_conn_header }; + headers = { date = date_header; connection = response_conn_header; x_request_id = request.id }; persistent = persistent; conn = conn; send = _M.send_response; @@ -281,11 +294,9 @@ function handle_request(conn, request, finish_cb) local global_event = request.method.." "..request.path:match("[^?]*"); local payload = { request = request, response = response }; - log("debug", "Firing event: %s", global_event); local result = events.fire_event(global_event, payload); if result == nil and is_head_request then local global_head_event = "GET "..request.path:match("[^?]*"); - log("debug", "Firing event: %s", global_head_event); result = events.fire_event(global_head_event, payload); end if result == nil then @@ -306,12 +317,10 @@ function handle_request(conn, request, finish_cb) end local host_event = request.method.." "..host..request.path:match("[^?]*"); - log("debug", "Firing event: %s", host_event); result = events.fire_event(host_event, payload); if result == nil and is_head_request then local host_head_event = "GET "..host..request.path:match("[^?]*"); - log("debug", "Firing event: %s", host_head_event); result = events.fire_event(host_head_event, payload); end end @@ -321,6 +330,7 @@ end local function prepare_header(response) local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]); + response.log("debug", "%s", status_line); local headers = response.headers; local output = { status_line }; for k,v in pairs(headers) do @@ -378,11 +388,11 @@ function _M.send_file(response, f) response.conn:write(chunk); else incomplete[response.conn] = nil; + if f.close then f:close(); end if chunked then response.conn:write("0\r\n\r\n"); end -- io.write("\n"); - if f.close then f:close(); end return response:done(); end end diff --git a/net/httpserver.lua b/net/httpserver.lua index 6b14313b..0dfd862e 100644 --- a/net/httpserver.lua +++ b/net/httpserver.lua @@ -1,5 +1,5 @@ -- COMPAT w/pre-0.9 -local log = require "util.logger".init("net.httpserver"); +local log = require "prosody.util.logger".init("net.httpserver"); local traceback = debug.traceback; local _ENV = nil; diff --git a/net/resolvers/basic.lua b/net/resolvers/basic.lua index 305bce76..d9d33310 100644 --- a/net/resolvers/basic.lua +++ b/net/resolvers/basic.lua @@ -1,14 +1,62 @@ -local adns = require "net.adns"; -local inet_pton = require "util.net".pton; -local inet_ntop = require "util.net".ntop; -local idna_to_ascii = require "util.encodings".idna.to_ascii; -local unpack = table.unpack or unpack; -- luacheck: ignore 113 +local adns = require "prosody.net.adns"; +local inet_pton = require "prosody.util.net".pton; +local inet_ntop = require "prosody.util.net".ntop; +local idna_to_ascii = require "prosody.util.encodings".idna.to_ascii; +local promise = require "prosody.util.promise"; +local t_move = require "prosody.util.table".move; local methods = {}; local resolver_mt = { __index = methods }; -- FIXME RFC 6724 +local function do_dns_lookup(self, dns_resolver, record_type, name, allow_insecure) + return promise.new(function (resolve, reject) + local ipv = (record_type == "A" and "4") or (record_type == "AAAA" and "6") or nil; + if ipv and self.extra["use_ipv"..ipv] == false then + return reject(("IPv%s disabled - %s lookup skipped"):format(ipv, record_type)); + elseif record_type == "TLSA" and self.extra.use_dane ~= true then + return reject("DANE disabled - TLSA lookup skipped"); + end + dns_resolver:lookup(function (answer, err) + if not answer then + return reject(err); + elseif answer.bogus then + return reject(("Validation error in %s lookup"):format(record_type)); + elseif not (answer.secure or allow_insecure) then + return reject(("Insecure response in %s lookup"):format(record_type)); + elseif answer.status and #answer == 0 then + return reject(("%s in %s lookup"):format(answer.status, record_type)); + end + + local targets = { secure = answer.secure }; + for _, record in ipairs(answer) do + if ipv then + table.insert(targets, { self.conn_type..ipv, record[record_type:lower()], self.port, self.extra }); + else + table.insert(targets, record[record_type:lower()]); + end + end + return resolve(targets); + end, name, record_type, "IN"); + end); +end + +local function merge_targets(ipv4_targets, ipv6_targets) + local result = { secure = ipv4_targets.secure and ipv6_targets.secure }; + local common_length = math.min(#ipv4_targets, #ipv6_targets); + for i = 1, common_length do + table.insert(result, ipv6_targets[i]); + table.insert(result, ipv4_targets[i]); + end + if common_length < #ipv4_targets then + t_move(ipv4_targets, common_length+1, #ipv4_targets, common_length+1, result); + elseif common_length < #ipv6_targets then + t_move(ipv6_targets, common_length+1, #ipv6_targets, common_length+1, result); + end + return result; +end + -- Find the next target to connect to, and -- pass it to cb() function methods:next(cb) @@ -18,7 +66,7 @@ function methods:next(cb) return; end local next_target = table.remove(self.targets, 1); - cb(unpack(next_target, 1, 4)); + cb(next_target[1], next_target[2], next_target[3], next_target[4], not not self.targets[1]); return; end @@ -28,91 +76,47 @@ function methods:next(cb) return; end - local secure = true; - local tlsa = {}; - local targets = {}; - local n = 3; - local function ready() - n = n - 1; - if n > 0 then return; end - self.targets = targets; + -- Resolve DNS to target list + local dns_resolver = adns.resolver(); + + local dns_lookups = { + ipv4 = do_dns_lookup(self, dns_resolver, "A", self.hostname, true); + ipv6 = do_dns_lookup(self, dns_resolver, "AAAA", self.hostname, true); + tlsa = do_dns_lookup(self, dns_resolver, "TLSA", ("_%d._%s.%s"):format(self.port, self.conn_type, self.hostname)); + }; + + promise.all_settled(dns_lookups):next(function (dns_results) + -- Combine targets, assign to self.targets, self:next(cb) + local have_ipv4 = dns_results.ipv4.status == "fulfilled"; + local have_ipv6 = dns_results.ipv6.status == "fulfilled"; + + if have_ipv4 and have_ipv6 then + self.targets = merge_targets(dns_results.ipv4.value, dns_results.ipv6.value); + elseif have_ipv4 then + self.targets = dns_results.ipv4.value; + elseif have_ipv6 then + self.targets = dns_results.ipv6.value; + else + self.targets = {}; + end + if self.extra and self.extra.use_dane then - if secure and tlsa[1] then - self.extra.tlsa = tlsa; + if self.targets.secure and dns_results.tlsa.status == "fulfilled" then + self.extra.tlsa = dns_results.tlsa.value; self.extra.dane_hostname = self.hostname; else self.extra.tlsa = nil; self.extra.dane_hostname = nil; end + elseif self.extra and self.extra.srv_secure then + self.extra.secure_hostname = self.hostname; end - self:next(cb); - end - -- Resolve DNS to target list - local dns_resolver = adns.resolver(); - - if not self.extra or self.extra.use_ipv4 ~= false then - dns_resolver:lookup(function (answer, err) - if answer then - secure = secure and answer.secure; - for _, record in ipairs(answer) do - table.insert(targets, { self.conn_type.."4", record.a, self.port, self.extra }); - end - if answer.bogus then - self.last_error = "Validation error in A lookup"; - elseif answer.status then - self.last_error = answer.status .. " in A lookup"; - end - else - self.last_error = err; - end - ready(); - end, self.hostname, "A", "IN"); - else - ready(); - end - - if not self.extra or self.extra.use_ipv6 ~= false then - dns_resolver:lookup(function (answer, err) - if answer then - secure = secure and answer.secure; - for _, record in ipairs(answer) do - table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra }); - end - if answer.bogus then - self.last_error = "Validation error in AAAA lookup"; - elseif answer.status then - self.last_error = answer.status .. " in AAAA lookup"; - end - else - self.last_error = err; - end - ready(); - end, self.hostname, "AAAA", "IN"); - else - ready(); - end - - if self.extra and self.extra.use_dane == true then - dns_resolver:lookup(function (answer, err) - if answer then - secure = secure and answer.secure; - for _, record in ipairs(answer) do - table.insert(tlsa, record.tlsa); - end - if answer.bogus then - self.last_error = "Validation error in TLSA lookup"; - elseif answer.status then - self.last_error = answer.status .. " in TLSA lookup"; - end - else - self.last_error = err; - end - ready(); - end, ("_%d._tcp.%s"):format(self.port, self.hostname), "TLSA", "IN"); - else - ready(); - end + self:next(cb); + end):catch(function (err) + self.last_error = err; + self.targets = {}; + end); end local function new(hostname, port, conn_type, extra) @@ -137,7 +141,7 @@ local function new(hostname, port, conn_type, extra) hostname = ascii_host; port = port; conn_type = conn_type; - extra = extra; + extra = extra or {}; targets = targets; }, resolver_mt); end diff --git a/net/resolvers/manual.lua b/net/resolvers/manual.lua index dbc40256..c766a11f 100644 --- a/net/resolvers/manual.lua +++ b/net/resolvers/manual.lua @@ -1,6 +1,6 @@ local methods = {}; local resolver_mt = { __index = methods }; -local unpack = table.unpack or unpack; -- luacheck: ignore 113 +local unpack = table.unpack; -- Find the next target to connect to, and -- pass it to cb() diff --git a/net/resolvers/service.lua b/net/resolvers/service.lua index 3810cac8..d4318bb5 100644 --- a/net/resolvers/service.lua +++ b/net/resolvers/service.lua @@ -1,24 +1,79 @@ -local adns = require "net.adns"; -local basic = require "net.resolvers.basic"; -local inet_pton = require "util.net".pton; -local idna_to_ascii = require "util.encodings".idna.to_ascii; -local unpack = table.unpack or unpack; -- luacheck: ignore 113 +local adns = require "prosody.net.adns"; +local basic = require "prosody.net.resolvers.basic"; +local inet_pton = require "prosody.util.net".pton; +local idna_to_ascii = require "prosody.util.encodings".idna.to_ascii; local methods = {}; local resolver_mt = { __index = methods }; +local function new_target_selector(rrset) + local rr_count = rrset and #rrset; + if not rr_count or rr_count == 0 then + rrset = nil; + else + table.sort(rrset, function (a, b) return a.srv.priority < b.srv.priority end); + end + local rrset_pos = 1; + local priority_bucket, bucket_total_weight, bucket_len, bucket_used; + return function () + if not rrset then return; end + + if not priority_bucket or bucket_used >= bucket_len then + if rrset_pos > rr_count then return; end -- Used up all records + + -- Going to start on a new priority now. Gather up all the next + -- records with the same priority and add them to priority_bucket + priority_bucket, bucket_total_weight, bucket_len, bucket_used = {}, 0, 0, 0; + local current_priority; + repeat + local curr_record = rrset[rrset_pos].srv; + if not current_priority then + current_priority = curr_record.priority; + elseif current_priority ~= curr_record.priority then + break; + end + table.insert(priority_bucket, curr_record); + bucket_total_weight = bucket_total_weight + curr_record.weight; + bucket_len = bucket_len + 1; + rrset_pos = rrset_pos + 1; + until rrset_pos > rr_count; + end + + bucket_used = bucket_used + 1; + local n, running_total = math.random(0, bucket_total_weight), 0; + local target_record; + for i = 1, bucket_len do + local candidate = priority_bucket[i]; + if candidate then + running_total = running_total + candidate.weight; + if running_total >= n then + target_record = candidate; + bucket_total_weight = bucket_total_weight - candidate.weight; + priority_bucket[i] = nil; + break; + end + end + end + return target_record; + end; +end + -- Find the next target to connect to, and -- pass it to cb() function methods:next(cb) - if self.targets then - if not self.resolver then - if #self.targets == 0 then + if self.resolver or self._get_next_target then + if not self.resolver then -- Do we have a basic resolver currently? + -- We don't, so fetch a new SRV target, create a new basic resolver for it + local next_srv_target = self._get_next_target and self._get_next_target(); + if not next_srv_target then + -- No more SRV targets left cb(nil); return; end - local next_target = table.remove(self.targets, 1); - self.resolver = basic.new(unpack(next_target, 1, 4)); + -- Create a new basic resolver for this SRV target + self.resolver = basic.new(next_srv_target.target, next_srv_target.port, self.conn_type, self.extra); end + -- Look up the next (basic) target from the current target's resolver self.resolver:next(function (...) if self.resolver then self.last_error = self.resolver.last_error; @@ -31,6 +86,9 @@ function methods:next(cb) end end); return; + elseif self.in_progress then + cb(nil); + return; end if not self.hostname then @@ -39,9 +97,9 @@ function methods:next(cb) return; end - local targets = {}; + self.in_progress = true; + local function ready() - self.targets = targets; self:next(cb); end @@ -53,17 +111,23 @@ function methods:next(cb) answer = {}; end if answer then - if self.extra and not answer.secure then - self.extra.use_dane = false; - elseif answer.bogus then + if answer.bogus then self.last_error = "Validation error in SRV lookup"; ready(); return; + elseif not answer.secure then + if self.extra then + -- Insecure results, so no DANE + self.extra.use_dane = false; + end + end + if self.extra then + self.extra.srv_secure = answer.secure; end if #answer == 0 then if self.extra and self.extra.default_port then - table.insert(targets, { self.hostname, self.extra.default_port, self.conn_type, self.extra }); + self.resolver = basic.new(self.hostname, self.extra.default_port, self.conn_type, self.extra); else self.last_error = "zero SRV records found"; end @@ -77,10 +141,7 @@ function methods:next(cb) return; end - table.sort(answer, function (a, b) return a.srv.priority < b.srv.priority end); - for _, record in ipairs(answer) do - table.insert(targets, { record.srv.target, record.srv.port, self.conn_type, self.extra }); - end + self._get_next_target = new_target_selector(answer); else self.last_error = err; end diff --git a/net/server.lua b/net/server.lua index 0696fd52..942aceb2 100644 --- a/net/server.lua +++ b/net/server.lua @@ -6,20 +6,23 @@ -- COPYING file in the source package for more information. -- -if not (prosody and prosody.config_loaded) then - -- This module only supports loading inside Prosody, outside Prosody - -- you should directly require net.server_select or server_event, etc. - error(debug.traceback("Loading outside Prosody or Prosody not yet initialized"), 0); +local function log(level, format, ...) + print("net.server", level, format:format(...)); end -local log = require "util.logger".init("net.server"); +local default_backend = "select"; +local server_type = default_backend; -local default_backend = "epoll"; +if (prosody and prosody.config_loaded) then + default_backend = "epoll"; + log = require"prosody.util.logger".init("net.server"); + server_type = require"prosody.core.configmanager".get("*", "network_backend") or default_backend; -local server_type = require "core.configmanager".get("*", "network_backend") or default_backend; - -if require "core.configmanager".get("*", "use_libevent") then - server_type = "event"; + if require"prosody.core.configmanager".get("*", "use_libevent") then + server_type = "event"; + end +elseif pcall(require, "prosody.util.poll") then + server_type = "epoll"; end if server_type == "event" then @@ -32,7 +35,7 @@ end local server; local set_config; if server_type == "event" then - server = require "net.server_event"; + server = require "prosody.net.server_event"; local defaults = {}; for k,v in pairs(server.cfg) do @@ -61,7 +64,7 @@ if server_type == "event" then elseif server_type == "select" then -- TODO Remove completely. log("warn", "select is deprecated, the new default is epoll. For more info see https://prosody.im/doc/network_backend"); - server = require "net.server_select"; + server = require "prosody.net.server_select"; local defaults = {}; for k,v in pairs(server.getsettings()) do @@ -75,7 +78,7 @@ elseif server_type == "select" then server.changesettings(select_settings); end else - server = require("net.server_"..server_type); + server = require("prosody.net.server_"..server_type); set_config = server.set_config; if not server.get_backend then function server.get_backend() @@ -85,7 +88,7 @@ else end -- If server.hook_signal exists, replace signal.signal() -local has_signal, signal = pcall(require, "util.signal"); +local has_signal, signal = pcall(require, "prosody.util.signal"); if has_signal then if server.hook_signal then function signal.signal(signal_id, handler) @@ -109,7 +112,7 @@ else end if prosody and set_config then - local config_get = require "core.configmanager".get; + local config_get = require "prosody.core.configmanager".get; local function load_config() local settings = config_get("*", "network_settings") or {}; return set_config(settings); @@ -118,6 +121,15 @@ if prosody and set_config then prosody.events.add_handler("config-reloaded", load_config); end --- require "net.server" shall now forever return this, +if prosody and server.tls_builder then + local tls_builder = server.tls_builder; + -- resolving the basedir here avoids util.sslconfig depending on + -- prosody.paths.config + function server.tls_builder() + return tls_builder(prosody.paths.config or "") + end +end + +-- require "prosody.net.server" shall now forever return this, -- ie. server_select or server_event as chosen above. return server; diff --git a/net/server_epoll.lua b/net/server_epoll.lua index fa275d71..e6b69258 100644 --- a/net/server_epoll.lua +++ b/net/server_epoll.lua @@ -15,21 +15,22 @@ local next = next; local pairs = pairs; local ipairs = ipairs; local traceback = debug.traceback; -local logger = require "util.logger"; +local logger = require "prosody.util.logger"; local log = logger.init("server_epoll"); local socket = require "socket"; -local luasec = require "ssl"; -local realtime = require "util.time".now; -local monotonic = require "util.time".monotonic; -local indexedbheap = require "util.indexedbheap"; -local createtable = require "util.table".create; -local inet = require "util.net"; +local realtime = require "prosody.util.time".now; +local monotonic = require "prosody.util.time".monotonic; +local indexedbheap = require "prosody.util.indexedbheap"; +local createtable = require "prosody.util.table".create; +local inet = require "prosody.util.net"; local inet_pton = inet.pton; local _SOCKETINVALID = socket._SOCKETINVALID or -1; -local new_id = require "util.id".short; -local xpcall = require "util.xpcall".xpcall; +local new_id = require "prosody.util.id".short; +local xpcall = require "prosody.util.xpcall".xpcall; +local sslconfig = require "prosody.util.sslconfig"; +local tls_impl = require "prosody.net.tls_luasec"; -local poller = require "util.poll" +local poller = require "prosody.util.poll" local EEXIST = poller.EEXIST; local ENOENT = poller.ENOENT; @@ -91,6 +92,12 @@ local default_config = { __index = { --- How long to wait after getting the shutdown signal before forcefully tearing down every socket shutdown_deadline = 5; + + -- TCP Fast Open + tcp_fastopen = false; + + -- Defer accept until incoming data is available + tcp_defer_accept = false; }}; local cfg = default_config.__index; @@ -614,6 +621,42 @@ function interface:set_sslctx(sslctx) self._sslctx = sslctx; end +function interface:sslctx() + return self.tls_ctx +end + +function interface:ssl_info() + local sock = self.conn; + if not sock.info then return nil, "not-implemented"; end + return sock:info(); +end + +function interface:ssl_peercertificate() + local sock = self.conn; + if not sock.getpeercertificate then return nil, "not-implemented"; end + return sock:getpeercertificate(); +end + +function interface:ssl_peerverification() + local sock = self.conn; + if not sock.getpeerverification then return nil, { { "Chain verification not supported" } }; end + return sock:getpeerverification(); +end + +function interface:ssl_peerfinished() + local sock = self.conn; + if not sock.getpeerfinished then return nil, "not-implemented"; end + return sock:getpeerfinished(); +end + +function interface:ssl_exportkeyingmaterial(label, len, context) + local sock = self.conn; + if sock.exportkeyingmaterial then + return sock:exportkeyingmaterial(label, len, context); + end +end + + function interface:starttls(tls_ctx) if tls_ctx then self.tls_ctx = tls_ctx; end self.starttls = false; @@ -641,11 +684,7 @@ function interface:inittls(tls_ctx, now) self.starttls = false; self:debug("Starting TLS now"); self:updatenames(); -- Can't getpeer/sockname after wrap() - local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx); - if not ok then - conn, err = ok, conn; - self:debug("Failed to initialize TLS: %s", err); - end + local conn, err = self.tls_ctx:wrap(self.conn); if not conn then self:on("disconnect", err); self:destroy(); @@ -656,8 +695,8 @@ function interface:inittls(tls_ctx, now) if conn.sni then if self.servername then conn:sni(self.servername); - elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then - conn:sni(self._server.hosts, true); + elseif next(self.tls_ctx._sni_contexts) ~= nil then + conn:sni(self.tls_ctx._sni_contexts, true); end end if self.extra and self.extra.tlsa and conn.settlsa then @@ -741,7 +780,6 @@ local function wrapsocket(client, server, read_size, listeners, tls_ctx, extra) end end - conn:updatenames(); return conn; end @@ -767,6 +805,7 @@ function interface:onacceptable() return; end local client = wrapsocket(conn, self, nil, self.listeners); + client:updatenames(); client:debug("New connection %s on server %s", client, self); client:defaultoptions(); client._writable = cfg.opportunistic_writes; @@ -885,6 +924,12 @@ local function wrapserver(conn, addr, port, listeners, config) log = logger.init(("serv%s"):format(new_id())); }, interface_mt); server:debug("Server %s created", server); + if cfg.tcp_fastopen then + server:setoption("tcp-fastopen", cfg.tcp_fastopen); + end + if type(cfg.tcp_defer_accept) == "number" then + server:setoption("tcp-defer-accept", cfg.tcp_defer_accept); + end server:add(true, false); return server; end @@ -908,6 +953,7 @@ end -- COMPAT local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx, extra) local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra); + client:updatenames(); if not client.peername then client.peername, client.peerport = addr, port; end @@ -941,9 +987,13 @@ local function addclient(addr, port, listeners, read_size, tls_ctx, typ, extra) if not conn then return conn, err; end local ok, err = conn:settimeout(0); if not ok then return ok, err; end + local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra) + if cfg.tcp_fastopen then + client:setoption("tcp-fastopen-connect", 1); + end local ok, err = conn:setpeername(addr, port); if not ok and err ~= "timeout" then return ok, err; end - local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra) + client:updatenames(); local ok, err = client:init(); if not client.peername then -- otherwise not set until connected @@ -1032,12 +1082,38 @@ local function setquitting(quit) end end +local function loop_once() + runtimers(); -- Ignore return value because we only do this once + local fd, r, w = poll:wait(0); + if fd then + local conn = fds[fd]; + if conn then + if r then + conn:onreadable(); + end + if w then + conn:onwritable(); + end + else + log("debug", "Removing unknown fd %d", fd); + poll:del(fd); + end + else + return fd, r; + end +end + -- Main loop local function loop(once) - repeat - local t = runtimers(cfg.max_wait, cfg.min_wait); + if once then + return loop_once(); + end + + local t = 0; + while not quitting do local fd, r, w = poll:wait(t); - while fd do + if fd then + t = 0; local conn = fds[fd]; if conn then if r then @@ -1050,12 +1126,12 @@ local function loop(once) log("debug", "Removing unknown fd %d", fd); poll:del(fd); end - fd, r, w = poll:wait(0); - end - if r ~= "timeout" and r ~= "signal" then + elseif r == "timeout" then + t = runtimers(cfg.max_wait, cfg.min_wait); + elseif r ~= "signal" then log("debug", "epoll_wait error: %s[%d]", r, w); end - until once or (quitting and next(fds) == nil); + end return quitting; end @@ -1085,6 +1161,10 @@ return { cfg = setmetatable(newconfig, default_config); end; + tls_builder = function(basedir) + return sslconfig._new(tls_impl.new_context, basedir) + end, + -- libevent emulation event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 }; addevent = function (fd, mode, callback) diff --git a/net/server_event.lua b/net/server_event.lua index c30181b8..3bad3474 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -47,15 +47,17 @@ local s_sub = string.sub local coroutine_wrap = coroutine.wrap local coroutine_yield = coroutine.yield -local has_luasec, ssl = pcall ( require , "ssl" ) +local has_luasec = pcall ( require , "ssl" ) local socket = require "socket" local levent = require "luaevent.core" -local inet = require "util.net"; +local inet = require "prosody.util.net"; local inet_pton = inet.pton; +local sslconfig = require "prosody.util.sslconfig"; +local tls_impl = require "prosody.net.tls_luasec"; local socket_gettime = socket.gettime -local log = require ("util.logger").init("socket") +local log = require ("prosody.util.logger").init("socket") local function debug(...) return log("debug", ("%s "):rep(select('#', ...)), ...) @@ -153,7 +155,7 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed _ = self.eventwrite and self.eventwrite:close( ) self.eventread, self.eventwrite = nil, nil local err - self.conn, err = ssl.wrap( self.conn, self._sslctx ) + self.conn, err = self._sslctx:wrap(self.conn) if err then self.fatalerror = err self.conn = nil -- cannot be used anymore @@ -168,8 +170,8 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed if self.conn.sni then if self.servername then self.conn:sni(self.servername); - elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then - self.conn:sni(self._server.hosts, true); + elseif next(self._sslctx._sni_contexts) ~= nil then + self.conn:sni(self._sslctx._sni_contexts, true); end end @@ -274,6 +276,34 @@ function interface_mt:pause() return self:_lock(self.nointerface, true, self.nowriting); end +function interface_mt:sslctx() + return self._sslctx +end + +function interface_mt:ssl_info() + local sock = self.conn; + if not sock.info then return nil, "not-implemented"; end + return sock:info(); +end + +function interface_mt:ssl_peercertificate() + local sock = self.conn; + if not sock.getpeercertificate then return nil, "not-implemented"; end + return sock:getpeercertificate(); +end + +function interface_mt:ssl_peerverification() + local sock = self.conn; + if not sock.getpeerverification then return nil, { { "Chain verification not supported" } }; end + return sock:getpeerverification(); +end + +function interface_mt:ssl_peerfinished() + local sock = self.conn; + if not sock.getpeerfinished then return nil, "not-implemented"; end + return sock:getpeerfinished(); +end + function interface_mt:resume() self:_lock(self.nointerface, false, self.nowriting); if self.readcallback and not self.eventread then @@ -924,6 +954,10 @@ return { add_task = add_task, watchfd = watchfd, + tls_builder = function(basedir) + return sslconfig._new(tls_impl.new_context, basedir) + end, + __NAME = SCRIPT_NAME, __DATE = LAST_MODIFIED, __AUTHOR = SCRIPT_AUTHOR, diff --git a/net/server_select.lua b/net/server_select.lua index eea850ce..cccf7a4e 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -11,7 +11,7 @@ local use = function( what ) return _G[ what ] end -local log, table_concat = require ("util.logger").init("socket"), table.concat; +local log, table_concat = require ("prosody.util.logger").init("socket"), table.concat; local out_put = function (...) return log("debug", table_concat{...}); end local out_error = function (...) return log("warn", table_concat{...}); end @@ -47,15 +47,15 @@ local coroutine_yield = coroutine.yield --// extern libs //-- -local has_luasec, luasec = pcall ( require , "ssl" ) local luasocket = use "socket" or require "socket" local luasocket_gettime = luasocket.gettime -local inet = require "util.net"; +local inet = require "prosody.util.net"; local inet_pton = inet.pton; +local sslconfig = require "prosody.util.sslconfig"; +local has_luasec, tls_impl = pcall(require, "prosody.net.tls_luasec"); --// extern lib methods //-- -local ssl_wrap = ( has_luasec and luasec.wrap ) local socket_bind = luasocket.bind local socket_select = luasocket.select @@ -359,6 +359,21 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.sslctx = function ( ) return sslctx end + handler.ssl_info = function( ) + return socket.info and socket:info() + end + handler.ssl_peercertificate = function( ) + if not socket.getpeercertificate then return nil, "not-implemented"; end + return socket:getpeercertificate() + end + handler.ssl_peerverification = function( ) + if not socket.getpeerverification then return nil, { { "Chain verification not supported" } }; end + return socket:getpeerverification(); + end + handler.ssl_peerfinished = function( ) + if not socket.getpeerfinished then return nil, "not-implemented"; end + return socket:getpeerfinished(); + end handler.send = function( _, data, i, j ) return send( socket, data, i, j ) end @@ -652,7 +667,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) local oldsocket, err = socket - socket, err = ssl_wrap( socket, sslctx ) -- wrap socket + socket, err = sslctx:wrap(socket) -- wrap socket if not socket then out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) @@ -662,8 +677,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport if socket.sni then if self.servername then socket:sni(self.servername); - elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then - socket:sni(self.server().hosts, true); + elseif next(sslctx._sni_contexts) ~= nil then + socket:sni(sslctx._sni_contexts, true); end end @@ -1169,4 +1184,8 @@ return { removeserver = removeserver, get_backend = get_backend, changesettings = changesettings, + + tls_builder = function(basedir) + return sslconfig._new(tls_impl.new_context, basedir) + end, } diff --git a/net/stun.lua b/net/stun.lua index 2c35786f..586f96e2 100644 --- a/net/stun.lua +++ b/net/stun.lua @@ -1,11 +1,11 @@ -local base64 = require "util.encodings".base64; -local hashes = require "util.hashes"; -local net = require "util.net"; -local random = require "util.random"; -local struct = require "util.struct"; -local bit32 = require"util.bitcompat"; -local sxor = require"util.strbitop".sxor; -local new_ip = require "util.ip".new_ip; +local base64 = require "prosody.util.encodings".base64; +local hashes = require "prosody.util.hashes"; +local net = require "prosody.util.net"; +local random = require "prosody.util.random"; +local struct = require "prosody.util.struct"; +local bit32 = require"prosody.util.bitcompat"; +local sxor = require"prosody.util.strbitop".sxor; +local new_ip = require "prosody.util.ip".new_ip; --- Public helpers diff --git a/net/tls_luasec.lua b/net/tls_luasec.lua new file mode 100644 index 00000000..3af2fc6b --- /dev/null +++ b/net/tls_luasec.lua @@ -0,0 +1,114 @@ +-- Prosody IM +-- Copyright (C) 2021 Prosody folks +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +--[[ +This file provides a shim abstraction over LuaSec, consolidating some code +which was previously spread between net.server backends, portmanager and +certmanager. + +The goal is to provide a more or less well-defined API on top of LuaSec which +abstracts away some of the things which are not needed and simplifies usage of +commonly used things (such as SNI contexts). Eventually, network backends +which do not rely on LuaSocket+LuaSec should be able to provide *this* API +instead of having to mimic LuaSec. +]] +local ssl = require "ssl"; +local ssl_newcontext = ssl.newcontext; +local ssl_context = ssl.context or require "ssl.context"; +local io_open = io.open; + +local context_api = {}; +local context_mt = {__index = context_api}; + +function context_api:set_sni_host(host, cert, key) + local ctx, err = self._builder:clone():apply({ + certificate = cert, + key = key, + }):build(); + if not ctx then + return false, err + end + + self._sni_contexts[host] = ctx._inner + + return true, nil +end + +function context_api:remove_sni_host(host) + self._sni_contexts[host] = nil +end + +function context_api:wrap(sock) + local ok, conn, err = pcall(ssl.wrap, sock, self._inner); + if not ok then + return nil, err + end + return conn, nil +end + +local function new_context(cfg, builder) + -- LuaSec expects dhparam to be a callback that takes two arguments. + -- We ignore those because it is mostly used for having a separate + -- set of params for EXPORT ciphers, which we don't have by default. + if type(cfg.dhparam) == "string" then + local f, err = io_open(cfg.dhparam); + if not f then return nil, "Could not open DH parameters: "..err end + local dhparam = f:read("*a"); + f:close(); + cfg.dhparam = function() return dhparam; end + end + + local inner, err = ssl_newcontext(cfg); + if not inner then + return nil, err + end + + -- COMPAT Older LuaSec ignores the cipher list from the config, so we have to take care + -- of it ourselves (W/A for #x) + if inner and cfg.ciphers then + local success; + success, err = ssl_context.setcipher(inner, cfg.ciphers); + if not success then + return nil, err + end + end + + return setmetatable({ + _inner = inner, + _builder = builder, + _sni_contexts = {}, + }, context_mt), nil +end + +-- Feature detection / guessing +local function test_option(option) + return not not ssl_newcontext({mode="server",protocol="sslv23",options={ option }}); +end +local luasec_major, luasec_minor = ssl._VERSION:match("^(%d+)%.(%d+)"); +local luasec_version = tonumber(luasec_major) * 100 + tonumber(luasec_minor); +local luasec_has = ssl.config or { + algorithms = { + ec = luasec_version >= 5; + }; + capabilities = { + curves_list = luasec_version >= 7; + }; + options = { + cipher_server_preference = test_option("cipher_server_preference"); + no_ticket = test_option("no_ticket"); + no_compression = test_option("no_compression"); + single_dh_use = test_option("single_dh_use"); + single_ecdh_use = test_option("single_ecdh_use"); + no_renegotiation = test_option("no_renegotiation"); + }; +}; + +return { + features = luasec_has; + new_context = new_context, + load_certificate = ssl.loadcertificate; +}; diff --git a/net/unbound.lua b/net/unbound.lua index ee742b7c..c3ccb9ea 100644 --- a/net/unbound.lua +++ b/net/unbound.lua @@ -13,15 +13,15 @@ local s_lower = string.lower; local s_upper = string.upper; local noop = function() end; -local logger = require "util.logger"; +local logger = require "prosody.util.logger"; local log = logger.init("unbound"); -local net_server = require "net.server"; +local net_server = require "prosody.net.server"; local libunbound = require"lunbound"; -local promise = require"util.promise"; -local new_id = require "util.id".short; +local promise = require"prosody.util.promise"; +local new_id = require "prosody.util.id".short; local gettime = require"socket".gettime; -local dns_utils = require"util.dns"; +local dns_utils = require"prosody.util.dns"; local classes, types, errors = dns_utils.classes, dns_utils.types, dns_utils.errors; local parsers = dns_utils.parsers; @@ -44,7 +44,7 @@ end local unbound_config; if prosody then - local config = require"core.configmanager"; + local config = require"prosody.core.configmanager"; unbound_config = add_defaults(config.get("*", "unbound")); prosody.events.add_handler("config-reloaded", function() unbound_config = add_defaults(config.get("*", "unbound")); diff --git a/net/websocket.lua b/net/websocket.lua index 193cd556..708a7be7 100644 --- a/net/websocket.lua +++ b/net/websocket.lua @@ -8,13 +8,13 @@ local t_concat = table.concat; -local http = require "net.http"; -local frames = require "net.websocket.frames"; -local base64 = require "util.encodings".base64; -local sha1 = require "util.hashes".sha1; -local random_bytes = require "util.random".bytes; -local timer = require "util.timer"; -local log = require "util.logger".init "websocket"; +local http = require "prosody.net.http"; +local frames = require "prosody.net.websocket.frames"; +local base64 = require "prosody.util.encodings".base64; +local sha1 = require "prosody.util.hashes".sha1; +local random_bytes = require "prosody.util.random".bytes; +local timer = require "prosody.util.timer"; +local log = require "prosody.util.logger".init "websocket"; local close_timeout = 3; -- Seconds to wait after sending close frame until closing connection. diff --git a/net/websocket/frames.lua b/net/websocket/frames.lua index 6a088902..c22a12de 100644 --- a/net/websocket/frames.lua +++ b/net/websocket/frames.lua @@ -6,17 +6,17 @@ -- COPYING file in the source package for more information. -- -local random_bytes = require "util.random".bytes; +local random_bytes = require "prosody.util.random".bytes; -local bit = require "util.bitcompat"; +local bit = require "prosody.util.bitcompat"; local band = bit.band; local bor = bit.bor; -local sbit = require "util.strbitop"; +local sbit = require "prosody.util.strbitop"; local sxor = sbit.sxor; local s_char = string.char; -local s_pack = require"util.struct".pack; -local s_unpack = require"util.struct".unpack; +local s_pack = require"prosody.util.struct".pack; +local s_unpack = require"prosody.util.struct".unpack; local function pack_uint16be(x) return s_pack(">I2", x); @@ -77,7 +77,6 @@ local function parse_frame_header(frame) end -- XORs the string `str` with the array of bytes `key` --- TODO: optimize local function apply_mask(str, key, from, to) return sxor(str:sub(from or 1, to or -1), key); end |