aboutsummaryrefslogtreecommitdiffstats
path: root/net/stun.lua
blob: a718d4ef5b11f9d8b22bd2fe008441137964f88b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
local base64 = require "util.encodings".base64;
local hashes = require "util.hashes";
local net = require "util.net";
local random = require "util.random";
local struct = require "util.struct";
local bit32 = require"util.bitcompat";
local sxor = require"util.strbitop".sxor;

--- Public helpers

-- Following draft-uberti-behave-turn-rest-00, convert a 'secret' string
-- into a username/password pair that can be used to auth to a TURN server
local function get_user_pass_from_secret(secret, ttl, opt_username)
	ttl = ttl or 86400;
	local username;
	if opt_username then
		username = ("%d:%s"):format(os.time() + ttl, opt_username);
	else
		username = ("%d"):format(os.time() + ttl);
	end
	local password = base64.encode(hashes.hmac_sha1(secret, username));
	return username, password, ttl;
end

-- Following RFC 8489 9.2, convert credentials to a HMAC key for signing
local function get_long_term_auth_key(realm, username, password)
	return hashes.md5(username..":"..realm..":"..password);
end

--- Packet building/parsing

local packet_methods = {};
local packet_mt = { __index = packet_methods };

local magic_cookie = string.char(0x21, 0x12, 0xA4, 0x42);

local methods = {
	binding = 0x001;
	-- TURN
	allocate = 0x003;
	refresh = 0x004;
	send = 0x006;
	data = 0x007;
	create_permission = 0x008;
	channel_bind = 0x009;
};
local method_lookup = {};
for name, value in pairs(methods) do
	method_lookup[name] = value;
	method_lookup[value] = name;
end

local classes = {
	request = 0;
	indication = 1;
	success = 2;
	error = 3;
};
local class_lookup = {};
for name, value in pairs(classes) do
	class_lookup[name] = value;
	class_lookup[value] = name;
end

local attributes = {
	["mapped-address"] = 0x0001;
	["username"] = 0x0006;
	["message-integrity"] = 0x0008;
	["error-code"] = 0x0009;
	["unknown-attributes"] = 0x000A;
	["realm"] = 0x0014;
	["nonce"] = 0x0015;
	["xor-mapped-address"] = 0x0020;
	["software"] = 0x8022;
	["alternate-server"] = 0x8023;
	["fingerprint"] = 0x8028;
	["message-integrity-sha256"] = 0x001C;
	["password-algorithm"] = 0x001D;
	["userhash"] = 0x001E;
	["password-algorithms"] = 0x8002;
	["alternate-domains"] = 0x8003;

	-- TURN
	["requested-transport"] = 0x0019;
};
local attribute_lookup = {};
for name, value in pairs(attributes) do
	attribute_lookup[name] = value;
	attribute_lookup[value] = name;
end

