From 8a24ec9653c09224a5e2d011e61067ddb11c1fe0 Mon Sep 17 00:00:00 2001 From: Matthew Wild Date: Thu, 17 Sep 2020 13:00:19 +0100 Subject: net.websocket.frames: Allow all methods to work on non-string objects Instead of using the string library, use methods from the passed object, which are assumed to be equivalent. This provides compatibility with objects from util.ringbuffer and util.dbuffer, for example. --- net/websocket/frames.lua | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/net/websocket/frames.lua b/net/websocket/frames.lua index ba25d261..e1b5527a 100644 --- a/net/websocket/frames.lua +++ b/net/websocket/frames.lua @@ -16,13 +16,12 @@ local bor = bit.bor; local bxor = bit.bxor; local lshift = bit.lshift; local rshift = bit.rshift; +local unpack = table.unpack or unpack; -- luacheck: ignore 113 local t_concat = table.concat; -local s_byte = string.byte; local s_char= string.char; -local s_sub = string.sub; -local s_pack = string.pack; -- luacheck: ignore 143 -local s_unpack = string.unpack; -- luacheck: ignore 143 +local s_pack = string.pack; +local s_unpack = string.unpack; if not s_pack and softreq"struct" then s_pack = softreq"struct".pack; @@ -30,12 +29,12 @@ if not s_pack and softreq"struct" then end local function read_uint16be(str, pos) - local l1, l2 = s_byte(str, pos, pos+1); + local l1, l2 = str:byte(pos, pos+1); return l1*256 + l2; end -- FIXME: this may lose precision local function read_uint64be(str, pos) - local l1, l2, l3, l4, l5, l6, l7, l8 = s_byte(str, pos, pos+7); + local l1, l2, l3, l4, l5, l6, l7, l8 = str:byte(pos, pos+7); local h = lshift(l1, 24) + lshift(l2, 16) + lshift(l3, 8) + l4; local l = lshift(l5, 24) + lshift(l6, 16) + lshift(l7, 8) + l8; return h * 2^32 + l; @@ -63,9 +62,15 @@ end if s_unpack then function read_uint16be(str, pos) + if type(str) ~= "string" then + str, pos = str:sub(pos, pos+1), 1; + end return s_unpack(">I2", str, pos); end function read_uint64be(str, pos) + if type(str) ~= "string" then + str, pos = str:sub(pos, pos+7), 1; + end return s_unpack(">I8", str, pos); end end @@ -73,7 +78,7 @@ end local function parse_frame_header(frame) if #frame < 2 then return; end - local byte1, byte2 = s_byte(frame, 1, 2); + local byte1, byte2 = frame:byte(1, 2); local result = { FIN = band(byte1, 0x80) > 0; RSV1 = band(byte1, 0x40) > 0; @@ -102,7 +107,7 @@ local function parse_frame_header(frame) end if result.MASK then - result.key = { s_byte(frame, length_bytes+3, length_bytes+6) }; + result.key = { frame:byte(length_bytes+3, length_bytes+6) }; end return result, header_length; @@ -121,7 +126,7 @@ local function apply_mask(str, key, from, to) for i = from, to do local key_index = counter%key_len + 1; counter = counter + 1; - data[counter] = s_char(bxor(key[key_index], s_byte(str, i))); + data[counter] = s_char(bxor(key[key_index], str:byte(i))); end return t_concat(data); end @@ -189,7 +194,7 @@ local function parse_close(data) if #data >= 2 then code = read_uint16be(data, 1); if #data > 2 then - message = s_sub(data, 3); + message = data:sub(3); end end return code, message -- cgit v1.2.3