1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
|
-- Prosody IM
-- Copyright (C) 2012 Florian Zeitz
-- Copyright (C) 2014 Daurnimator
--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
local softreq = require "util.dependencies".softreq;
local log = require "util.logger".init "websocket.frames";
local random_bytes = require "util.random".bytes;
local bit = softreq"bit" or softreq"bit32";
if not bit then log("error", "No bit module found. Either LuaJIT 2, lua-bitop or Lua 5.2 is required"); end
local band = bit.band;
local bor = bit.bor;
local bxor = bit.bxor;
local lshift = bit.lshift;
local rshift = bit.rshift;
local t_concat = table.concat;
local s_byte = string.byte;
local s_char= string.char;
local s_sub = string.sub;
local function read_uint16be(str, pos)
local l1, l2 = s_byte(str, pos, pos+1);
return l1*256 + l2;
end
-- FIXME: this may lose precision
local function read_uint64be(str, pos)
local l1, l2, l3, l4, l5, l6, l7, l8 = s_byte(str, pos, pos+7);
return lshift(l1, 56) + lshift(l2, 48) + lshift(l3, 40) + lshift(l4, 32)
+ lshift(l5, 24) + lshift(l6, 16) + lshift(l7, 8) + l8;
end
local function pack_uint16be(x)
return s_char(rshift(x, 8), band(x, 0xFF));
end
local function get_byte(x, n)
return band(rshift(x, n), 0xFF);
end
local function pack_uint64be(x)
return s_char(rshift(x, 56), get_byte(x, 48), get_byte(x, 40), get_byte(x, 32),
get_byte(x, 24), get_byte(x, 16), get_byte(x, 8), band(x, 0xFF));
end
local function parse_frame_header(frame)
if #frame < 2 then return; end
local byte1, byte2 = s_byte(frame, 1, 2);
local result = {
FIN = band(byte1, 0x80) > 0;
RSV1 = band(byte1, 0x40) > 0;
RSV2 = band(byte1, 0x20) > 0;
RSV3 = band(byte1, 0x10) > 0;
opcode = band(byte1, 0x0F);
MASK = band(byte2, 0x80) > 0;
length = band(byte2, 0x7F);
};
local length_bytes = 0;
if result.length == 126 then
length_bytes = 2;
elseif result.length == 127 then
length_bytes = 8;
end
local header_length = 2 + length_bytes + (result.MASK and 4 or 0);
if #frame < header_length then return; end
if length_bytes == 2 then
result.length = read_uint16be(frame, 3);
elseif length_bytes == 8 then
result.length = read_uint64be(frame, 3);
end
if result.MASK then
result.key = { s_byte(frame, length_bytes+3, length_bytes+6) };
end
return result, header_length;
end
-- XORs the string `str` with the array of bytes `key`
-- TODO: optimize
local function apply_mask(str, key, from, to)
from = from or 1
if from < 0 then from = #str + from + 1 end -- negative indicies
to = to or #str
if to < 0 then to = #str + to + 1 end -- negative indicies
local key_len = #key
local counter = 0;
local data = {};
for i = from, to do
local key_index = counter%key_len + 1;
counter = counter + 1;
data[counter] = s_char(bxor(key[key_index], s_byte(str, i)));
end
return t_concat(data);
end
local function parse_frame_body(frame, header, pos)
if header.MASK then
return apply_mask(frame, header.key, pos, pos + header.length - 1);
else
return frame:sub(pos, pos + header.length - 1);
end
end
local function parse_frame(frame)
local result, pos = parse_frame_header(frame);
if result == nil or #frame < (pos + result.length) then return; end
result.data = parse_frame_body(frame, result, pos+1);
return result, pos + result.length;
end
local function build_frame(desc)
local data = desc.data or "";
assert(desc.opcode and desc.opcode >= 0 and desc.opcode <= 0xF, "Invalid WebSocket opcode");
if desc.opcode >= 0x8 then
-- RFC 6455 5.5
assert(#data <= 125, "WebSocket control frames MUST have a payload length of 125 bytes or less.");
end
local b1 = bor(desc.opcode,
desc.FIN and 0x80 or 0,
desc.RSV1 and 0x40 or 0,
desc.RSV2 and 0x20 or 0,
desc.RSV3 and 0x10 or 0);
local b2 = #data;
local length_extra;
if b2 <= 125 then -- 7-bit length
length_extra = "";
elseif b2 <= 0xFFFF then -- 2-byte length
b2 = 126;
length_extra = pack_uint16be(#data);
else -- 8-byte length
b2 = 127;
length_extra = pack_uint64be(#data);
end
local key = ""
if desc.MASK then
local key_a = desc.key
if key_a then
key = s_char(unpack(key_a, 1, 4));
else
key = random_bytes(4);
key_a = {key:byte(1,4)};
end
b2 = bor(b2, 0x80);
data = apply_mask(data, key_a);
end
return s_char(b1, b2) .. length_extra .. key .. data
end
local function parse_close(data)
local code, message
if #data >= 2 then
code = read_uint16be(data, 1);
if #data > 2 then
message = s_sub(data, 3);
end
end
return code, message
end
local function build_close(code, message, mask)
local data = pack_uint16be(code);
if message then
assert(#message<=123, "Close reason must be <=123 bytes");
data = data .. message;
end
return build_frame({
opcode = 0x8;
FIN = true;
MASK = mask;
data = data;
});
end
return {
parse_header = parse_frame_header;
parse_body = parse_frame_body;
parse = parse_frame;
build = build_frame;
parse_close = parse_close;
build_close = build_close;
};
|