diff options
-rw-r--r-- | net/websocket.lua | 83 | ||||
-rw-r--r-- | net/websocket/frames.lua | 19 | ||||
-rw-r--r-- | plugins/mod_websocket.lua | 301 |
3 files changed, 354 insertions, 49 deletions
diff --git a/net/websocket.lua b/net/websocket.lua index 32bb1a6e..3c4746b7 100644 --- a/net/websocket.lua +++ b/net/websocket.lua @@ -6,6 +6,8 @@ -- COPYING file in the source package for more information. -- +local t_concat = table.concat; + local http = require "net.http"; local frames = require "net.websocket.frames"; local base64 = require "util.encodings".base64; @@ -76,7 +78,7 @@ function websocket_listeners.onincoming(handler, buffer, err) if frame.FIN then s.databuffer = nil; if s.onmessage then - s:onmessage(table.concat(databuffer), databuffer.type); + s:onmessage(t_concat(databuffer), databuffer.type); end end else -- Control frame @@ -176,28 +178,31 @@ local websocket_metatable = { local function connect(url, ex, listeners) ex = ex or {}; - --[[ RFC 6455 4.1.7: + --[[RFC 6455 4.1.7: The request MUST include a header field with the name - |Sec-WebSocket-Key|. The value of this header field MUST be a - nonce consisting of a randomly selected 16-byte value that has - been base64-encoded (see Section 4 of [RFC4648]). The nonce - MUST be selected randomly for each connection. - ]] + |Sec-WebSocket-Key|. The value of this header field MUST be a + nonce consisting of a randomly selected 16-byte value that has + been base64-encoded (see Section 4 of [RFC4648]). The nonce + MUST be selected randomly for each connection. + ]] local key = base64.encode(random_bytes(16)); -- Either a single protocol string or an array of protocol strings. local protocol = ex.protocol; - if type(protocol) == "table" then - protocol = table.concat(protocol, ", "); + if type(protocol) == "string" then + protocol = { protocol }; + end + for _, v in ipairs(protocol) do + protocol[v] = true; end local headers = { ["Upgrade"] = "websocket"; ["Connection"] = "Upgrade"; ["Sec-WebSocket-Key"] = key; - ["Sec-WebSocket-Protocol"] = protocol; - ["Sec-WebSocket-Version"] = "13"; - ["Sec-WebSocket-Extensions"] = ex.extensions; + ["Sec-WebSocket-Protocol"] = t_concat(protocol, ", "); + ["Sec-WebSocket-Version"] = "13"; + ["Sec-WebSocket-Extensions"] = ex.extensions; } if ex.headers then for k,v in pairs(ex.headers) do @@ -225,36 +230,36 @@ local function connect(url, ex, listeners) local http_url = url:gsub("^(ws)", "http"); local http_req = http.request(http_url, { - method = "GET"; - headers = headers; - sslctx = ex.sslctx; - }, function(b, c, r, http_req) - if c ~= 101 - or r.headers["connection"]:lower() ~= "upgrade" - or r.headers["upgrade"] ~= "websocket" - or r.headers["sec-websocket-accept"] ~= base64.encode(sha1(key .. "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) - -- TODO: check "Sec-WebSocket-Protocol" - then - s.readyState = 3; - log("warn", "WebSocket connection to %s failed: %s", url, tostring(b)); - if s.onerror then s:onerror("connecting-failed"); end - return - end + method = "GET"; + headers = headers; + sslctx = ex.sslctx; + }, function(b, c, r, http_req) + if c ~= 101 + or r.headers["connection"]:lower() ~= "upgrade" + or r.headers["upgrade"] ~= "websocket" + or r.headers["sec-websocket-accept"] ~= base64.encode(sha1(key .. "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) + or not protocol[r.headers["sec-websocket-protocol"]] + then + s.readyState = 3; + log("warn", "WebSocket connection to %s failed: %s", url, tostring(b)); + if s.onerror then s:onerror("connecting-failed"); end + return; + end - s.protocol = r.headers["sec-websocket-protocol"]; + s.protocol = r.headers["sec-websocket-protocol"]; - -- Take possession of socket from http - http_req.conn = nil; - local handler = http_req.handler; - s.handler = handler; - websockets[handler] = s; - handler:setlistener(websocket_listeners); + -- Take possession of socket from http + http_req.conn = nil; + local handler = http_req.handler; + s.handler = handler; + websockets[handler] = s; + handler:setlistener(websocket_listeners); - log("debug", "WebSocket connected successfully to %s", url); - s.readyState = 1; - if s.onopen then s:onopen(); end - websocket_listeners.onincoming(handler, b); - end); + log("debug", "WebSocket connected successfully to %s", url); + s.readyState = 1; + if s.onopen then s:onopen(); end + websocket_listeners.onincoming(handler, b); + end); return s; end diff --git a/net/websocket/frames.lua b/net/websocket/frames.lua index a5fcdad9..8bbddd1c 100644 --- a/net/websocket/frames.lua +++ b/net/websocket/frames.lua @@ -29,7 +29,7 @@ local function read_uint16be(str, pos) local l1, l2 = s_byte(str, pos, pos+1); return l1*256 + l2; end --- TODO: this may lose precision +-- FIXME: this may lose precision local function read_uint64be(str, pos) local l1, l2, l3, l4, l5, l6, l7, l8 = s_byte(str, pos, pos+7); return lshift(l1, 56) + lshift(l2, 48) + lshift(l3, 40) + lshift(l4, 32) @@ -38,12 +38,12 @@ end local function pack_uint16be(x) return s_char(rshift(x, 8), band(x, 0xFF)); end -local function sm(x, n) +local function get_byte(x, n) return band(rshift(x, n), 0xFF); end local function pack_uint64be(x) - return s_char(rshift(x, 56), sm(x, 48), sm(x, 40), sm(x, 32), - sm(x, 24), sm(x, 16), sm(x, 8), band(x, 0xFF)); + return s_char(rshift(x, 56), get_byte(x, 48), get_byte(x, 40), get_byte(x, 32), + get_byte(x, 24), get_byte(x, 16), get_byte(x, 8), band(x, 0xFF)); end local function parse_frame_header(frame) @@ -78,7 +78,7 @@ local function parse_frame_header(frame) end if result.MASK then - result.key = { s_byte(frame, pos+1, pos+4) }; + result.key = { s_byte(frame, length_bytes+3, length_bytes+6) }; end return result, header_length; @@ -131,13 +131,12 @@ local function build_frame(desc) desc.RSV1 and 0x40 or 0, desc.RSV2 and 0x20 or 0, desc.RSV3 and 0x10 or 0); - local b2; - local length_extra - if #data <= 125 then -- 7-bit length - b2 = #data; + local b2 = #data; + local length_extra; + if b2 <= 125 then -- 7-bit length length_extra = ""; - elseif #data <= 0xFFFF then -- 2-byte length + elseif b2 <= 0xFFFF then -- 2-byte length b2 = 126; length_extra = pack_uint16be(#data); else -- 8-byte length diff --git a/plugins/mod_websocket.lua b/plugins/mod_websocket.lua new file mode 100644 index 00000000..313dbd41 --- /dev/null +++ b/plugins/mod_websocket.lua @@ -0,0 +1,301 @@ +-- Prosody IM +-- Copyright (C) 2012-2014 Florian Zeitz +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +module:set_global(); + +local add_filter = require "util.filters".add_filter; +local sha1 = require "util.hashes".sha1; +local base64 = require "util.encodings".base64.encode; +local st = require "util.stanza"; +local parse_xml = require "util.xml".parse; +local portmanager = require "core.portmanager"; +local sm_destroy_session = sessionmanager.destroy_session; +local log = module._log; + +local websocket_frames = require"net.websocket.frames"; +local parse_frame = websocket_frames.parse; +local build_frame = websocket_frames.build; +local build_close = websocket_frames.build_close; +local parse_close = websocket_frames.parse_close; + +local t_concat = table.concat; + +local consider_websocket_secure = module:get_option_boolean("consider_websocket_secure"); +local cross_domain = module:get_option("cross_domain_websocket"); +if cross_domain then + if cross_domain == true then + cross_domain = "*"; + elseif type(cross_domain) == "table" then + cross_domain = t_concat(cross_domain, ", "); + end + if type(cross_domain) ~= "string" then + cross_domain = nil; + end +end + +local xmlns_framing = "urn:ietf:params:xml:ns:xmpp-framing"; +local xmlns_streams = "http://etherx.jabber.org/streams"; +local xmlns_client = "jabber:client"; +local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; + +module:depends("c2s") +local sessions = module:shared("c2s/sessions"); +local c2s_listener = portmanager.get_service("c2s").listener; + +--- Session methods +local function session_open_stream(session) + local attr = { + xmlns = xmlns_framing, + version = "1.0", + id = session.streamid or "", + from = session.host + }; + session.send(st.stanza("open", attr)); +end + +local function session_close(session, reason) + local log = session.log or log; + if session.conn then + if session.notopen then + session:open_stream(); + end + if reason then -- nil == no err, initiated by us, false == initiated by client + local stream_error = st.stanza("stream:error"); + if type(reason) == "string" then -- assume stream error + stream_error:tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' }); + elseif type(reason) == "table" then + if reason.condition then + stream_error:tag(reason.condition, stream_xmlns_attr):up(); + if reason.text then + stream_error:tag("text", stream_xmlns_attr):text(reason.text):up(); + end + if reason.extra then + stream_error:add_child(reason.extra); + end + elseif reason.name then -- a stanza + stream_error = reason; + end + end + log("debug", "Disconnecting client, <stream:error> is: %s", tostring(stream_error)); + session.send(stream_error); + end + + session.send(st.stanza("close", { xmlns = xmlns_framing })); + function session.send() return false; end + + local reason = (reason and (reason.name or reason.text or reason.condition)) or reason; + session.log("debug", "c2s stream for %s closed: %s", session.full_jid or ("<"..session.ip..">"), reason or "session closed"); + + -- Authenticated incoming stream may still be sending us stanzas, so wait for </stream:stream> from remote + local conn = session.conn; + if reason == nil and not session.notopen and session.type == "c2s" then + -- Grace time to process data from authenticated cleanly-closed stream + add_task(stream_close_timeout, function () + if not session.destroyed then + session.log("warn", "Failed to receive a stream close response, closing connection anyway..."); + sm_destroy_session(session, reason); + conn:write(build_close(1000, "Stream closed")); + conn:close(); + end + end); + else + sm_destroy_session(session, reason); + conn:write(build_close(1000, "Stream closed")); + conn:close(); + end + end +end + + +--- Filters +local function filter_open_close(data) + if not data:find(xmlns_framing, 1, true) then return data; end + + local oc = parse_xml(data); + if not oc then return data; end + if oc.attr.xmlns ~= xmlns_framing then return data; end + if oc.name == "close" then return "</stream:stream>"; end + if oc.name == "open" then + oc.name = "stream:stream"; + oc.attr.xmlns = nil; + oc.attr["xmlns:stream"] = xmlns_streams; + return oc:top_tag(); + end + + return data; +end +function handle_request(event, path) + local request, response = event.request, event.response; + local conn = response.conn; + + if not request.headers.sec_websocket_key then + response.headers.content_type = "text/html"; + return [[<!DOCTYPE html><html><head><title>Websocket</title></head><body> + <p>It works! Now point your WebSocket client to this URL to connect to Prosody.</p> + </body></html>]]; + end + + local wants_xmpp = false; + (request.headers.sec_websocket_protocol or ""):gsub("([^,]*),?", function (proto) + if proto == "xmpp" then wants_xmpp = true; end + end); + + if not wants_xmpp then + return 501; + end + + local function websocket_close(code, message) + conn:write(build_close(code, message)); + conn:close(); + end + + local dataBuffer; + local function handle_frame(frame) + local opcode = frame.opcode; + local length = frame.length; + module:log("debug", "Websocket received frame: opcode=%0x, %i bytes", frame.opcode, #frame.data); + + -- Error cases + if frame.RSV1 or frame.RSV2 or frame.RSV3 then -- Reserved bits non zero + websocket_close(1002, "Reserved bits not zero"); + return false; + end + + if opcode == 0x8 then -- close frame + if length == 1 then + websocket_close(1002, "Close frame with payload, but too short for status code"); + return false; + elseif length >= 2 then + local status_code = parse_close(frame.data) + if status_code < 1000 then + websocket_close(1002, "Closed with invalid status code"); + return false; + elseif ((status_code > 1003 and status_code < 1007) or status_code > 1011) and status_code < 3000 then + websocket_close(1002, "Closed with reserved status code"); + return false; + end + end + end + + if opcode >= 0x8 then + if length > 125 then -- Control frame with too much payload + websocket_close(1002, "Payload too large"); + return false; + end + + if not frame.FIN then -- Fragmented control frame + websocket_close(1002, "Fragmented control frame"); + return false; + end + end + + if (opcode > 0x2 and opcode < 0x8) or (opcode > 0xA) then + websocket_close(1002, "Reserved opcode"); + return false; + end + + if opcode == 0x0 and not dataBuffer then + websocket_close(1002, "Unexpected continuation frame"); + return false; + end + + if (opcode == 0x1 or opcode == 0x2) and dataBuffer then + websocket_close(1002, "Continuation frame expected"); + return false; + end + + -- Valid cases + if opcode == 0x0 then -- Continuation frame + dataBuffer[#dataBuffer+1] = frame.data; + elseif opcode == 0x1 then -- Text frame + dataBuffer = {frame.data}; + elseif opcode == 0x2 then -- Binary frame + websocket_close(1003, "Only text frames are supported"); + return; + elseif opcode == 0x8 then -- Close request + websocket_close(1000, "Goodbye"); + return; + elseif opcode == 0x9 then -- Ping frame + frame.opcode = 0xA; + conn:write(build_frame(frame)); + return ""; + elseif opcode == 0xA then -- Pong frame + module:log("warn", "Received unexpected pong frame: " .. tostring(frame.data)); + return ""; + else + log("warn", "Received frame with unsupported opcode %i", opcode); + return ""; + end + + if frame.FIN then + local data = t_concat(dataBuffer, ""); + dataBuffer = nil; + return data; + end + return ""; + end + + conn:setlistener(c2s_listener); + c2s_listener.onconnect(conn); + + local session = sessions[conn]; + + session.secure = consider_websocket_secure or session.secure; + + session.open_stream = session_open_stream; + session.close = session_close; + + local frameBuffer = ""; + add_filter(session, "bytes/in", function(data) + local cache = {}; + frameBuffer = frameBuffer .. data; + local frame, length = parse_frame(frameBuffer); + + while frame do + frameBuffer = frameBuffer:sub(length + 1); + local result = handle_frame(frame); + if not result then return; end + cache[#cache+1] = filter_open_close(result); + frame, length = parse_frame(frameBuffer); + end + return t_concat(cache, ""); + end); + + add_filter(session, "stanzas/out", function(stanza) + local attr = stanza.attr; + attr.xmlns = attr.xmlns or xmlns_client; + if stanza.name:find("^stream:") then + attr["xmlns:stream"] = attr["xmlns:stream"] or xmlns_streams; + end + return stanza; + end); + + add_filter(session, "bytes/out", function(data) + return build_frame({ FIN = true, opcode = 0x01, data = tostring(data)}); + end); + + response.status_code = 101; + response.headers.upgrade = "websocket"; + response.headers.connection = "Upgrade"; + response.headers.sec_webSocket_accept = base64(sha1(request.headers.sec_websocket_key .. "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")); + response.headers.sec_webSocket_protocol = "xmpp"; + response.headers.access_control_allow_origin = cross_domain; + + return ""; +end + +function module.add_host(module) + module:depends("http"); + module:provides("http", { + name = "websocket"; + default_path = "xmpp-websocket"; + route = { + ["GET"] = handle_request; + ["GET /"] = handle_request; + }; + }); +end |