diff options
Diffstat (limited to 'util/stanza.lua')
-rw-r--r-- | util/stanza.lua | 106 |
1 files changed, 86 insertions, 20 deletions
diff --git a/util/stanza.lua b/util/stanza.lua index 07365144..85c89d43 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -7,6 +7,7 @@ -- +local error = error; local t_insert = table.insert; local t_remove = table.remove; local t_concat = table.concat; @@ -23,6 +24,8 @@ local s_sub = string.sub; local s_find = string.find; local os = os; +local valid_utf8 = require "util.encodings".utf8.valid; + local do_pretty_printing = not os.getenv("WINDIR"); local getstyle, getstring; if do_pretty_printing then @@ -37,12 +40,52 @@ end local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas"; local _ENV = nil; +-- luacheck: std none -local stanza_mt = { __type = "stanza" }; +local stanza_mt = { __name = "stanza" }; stanza_mt.__index = stanza_mt; -local function new_stanza(name, attr) - local stanza = { name = name, attr = attr or {}, tags = {} }; +local function check_name(name, name_type) + if type(name) ~= "string" then + error("invalid "..name_type.." name: expected string, got "..type(name)); + elseif #name == 0 then + error("invalid "..name_type.." name: empty string"); + elseif s_find(name, "[<>& '\"]") then + error("invalid "..name_type.." name: contains invalid characters"); + elseif not valid_utf8(name) then + error("invalid "..name_type.." name: contains invalid utf8"); + end +end + +local function check_text(text, text_type) + if type(text) ~= "string" then + error("invalid "..text_type.." value: expected string, got "..type(text)); + elseif not valid_utf8(text) then + error("invalid "..text_type.." value: contains invalid utf8"); + end +end + +local function check_attr(attr) + if attr ~= nil then + if type(attr) ~= "table" then + error("invalid attributes, expected table got "..type(attr)); + end + for k, v in pairs(attr) do + check_name(k, "attribute"); + check_text(v, "attribute"); + if type(v) ~= "string" then + error("invalid attribute value for '"..k.."': expected string, got "..type(v)); + elseif not valid_utf8(v) then + error("invalid attribute value for '"..k.."': contains invalid utf8"); + end + end + end +end + +local function new_stanza(name, attr, namespaces) + check_name(name, "tag"); + check_attr(attr); + local stanza = { name = name, attr = attr or {}, namespaces = namespaces, tags = {} }; return setmetatable(stanza, stanza_mt); end @@ -58,8 +101,12 @@ function stanza_mt:body(text, attr) return self:tag("body", attr):text(text); end -function stanza_mt:tag(name, attrs) - local s = new_stanza(name, attrs); +function stanza_mt:text_tag(name, text, attr, namespaces) + return self:tag(name, attr, namespaces):text(text):up(); +end + +function stanza_mt:tag(name, attr, namespaces) + local s = new_stanza(name, attr, namespaces); local last_add = self.last_add; if not last_add then last_add = {}; self.last_add = last_add; end (last_add[#last_add] or self):add_direct_child(s); @@ -68,8 +115,10 @@ function stanza_mt:tag(name, attrs) end function stanza_mt:text(text) - local last_add = self.last_add; - (last_add and last_add[#last_add] or self):add_direct_child(text); + if text ~= nil and text ~= "" then + local last_add = self.last_add; + (last_add and last_add[#last_add] or self):add_direct_child(text); + end return self; end @@ -85,10 +134,13 @@ function stanza_mt:reset() end function stanza_mt:add_direct_child(child) - if type(child) == "table" then + if is_stanza(child) then t_insert(self.tags, child); + t_insert(self, child); + else + check_text(child, "text"); + t_insert(self, child); end - t_insert(self, child); end function stanza_mt:add_child(child) @@ -165,6 +217,7 @@ end function stanza_mt:maptags(callback) local tags, curr_tag = self.tags, 1; local n_children, n_tags = #self, #tags; + local max_iterations = n_children + 1; local i = 1; while curr_tag <= n_tags and n_tags > 0 do @@ -184,6 +237,11 @@ function stanza_mt:maptags(callback) curr_tag = curr_tag + 1; end i = i + 1; + if i > max_iterations then + -- COMPAT: Hopefully temporary guard against #981 while we + -- figure out the root cause + error("Invalid stanza state! Please report this error."); + end end return self; end @@ -289,12 +347,6 @@ function stanza_mt.get_error(stanza) return error_type, condition or "undefined-condition", text; end -local id = 0; -local function new_id() - id = id + 1; - return "lx"..id; -end - local function preserialize(stanza) local s = { name = stanza.name, attr = stanza.attr }; for _, child in ipairs(stanza) do @@ -307,6 +359,8 @@ local function preserialize(stanza) return s; end +stanza_mt.__freeze = preserialize; + local function deserialize(stanza) -- Set metatable if stanza then @@ -347,7 +401,12 @@ end local function clone(stanza) local attr, tags = {}, {}; for k,v in pairs(stanza.attr) do attr[k] = v; end - local new = { name = stanza.name, attr = attr, tags = tags }; + local old_namespaces, namespaces = stanza.namespaces; + if old_namespaces then + namespaces = {}; + for k,v in pairs(old_namespaces) do namespaces[k] = v; end + end + local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags }; for i=1,#stanza do local child = stanza[i]; if child.name then @@ -367,12 +426,20 @@ local function message(attr, body) end end local function iq(attr) - if attr and not attr.id then attr.id = new_id(); end - return new_stanza("iq", attr or { id = new_id() }); + if not (attr and attr.id) then + error("iq stanzas require an id attribute"); + end + return new_stanza("iq", attr); end local function reply(orig) - return new_stanza(orig.name, orig.attr and { to = orig.attr.from, from = orig.attr.to, id = orig.attr.id, type = ((orig.name == "iq" and "result") or orig.attr.type) }); + return new_stanza(orig.name, + orig.attr and { + to = orig.attr.from, + from = orig.attr.to, + id = orig.attr.id, + type = ((orig.name == "iq" and "result") or orig.attr.type) + }); end local xmpp_stanzas_attr = { xmlns = xmlns_stanzas }; @@ -433,7 +500,6 @@ return { stanza_mt = stanza_mt; stanza = new_stanza; is_stanza = is_stanza; - new_id = new_id; preserialize = preserialize; deserialize = deserialize; clone = clone; |