aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/websocket.lua83
-rw-r--r--net/websocket/frames.lua19
2 files changed, 53 insertions, 49 deletions
diff --git a/net/websocket.lua b/net/websocket.lua
index 32bb1a6e..3c4746b7 100644
--- a/net/websocket.lua
+++ b/net/websocket.lua
@@ -6,6 +6,8 @@
-- COPYING file in the source package for more information.
--
+local t_concat = table.concat;
+
local http = require "net.http";
local frames = require "net.websocket.frames";
local base64 = require "util.encodings".base64;
@@ -76,7 +78,7 @@ function websocket_listeners.onincoming(handler, buffer, err)
if frame.FIN then
s.databuffer = nil;
if s.onmessage then
- s:onmessage(table.concat(databuffer), databuffer.type);
+ s:onmessage(t_concat(databuffer), databuffer.type);
end
end
else -- Control frame
@@ -176,28 +178,31 @@ local websocket_metatable = {
local function connect(url, ex, listeners)
ex = ex or {};
- --[[ RFC 6455 4.1.7:
+ --[[RFC 6455 4.1.7:
The request MUST include a header field with the name
- |Sec-WebSocket-Key|. The value of this header field MUST be a
- nonce consisting of a randomly selected 16-byte value that has
- been base64-encoded (see Section 4 of [RFC4648]). The nonce
- MUST be selected randomly for each connection.
- ]]
+ |Sec-WebSocket-Key|. The value of this header field MUST be a
+ nonce consisting of a randomly selected 16-byte value that has
+ been base64-encoded (see Section 4 of [RFC4648]). The nonce
+ MUST be selected randomly for each connection.
+ ]]
local key = base64.encode(random_bytes(16));
-- Either a single protocol string or an array of protocol strings.
local protocol = ex.protocol;
- if type(protocol) == "table" then
- protocol = table.concat(protocol, ", ");
+ if type(protocol) == "string" then
+ protocol = { protocol };
+ end
+ for _, v in ipairs(protocol) do
+ protocol[v] = true;
end
local headers = {
["Upgrade"] = "websocket";
["Connection"] = "Upgrade";
["Sec-WebSocket-Key"] = key;
- ["Sec-WebSocket-Protocol"] = protocol;
- ["Sec-WebSocket-Version"] = "13";
- ["Sec-WebSocket-Extensions"] = ex.extensions;
+ ["Sec-WebSocket-Protocol"] = t_concat(protocol, ", ");
+ ["Sec-WebSocket-Version"] = "13";
+ ["Sec-WebSocket-Extensions"] = ex.extensions;
}
if ex.headers then
for k,v in pairs(ex.headers) do
@@ -225,36 +230,36 @@ local function connect(url, ex, listeners)
local http_url = url:gsub("^(ws)", "http");
local http_req = http.request(http_url, {
- method = "GET";
- headers = headers;
- sslctx = ex.sslctx;
- }, function(b, c, r, http_req)
- if c ~= 101
- or r.headers["connection"]:lower() ~= "upgrade"
- or r.headers["upgrade"] ~= "websocket"
- or r.headers["sec-websocket-accept"] ~= base64.encode(sha1(key .. "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
- -- TODO: check "Sec-WebSocket-Protocol"
- then
- s.readyState = 3;
- log("warn", "WebSocket connection to %s failed: %s", url, tostring(b));
- if s.onerror then s:onerror("connecting-failed"); end
- return
- end
+ method = "GET";
+ headers = headers;
+ sslctx = ex.sslctx;
+ }, function(b, c, r, http_req)
+ if c ~= 101
+ or r.headers["connection"]:lower() ~= "upgrade"
+ or r.headers["upgrade"] ~= "websocket"
+ or r.headers["sec-websocket-accept"] ~= base64.encode(sha1(key .. "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
+ or not protocol[r.headers["sec-websocket-protocol"]]
+ then
+ s.readyState = 3;
+ log("warn", "WebSocket connection to %s failed: %s", url, tostring(b));
+ if s.onerror then s:onerror("connecting-failed"); end
+ return;
+ end
- s.protocol = r.headers["sec-websocket-protocol"];
+ s.protocol = r.headers["sec-websocket-protocol"];
- -- Take possession of socket from http
- http_req.conn = nil;
- local handler = http_req.handler;
- s.handler = handler;
- websockets[handler] = s;
- handler:setlistener(websocket_listeners);
+ -- Take possession of socket from http
+ http_req.conn = nil;
+ local handler = http_req.handler;
+ s.handler = handler;
+ websockets[handler] = s;
+ handler:setlistener(websocket_listeners);
- log("debug", "WebSocket connected successfully to %s", url);
- s.readyState = 1;
- if s.onopen then s:onopen(); end
- websocket_listeners.onincoming(handler, b);
- end);
+ log("debug", "WebSocket connected successfully to %s", url);
+ s.readyState = 1;
+ if s.onopen then s:onopen(); end
+ websocket_listeners.onincoming(handler, b);
+ end);
return s;
end
diff --git a/net/websocket/frames.lua b/net/websocket/frames.lua
index a5fcdad9..8bbddd1c 100644
--- a/net/websocket/frames.lua
+++ b/net/websocket/frames.lua
@@ -29,7 +29,7 @@ local function read_uint16be(str, pos)
local l1, l2 = s_byte(str, pos, pos+1);
return l1*256 + l2;
end
--- TODO: this may lose precision
+-- FIXME: this may lose precision
local function read_uint64be(str, pos)
local l1, l2, l3, l4, l5, l6, l7, l8 = s_byte(str, pos, pos+7);
return lshift(l1, 56) + lshift(l2, 48) + lshift(l3, 40) + lshift(l4, 32)
@@ -38,12 +38,12 @@ end
local function pack_uint16be(x)
return s_char(rshift(x, 8), band(x, 0xFF));
end
-local function sm(x, n)
+local function get_byte(x, n)
return band(rshift(x, n), 0xFF);
end
local function pack_uint64be(x)
- return s_char(rshift(x, 56), sm(x, 48), sm(x, 40), sm(x, 32),
- sm(x, 24), sm(x, 16), sm(x, 8), band(x, 0xFF));
+ return s_char(rshift(x, 56), get_byte(x, 48), get_byte(x, 40), get_byte(x, 32),
+ get_byte(x, 24), get_byte(x, 16), get_byte(x, 8), band(x, 0xFF));
end
local function parse_frame_header(frame)
@@ -78,7 +78,7 @@ local function parse_frame_header(frame)
end
if result.MASK then
- result.key = { s_byte(frame, pos+1, pos+4) };
+ result.key = { s_byte(frame, length_bytes+3, length_bytes+6) };
end
return result, header_length;
@@ -131,13 +131,12 @@ local function build_frame(desc)
desc.RSV1 and 0x40 or 0,
desc.RSV2 and 0x20 or 0,
desc.RSV3 and 0x10 or 0);
- local b2;
- local length_extra
- if #data <= 125 then -- 7-bit length
- b2 = #data;
+ local b2 = #data;
+ local length_extra;
+ if b2 <= 125 then -- 7-bit length
length_extra = "";
- elseif #data <= 0xFFFF then -- 2-byte length
+ elseif b2 <= 0xFFFF then -- 2-byte length
b2 = 126;
length_extra = pack_uint16be(#data);
else -- 8-byte length