diff options
author | Kim Alvefur <zash@zash.se> | 2022-12-12 07:10:54 +0100 |
---|---|---|
committer | Kim Alvefur <zash@zash.se> | 2022-12-12 07:10:54 +0100 |
commit | 080d7974bf0c1da8a1c0578d67c3172facc9d719 (patch) | |
tree | 838d6904e47ab8681928b37701ff4f1c6e89184a /net | |
parent | baff85a52c5fda705e8b3699410c770f015d89ab (diff) | |
parent | c916ce76ee89dca32e7e653dff1ade4732462efc (diff) | |
download | prosody-080d7974bf0c1da8a1c0578d67c3172facc9d719.tar.gz prosody-080d7974bf0c1da8a1c0578d67c3172facc9d719.zip |
Merge 0.12->trunk
Diffstat (limited to 'net')
-rw-r--r-- | net/connect.lua | 46 | ||||
-rw-r--r-- | net/dns.lua | 4 | ||||
-rw-r--r-- | net/http/codes.lua | 92 | ||||
-rw-r--r-- | net/resolvers/basic.lua | 162 | ||||
-rw-r--r-- | net/resolvers/manual.lua | 2 | ||||
-rw-r--r-- | net/resolvers/service.lua | 81 | ||||
-rw-r--r-- | net/server.lua | 7 | ||||
-rw-r--r-- | net/server_epoll.lua | 74 | ||||
-rw-r--r-- | net/server_event.lua | 42 | ||||
-rw-r--r-- | net/server_select.lua | 29 | ||||
-rw-r--r-- | net/tls_luasec.lua | 89 |
11 files changed, 457 insertions, 171 deletions
diff --git a/net/connect.lua b/net/connect.lua index 4b602be4..3cb407a1 100644 --- a/net/connect.lua +++ b/net/connect.lua @@ -1,8 +1,8 @@ local server = require "net.server"; local log = require "util.logger".init("net.connect"); local new_id = require "util.id".short; +local timer = require "util.timer"; --- TODO #1246 Happy Eyeballs -- FIXME RFC 6724 -- FIXME Error propagation from resolvers doesn't work -- FIXME #1428 Reuse DNS resolver object between service and basic resolver @@ -28,16 +28,17 @@ local pending_connection_listeners = {}; local function attempt_connection(p) p:log("debug", "Checking for targets..."); - if p.conn then - pending_connections_map[p.conn] = nil; - p.conn = nil; - end - p.target_resolver:next(function (conn_type, ip, port, extra) + p.target_resolver:next(function (conn_type, ip, port, extra, more_targets_available) if not conn_type then -- No more targets to try p:log("debug", "No more connection targets to try", p.target_resolver.last_error); - if p.listeners.onfail then - p.listeners.onfail(p.data, p.last_error or p.target_resolver.last_error or "unable to resolve service"); + if next(p.conns) == nil then + p:log("debug", "No more targets, no pending connections. Connection failed."); + if p.listeners.onfail then + p.listeners.onfail(p.data, p.last_error or p.target_resolver.last_error or "unable to resolve service"); + end + else + p:log("debug", "One or more connection attempts are still pending. Waiting for now."); end return; end @@ -49,8 +50,16 @@ local function attempt_connection(p) p.last_error = err or "unknown reason"; return attempt_connection(p); end - p.conn = conn; + p.conns[conn] = true; pending_connections_map[conn] = p; + if more_targets_available then + timer.add_task(0.250, function () + if not p.connected then + p:log("debug", "Still not connected, making parallel connection attempt..."); + attempt_connection(p); + end + end); + end end); end @@ -62,6 +71,13 @@ function pending_connection_listeners.onconnect(conn) return; end pending_connections_map[conn] = nil; + if p.connected then + -- We already succeeded in connecting + p.conns[conn] = nil; + conn:close(); + return; + end + p.connected = true; p:log("debug", "Successfully connected"); conn:setlistener(p.listeners, p.data); return p.listeners.onconnect(conn); @@ -73,9 +89,18 @@ function pending_connection_listeners.ondisconnect(conn, reason) log("warn", "Failed connection, but unexpected!"); return; end + p.conns[conn] = nil; + pending_connections_map[conn] = nil; p.last_error = reason or "unknown reason"; p:log("debug", "Connection attempt failed: %s", p.last_error); - attempt_connection(p); + if p.connected then + p:log("debug", "Connection already established, ignoring failure"); + elseif next(p.conns) == nil then + p:log("debug", "No pending connection attempts, and not yet connected"); + attempt_connection(p); + else + p:log("debug", "Other attempts are still pending, ignoring failure"); + end end local function connect(target_resolver, listeners, options, data) @@ -85,6 +110,7 @@ local function connect(target_resolver, listeners, options, data) listeners = assert(listeners); options = options or {}; data = data; + conns = {}; }, pending_connection_mt); p:log("debug", "Starting connection process"); diff --git a/net/dns.lua b/net/dns.lua index a9846e86..e6179637 100644 --- a/net/dns.lua +++ b/net/dns.lua @@ -8,8 +8,8 @@ -- todo: cache results of encodeName --- reference: http://tools.ietf.org/html/rfc1035 --- reference: http://tools.ietf.org/html/rfc1876 (LOC) +-- reference: https://www.rfc-editor.org/rfc/rfc1035.html +-- reference: https://www.rfc-editor.org/rfc/rfc1876.html (LOC) local socket = require "socket"; diff --git a/net/http/codes.lua b/net/http/codes.lua index 4327f151..b2949286 100644 --- a/net/http/codes.lua +++ b/net/http/codes.lua @@ -2,62 +2,62 @@ local response_codes = { -- Source: http://www.iana.org/assignments/http-status-codes - [100] = "Continue"; -- RFC7231, Section 6.2.1 - [101] = "Switching Protocols"; -- RFC7231, Section 6.2.2 + [100] = "Continue"; -- RFC9110, Section 15.2.1 + [101] = "Switching Protocols"; -- RFC9110, Section 15.2.2 [102] = "Processing"; [103] = "Early Hints"; -- [104-199] = "Unassigned"; - [200] = "OK"; -- RFC7231, Section 6.3.1 - [201] = "Created"; -- RFC7231, Section 6.3.2 - [202] = "Accepted"; -- RFC7231, Section 6.3.3 - [203] = "Non-Authoritative Information"; -- RFC7231, Section 6.3.4 - [204] = "No Content"; -- RFC7231, Section 6.3.5 - [205] = "Reset Content"; -- RFC7231, Section 6.3.6 - [206] = "Partial Content"; -- RFC7233, Section 4.1 + [200] = "OK"; -- RFC9110, Section 15.3.1 + [201] = "Created"; -- RFC9110, Section 15.3.2 + [202] = "Accepted"; -- RFC9110, Section 15.3.3 + [203] = "Non-Authoritative Information"; -- RFC9110, Section 15.3.4 + [204] = "No Content"; -- RFC9110, Section 15.3.5 + [205] = "Reset Content"; -- RFC9110, Section 15.3.6 + [206] = "Partial Content"; -- RFC9110, Section 15.3.7 [207] = "Multi-Status"; [208] = "Already Reported"; -- [209-225] = "Unassigned"; [226] = "IM Used"; -- [227-299] = "Unassigned"; - [300] = "Multiple Choices"; -- RFC7231, Section 6.4.1 - [301] = "Moved Permanently"; -- RFC7231, Section 6.4.2 - [302] = "Found"; -- RFC7231, Section 6.4.3 - [303] = "See Other"; -- RFC7231, Section 6.4.4 - [304] = "Not Modified"; -- RFC7232, Section 4.1 - [305] = "Use Proxy"; -- RFC7231, Section 6.4.5 - -- [306] = "(Unused)"; -- RFC7231, Section 6.4.6 - [307] = "Temporary Redirect"; -- RFC7231, Section 6.4.7 - [308] = "Permanent Redirect"; + [300] = "Multiple Choices"; -- RFC9110, Section 15.4.1 + [301] = "Moved Permanently"; -- RFC9110, Section 15.4.2 + [302] = "Found"; -- RFC9110, Section 15.4.3 + [303] = "See Other"; -- RFC9110, Section 15.4.4 + [304] = "Not Modified"; -- RFC9110, Section 15.4.5 + [305] = "Use Proxy"; -- RFC9110, Section 15.4.6 + -- [306] = "(Unused)"; -- RFC9110, Section 15.4.7 + [307] = "Temporary Redirect"; -- RFC9110, Section 15.4.8 + [308] = "Permanent Redirect"; -- RFC9110, Section 15.4.9 -- [309-399] = "Unassigned"; - [400] = "Bad Request"; -- RFC7231, Section 6.5.1 - [401] = "Unauthorized"; -- RFC7235, Section 3.1 - [402] = "Payment Required"; -- RFC7231, Section 6.5.2 - [403] = "Forbidden"; -- RFC7231, Section 6.5.3 - [404] = "Not Found"; -- RFC7231, Section 6.5.4 - [405] = "Method Not Allowed"; -- RFC7231, Section 6.5.5 - [406] = "Not Acceptable"; -- RFC7231, Section 6.5.6 - [407] = "Proxy Authentication Required"; -- RFC7235, Section 3.2 - [408] = "Request Timeout"; -- RFC7231, Section 6.5.7 - [409] = "Conflict"; -- RFC7231, Section 6.5.8 - [410] = "Gone"; -- RFC7231, Section 6.5.9 - [411] = "Length Required"; -- RFC7231, Section 6.5.10 - [412] = "Precondition Failed"; -- RFC7232, Section 4.2 - [413] = "Payload Too Large"; -- RFC7231, Section 6.5.11 - [414] = "URI Too Long"; -- RFC7231, Section 6.5.12 - [415] = "Unsupported Media Type"; -- RFC7231, Section 6.5.13 - [416] = "Range Not Satisfiable"; -- RFC7233, Section 4.4 - [417] = "Expectation Failed"; -- RFC7231, Section 6.5.14 + [400] = "Bad Request"; -- RFC9110, Section 15.5.1 + [401] = "Unauthorized"; -- RFC9110, Section 15.5.2 + [402] = "Payment Required"; -- RFC9110, Section 15.5.3 + [403] = "Forbidden"; -- RFC9110, Section 15.5.4 + [404] = "Not Found"; -- RFC9110, Section 15.5.5 + [405] = "Method Not Allowed"; -- RFC9110, Section 15.5.6 + [406] = "Not Acceptable"; -- RFC9110, Section 15.5.7 + [407] = "Proxy Authentication Required"; -- RFC9110, Section 15.5.8 + [408] = "Request Timeout"; -- RFC9110, Section 15.5.9 + [409] = "Conflict"; -- RFC9110, Section 15.5.10 + [410] = "Gone"; -- RFC9110, Section 15.5.11 + [411] = "Length Required"; -- RFC9110, Section 15.5.12 + [412] = "Precondition Failed"; -- RFC9110, Section 15.5.13 + [413] = "Content Too Large"; -- RFC9110, Section 15.5.14 + [414] = "URI Too Long"; -- RFC9110, Section 15.5.15 + [415] = "Unsupported Media Type"; -- RFC9110, Section 15.5.16 + [416] = "Range Not Satisfiable"; -- RFC9110, Section 15.5.17 + [417] = "Expectation Failed"; -- RFC9110, Section 15.5.18 [418] = "I'm a teapot"; -- RFC2324, Section 2.3.2 -- [419-420] = "Unassigned"; - [421] = "Misdirected Request"; -- RFC7540, Section 9.1.2 - [422] = "Unprocessable Entity"; + [421] = "Misdirected Request"; -- RFC9110, Section 15.5.20 + [422] = "Unprocessable Content"; -- RFC9110, Section 15.5.21 [423] = "Locked"; [424] = "Failed Dependency"; [425] = "Too Early"; - [426] = "Upgrade Required"; -- RFC7231, Section 6.5.15 + [426] = "Upgrade Required"; -- RFC9110, Section 15.5.22 -- [427] = "Unassigned"; [428] = "Precondition Required"; [429] = "Too Many Requests"; @@ -67,17 +67,17 @@ local response_codes = { [451] = "Unavailable For Legal Reasons"; -- [452-499] = "Unassigned"; - [500] = "Internal Server Error"; -- RFC7231, Section 6.6.1 - [501] = "Not Implemented"; -- RFC7231, Section 6.6.2 - [502] = "Bad Gateway"; -- RFC7231, Section 6.6.3 - [503] = "Service Unavailable"; -- RFC7231, Section 6.6.4 - [504] = "Gateway Timeout"; -- RFC7231, Section 6.6.5 - [505] = "HTTP Version Not Supported"; -- RFC7231, Section 6.6.6 + [500] = "Internal Server Error"; -- RFC9110, Section 15.6.1 + [501] = "Not Implemented"; -- RFC9110, Section 15.6.2 + [502] = "Bad Gateway"; -- RFC9110, Section 15.6.3 + [503] = "Service Unavailable"; -- RFC9110, Section 15.6.4 + [504] = "Gateway Timeout"; -- RFC9110, Section 15.6.5 + [505] = "HTTP Version Not Supported"; -- RFC9110, Section 15.6.6 [506] = "Variant Also Negotiates"; [507] = "Insufficient Storage"; [508] = "Loop Detected"; -- [509] = "Unassigned"; - [510] = "Not Extended"; + [510] = "Not Extended"; -- (OBSOLETED) [511] = "Network Authentication Required"; -- [512-599] = "Unassigned"; }; diff --git a/net/resolvers/basic.lua b/net/resolvers/basic.lua index 305bce76..e58165ba 100644 --- a/net/resolvers/basic.lua +++ b/net/resolvers/basic.lua @@ -2,13 +2,61 @@ local adns = require "net.adns"; local inet_pton = require "util.net".pton; local inet_ntop = require "util.net".ntop; local idna_to_ascii = require "util.encodings".idna.to_ascii; -local unpack = table.unpack or unpack; -- luacheck: ignore 113 +local promise = require "util.promise"; +local t_move = require "util.table".move; local methods = {}; local resolver_mt = { __index = methods }; -- FIXME RFC 6724 +local function do_dns_lookup(self, dns_resolver, record_type, name, allow_insecure) + return promise.new(function (resolve, reject) + local ipv = (record_type == "A" and "4") or (record_type == "AAAA" and "6") or nil; + if ipv and self.extra["use_ipv"..ipv] == false then + return reject(("IPv%s disabled - %s lookup skipped"):format(ipv, record_type)); + elseif record_type == "TLSA" and self.extra.use_dane ~= true then + return reject("DANE disabled - TLSA lookup skipped"); + end + dns_resolver:lookup(function (answer, err) + if not answer then + return reject(err); + elseif answer.bogus then + return reject(("Validation error in %s lookup"):format(record_type)); + elseif not (answer.secure or allow_insecure) then + return reject(("Insecure response in %s lookup"):format(record_type)); + elseif answer.status and #answer == 0 then + return reject(("%s in %s lookup"):format(answer.status, record_type)); + end + + local targets = { secure = answer.secure }; + for _, record in ipairs(answer) do + if ipv then + table.insert(targets, { self.conn_type..ipv, record[record_type:lower()], self.port, self.extra }); + else + table.insert(targets, record[record_type:lower()]); + end + end + return resolve(targets); + end, name, record_type, "IN"); + end); +end + +local function merge_targets(ipv4_targets, ipv6_targets) + local result = { secure = ipv4_targets.secure and ipv6_targets.secure }; + local common_length = math.min(#ipv4_targets, #ipv6_targets); + for i = 1, common_length do + table.insert(result, ipv6_targets[i]); + table.insert(result, ipv4_targets[i]); + end + if common_length < #ipv4_targets then + t_move(ipv4_targets, common_length+1, #ipv4_targets, common_length+1, result); + elseif common_length < #ipv6_targets then + t_move(ipv6_targets, common_length+1, #ipv6_targets, common_length+1, result); + end + return result; +end + -- Find the next target to connect to, and -- pass it to cb() function methods:next(cb) @@ -18,7 +66,7 @@ function methods:next(cb) return; end local next_target = table.remove(self.targets, 1); - cb(unpack(next_target, 1, 4)); + cb(next_target[1], next_target[2], next_target[3], next_target[4], not not self.targets[1]); return; end @@ -28,91 +76,45 @@ function methods:next(cb) return; end - local secure = true; - local tlsa = {}; - local targets = {}; - local n = 3; - local function ready() - n = n - 1; - if n > 0 then return; end - self.targets = targets; + -- Resolve DNS to target list + local dns_resolver = adns.resolver(); + + local dns_lookups = { + ipv4 = do_dns_lookup(self, dns_resolver, "A", self.hostname, true); + ipv6 = do_dns_lookup(self, dns_resolver, "AAAA", self.hostname, true); + tlsa = do_dns_lookup(self, dns_resolver, "TLSA", ("_%d._%s.%s"):format(self.port, self.conn_type, self.hostname)); + }; + + promise.all_settled(dns_lookups):next(function (dns_results) + -- Combine targets, assign to self.targets, self:next(cb) + local have_ipv4 = dns_results.ipv4.status == "fulfilled"; + local have_ipv6 = dns_results.ipv6.status == "fulfilled"; + + if have_ipv4 and have_ipv6 then + self.targets = merge_targets(dns_results.ipv4.value, dns_results.ipv6.value); + elseif have_ipv4 then + self.targets = dns_results.ipv4.value; + elseif have_ipv6 then + self.targets = dns_results.ipv6.value; + else + self.targets = {}; + end + if self.extra and self.extra.use_dane then - if secure and tlsa[1] then - self.extra.tlsa = tlsa; + if self.targets.secure and dns_results.tlsa.status == "fulfilled" then + self.extra.tlsa = dns_results.tlsa.value; self.extra.dane_hostname = self.hostname; else self.extra.tlsa = nil; self.extra.dane_hostname = nil; end end - self:next(cb); - end - -- Resolve DNS to target list - local dns_resolver = adns.resolver(); - - if not self.extra or self.extra.use_ipv4 ~= false then - dns_resolver:lookup(function (answer, err) - if answer then - secure = secure and answer.secure; - for _, record in ipairs(answer) do - table.insert(targets, { self.conn_type.."4", record.a, self.port, self.extra }); - end - if answer.bogus then - self.last_error = "Validation error in A lookup"; - elseif answer.status then - self.last_error = answer.status .. " in A lookup"; - end - else - self.last_error = err; - end - ready(); - end, self.hostname, "A", "IN"); - else - ready(); - end - - if not self.extra or self.extra.use_ipv6 ~= false then - dns_resolver:lookup(function (answer, err) - if answer then - secure = secure and answer.secure; - for _, record in ipairs(answer) do - table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra }); - end - if answer.bogus then - self.last_error = "Validation error in AAAA lookup"; - elseif answer.status then - self.last_error = answer.status .. " in AAAA lookup"; - end - else - self.last_error = err; - end - ready(); - end, self.hostname, "AAAA", "IN"); - else - ready(); - end - - if self.extra and self.extra.use_dane == true then - dns_resolver:lookup(function (answer, err) - if answer then - secure = secure and answer.secure; - for _, record in ipairs(answer) do - table.insert(tlsa, record.tlsa); - end - if answer.bogus then - self.last_error = "Validation error in TLSA lookup"; - elseif answer.status then - self.last_error = answer.status .. " in TLSA lookup"; - end - else - self.last_error = err; - end - ready(); - end, ("_%d._tcp.%s"):format(self.port, self.hostname), "TLSA", "IN"); - else - ready(); - end + self:next(cb); + end):catch(function (err) + self.last_error = err; + self.targets = {}; + end); end local function new(hostname, port, conn_type, extra) @@ -137,7 +139,7 @@ local function new(hostname, port, conn_type, extra) hostname = ascii_host; port = port; conn_type = conn_type; - extra = extra; + extra = extra or {}; targets = targets; }, resolver_mt); end diff --git a/net/resolvers/manual.lua b/net/resolvers/manual.lua index dbc40256..c766a11f 100644 --- a/net/resolvers/manual.lua +++ b/net/resolvers/manual.lua @@ -1,6 +1,6 @@ local methods = {}; local resolver_mt = { __index = methods }; -local unpack = table.unpack or unpack; -- luacheck: ignore 113 +local unpack = table.unpack; -- Find the next target to connect to, and -- pass it to cb() diff --git a/net/resolvers/service.lua b/net/resolvers/service.lua index 3810cac8..a7ce76a3 100644 --- a/net/resolvers/service.lua +++ b/net/resolvers/service.lua @@ -2,23 +2,78 @@ local adns = require "net.adns"; local basic = require "net.resolvers.basic"; local inet_pton = require "util.net".pton; local idna_to_ascii = require "util.encodings".idna.to_ascii; -local unpack = table.unpack or unpack; -- luacheck: ignore 113 local methods = {}; local resolver_mt = { __index = methods }; +local function new_target_selector(rrset) + local rr_count = rrset and #rrset; + if not rr_count or rr_count == 0 then + rrset = nil; + else + table.sort(rrset, function (a, b) return a.srv.priority < b.srv.priority end); + end + local rrset_pos = 1; + local priority_bucket, bucket_total_weight, bucket_len, bucket_used; + return function () + if not rrset then return; end + + if not priority_bucket or bucket_used >= bucket_len then + if rrset_pos > rr_count then return; end -- Used up all records + + -- Going to start on a new priority now. Gather up all the next + -- records with the same priority and add them to priority_bucket + priority_bucket, bucket_total_weight, bucket_len, bucket_used = {}, 0, 0, 0; + local current_priority; + repeat + local curr_record = rrset[rrset_pos].srv; + if not current_priority then + current_priority = curr_record.priority; + elseif current_priority ~= curr_record.priority then + break; + end + table.insert(priority_bucket, curr_record); + bucket_total_weight = bucket_total_weight + curr_record.weight; + bucket_len = bucket_len + 1; + rrset_pos = rrset_pos + 1; + until rrset_pos > rr_count; + end + + bucket_used = bucket_used + 1; + local n, running_total = math.random(0, bucket_total_weight), 0; + local target_record; + for i = 1, bucket_len do + local candidate = priority_bucket[i]; + if candidate then + running_total = running_total + candidate.weight; + if running_total >= n then + target_record = candidate; + bucket_total_weight = bucket_total_weight - candidate.weight; + priority_bucket[i] = nil; + break; + end + end + end + return target_record; + end; +end + -- Find the next target to connect to, and -- pass it to cb() function methods:next(cb) - if self.targets then - if not self.resolver then - if #self.targets == 0 then + if self.resolver or self._get_next_target then + if not self.resolver then -- Do we have a basic resolver currently? + -- We don't, so fetch a new SRV target, create a new basic resolver for it + local next_srv_target = self._get_next_target and self._get_next_target(); + if not next_srv_target then + -- No more SRV targets left cb(nil); return; end - local next_target = table.remove(self.targets, 1); - self.resolver = basic.new(unpack(next_target, 1, 4)); + -- Create a new basic resolver for this SRV target + self.resolver = basic.new(next_srv_target.target, next_srv_target.port, self.conn_type, self.extra); end + -- Look up the next (basic) target from the current target's resolver self.resolver:next(function (...) if self.resolver then self.last_error = self.resolver.last_error; @@ -31,6 +86,9 @@ function methods:next(cb) end end); return; + elseif self.in_progress then + cb(nil); + return; end if not self.hostname then @@ -39,9 +97,9 @@ function methods:next(cb) return; end - local targets = {}; + self.in_progress = true; + local function ready() - self.targets = targets; self:next(cb); end @@ -63,7 +121,7 @@ function methods:next(cb) if #answer == 0 then if self.extra and self.extra.default_port then - table.insert(targets, { self.hostname, self.extra.default_port, self.conn_type, self.extra }); + self.resolver = basic.new(self.hostname, self.extra.default_port, self.conn_type, self.extra); else self.last_error = "zero SRV records found"; end @@ -77,10 +135,7 @@ function methods:next(cb) return; end - table.sort(answer, function (a, b) return a.srv.priority < b.srv.priority end); - for _, record in ipairs(answer) do - table.insert(targets, { record.srv.target, record.srv.port, self.conn_type, self.extra }); - end + self._get_next_target = new_target_selector(answer); else self.last_error = err; end diff --git a/net/server.lua b/net/server.lua index 0696fd52..72272bef 100644 --- a/net/server.lua +++ b/net/server.lua @@ -118,6 +118,13 @@ if prosody and set_config then prosody.events.add_handler("config-reloaded", load_config); end +local tls_builder = server.tls_builder; +-- resolving the basedir here avoids util.sslconfig depending on +-- prosody.paths.config +function server.tls_builder() + return tls_builder(prosody.paths.config or "") +end + -- require "net.server" shall now forever return this, -- ie. server_select or server_event as chosen above. return server; diff --git a/net/server_epoll.lua b/net/server_epoll.lua index fa275d71..b269bd9c 100644 --- a/net/server_epoll.lua +++ b/net/server_epoll.lua @@ -18,7 +18,6 @@ local traceback = debug.traceback; local logger = require "util.logger"; local log = logger.init("server_epoll"); local socket = require "socket"; -local luasec = require "ssl"; local realtime = require "util.time".now; local monotonic = require "util.time".monotonic; local indexedbheap = require "util.indexedbheap"; @@ -28,6 +27,8 @@ local inet_pton = inet.pton; local _SOCKETINVALID = socket._SOCKETINVALID or -1; local new_id = require "util.id".short; local xpcall = require "util.xpcall".xpcall; +local sslconfig = require "util.sslconfig"; +local tls_impl = require "net.tls_luasec"; local poller = require "util.poll" local EEXIST = poller.EEXIST; @@ -91,6 +92,12 @@ local default_config = { __index = { --- How long to wait after getting the shutdown signal before forcefully tearing down every socket shutdown_deadline = 5; + + -- TCP Fast Open + tcp_fastopen = false; + + -- Defer accept until incoming data is available + tcp_defer_accept = false; }}; local cfg = default_config.__index; @@ -614,6 +621,42 @@ function interface:set_sslctx(sslctx) self._sslctx = sslctx; end +function interface:sslctx() + return self.tls_ctx +end + +function interface:ssl_info() + local sock = self.conn; + if not sock.info then return nil, "not-implemented"; end + return sock:info(); +end + +function interface:ssl_peercertificate() + local sock = self.conn; + if not sock.getpeercertificate then return nil, "not-implemented"; end + return sock:getpeercertificate(); +end + +function interface:ssl_peerverification() + local sock = self.conn; + if not sock.getpeerverification then return nil, { { "Chain verification not supported" } }; end + return sock:getpeerverification(); +end + +function interface:ssl_peerfinished() + local sock = self.conn; + if not sock.getpeerfinished then return nil, "not-implemented"; end + return sock:getpeerfinished(); +end + +function interface:ssl_exportkeyingmaterial(label, len, context) + local sock = self.conn; + if sock.exportkeyingmaterial then + return sock:exportkeyingmaterial(label, len, context); + end +end + + function interface:starttls(tls_ctx) if tls_ctx then self.tls_ctx = tls_ctx; end self.starttls = false; @@ -641,11 +684,7 @@ function interface:inittls(tls_ctx, now) self.starttls = false; self:debug("Starting TLS now"); self:updatenames(); -- Can't getpeer/sockname after wrap() - local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx); - if not ok then - conn, err = ok, conn; - self:debug("Failed to initialize TLS: %s", err); - end + local conn, err = self.tls_ctx:wrap(self.conn); if not conn then self:on("disconnect", err); self:destroy(); @@ -656,8 +695,8 @@ function interface:inittls(tls_ctx, now) if conn.sni then if self.servername then conn:sni(self.servername); - elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then - conn:sni(self._server.hosts, true); + elseif next(self.tls_ctx._sni_contexts) ~= nil then + conn:sni(self.tls_ctx._sni_contexts, true); end end if self.extra and self.extra.tlsa and conn.settlsa then @@ -741,7 +780,6 @@ local function wrapsocket(client, server, read_size, listeners, tls_ctx, extra) end end - conn:updatenames(); return conn; end @@ -767,6 +805,7 @@ function interface:onacceptable() return; end local client = wrapsocket(conn, self, nil, self.listeners); + client:updatenames(); client:debug("New connection %s on server %s", client, self); client:defaultoptions(); client._writable = cfg.opportunistic_writes; @@ -885,6 +924,12 @@ local function wrapserver(conn, addr, port, listeners, config) log = logger.init(("serv%s"):format(new_id())); }, interface_mt); server:debug("Server %s created", server); + if cfg.tcp_fastopen then + server:setoption("tcp-fastopen", cfg.tcp_fastopen); + end + if type(cfg.tcp_defer_accept) == "number" then + server:setoption("tcp-defer-accept", cfg.tcp_defer_accept); + end server:add(true, false); return server; end @@ -908,6 +953,7 @@ end -- COMPAT local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx, extra) local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra); + client:updatenames(); if not client.peername then client.peername, client.peerport = addr, port; end @@ -941,9 +987,13 @@ local function addclient(addr, port, listeners, read_size, tls_ctx, typ, extra) if not conn then return conn, err; end local ok, err = conn:settimeout(0); if not ok then return ok, err; end + local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra) + if cfg.tcp_fastopen then + client:setoption("tcp-fastopen-connect", 1); + end local ok, err = conn:setpeername(addr, port); if not ok and err ~= "timeout" then return ok, err; end - local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra) + client:updatenames(); local ok, err = client:init(); if not client.peername then -- otherwise not set until connected @@ -1085,6 +1135,10 @@ return { cfg = setmetatable(newconfig, default_config); end; + tls_builder = function(basedir) + return sslconfig._new(tls_impl.new_context, basedir) + end, + -- libevent emulation event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 }; addevent = function (fd, mode, callback) diff --git a/net/server_event.lua b/net/server_event.lua index c30181b8..d8f08c8d 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -47,11 +47,13 @@ local s_sub = string.sub local coroutine_wrap = coroutine.wrap local coroutine_yield = coroutine.yield -local has_luasec, ssl = pcall ( require , "ssl" ) +local has_luasec = pcall ( require , "ssl" ) local socket = require "socket" local levent = require "luaevent.core" local inet = require "util.net"; local inet_pton = inet.pton; +local sslconfig = require "util.sslconfig"; +local tls_impl = require "net.tls_luasec"; local socket_gettime = socket.gettime @@ -153,7 +155,7 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed _ = self.eventwrite and self.eventwrite:close( ) self.eventread, self.eventwrite = nil, nil local err - self.conn, err = ssl.wrap( self.conn, self._sslctx ) + self.conn, err = self._sslctx:wrap(self.conn) if err then self.fatalerror = err self.conn = nil -- cannot be used anymore @@ -168,8 +170,8 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed if self.conn.sni then if self.servername then self.conn:sni(self.servername); - elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then - self.conn:sni(self._server.hosts, true); + elseif next(self._sslctx._sni_contexts) ~= nil then + self.conn:sni(self._sslctx._sni_contexts, true); end end @@ -274,6 +276,34 @@ function interface_mt:pause() return self:_lock(self.nointerface, true, self.nowriting); end +function interface_mt:sslctx() + return self._sslctx +end + +function interface_mt:ssl_info() + local sock = self.conn; + if not sock.info then return nil, "not-implemented"; end + return sock:info(); +end + +function interface_mt:ssl_peercertificate() + local sock = self.conn; + if not sock.getpeercertificate then return nil, "not-implemented"; end + return sock:getpeercertificate(); +end + +function interface_mt:ssl_peerverification() + local sock = self.conn; + if not sock.getpeerverification then return nil, { { "Chain verification not supported" } }; end + return sock:getpeerverification(); +end + +function interface_mt:ssl_peerfinished() + local sock = self.conn; + if not sock.getpeerfinished then return nil, "not-implemented"; end + return sock:getpeerfinished(); +end + function interface_mt:resume() self:_lock(self.nointerface, false, self.nowriting); if self.readcallback and not self.eventread then @@ -924,6 +954,10 @@ return { add_task = add_task, watchfd = watchfd, + tls_builder = function(basedir) + return sslconfig._new(tls_impl.new_context, basedir) + end, + __NAME = SCRIPT_NAME, __DATE = LAST_MODIFIED, __AUTHOR = SCRIPT_AUTHOR, diff --git a/net/server_select.lua b/net/server_select.lua index eea850ce..651bdfde 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -47,15 +47,15 @@ local coroutine_yield = coroutine.yield --// extern libs //-- -local has_luasec, luasec = pcall ( require , "ssl" ) local luasocket = use "socket" or require "socket" local luasocket_gettime = luasocket.gettime local inet = require "util.net"; local inet_pton = inet.pton; +local sslconfig = require "util.sslconfig"; +local has_luasec, tls_impl = pcall(require, "net.tls_luasec"); --// extern lib methods //-- -local ssl_wrap = ( has_luasec and luasec.wrap ) local socket_bind = luasocket.bind local socket_select = luasocket.select @@ -359,6 +359,21 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.sslctx = function ( ) return sslctx end + handler.ssl_info = function( ) + return socket.info and socket:info() + end + handler.ssl_peercertificate = function( ) + if not socket.getpeercertificate then return nil, "not-implemented"; end + return socket:getpeercertificate() + end + handler.ssl_peerverification = function( ) + if not socket.getpeerverification then return nil, { { "Chain verification not supported" } }; end + return socket:getpeerverification(); + end + handler.ssl_peerfinished = function( ) + if not socket.getpeerfinished then return nil, "not-implemented"; end + return socket:getpeerfinished(); + end handler.send = function( _, data, i, j ) return send( socket, data, i, j ) end @@ -652,7 +667,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) local oldsocket, err = socket - socket, err = ssl_wrap( socket, sslctx ) -- wrap socket + socket, err = sslctx:wrap(socket) -- wrap socket if not socket then out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) @@ -662,8 +677,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport if socket.sni then if self.servername then socket:sni(self.servername); - elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then - socket:sni(self.server().hosts, true); + elseif next(sslctx._sni_contexts) ~= nil then + socket:sni(sslctx._sni_contexts, true); end end @@ -1169,4 +1184,8 @@ return { removeserver = removeserver, get_backend = get_backend, changesettings = changesettings, + + tls_builder = function(basedir) + return sslconfig._new(tls_impl.new_context, basedir) + end, } diff --git a/net/tls_luasec.lua b/net/tls_luasec.lua new file mode 100644 index 00000000..2bedb5ab --- /dev/null +++ b/net/tls_luasec.lua @@ -0,0 +1,89 @@ +-- Prosody IM +-- Copyright (C) 2021 Prosody folks +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +--[[ +This file provides a shim abstraction over LuaSec, consolidating some code +which was previously spread between net.server backends, portmanager and +certmanager. + +The goal is to provide a more or less well-defined API on top of LuaSec which +abstracts away some of the things which are not needed and simplifies usage of +commonly used things (such as SNI contexts). Eventually, network backends +which do not rely on LuaSocket+LuaSec should be able to provide *this* API +instead of having to mimic LuaSec. +]] +local ssl = require "ssl"; +local ssl_newcontext = ssl.newcontext; +local ssl_context = ssl.context or require "ssl.context"; +local io_open = io.open; + +local context_api = {}; +local context_mt = {__index = context_api}; + +function context_api:set_sni_host(host, cert, key) + local ctx, err = self._builder:clone():apply({ + certificate = cert, + key = key, + }):build(); + if not ctx then + return false, err + end + + self._sni_contexts[host] = ctx._inner + + return true, nil +end + +function context_api:remove_sni_host(host) + self._sni_contexts[host] = nil +end + +function context_api:wrap(sock) + local ok, conn, err = pcall(ssl.wrap, sock, self._inner); + if not ok then + return nil, err + end + return conn, nil +end + +local function new_context(cfg, builder) + -- LuaSec expects dhparam to be a callback that takes two arguments. + -- We ignore those because it is mostly used for having a separate + -- set of params for EXPORT ciphers, which we don't have by default. + if type(cfg.dhparam) == "string" then + local f, err = io_open(cfg.dhparam); + if not f then return nil, "Could not open DH parameters: "..err end + local dhparam = f:read("*a"); + f:close(); + cfg.dhparam = function() return dhparam; end + end + + local inner, err = ssl_newcontext(cfg); + if not inner then + return nil, err + end + + -- COMPAT Older LuaSec ignores the cipher list from the config, so we have to take care + -- of it ourselves (W/A for #x) + if inner and cfg.ciphers then + local success; + success, err = ssl_context.setcipher(inner, cfg.ciphers); + if not success then + return nil, err + end + end + + return setmetatable({ + _inner = inner, + _builder = builder, + _sni_contexts = {}, + }, context_mt), nil +end + +return { + new_context = new_context, +}; |