diff options
Diffstat (limited to 'util/stanza.lua')
-rw-r--r-- | util/stanza.lua | 161 |
1 files changed, 112 insertions, 49 deletions
diff --git a/util/stanza.lua b/util/stanza.lua index 07365144..a90d56b3 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -7,6 +7,7 @@ -- +local error = error; local t_insert = table.insert; local t_remove = table.remove; local t_concat = table.concat; @@ -23,6 +24,8 @@ local s_sub = string.sub; local s_find = string.find; local os = os; +local valid_utf8 = require "util.encodings".utf8.valid; + local do_pretty_printing = not os.getenv("WINDIR"); local getstyle, getstring; if do_pretty_printing then @@ -37,12 +40,52 @@ end local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas"; local _ENV = nil; +-- luacheck: std none -local stanza_mt = { __type = "stanza" }; +local stanza_mt = { __name = "stanza" }; stanza_mt.__index = stanza_mt; -local function new_stanza(name, attr) - local stanza = { name = name, attr = attr or {}, tags = {} }; +local function check_name(name, name_type) + if type(name) ~= "string" then + error("invalid "..name_type.." name: expected string, got "..type(name)); + elseif #name == 0 then + error("invalid "..name_type.." name: empty string"); + elseif s_find(name, "[<>& '\"]") then + error("invalid "..name_type.." name: contains invalid characters"); + elseif not valid_utf8(name) then + error("invalid "..name_type.." name: contains invalid utf8"); + end +end + +local function check_text(text, text_type) + if type(text) ~= "string" then + error("invalid "..text_type.." value: expected string, got "..type(text)); + elseif not valid_utf8(text) then + error("invalid "..text_type.." value: contains invalid utf8"); + end +end + +local function check_attr(attr) + if attr ~= nil then + if type(attr) ~= "table" then + error("invalid attributes, expected table got "..type(attr)); + end + for k, v in pairs(attr) do + check_name(k, "attribute"); + check_text(v, "attribute"); + if type(v) ~= "string" then + error("invalid attribute value for '"..k.."': expected string, got "..type(v)); + elseif not valid_utf8(v) then + error("invalid attribute value for '"..k.."': contains invalid utf8"); + end + end + end +end + +local function new_stanza(name, attr, namespaces) + check_name(name, "tag"); + check_attr(attr); + local stanza = { name = name, attr = attr or {}, namespaces = namespaces, tags = {} }; return setmetatable(stanza, stanza_mt); end @@ -58,8 +101,12 @@ function stanza_mt:body(text, attr) return self:tag("body", attr):text(text); end -function stanza_mt:tag(name, attrs) - local s = new_stanza(name, attrs); +function stanza_mt:text_tag(name, text, attr, namespaces) + return self:tag(name, attr, namespaces):text(text):up(); +end + +function stanza_mt:tag(name, attr, namespaces) + local s = new_stanza(name, attr, namespaces); local last_add = self.last_add; if not last_add then last_add = {}; self.last_add = last_add; end (last_add[#last_add] or self):add_direct_child(s); @@ -68,8 +115,10 @@ function stanza_mt:tag(name, attrs) end function stanza_mt:text(text) - local last_add = self.last_add; - (last_add and last_add[#last_add] or self):add_direct_child(text); + if text ~= nil and text ~= "" then + local last_add = self.last_add; + (last_add and last_add[#last_add] or self):add_direct_child(text); + end return self; end @@ -85,10 +134,13 @@ function stanza_mt:reset() end function stanza_mt:add_direct_child(child) - if type(child) == "table" then + if is_stanza(child) then t_insert(self.tags, child); + t_insert(self, child); + else + check_text(child, "text"); + t_insert(self, child); end - t_insert(self, child); end function stanza_mt:add_child(child) @@ -165,6 +217,7 @@ end function stanza_mt:maptags(callback) local tags, curr_tag = self.tags, 1; local n_children, n_tags = #self, #tags; + local max_iterations = n_children + 1; local i = 1; while curr_tag <= n_tags and n_tags > 0 do @@ -184,6 +237,11 @@ function stanza_mt:maptags(callback) curr_tag = curr_tag + 1; end i = i + 1; + if i > max_iterations then + -- COMPAT: Hopefully temporary guard against #981 while we + -- figure out the root cause + error("Invalid stanza state! Please report this error."); + end end return self; end @@ -289,12 +347,6 @@ function stanza_mt.get_error(stanza) return error_type, condition or "undefined-condition", text; end -local id = 0; -local function new_id() - id = id + 1; - return "lx"..id; -end - local function preserialize(stanza) local s = { name = stanza.name, attr = stanza.attr }; for _, child in ipairs(stanza) do @@ -307,51 +359,48 @@ local function preserialize(stanza) return s; end -local function deserialize(stanza) +stanza_mt.__freeze = preserialize; + +local function deserialize(serialized) -- Set metatable - if stanza then - local attr = stanza.attr; - for i=1,#attr do attr[i] = nil; end + if serialized then + local attr = serialized.attr; local attrx = {}; - for att in pairs(attr) do - if s_find(att, "|", 1, true) and not s_find(att, "\1", 1, true) then - local ns,na = s_match(att, "^([^|]+)|(.+)$"); - attrx[ns.."\1"..na] = attr[att]; - attr[att] = nil; + for att, val in pairs(attr) do + if type(att) == "string" then + if s_find(att, "|", 1, true) and not s_find(att, "\1", 1, true) then + local ns,na = s_match(att, "^([^|]+)|(.+)$"); + attrx[ns.."\1"..na] = val; + else + attrx[att] = val; + end end end - for a,v in pairs(attrx) do - attr[a] = v; - end - setmetatable(stanza, stanza_mt); - for _, child in ipairs(stanza) do + local stanza = new_stanza(serialized.name, attrx); + for _, child in ipairs(serialized) do if type(child) == "table" then - deserialize(child); - end - end - if not stanza.tags then - -- Rebuild tags - local tags = {}; - for _, child in ipairs(stanza) do - if type(child) == "table" then - t_insert(tags, child); - end + stanza:add_direct_child(deserialize(child)); + elseif type(child) == "string" then + stanza:add_direct_child(child); end - stanza.tags = tags; end + return stanza; end - - return stanza; end -local function clone(stanza) +local function _clone(stanza) local attr, tags = {}, {}; for k,v in pairs(stanza.attr) do attr[k] = v; end - local new = { name = stanza.name, attr = attr, tags = tags }; + local old_namespaces, namespaces = stanza.namespaces; + if old_namespaces then + namespaces = {}; + for k,v in pairs(old_namespaces) do namespaces[k] = v; end + end + local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags }; for i=1,#stanza do local child = stanza[i]; if child.name then - child = clone(child); + child = _clone(child); t_insert(tags, child); end t_insert(new, child); @@ -359,6 +408,13 @@ local function clone(stanza) return setmetatable(new, stanza_mt); end +local function clone(stanza) + if not is_stanza(stanza) then + error("bad argument to clone: expected stanza, got "..type(stanza)); + end + return _clone(stanza); +end + local function message(attr, body) if not body then return new_stanza("message", attr); @@ -367,12 +423,20 @@ local function message(attr, body) end end local function iq(attr) - if attr and not attr.id then attr.id = new_id(); end - return new_stanza("iq", attr or { id = new_id() }); + if not (attr and attr.id) then + error("iq stanzas require an id attribute"); + end + return new_stanza("iq", attr); end local function reply(orig) - return new_stanza(orig.name, orig.attr and { to = orig.attr.from, from = orig.attr.to, id = orig.attr.id, type = ((orig.name == "iq" and "result") or orig.attr.type) }); + return new_stanza(orig.name, + orig.attr and { + to = orig.attr.from, + from = orig.attr.to, + id = orig.attr.id, + type = ((orig.name == "iq" and "result") or orig.attr.type) + }); end local xmpp_stanzas_attr = { xmlns = xmlns_stanzas }; @@ -433,7 +497,6 @@ return { stanza_mt = stanza_mt; stanza = new_stanza; is_stanza = is_stanza; - new_id = new_id; preserialize = preserialize; deserialize = deserialize; clone = clone; |