aboutsummaryrefslogtreecommitdiffstats
path: root/net/resolvers/service.lua
blob: a7ce76a320b4bab1d0c2ec51a575ffd2c8452637 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
local adns = require "net.adns";
local basic = require "net.resolvers.basic";
local inet_pton = require "util.net".pton;
local idna_to_ascii = require "util.encodings".idna.to_ascii;

local methods = {};
local resolver_mt = { __index = methods };

local function new_target_selector(rrset)
	local rr_count = rrset and #rrset;
	if not rr_count or rr_count == 0 then
		rrset = nil;
	else
		table.sort(rrset, function (a, b) return a.srv.priority < b.srv.priority end);
	end
	local rrset_pos = 1;
	local priority_bucket, bucket_total_weight, bucket_len, bucket_used;
	return function ()
		if not rrset then return; end

		if not priority_bucket or bucket_used >= bucket_len then
			if rrset_pos > rr_count then return; end -- Used up all records

			-- Going to start on a new priority now. Gather up all the next
			-- records with the same priority and add them to priority_bucket
			priority_bucket, bucket_total_weight, bucket_len, bucket_used = {}, 0, 0, 0;
			local current_priority;
			repeat
				local curr_record = rrset[rrset_pos].srv;
				if not current_priority then
					current_priority = curr_record.priority;
				elseif current_priority ~= curr_record.priority then
					break;
				end
				table.insert(priority_bucket, curr_record);
				bucket_total_weight = bucket_total_weight + curr_record.weight;
				bucket_len = bucket_len + 1;
				rrset_pos = rrset_pos + 1;
			until rrset_pos > rr_count;
		end

		bucket_used = bucket_used + 1;
		local n, running_total = math.random(0, bucket_total_weight), 0;
		local target_record;
		for i = 1, bucket_len do
			local candidate = priority_bucket[i];
			if candidate then
				running_total = running_total + candidate.weight;
				if running_total >= n then
					target_record = candidate;
					bucket_total_weight = bucket_total_weight - candidate.weight;
					priority_bucket[i] = nil;
					break;
				end
			end
		end
		return target_record;
	end;
end

-- Find the next target to connect to, and
-- pass it to cb()
function methods:next(cb)
	if self.resolver or self._get_next_target then
		if not self.resolver then -- Do we have a basic resolver currently?
			-- We don't, so fetch a new SRV target, create a new basic resolver for it
			local next_srv_target = self._get_next_target and self._get_next_target();
			if not next_srv_target then
				-- No more SRV targets left
				cb(nil);
				return;
			end
			-- Create a new basic resolver for this SRV target
			self.resolver = basic.new(next_srv_target.target, next_srv_target.port, self.conn_type, self.extra);
		end
		-- Look up the next (basic) target from the current target's resolver
		self.resolver:next(function (...)
			if self.resolver then
				self.last_error = self.resolver.last_error;
			end
			if ... == nil then
				self.resolver = nil;
				self:next(cb);
			else
				cb(...);
			end
		end);
		return;
	elseif self.in_progress then
		cb(nil);
		return;
	end

	if not self.hostname then
		self.last_error = "hostname failed IDNA";
		cb(nil);
		return;
	end

	self.in_progress = true;

	local function ready()
		self:next(cb);
	end

	-- Resolve DNS to target list
	local dns_resolver = adns.resolver();
	dns_resolver:lookup(function (answer, err)
		if not answer and not err then
			-- net.adns returns nil if there are zero records or nxdomain
			answer = {};
		end
		if answer then
			if self.extra and not answer.secure then
				self.extra.use_dane = false;
			elseif answer.bogus then
				self.last_error = "Validation error in SRV lookup";
				ready();
				return;
			end

			if #answer == 0 then
				if self.extra and self.extra.default_port then
					self.resolver = basic.new(self.hostname, self.extra.default_port, self.conn_type, self.extra);
				else
					self.last_error = "zero SRV records found";
				end
				ready();
				return;
			end

			if #answer == 1 and answer[1].srv.target == "." then -- No service here
				self.last_error = "service explicitly unavailable";
				ready();
				return;
			end

			self._get_next_target = new_target_selector(answer);
		else
			self.last_error = err;
		end
		ready();
	end, "_" .. self.service .. "._" .. self.conn_type .. "." .. self.hostname, "SRV", "IN");
end

local function new(hostname, service, conn_type, extra)
	local is_ip = inet_pton(hostname);
	if not is_ip and hostname:sub(1,1) == '[' then
		is_ip = inet_pton(hostname:sub(2,-2));
	end
	if is_ip and extra and extra.default_port then
		return basic.new(hostname, extra.default_port, conn_type, extra);
	end

	return setmetatable({
		hostname = idna_to_ascii(hostname);
		service = service;
		conn_type = conn_type or "tcp";
		extra = extra;
	}, resolver_mt);
end

return {
	new = new;
};