aboutsummaryrefslogtreecommitdiffstats
path: root/util/stanza.lua
diff options
context:
space:
mode:
Diffstat (limited to 'util/stanza.lua')
-rw-r--r--util/stanza.lua85
1 files changed, 74 insertions, 11 deletions
diff --git a/util/stanza.lua b/util/stanza.lua
index 07365144..11398179 100644
--- a/util/stanza.lua
+++ b/util/stanza.lua
@@ -7,6 +7,7 @@
--
+local error = error;
local t_insert = table.insert;
local t_remove = table.remove;
local t_concat = table.concat;
@@ -23,6 +24,8 @@ local s_sub = string.sub;
local s_find = string.find;
local os = os;
+local valid_utf8 = require "util.encodings".utf8.valid;
+
local do_pretty_printing = not os.getenv("WINDIR");
local getstyle, getstring;
if do_pretty_printing then
@@ -37,12 +40,52 @@ end
local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas";
local _ENV = nil;
+-- luacheck: std none
-local stanza_mt = { __type = "stanza" };
+local stanza_mt = { __name = "stanza" };
stanza_mt.__index = stanza_mt;
-local function new_stanza(name, attr)
- local stanza = { name = name, attr = attr or {}, tags = {} };
+local function check_name(name, name_type)
+ if type(name) ~= "string" then
+ error("invalid "..name_type.." name: expected string, got "..type(name));
+ elseif #name == 0 then
+ error("invalid "..name_type.." name: empty string");
+ elseif s_find(name, "[<>& '\"]") then
+ error("invalid "..name_type.." name: contains invalid characters");
+ elseif not valid_utf8(name) then
+ error("invalid "..name_type.." name: contains invalid utf8");
+ end
+end
+
+local function check_text(text, text_type)
+ if type(text) ~= "string" then
+ error("invalid "..text_type.." value: expected string, got "..type(text));
+ elseif not valid_utf8(text) then
+ error("invalid "..text_type.." value: contains invalid utf8");
+ end
+end
+
+local function check_attr(attr)
+ if attr ~= nil then
+ if type(attr) ~= "table" then
+ error("invalid attributes, expected table got "..type(attr));
+ end
+ for k, v in pairs(attr) do
+ check_name(k, "attribute");
+ check_text(v, "attribute");
+ if type(v) ~= "string" then
+ error("invalid attribute value for '"..k.."': expected string, got "..type(v));
+ elseif not valid_utf8(v) then
+ error("invalid attribute value for '"..k.."': contains invalid utf8");
+ end
+ end
+ end
+end
+
+local function new_stanza(name, attr, namespaces)
+ check_name(name, "tag");
+ check_attr(attr);
+ local stanza = { name = name, attr = attr or {}, namespaces = namespaces, tags = {} };
return setmetatable(stanza, stanza_mt);
end
@@ -58,8 +101,12 @@ function stanza_mt:body(text, attr)
return self:tag("body", attr):text(text);
end
-function stanza_mt:tag(name, attrs)
- local s = new_stanza(name, attrs);
+function stanza_mt:text_tag(name, text, attr, namespaces)
+ return self:tag(name, attr, namespaces):text(text):up();
+end
+
+function stanza_mt:tag(name, attr, namespaces)
+ local s = new_stanza(name, attr, namespaces);
local last_add = self.last_add;
if not last_add then last_add = {}; self.last_add = last_add; end
(last_add[#last_add] or self):add_direct_child(s);
@@ -68,8 +115,10 @@ function stanza_mt:tag(name, attrs)
end
function stanza_mt:text(text)
- local last_add = self.last_add;
- (last_add and last_add[#last_add] or self):add_direct_child(text);
+ if text ~= nil and text ~= "" then
+ local last_add = self.last_add;
+ (last_add and last_add[#last_add] or self):add_direct_child(text);
+ end
return self;
end
@@ -85,10 +134,13 @@ function stanza_mt:reset()
end
function stanza_mt:add_direct_child(child)
- if type(child) == "table" then
+ if is_stanza(child) then
t_insert(self.tags, child);
+ t_insert(self, child);
+ else
+ check_text(child, "text");
+ t_insert(self, child);
end
- t_insert(self, child);
end
function stanza_mt:add_child(child)
@@ -347,7 +399,12 @@ end
local function clone(stanza)
local attr, tags = {}, {};
for k,v in pairs(stanza.attr) do attr[k] = v; end
- local new = { name = stanza.name, attr = attr, tags = tags };
+ local old_namespaces, namespaces = stanza.namespaces;
+ if old_namespaces then
+ namespaces = {};
+ for k,v in pairs(old_namespaces) do namespaces[k] = v; end
+ end
+ local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
for i=1,#stanza do
local child = stanza[i];
if child.name then
@@ -372,7 +429,13 @@ local function iq(attr)
end
local function reply(orig)
- return new_stanza(orig.name, orig.attr and { to = orig.attr.from, from = orig.attr.to, id = orig.attr.id, type = ((orig.name == "iq" and "result") or orig.attr.type) });
+ return new_stanza(orig.name,
+ orig.attr and {
+ to = orig.attr.from,
+ from = orig.attr.to,
+ id = orig.attr.id,
+ type = ((orig.name == "iq" and "result") or orig.attr.type)
+ });
end
local xmpp_stanzas_attr = { xmlns = xmlns_stanzas };