diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/websocket.lua | 83 | ||||
-rw-r--r-- | net/websocket/frames.lua | 19 |
2 files changed, 53 insertions, 49 deletions
diff --git a/net/websocket.lua b/net/websocket.lua index 32bb1a6e..3c4746b7 100644 --- a/net/websocket.lua +++ b/net/websocket.lua @@ -6,6 +6,8 @@ -- COPYING file in the source package for more information. -- +local t_concat = table.concat; + local http = require "net.http"; local frames = require "net.websocket.frames"; local base64 = require "util.encodings".base64; @@ -76,7 +78,7 @@ function websocket_listeners.onincoming(handler, buffer, err) if frame.FIN then s.databuffer = nil; if s.onmessage then - s:onmessage(table.concat(databuffer), databuffer.type); + s:onmessage(t_concat(databuffer), databuffer.type); end end else -- Control frame @@ -176,28 +178,31 @@ local websocket_metatable = { local function connect(url, ex, listeners) ex = ex or {}; - --[[ RFC 6455 4.1.7: + --[[RFC 6455 4.1.7: The request MUST include a header field with the name - |Sec-WebSocket-Key|. The value of this header field MUST be a - nonce consisting of a randomly selected 16-byte value that has - been base64-encoded (see Section 4 of [RFC4648]). The nonce - MUST be selected randomly for each connection. - ]] + |Sec-WebSocket-Key|. The value of this header field MUST be a + nonce consisting of a randomly selected 16-byte value that has + been base64-encoded (see Section 4 of [RFC4648]). The nonce + MUST be selected randomly for each connection. + ]] local key = base64.encode(random_bytes(16)); -- Either a single protocol string or an array of protocol strings. local protocol = ex.protocol; - if type(protocol) == "table" then - protocol = table.concat(protocol, ", "); + if type(protocol) == "string" then + protocol = { protocol }; + end + for _, v in ipairs(protocol) do + protocol[v] = true; end local headers = { ["Upgrade"] = "websocket"; ["Connection"] = "Upgrade"; ["Sec-WebSocket-Key"] = key; - ["Sec-WebSocket-Protocol"] = protocol; - ["Sec-WebSocket-Version"] = "13"; - ["Sec-WebSocket-Extensions"] = ex.extensions; + ["Sec-WebSocket-Protocol"] = t_concat(protocol, ", "); + ["Sec-WebSocket-Version"] = "13"; + ["Sec-WebSocket-Extensions"] = ex.extensions; } if ex.headers then for k,v in pairs(ex.headers) do @@ -225,36 +230,36 @@ local function connect(url, ex, listeners) local http_url = url:gsub("^(ws)", "http"); local http_req = http.request(http_url, { - method = "GET"; - headers = headers; - sslctx = ex.sslctx; - }, function(b, c, r, http_req) - if c ~= 101 - or r.headers["connection"]:lower() ~= "upgrade" - or r.headers["upgrade"] ~= "websocket" - or r.headers["sec-websocket-accept"] ~= base64.encode(sha1(key .. "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) - -- TODO: check "Sec-WebSocket-Protocol" - then - s.readyState = 3; - log("warn", "WebSocket connection to %s failed: %s", url, tostring(b)); - if s.onerror then s:onerror("connecting-failed"); end - return - end + method = "GET"; + headers = headers; + sslctx = ex.sslctx; + }, function(b, c, r, http_req) + if c ~= 101 + or r.headers["connection"]:lower() ~= "upgrade" + or r.headers["upgrade"] ~= "websocket" + or r.headers["sec-websocket-accept"] ~= base64.encode(sha1(key .. "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) + or not protocol[r.headers["sec-websocket-protocol"]] + then + s.readyState = 3; + log("warn", "WebSocket connection to %s failed: %s", url, tostring(b)); + if s.onerror then s:onerror("connecting-failed"); end + return; + end - s.protocol = r.headers["sec-websocket-protocol"]; + s.protocol = r.headers["sec-websocket-protocol"]; - -- Take possession of socket from http - http_req.conn = nil; - local handler = http_req.handler; - s.handler = handler; - websockets[handler] = s; - handler:setlistener(websocket_listeners); + -- Take possession of socket from http + http_req.conn = nil; + local handler = http_req.handler; + s.handler = handler; + websockets[handler] = s; + handler:setlistener(websocket_listeners); - log("debug", "WebSocket connected successfully to %s", url); - s.readyState = 1; - if s.onopen then s:onopen(); end - websocket_listeners.onincoming(handler, b); - end); + log("debug", "WebSocket connected successfully to %s", url); + s.readyState = 1; + if s.onopen then s:onopen(); end + websocket_listeners.onincoming(handler, b); + end); return s; end diff --git a/net/websocket/frames.lua b/net/websocket/frames.lua index a5fcdad9..8bbddd1c 100644 --- a/net/websocket/frames.lua +++ b/net/websocket/frames.lua @@ -29,7 +29,7 @@ local function read_uint16be(str, pos) local l1, l2 = s_byte(str, pos, pos+1); return l1*256 + l2; end --- TODO: this may lose precision +-- FIXME: this may lose precision local function read_uint64be(str, pos) local l1, l2, l3, l4, l5, l6, l7, l8 = s_byte(str, pos, pos+7); return lshift(l1, 56) + lshift(l2, 48) + lshift(l3, 40) + lshift(l4, 32) @@ -38,12 +38,12 @@ end local function pack_uint16be(x) return s_char(rshift(x, 8), band(x, 0xFF)); end -local function sm(x, n) +local function get_byte(x, n) return band(rshift(x, n), 0xFF); end local function pack_uint64be(x) - return s_char(rshift(x, 56), sm(x, 48), sm(x, 40), sm(x, 32), - sm(x, 24), sm(x, 16), sm(x, 8), band(x, 0xFF)); + return s_char(rshift(x, 56), get_byte(x, 48), get_byte(x, 40), get_byte(x, 32), + get_byte(x, 24), get_byte(x, 16), get_byte(x, 8), band(x, 0xFF)); end local function parse_frame_header(frame) @@ -78,7 +78,7 @@ local function parse_frame_header(frame) end if result.MASK then - result.key = { s_byte(frame, pos+1, pos+4) }; + result.key = { s_byte(frame, length_bytes+3, length_bytes+6) }; end return result, header_length; @@ -131,13 +131,12 @@ local function build_frame(desc) desc.RSV1 and 0x40 or 0, desc.RSV2 and 0x20 or 0, desc.RSV3 and 0x10 or 0); - local b2; - local length_extra - if #data <= 125 then -- 7-bit length - b2 = #data; + local b2 = #data; + local length_extra; + if b2 <= 125 then -- 7-bit length length_extra = ""; - elseif #data <= 0xFFFF then -- 2-byte length + elseif b2 <= 0xFFFF then -- 2-byte length b2 = 126; length_extra = pack_uint16be(#data); else -- 8-byte length |