aboutsummaryrefslogtreecommitdiffstats
path: root/util/stanza.lua
diff options
context:
space:
mode:
Diffstat (limited to 'util/stanza.lua')
-rw-r--r--util/stanza.lua161
1 files changed, 112 insertions, 49 deletions
diff --git a/util/stanza.lua b/util/stanza.lua
index 07365144..a90d56b3 100644
--- a/util/stanza.lua
+++ b/util/stanza.lua
@@ -7,6 +7,7 @@
--
+local error = error;
local t_insert = table.insert;
local t_remove = table.remove;
local t_concat = table.concat;
@@ -23,6 +24,8 @@ local s_sub = string.sub;
local s_find = string.find;
local os = os;
+local valid_utf8 = require "util.encodings".utf8.valid;
+
local do_pretty_printing = not os.getenv("WINDIR");
local getstyle, getstring;
if do_pretty_printing then
@@ -37,12 +40,52 @@ end
local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas";
local _ENV = nil;
+-- luacheck: std none
-local stanza_mt = { __type = "stanza" };
+local stanza_mt = { __name = "stanza" };
stanza_mt.__index = stanza_mt;
-local function new_stanza(name, attr)
- local stanza = { name = name, attr = attr or {}, tags = {} };
+local function check_name(name, name_type)
+ if type(name) ~= "string" then
+ error("invalid "..name_type.." name: expected string, got "..type(name));
+ elseif #name == 0 then
+ error("invalid "..name_type.." name: empty string");
+ elseif s_find(name, "[<>& '\"]") then
+ error("invalid "..name_type.." name: contains invalid characters");
+ elseif not valid_utf8(name) then
+ error("invalid "..name_type.." name: contains invalid utf8");
+ end
+end
+
+local function check_text(text, text_type)
+ if type(text) ~= "string" then
+ error("invalid "..text_type.." value: expected string, got "..type(text));
+ elseif not valid_utf8(text) then
+ error("invalid "..text_type.." value: contains invalid utf8");
+ end
+end
+
+local function check_attr(attr)
+ if attr ~= nil then
+ if type(attr) ~= "table" then
+ error("invalid attributes, expected table got "..type(attr));
+ end
+ for k, v in pairs(attr) do
+ check_name(k, "attribute");
+ check_text(v, "attribute");
+ if type(v) ~= "string" then
+ error("invalid attribute value for '"..k.."': expected string, got "..type(v));
+ elseif not valid_utf8(v) then
+ error("invalid attribute value for '"..k.."': contains invalid utf8");
+ end
+ end
+ end
+end
+
+local function new_stanza(name, attr, namespaces)
+ check_name(name, "tag");
+ check_attr(attr);
+ local stanza = { name = name, attr = attr or {}, namespaces = namespaces, tags = {} };
return setmetatable(stanza, stanza_mt);
end
@@ -58,8 +101,12 @@ function stanza_mt:body(text, attr)
return self:tag("body", attr):text(text);
end
-function stanza_mt:tag(name, attrs)
- local s = new_stanza(name, attrs);
+function stanza_mt:text_tag(name, text, attr, namespaces)
+ return self:tag(name, attr, namespaces):text(text):up();
+end
+
+function stanza_mt:tag(name, attr, namespaces)
+ local s = new_stanza(name, attr, namespaces);
local last_add = self.last_add;
if not last_add then last_add = {}; self.last_add = last_add; end
(last_add[#last_add] or self):add_direct_child(s);
@@ -68,8 +115,10 @@ function stanza_mt:tag(name, attrs)
end
function stanza_mt:text(text)
- local last_add = self.last_add;
- (last_add and last_add[#last_add] or self):add_direct_child(text);
+ if text ~= nil and text ~= "" then
+ local last_add = self.last_add;
+ (last_add and last_add[#last_add] or self):add_direct_child(text);
+ end
return self;
end
@@ -85,10 +134,13 @@ function stanza_mt:reset()
end
function stanza_mt:add_direct_child(child)
- if type(child) == "table" then
+ if is_stanza(child) then
t_insert(self.tags, child);
+ t_insert(self, child);
+ else
+ check_text(child, "text");
+ t_insert(self, child);
end
- t_insert(self, child);
end
function stanza_mt:add_child(child)
@@ -165,6 +217,7 @@ end
function stanza_mt:maptags(callback)
local tags, curr_tag = self.tags, 1;
local n_children, n_tags = #self, #tags;
+ local max_iterations = n_children + 1;
local i = 1;
while curr_tag <= n_tags and n_tags > 0 do
@@ -184,6 +237,11 @@ function stanza_mt:maptags(callback)
curr_tag = curr_tag + 1;
end
i = i + 1;
+ if i > max_iterations then
+ -- COMPAT: Hopefully temporary guard against #981 while we
+ -- figure out the root cause
+ error("Invalid stanza state! Please report this error.");
+ end
end
return self;
end
@@ -289,12 +347,6 @@ function stanza_mt.get_error(stanza)
return error_type, condition or "undefined-condition", text;
end
-local id = 0;
-local function new_id()
- id = id + 1;
- return "lx"..id;
-end
-
local function preserialize(stanza)
local s = { name = stanza.name, attr = stanza.attr };
for _, child in ipairs(stanza) do
@@ -307,51 +359,48 @@ local function preserialize(stanza)
return s;
end
-local function deserialize(stanza)
+stanza_mt.__freeze = preserialize;
+
+local function deserialize(serialized)
-- Set metatable
- if stanza then
- local attr = stanza.attr;
- for i=1,#attr do attr[i] = nil; end
+ if serialized then
+ local attr = serialized.attr;
local attrx = {};
- for att in pairs(attr) do
- if s_find(att, "|", 1, true) and not s_find(att, "\1", 1, true) then
- local ns,na = s_match(att, "^([^|]+)|(.+)$");
- attrx[ns.."\1"..na] = attr[att];
- attr[att] = nil;
+ for att, val in pairs(attr) do
+ if type(att) == "string" then
+ if s_find(att, "|", 1, true) and not s_find(att, "\1", 1, true) then
+ local ns,na = s_match(att, "^([^|]+)|(.+)$");
+ attrx[ns.."\1"..na] = val;
+ else
+ attrx[att] = val;
+ end
end
end
- for a,v in pairs(attrx) do
- attr[a] = v;
- end
- setmetatable(stanza, stanza_mt);
- for _, child in ipairs(stanza) do
+ local stanza = new_stanza(serialized.name, attrx);
+ for _, child in ipairs(serialized) do
if type(child) == "table" then
- deserialize(child);
- end
- end
- if not stanza.tags then
- -- Rebuild tags
- local tags = {};
- for _, child in ipairs(stanza) do
- if type(child) == "table" then
- t_insert(tags, child);
- end
+ stanza:add_direct_child(deserialize(child));
+ elseif type(child) == "string" then
+ stanza:add_direct_child(child);
end
- stanza.tags = tags;
end
+ return stanza;
end
-
- return stanza;
end
-local function clone(stanza)
+local function _clone(stanza)
local attr, tags = {}, {};
for k,v in pairs(stanza.attr) do attr[k] = v; end
- local new = { name = stanza.name, attr = attr, tags = tags };
+ local old_namespaces, namespaces = stanza.namespaces;
+ if old_namespaces then
+ namespaces = {};
+ for k,v in pairs(old_namespaces) do namespaces[k] = v; end
+ end
+ local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
for i=1,#stanza do
local child = stanza[i];
if child.name then
- child = clone(child);
+ child = _clone(child);
t_insert(tags, child);
end
t_insert(new, child);
@@ -359,6 +408,13 @@ local function clone(stanza)
return setmetatable(new, stanza_mt);
end
+local function clone(stanza)
+ if not is_stanza(stanza) then
+ error("bad argument to clone: expected stanza, got "..type(stanza));
+ end
+ return _clone(stanza);
+end
+
local function message(attr, body)
if not body then
return new_stanza("message", attr);
@@ -367,12 +423,20 @@ local function message(attr, body)
end
end
local function iq(attr)
- if attr and not attr.id then attr.id = new_id(); end
- return new_stanza("iq", attr or { id = new_id() });
+ if not (attr and attr.id) then
+ error("iq stanzas require an id attribute");
+ end
+ return new_stanza("iq", attr);
end
local function reply(orig)
- return new_stanza(orig.name, orig.attr and { to = orig.attr.from, from = orig.attr.to, id = orig.attr.id, type = ((orig.name == "iq" and "result") or orig.attr.type) });
+ return new_stanza(orig.name,
+ orig.attr and {
+ to = orig.attr.from,
+ from = orig.attr.to,
+ id = orig.attr.id,
+ type = ((orig.name == "iq" and "result") or orig.attr.type)
+ });
end
local xmpp_stanzas_attr = { xmlns = xmlns_stanzas };
@@ -433,7 +497,6 @@ return {
stanza_mt = stanza_mt;
stanza = new_stanza;
is_stanza = is_stanza;
- new_id = new_id;
preserialize = preserialize;
deserialize = deserialize;
clone = clone;