aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
authorMatthew Wild <mwild1@gmail.com>2022-03-18 16:09:22 +0000
committerMatthew Wild <mwild1@gmail.com>2022-03-18 16:09:22 +0000
commit668bd38c71bacc78bbd44154f3305be6450e0729 (patch)
treec4809cd7eeafa817e6f098a6a69605d7ef9adfd3 /net
parent47dbb928887f1ece83ca7005eb714a4e937113e5 (diff)
downloadprosody-668bd38c71bacc78bbd44154f3305be6450e0729.tar.gz
prosody-668bd38c71bacc78bbd44154f3305be6450e0729.zip
net.resolvers.basic: Refactor to remove code duplication
...and prepare for Happy Eyeballs
Diffstat (limited to 'net')
-rw-r--r--net/resolvers/basic.lua152
1 files changed, 72 insertions, 80 deletions
diff --git a/net/resolvers/basic.lua b/net/resolvers/basic.lua
index 305bce76..a5ea5dee 100644
--- a/net/resolvers/basic.lua
+++ b/net/resolvers/basic.lua
@@ -2,13 +2,51 @@ local adns = require "net.adns";
local inet_pton = require "util.net".pton;
local inet_ntop = require "util.net".ntop;
local idna_to_ascii = require "util.encodings".idna.to_ascii;
-local unpack = table.unpack or unpack; -- luacheck: ignore 113
+local promise = require "util.promise";
+local t_move = require "util.table".move;
local methods = {};
local resolver_mt = { __index = methods };
-- FIXME RFC 6724
+local function do_dns_lookup(self, dns_resolver, record_type, name)
+ return promise.new(function (resolve, reject)
+ local ipv = (record_type == "A" and "4") or (record_type == "AAAA" and "6") or nil;
+ if ipv and self.extra["use_ipv"..ipv] == false then
+ return reject(("IPv%s disabled - %s lookup skipped"):format(ipv, record_type));
+ elseif record_type == "TLSA" and self.extra.use_dane ~= true then
+ return reject("DANE disabled - TLSA lookup skipped");
+ end
+ dns_resolver:lookup(function (answer, err)
+ if not answer then
+ return reject(err);
+ elseif answer.bogus then
+ return reject(("Validation error in %s lookup"):format(record_type));
+ elseif answer.status and #answer == 0 then
+ return reject(("%s in %s lookup"):format(answer.status, record_type));
+ end
+
+ local targets = { secure = answer.secure };
+ for _, record in ipairs(answer) do
+ if ipv then
+ table.insert(targets, { self.conn_type..ipv, record[record_type:lower()], self.port, self.extra });
+ else
+ table.insert(targets, record[record_type:lower()]);
+ end
+ end
+ return resolve(targets);
+ end, name, record_type, "IN");
+ end);
+end
+
+local function merge_targets(ipv4_targets, ipv6_targets)
+ local result = { secure = ipv4_targets.secure and ipv6_targets.secure };
+ t_move(ipv6_targets, 1, #ipv6_targets, 1, result);
+ t_move(ipv4_targets, 1, #ipv4_targets, #result+1, result);
+ return result;
+end
+
-- Find the next target to connect to, and
-- pass it to cb()
function methods:next(cb)
@@ -18,7 +56,7 @@ function methods:next(cb)
return;
end
local next_target = table.remove(self.targets, 1);
- cb(unpack(next_target, 1, 4));
+ cb(next_target[1], next_target[2], next_target[3], next_target[4]);
return;
end
@@ -28,91 +66,45 @@ function methods:next(cb)
return;
end
- local secure = true;
- local tlsa = {};
- local targets = {};
- local n = 3;
- local function ready()
- n = n - 1;
- if n > 0 then return; end
- self.targets = targets;
+ -- Resolve DNS to target list
+ local dns_resolver = adns.resolver();
+
+ local dns_lookups = {
+ ipv4 = do_dns_lookup(self, dns_resolver, "A", self.hostname);
+ ipv6 = do_dns_lookup(self, dns_resolver, "AAAA", self.hostname);
+ tlsa = do_dns_lookup(self, dns_resolver, "TLSA", ("_%d._%s.%s"):format(self.port, self.conntype, self.hostname));
+ };
+
+ promise.all_settled(dns_lookups):next(function (dns_results)
+ -- Combine targets, assign to self.targets, self:next(cb)
+ local have_ipv4 = dns_results.ipv4.status == "fulfilled";
+ local have_ipv6 = dns_results.ipv6.status == "fulfilled";
+
+ if have_ipv4 and have_ipv6 then
+ self.targets = merge_targets(dns_results.ipv4.value, dns_results.ipv6.value);
+ elseif have_ipv4 then
+ self.targets = dns_results.ipv4.value;
+ elseif have_ipv6 then
+ self.targets = dns_results.ipv6.value;
+ else
+ self.targets = {};
+ end
+
if self.extra and self.extra.use_dane then
- if secure and tlsa[1] then
- self.extra.tlsa = tlsa;
+ if self.targets.secure and dns_results.tlsa.status == "fulfilled" then
+ self.extra.tlsa = dns_results.tlsa.value;
self.extra.dane_hostname = self.hostname;
else
self.extra.tlsa = nil;
self.extra.dane_hostname = nil;
end
end
- self:next(cb);
- end
-
- -- Resolve DNS to target list
- local dns_resolver = adns.resolver();
- if not self.extra or self.extra.use_ipv4 ~= false then
- dns_resolver:lookup(function (answer, err)
- if answer then
- secure = secure and answer.secure;
- for _, record in ipairs(answer) do
- table.insert(targets, { self.conn_type.."4", record.a, self.port, self.extra });
- end
- if answer.bogus then
- self.last_error = "Validation error in A lookup";
- elseif answer.status then
- self.last_error = answer.status .. " in A lookup";
- end
- else
- self.last_error = err;
- end
- ready();
- end, self.hostname, "A", "IN");
- else
- ready();
- end
-
- if not self.extra or self.extra.use_ipv6 ~= false then
- dns_resolver:lookup(function (answer, err)
- if answer then
- secure = secure and answer.secure;
- for _, record in ipairs(answer) do
- table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra });
- end
- if answer.bogus then
- self.last_error = "Validation error in AAAA lookup";
- elseif answer.status then
- self.last_error = answer.status .. " in AAAA lookup";
- end
- else
- self.last_error = err;
- end
- ready();
- end, self.hostname, "AAAA", "IN");
- else
- ready();
- end
-
- if self.extra and self.extra.use_dane == true then
- dns_resolver:lookup(function (answer, err)
- if answer then
- secure = secure and answer.secure;
- for _, record in ipairs(answer) do
- table.insert(tlsa, record.tlsa);
- end
- if answer.bogus then
- self.last_error = "Validation error in TLSA lookup";
- elseif answer.status then
- self.last_error = answer.status .. " in TLSA lookup";
- end
- else
- self.last_error = err;
- end
- ready();
- end, ("_%d._tcp.%s"):format(self.port, self.hostname), "TLSA", "IN");
- else
- ready();
- end
+ self:next(cb);
+ end):catch(function (err)
+ self.last_error = err;
+ self.targets = {};
+ end);
end
local function new(hostname, port, conn_type, extra)
@@ -137,7 +129,7 @@ local function new(hostname, port, conn_type, extra)
hostname = ascii_host;
port = port;
conn_type = conn_type;
- extra = extra;
+ extra = extra or {};
targets = targets;
}, resolver_mt);
end