function packet_methods:serialize_header(length)
	assert(#self.transaction_id == 12, "invalid transaction id length");
	local header = struct.pack(">I2I2",
		self.type,
		length
	)..magic_cookie..self.transaction_id;
	return header;
end

function packet_methods:serialize()
	local payload = table.concat(self.attributes);
	return self:serialize_header(#payload)..payload;
end

function packet_methods:is_request()
	return bit32.band(self.type, 0x0110) == 0x0000;
end

function packet_methods:is_indication()
	return bit32.band(self.type, 0x0110) == 0x0010;
end

function packet_methods:is_success_resp()
	return bit32.band(self.type, 0x0110) == 0x0100;
end

function packet_methods:is_err_resp()
	return bit32.band(self.type, 0x0110) == 0x0110;
end

function packet_methods:get_method()
	local method = bit32.bor(
		bit32.rshift(bit32.band(self.type, 0x3E00), 2),
		bit32.rshift(bit32.band(self.type, 0x00E0), 1),
		bit32.band(self.type, 0x000F)
	);
	return method, method_lookup[method];
end

function packet_methods:get_class()
	local class = bit32.bor(
		bit32.rshift(bit32.band(self.type, 0x0100), 7),
		bit32.rshift(bit32.band(self.type, 0x0010), 4)
	);
	return class, class_lookup[class];
end

function packet_methods:set_type(method, class)
	if type(method) == "string" then
		method = assert(method_lookup[method:lower()], "unknown method: "..method);
	end
	if type(class) == "string" then
		class = assert(classes[class], "unknown class: "..class);
	end
	self.type = bit32.bor(
		bit32.lshift(bit32.band(method, 0x1F80), 2),
		bit32.lshift(bit32.band(method, 0x0070), 1),
		bit32.band(method, 0x000F),
		bit32.lshift(bit32.band(class, 0x0002), 7),
		bit32.lshift(bit32.band(class, 0x0001), 4)
	);
end

local function _serialize_attribute(attr_type, value)
	local len = #value;
	local padding = string.rep("\0", (4 - len)%4);
	return struct.pack(">I2I2",
		attr_type, len
	)..value..padding;
end

function packet_methods:add_attribute(attr_type, value)
	if type(attr_type) == "string" then
		attr_type = assert(attributes[attr_type], "unknown attribute: "..attr_type);
	end
	table.insert(self.attributes, _serialize_attribute(attr_type, value));
end

function packet_methods:deserialize(bytes)
	local type, len, cookie = struct.unpack(">I2I2I4", bytes);
	assert(#bytes == (len + 20), "incorrect packet length");
	assert(cookie == 0x2112A442, "invalid magic cookie");
	self.type = type;
	self.transaction_id = bytes:sub(9, 20);
	self.attributes = {};
	local pos = 21;
	while pos < #bytes do
		local attr_hdr = bytes:sub(pos, pos+3);
		assert(#attr_hdr == 4, "packet truncated in attribute header");
		local attr_type, attr_len = struct.unpack(">I2I2", attr_hdr); --luacheck: ignore 211/attr_type
		if attr_len == 0 then
			table.insert(self.attributes, attr_hdr);
			pos = pos + 20;
		else
			local data = bytes:sub(pos + 4, pos + 3 + attr_len);
			assert(#data == attr_len, "packet truncated in attribute value");
			table.insert(self.attributes, attr_hdr..data);
			local n_padding = (4 - attr_len)%4;
			pos = pos + 4 + attr_len + n_padding;
		end
	end
	return self;
end

function packet_methods:get_attribute(attr_type)
	if type(attr_type) == "string" then
		attr_type = assert(attribute_lookup[attr_type:lower()], "unknown attribute: "..attr_type);
	end
	for _, attribute in ipairs(self.attributes) do
		if struct.unpack(">I2", attribute) == attr_type then
			return attribute:sub(5);
		end
	end
end

local addr_families = { "IPv4", "IPv6" };
function packet_methods:get_mapped_address()
	local data = self:get_attribute("mapped-address");
	if not data then return; end
	local family, port = struct.unpack("x>BI2", data);
	local addr = data:sub(5);
	return {
		family = addr_families[family] or "unknown";
		port = port;
		address = net.ntop(addr);
	};
end

function packet_methods:get_xor_mapped_address()
	local data = self:get_attribute("xor-mapped-address");
	if not data then return; end
	local family, port = struct.unpack("x>BI2", data);
	local addr = sxor(data:sub(5), magic_cookie..self.transaction_id);
	return {
		family = addr_families[family] or "unknown";
		port = bit32.bxor(port, 0x2112);
		address = net.ntop(addr);
		address_raw = data:sub(5);
	};
end

function packet_methods:add_message_integrity(key)
	-- Add attribute with a dummy value so we can artificially increase
	-- the packet 'length'
	self:add_attribute("message-integrity", string.rep("\0", 20));
	-- Get the packet data, minus the message-integrity attribute itself
	local pkt = self:serialize():sub(1, -25);
	local hash = hashes.hmac_sha1(key, pkt, false);
	self.attributes[#self.attributes] = nil;
	assert(#hash == 20, "invalid hash length");
	self:add_attribute("message-integrity", hash);
end

do
	local transports = {
		udp = 0x11;
	};
	function packet_methods:add_requested_transport(transport)
		local transport_code = transports[transport];
		assert(transport_code, "unsupported transport: "..tostring(transport));
		self:add_attribute("requested-transport", string.char(
			transport_code, 0x00, 0x00, 0x00
		));
	end
end

function packet_methods:get_error()
	local err_attr = self:get_attribute("error-code");
	if not err_attr then
		return nil;
	end
	local number = err_attr:byte(4);
	local class = bit32.band(0x07, err_attr:byte(3));
	local msg = err_attr:sub(5);
	return (class*100)+number, msg;
end

local function new_packet(method, class)
	local p = setmetatable({
		transaction_id = random.bytes(12);
		length = 0;
		attributes = {};
	}, packet_mt);
	p:set_type(method or "binding", class or "request");
	return p;
end

return {
	new_packet = new_packet;
	get_user_pass_from_secret = get_user_pass_from_secret;
	get_long_term_auth_key = get_long_term_auth_key;
};