aboutsummaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/adhoc.lua2
-rw-r--r--util/array.lua2
-rw-r--r--util/async.lua175
-rw-r--r--util/cache.lua26
-rw-r--r--util/datamanager.lua9
-rw-r--r--util/debug.lua21
-rw-r--r--util/dependencies.lua10
-rw-r--r--util/envload.lua2
-rw-r--r--util/events.lua13
-rw-r--r--util/format.lua20
-rw-r--r--util/import.lua2
-rw-r--r--util/indexedbheap.lua157
-rw-r--r--util/ip.lua250
-rw-r--r--util/iterators.lua4
-rw-r--r--util/json.lua5
-rw-r--r--util/multitable.lua4
-rw-r--r--util/openssl.lua2
-rw-r--r--util/pluginloader.lua1
-rw-r--r--util/pubsub.lua123
-rw-r--r--util/random.lua4
-rw-r--r--util/sasl.lua8
-rw-r--r--util/sasl/anonymous.lua2
-rw-r--r--util/sasl/scram.lua21
-rw-r--r--util/set.lua2
-rw-r--r--util/sql.lua49
-rw-r--r--util/stanza.lua25
-rw-r--r--util/template.lua4
-rw-r--r--util/timer.lua142
-rw-r--r--util/vcard.lua572
-rw-r--r--util/xml.lua23
-rw-r--r--util/xmppstream.lua14
31 files changed, 1367 insertions, 327 deletions
diff --git a/util/adhoc.lua b/util/adhoc.lua
index 17c9eee5..d81b8242 100644
--- a/util/adhoc.lua
+++ b/util/adhoc.lua
@@ -1,3 +1,5 @@
+-- luacheck: ignore 212/self
+
local function new_simple_form(form, result_handler)
return function(self, data, state)
if state then
diff --git a/util/array.lua b/util/array.lua
index 150b4355..1a8ffec7 100644
--- a/util/array.lua
+++ b/util/array.lua
@@ -19,7 +19,7 @@ local type = type;
local array = {};
local array_base = {};
local array_methods = {};
-local array_mt = { __index = array_methods, __tostring = function (self) return "{"..self:concat(", ").."}"; end };
+local array_mt = { __index = array_methods, __name = "array", __tostring = function (self) return "{"..self:concat(", ").."}"; end };
local function new_array(self, t, _s, _var)
if type(t) == "function" then -- Assume iterator
diff --git a/util/async.lua b/util/async.lua
new file mode 100644
index 00000000..992797b8
--- /dev/null
+++ b/util/async.lua
@@ -0,0 +1,175 @@
+local log = require "util.logger".init("util.async");
+
+local function checkthread()
+ local thread, main = coroutine.running();
+ if not thread or main then
+ error("Not running in an async context, see https://prosody.im/doc/developers/util/async");
+ end
+ return thread;
+end
+
+local function runner_continue(thread)
+ -- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure)
+ if coroutine.status(thread) ~= "suspended" then -- This should suffice
+ return false;
+ end
+ local ok, state, runner = coroutine.resume(thread);
+ if not ok then
+ -- Running the coroutine failed, which means we have to find the runner manually,
+ -- in order to inform the error handler
+ local level = 0;
+ while debug.getinfo(thread, level, "") do level = level + 1; end
+ ok, runner = debug.getlocal(thread, level-1, 1);
+ local error_handler = runner.watchers.error;
+ if error_handler then error_handler(runner, debug.traceback(thread, state)); end
+ elseif state == "ready" then
+ -- If state is 'ready', it is our responsibility to update runner.state from 'waiting'.
+ -- We also have to :run(), because the queue might have further items that will not be
+ -- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer).
+ runner.state = "ready";
+ runner:run();
+ end
+ return true;
+end
+
+local function waiter(num)
+ local thread = checkthread();
+ num = num or 1;
+ local waiting;
+ return function ()
+ if num == 0 then return; end -- already done
+ waiting = true;
+ coroutine.yield("wait");
+ end, function ()
+ num = num - 1;
+ if num == 0 and waiting then
+ runner_continue(thread);
+ elseif num < 0 then
+ error("done() called too many times");
+ end
+ end;
+end
+
+local function guarder()
+ local guards = {};
+ return function (id, func)
+ local thread = checkthread();
+ local guard = guards[id];
+ if not guard then
+ guard = {};
+ guards[id] = guard;
+ log("debug", "New guard!");
+ else
+ table.insert(guard, thread);
+ log("debug", "Guarded. %d threads waiting.", #guard)
+ coroutine.yield("wait");
+ end
+ local function exit()
+ local next_waiting = table.remove(guard, 1);
+ if next_waiting then
+ log("debug", "guard: Executing next waiting thread (%d left)", #guard)
+ runner_continue(next_waiting);
+ else
+ log("debug", "Guard off duty.")
+ guards[id] = nil;
+ end
+ end
+ if func then
+ func();
+ exit();
+ return;
+ end
+ return exit;
+ end;
+end
+
+local runner_mt = {};
+runner_mt.__index = runner_mt;
+
+local function runner_create_thread(func, self)
+ local thread = coroutine.create(function (self) -- luacheck: ignore 432/self
+ while true do
+ func(coroutine.yield("ready", self));
+ end
+ end);
+ assert(coroutine.resume(thread, self)); -- Start it up, it will return instantly to wait for the first input
+ return thread;
+end
+
+local empty_watchers = {};
+local function runner(func, watchers, data)
+ return setmetatable({ func = func, thread = false, state = "ready", notified_state = "ready",
+ queue = {}, watchers = watchers or empty_watchers, data = data }
+ , runner_mt);
+end
+
+-- Add a task item for the runner to process
+function runner_mt:run(input)
+ if input ~= nil then
+ table.insert(self.queue, input);
+ end
+ if self.state ~= "ready" then
+ -- The runner is busy. Indicate that the task item has been
+ -- queued, and return information about the current runner state
+ return true, self.state, #self.queue;
+ end
+
+ local q, thread = self.queue, self.thread;
+ if not thread or coroutine.status(thread) == "dead" then
+ -- Create a new coroutine for this runner
+ thread = runner_create_thread(self.func, self);
+ self.thread = thread;
+ end
+
+ -- Process task item(s) while the queue is not empty, and we're not blocked
+ local n, state, err = #q, self.state, nil;
+ self.state = "running";
+ while n > 0 and state == "ready" do
+ local consumed;
+ -- Loop through queue items, and attempt to run them
+ for i = 1,n do
+ local queued_input = q[i];
+ local ok, new_state = coroutine.resume(thread, queued_input);
+ if not ok then
+ -- There was an error running the coroutine, save the error, mark runner as ready to begin again
+ consumed, state, err = i, "ready", debug.traceback(thread, new_state);
+ self.thread = nil;
+ break;
+ elseif new_state == "wait" then
+ -- Runner is blocked on waiting for a task item to complete
+ consumed, state = i, "waiting";
+ break;
+ end
+ end
+ -- Loop ended - either queue empty because all tasks passed without blocking (consumed == nil)
+ -- or runner is blocked/errored, and consumed will contain the number of tasks processed so far
+ if not consumed then consumed = n; end
+ -- Remove consumed items from the queue array
+ if q[n+1] ~= nil then
+ n = #q;
+ end
+ for i = 1, n do
+ q[i] = q[consumed+i];
+ end
+ n = #q;
+ end
+ -- Runner processed all items it can, so save current runner state
+ self.state = state;
+ if err or state ~= self.notified_state then
+ if err then
+ state = "error"
+ else
+ self.notified_state = state;
+ end
+ local handler = self.watchers[state];
+ if handler then handler(self, err); end
+ end
+ return true, state, n;
+end
+
+-- Add a task item to the queue without invoking the runner, even if it is idle
+function runner_mt:enqueue(input)
+ table.insert(self.queue, input);
+end
+
+return { waiter = waiter, guarder = guarder, runner = runner };
diff --git a/util/cache.lua b/util/cache.lua
index 9c141bb6..a5fd5e6d 100644
--- a/util/cache.lua
+++ b/util/cache.lua
@@ -116,6 +116,25 @@ function cache_methods:tail()
return tail.key, tail.value;
end
+function cache_methods:resize(new_size)
+ new_size = assert(tonumber(new_size), "cache size must be a number");
+ new_size = math.floor(new_size);
+ assert(new_size > 0, "cache size must be greater than zero");
+ local on_evict = self._on_evict;
+ while self._count > new_size do
+ local tail = self._tail;
+ local evicted_key, evicted_value = tail.key, tail.value;
+ if on_evict ~= nil and (on_evict == false or on_evict(evicted_key, evicted_value) == false) then
+ -- Cache is full, and we're not allowed to evict
+ return false;
+ end
+ _remove(self, tail);
+ self._data[evicted_key] = nil;
+ end
+ self.size = new_size;
+ return true;
+end
+
function cache_methods:table()
--luacheck: ignore 212/t
if not self.proxy_table then
@@ -139,6 +158,13 @@ function cache_methods:table()
return self.proxy_table;
end
+function cache_methods:clear()
+ self._data = {};
+ self._count = 0;
+ self._head = nil;
+ self._tail = nil;
+end
+
local function new(size, on_evict)
size = assert(tonumber(size), "cache size must be a number");
size = math.floor(size);
diff --git a/util/datamanager.lua b/util/datamanager.lua
index bd8fb7bb..bd402b51 100644
--- a/util/datamanager.lua
+++ b/util/datamanager.lua
@@ -42,7 +42,7 @@ end);
local _ENV = nil;
---- utils -----
-local encode, decode;
+local encode, decode, store_encode;
do
local urlcodes = setmetatable({}, { __index = function (t, k) t[k] = char(tonumber(k, 16)); return t[k]; end });
@@ -53,6 +53,12 @@ do
encode = function (s)
return s and (s:gsub("%W", function (c) return format("%%%02x", c:byte()); end));
end
+
+ -- Special encode function for store names, which historically were unencoded.
+ -- All currently known stores use a-z and underscore, so this one preserves underscores.
+ store_encode = function (s)
+ return s and (s:gsub("[^_%w]", function (c) return format("%%%02x", c:byte()); end));
+ end
end
if not atomic_append then
@@ -119,6 +125,7 @@ local function getpath(username, host, datastore, ext, create)
ext = ext or "dat";
host = (host and encode(host)) or "_global";
username = username and encode(username);
+ datastore = store_encode(datastore);
if username then
if create then mkdir(mkdir(mkdir(data_path).."/"..host).."/"..datastore); end
return format("%s/%s/%s/%s.%s", data_path, host, datastore, username, ext);
diff --git a/util/debug.lua b/util/debug.lua
index 00f476d0..9a28395a 100644
--- a/util/debug.lua
+++ b/util/debug.lua
@@ -47,6 +47,7 @@ local function get_upvalues_table(func)
for upvalue_num = 1, math.huge do
local name, value = debug.getupvalue(func, upvalue_num);
if not name then break; end
+ if name == "" then name = ("[%d]"):format(upvalue_num); end
table.insert(upvalues, { name = name, value = value });
end
end
@@ -112,7 +113,9 @@ end
local function build_source_boundary_marker(last_source_desc)
local padding = string.rep("-", math.floor(((optimal_line_length - 6) - #last_source_desc)/2));
- return getstring(styles.boundary_padding, "v"..padding).." "..getstring(styles.filename, last_source_desc).." "..getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v "));
+ return getstring(styles.boundary_padding, "v"..padding).." "..
+ getstring(styles.filename, last_source_desc).." "..
+ getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v "));
end
local function _traceback(thread, message, level)
@@ -142,9 +145,9 @@ local function _traceback(thread, message, level)
local last_source_desc;
local lines = {};
- for nlevel, level in ipairs(levels) do
- local info = level.info;
- local line = "...";
+ for nlevel, current_level in ipairs(levels) do
+ local info = current_level.info;
+ local line;
local func_type = info.namewhat.." ";
local source_desc = (info.short_src == "[C]" and "C code") or info.short_src or "Unknown";
if func_type == " " then func_type = ""; end;
@@ -160,7 +163,9 @@ local function _traceback(thread, message, level)
if func_type == "global " or func_type == "local " then
func_type = func_type.."function ";
end
- line = "[Lua] "..getstring(styles.location, info.short_src.." line "..info.currentline).." in "..func_type..getstring(styles.funcname, name).." (defined on line "..info.linedefined..")";
+ line = "[Lua] "..getstring(styles.location, info.short_src.." line "..
+ info.currentline).." in "..func_type..getstring(styles.funcname, name)..
+ " (defined on line "..info.linedefined..")";
end
if source_desc ~= last_source_desc then -- Venturing into a new source, add marker for previous
last_source_desc = source_desc;
@@ -169,13 +174,13 @@ local function _traceback(thread, message, level)
nlevel = nlevel-1;
table.insert(lines, "\t"..(nlevel==0 and ">" or " ")..getstring(styles.level_num, "("..nlevel..") ")..line);
local npadding = (" "):rep(#tostring(nlevel));
- if level.locals then
- local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding);
+ if current_level.locals then
+ local locals_str = string_from_var_table(current_level.locals, optimal_line_length, "\t "..npadding);
if locals_str then
table.insert(lines, "\t "..npadding.."Locals: "..locals_str);
end
end
- local upvalues_str = string_from_var_table(level.upvalues, optimal_line_length, "\t "..npadding);
+ local upvalues_str = string_from_var_table(current_level.upvalues, optimal_line_length, "\t "..npadding);
if upvalues_str then
table.insert(lines, "\t "..npadding.."Upvals: "..upvalues_str);
end
diff --git a/util/dependencies.lua b/util/dependencies.lua
index de840241..9b0afd77 100644
--- a/util/dependencies.lua
+++ b/util/dependencies.lua
@@ -28,7 +28,7 @@ local function missingdep(name, sources, msg)
end
print("");
print(msg or (name.." is required for Prosody to run, so we will now exit."));
- print("More help can be found on our website, at http://prosody.im/doc/depends");
+ print("More help can be found on our website, at https://prosody.im/doc/depends");
print("**************************");
print("");
end
@@ -40,7 +40,7 @@ end
package.preload["util.ztact"] = function ()
if not package.loaded["core.loggingmanager"] then
error("util.ztact has been removed from Prosody and you need to fix your config "
- .."file. More information can be found at http://prosody.im/doc/packagers#ztact", 0);
+ .."file. More information can be found at https://prosody.im/doc/packagers#ztact", 0);
else
error("module 'util.ztact' has been deprecated in Prosody 0.8.");
end
@@ -156,7 +156,7 @@ local function log_warnings()
if ssl then
local major, minor, veryminor, patched = ssl._VERSION:match("(%d+)%.(%d+)%.?(%d*)(M?)");
if not major or ((tonumber(major) == 0 and (tonumber(minor) or 0) <= 3 and (tonumber(veryminor) or 0) <= 2) and patched ~= "M") then
- prosody.log("error", "This version of LuaSec contains a known bug that causes disconnects, see http://prosody.im/doc/depends");
+ prosody.log("error", "This version of LuaSec contains a known bug that causes disconnects, see https://prosody.im/doc/depends");
end
end
local lxp = softreq"lxp";
@@ -165,7 +165,7 @@ local function log_warnings()
prosody.log("error", "The version of LuaExpat on your system leaves Prosody "
.."vulnerable to denial-of-service attacks. You should upgrade to "
.."LuaExpat 1.3.0 or higher as soon as possible. See "
- .."http://prosody.im/doc/depends#luaexpat for more information.");
+ .."https://prosody.im/doc/depends#luaexpat for more information.");
end
if not lxp.new({}).getcurrentbytecount then
prosody.log("error", "The version of LuaExpat on your system does not support "
@@ -173,7 +173,7 @@ local function log_warnings()
.."networks (e.g. the internet) vulnerable to denial-of-service "
.."attacks. You should upgrade to LuaExpat 1.3.0 or higher as "
.."soon as possible. See "
- .."http://prosody.im/doc/depends#luaexpat for more information.");
+ .."https://prosody.im/doc/depends#luaexpat for more information.");
end
end
end
diff --git a/util/envload.lua b/util/envload.lua
index 926f20c0..6182a1f9 100644
--- a/util/envload.lua
+++ b/util/envload.lua
@@ -4,7 +4,7 @@
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
--- luacheck: ignore 113/setfenv
+-- luacheck: ignore 113/setfenv 113/loadstring
local load, loadstring, setfenv = load, loadstring, setfenv;
local io_open = io.open;
diff --git a/util/events.lua b/util/events.lua
index e2943e44..a71d118f 100644
--- a/util/events.lua
+++ b/util/events.lua
@@ -26,7 +26,7 @@ local function new()
-- Event map: event_map[handler_function] = priority_number
local event_map = {};
-- Called on-demand to build handlers entries
- local function _rebuild_index(handlers, event)
+ local function _rebuild_index(self, event)
local _handlers = event_map[event];
if not _handlers or next(_handlers) == nil then return; end
local index = {};
@@ -34,7 +34,7 @@ local function new()
t_insert(index, handler);
end
t_sort(index, function(a, b) return _handlers[a] > _handlers[b]; end);
- handlers[event] = index;
+ self[event] = index;
return index;
end;
setmetatable(handlers, { __index = _rebuild_index });
@@ -61,13 +61,13 @@ local function new()
local function get_handlers(event)
return handlers[event];
end;
- local function add_handlers(handlers)
- for event, handler in pairs(handlers) do
+ local function add_handlers(self)
+ for event, handler in pairs(self) do
add_handler(event, handler);
end
end;
- local function remove_handlers(handlers)
- for event, handler in pairs(handlers) do
+ local function remove_handlers(self)
+ for event, handler in pairs(self) do
remove_handler(event, handler);
end
end;
@@ -81,6 +81,7 @@ local function new()
end
end;
local function fire_event(event_name, event_data)
+ -- luacheck: ignore 432/event_name 432/event_data
local w = wrappers[event_name] or global_wrappers;
if w then
local curr_wrapper = #w;
diff --git a/util/format.lua b/util/format.lua
index 5f2b12be..c5e513fa 100644
--- a/util/format.lua
+++ b/util/format.lua
@@ -4,11 +4,10 @@
local tostring = tostring;
local select = select;
-local assert = assert;
-local unpack = unpack;
+local unpack = table.unpack or unpack; -- luacheck: ignore 113/unpack
local type = type;
-local function format(format, ...)
+local function format(formatstring, ...)
local args, args_length = { ... }, select('#', ...);
-- format specifier spec:
@@ -25,7 +24,7 @@ local function format(format, ...)
-- process each format specifier
local i = 0;
- format = format:gsub("%%[^cdiouxXaAeEfgGqs%%]*[cdiouxXaAeEfgGqs%%]", function(spec)
+ formatstring = formatstring:gsub("%%[^cdiouxXaAeEfgGqs%%]*[cdiouxXaAeEfgGqs%%]", function(spec)
if spec ~= "%%" then
i = i + 1;
local arg = args[i];
@@ -54,21 +53,12 @@ local function format(format, ...)
else
args[i] = tostring(arg);
end
- format = format .. " [%s]"
+ formatstring = formatstring .. " [%s]"
end
- return format:format(unpack(args));
-end
-
-local function test()
- assert(format("%s", "hello") == "hello");
- assert(format("%s") == "<nil>");
- assert(format("%s", true) == "true");
- assert(format("%d", true) == "[true]");
- assert(format("%%", true) == "% [true]");
+ return formatstring:format(unpack(args));
end
return {
format = format;
- test = test;
};
diff --git a/util/import.lua b/util/import.lua
index c2b9dce1..12d1957f 100644
--- a/util/import.lua
+++ b/util/import.lua
@@ -8,7 +8,7 @@
-local unpack = table.unpack or unpack; --luacheck: ignore 113
+local unpack = table.unpack or unpack; --luacheck: ignore 113 143
local t_insert = table.insert;
function import(module, ...)
local m = package.loaded[module] or require(module);
diff --git a/util/indexedbheap.lua b/util/indexedbheap.lua
new file mode 100644
index 00000000..7f193d54
--- /dev/null
+++ b/util/indexedbheap.lua
@@ -0,0 +1,157 @@
+
+local setmetatable = setmetatable;
+local math_floor = math.floor;
+local t_remove = table.remove;
+
+local function _heap_insert(self, item, sync, item2, index)
+ local pos = #self + 1;
+ while true do
+ local half_pos = math_floor(pos / 2);
+ if half_pos == 0 or item > self[half_pos] then break; end
+ self[pos] = self[half_pos];
+ sync[pos] = sync[half_pos];
+ index[sync[pos]] = pos;
+ pos = half_pos;
+ end
+ self[pos] = item;
+ sync[pos] = item2;
+ index[item2] = pos;
+end
+
+local function _percolate_up(self, k, sync, index)
+ local tmp = self[k];
+ local tmp_sync = sync[k];
+ while k ~= 1 do
+ local parent = math_floor(k/2);
+ if tmp < self[parent] then break; end
+ self[k] = self[parent];
+ sync[k] = sync[parent];
+ index[sync[k]] = k;
+ k = parent;
+ end
+ self[k] = tmp;
+ sync[k] = tmp_sync;
+ index[tmp_sync] = k;
+ return k;
+end
+
+local function _percolate_down(self, k, sync, index)
+ local tmp = self[k];
+ local tmp_sync = sync[k];
+ local size = #self;
+ local child = 2*k;
+ while 2*k <= size do
+ if child ~= size and self[child] > self[child + 1] then
+ child = child + 1;
+ end
+ if tmp > self[child] then
+ self[k] = self[child];
+ sync[k] = sync[child];
+ index[sync[k]] = k;
+ else
+ break;
+ end
+
+ k = child;
+ child = 2*k;
+ end
+ self[k] = tmp;
+ sync[k] = tmp_sync;
+ index[tmp_sync] = k;
+ return k;
+end
+
+local function _heap_pop(self, sync, index)
+ local size = #self;
+ if size == 0 then return nil; end
+
+ local result = self[1];
+ local result_sync = sync[1];
+ index[result_sync] = nil;
+ if size == 1 then
+ self[1] = nil;
+ sync[1] = nil;
+ return result, result_sync;
+ end
+ self[1] = t_remove(self);
+ sync[1] = t_remove(sync);
+ index[sync[1]] = 1;
+
+ _percolate_down(self, 1, sync, index);
+
+ return result, result_sync;
+end
+
+local indexed_heap = {};
+
+function indexed_heap:insert(item, priority, id)
+ if id == nil then
+ id = self.current_id;
+ self.current_id = id + 1;
+ end
+ self.items[id] = item;
+ _heap_insert(self.priorities, priority, self.ids, id, self.index);
+ return id;
+end
+function indexed_heap:pop()
+ local priority, id = _heap_pop(self.priorities, self.ids, self.index);
+ if id then
+ local item = self.items[id];
+ self.items[id] = nil;
+ return priority, item, id;
+ end
+end
+function indexed_heap:peek()
+ return self.priorities[1];
+end
+function indexed_heap:reprioritize(id, priority)
+ local k = self.index[id];
+ if k == nil then return; end
+ self.priorities[k] = priority;
+
+ k = _percolate_up(self.priorities, k, self.ids, self.index);
+ _percolate_down(self.priorities, k, self.ids, self.index);
+end
+function indexed_heap:remove_index(k)
+ local result = self.priorities[k];
+ if result == nil then return; end
+
+ local result_sync = self.ids[k];
+ local item = self.items[result_sync];
+ local size = #self.priorities;
+
+ self.priorities[k] = self.priorities[size];
+ self.ids[k] = self.ids[size];
+ self.index[self.ids[k]] = k;
+
+ t_remove(self.priorities);
+ t_remove(self.ids);
+
+ self.index[result_sync] = nil;
+ self.items[result_sync] = nil;
+
+ if size > k then
+ k = _percolate_up(self.priorities, k, self.ids, self.index);
+ _percolate_down(self.priorities, k, self.ids, self.index);
+ end
+
+ return result, item, result_sync;
+end
+function indexed_heap:remove(id)
+ return self:remove_index(self.index[id]);
+end
+
+local mt = { __index = indexed_heap };
+
+local _M = {
+ create = function()
+ return setmetatable({
+ ids = {}; -- heap of ids, sync'd with priorities
+ items = {}; -- map id->items
+ priorities = {}; -- heap of priorities
+ index = {}; -- map of id->index of id in ids
+ current_id = 1.5
+ }, mt);
+ end
+};
+return _M;
diff --git a/util/ip.lua b/util/ip.lua
index 81a98ef7..0ec9e297 100644
--- a/util/ip.lua
+++ b/util/ip.lua
@@ -5,69 +5,76 @@
-- COPYING file in the source package for more information.
--
+local net = require "util.net";
+local hex = require "util.hex";
+
local ip_methods = {};
-local ip_mt = { __index = function (ip, key) return (ip_methods[key])(ip); end,
- __tostring = function (ip) return ip.addr; end,
- __eq = function (ipA, ipB) return ipA.addr == ipB.addr; end};
-local hex2bits = { ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111", ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011", ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111" };
+
+local ip_mt = {
+ __index = function (ip, key)
+ local method = ip_methods[key];
+ if not method then return nil; end
+ local ret = method(ip);
+ ip[key] = ret;
+ return ret;
+ end,
+ __tostring = function (ip) return ip.addr; end,
+ __eq = function (ipA, ipB) return ipA.packed == ipB.packed; end
+};
+
+local hex2bits = {
+ ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011",
+ ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111",
+ ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011",
+ ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111",
+};
local function new_ip(ipStr, proto)
- if not proto then
- local sep = ipStr:match("^%x+(.)");
- if sep == ":" or (not(sep) and ipStr:sub(1,1) == ":") then
- proto = "IPv6"
- elseif sep == "." then
- proto = "IPv4"
- end
- if not proto then
- return nil, "invalid address";
- end
- elseif proto ~= "IPv4" and proto ~= "IPv6" then
- return nil, "invalid protocol";
- end
local zone;
- if proto == "IPv6" and ipStr:find('%', 1, true) then
+ if (not proto or proto == "IPv6") and ipStr:find('%', 1, true) then
ipStr, zone = ipStr:match("^(.-)%%(.*)");
end
- if proto == "IPv6" and ipStr:find('.', 1, true) then
- local changed;
- ipStr, changed = ipStr:gsub(":(%d+)%.(%d+)%.(%d+)%.(%d+)$", function(a,b,c,d)
- return (":%04X:%04X"):format(a*256+b,c*256+d);
- end);
- if changed ~= 1 then return nil, "invalid-address"; end
+
+ local packed, err = net.pton(ipStr);
+ if not packed then return packed, err end
+ if proto == "IPv6" and #packed ~= 16 then
+ return nil, "invalid-ipv6";
+ elseif proto == "IPv4" and #packed ~= 4 then
+ return nil, "invalid-ipv4";
+ elseif not proto then
+ if #packed == 16 then
+ proto = "IPv6";
+ elseif #packed == 4 then
+ proto = "IPv4";
+ else
+ return nil, "unknown protocol";
+ end
+ elseif proto ~= "IPv6" and proto ~= "IPv4" then
+ return nil, "invalid protocol";
end
- return setmetatable({ addr = ipStr, proto = proto, zone = zone }, ip_mt);
+ return setmetatable({ addr = ipStr, packed = packed, proto = proto, zone = zone }, ip_mt);
+end
+
+function ip_methods:normal()
+ return net.ntop(self.packed);
end
-local function toBits(ip)
- local result = "";
- local fields = {};
+function ip_methods.bits(ip)
+ return hex.to(ip.packed):upper():gsub(".", hex2bits);
+end
+
+function ip_methods.bits_full(ip)
if ip.proto == "IPv4" then
ip = ip.toV4mapped;
end
- ip = (ip.addr):upper();
- ip:gsub("([^:]*):?", function (c) fields[#fields + 1] = c end);
- if not ip:match(":$") then fields[#fields] = nil; end
- for i, field in ipairs(fields) do
- if field:len() == 0 and i ~= 1 and i ~= #fields then
- for _ = 1, 16 * (9 - #fields) do
- result = result .. "0";
- end
- else
- for _ = 1, 4 - field:len() do
- result = result .. "0000";
- end
- for j = 1, field:len() do
- result = result .. hex2bits[field:sub(j, j)];
- end
- end
- end
- return result;
+ return ip.bits;
end
+local match;
+
local function commonPrefixLength(ipA, ipB)
- ipA, ipB = toBits(ipA), toBits(ipB);
+ ipA, ipB = ipA.bits_full, ipB.bits_full;
for i = 1, 128 do
if ipA:sub(i,i) ~= ipB:sub(i,i) then
return i-1;
@@ -76,56 +83,60 @@ local function commonPrefixLength(ipA, ipB)
return 128;
end
+-- Instantiate once
+local loopback = new_ip("::1");
+local loopback4 = new_ip("127.0.0.0");
+local sixtofour = new_ip("2002::");
+local teredo = new_ip("2001::");
+local linklocal = new_ip("fe80::");
+local linklocal4 = new_ip("169.254.0.0");
+local uniquelocal = new_ip("fc00::");
+local sitelocal = new_ip("fec0::");
+local sixbone = new_ip("3ffe::");
+local defaultunicast = new_ip("::");
+local multicast = new_ip("ff00::");
+local ipv6mapped = new_ip("::ffff:0:0");
+
local function v4scope(ip)
- local fields = {};
- ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end);
- -- Loopback:
- if fields[1] == 127 then
+ if match(ip, loopback4, 8) then
return 0x2;
- -- Link-local unicast:
- elseif fields[1] == 169 and fields[2] == 254 then
+ elseif match(ip, linklocal4) then
return 0x2;
- -- Global unicast:
- else
+ else -- Global unicast
return 0xE;
end
end
local function v6scope(ip)
- -- Loopback:
- if ip:match("^[0:]*1$") then
+ if ip == loopback then
return 0x2;
- -- Link-local unicast:
- elseif ip:match("^[Ff][Ee][89ABab]") then
+ elseif match(ip, linklocal, 10) then
return 0x2;
- -- Site-local unicast:
- elseif ip:match("^[Ff][Ee][CcDdEeFf]") then
+ elseif match(ip, sitelocal, 10) then
return 0x5;
- -- Multicast:
- elseif ip:match("^[Ff][Ff]") then
- return tonumber("0x"..ip:sub(4,4));
- -- Global unicast:
- else
+ elseif match(ip, multicast, 10) then
+ return ip.packed:byte(2) % 0x10;
+ else -- Global unicast
return 0xE;
end
end
local function label(ip)
- if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then
+ if ip == loopback then
return 0;
- elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then
+ elseif match(ip, sixtofour, 16) then
return 2;
- elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then
+ elseif match(ip, teredo, 32) then
return 5;
- elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then
+ elseif match(ip, uniquelocal, 7) then
return 13;
- elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then
+ elseif match(ip, sitelocal, 10) then
return 11;
- elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then
+ elseif match(ip, sixbone, 16) then
return 12;
- elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then
+ elseif match(ip, defaultunicast, 96) then
return 3;
- elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then
+ elseif match(ip, ipv6mapped, 96) then
return 4;
else
return 1;
@@ -133,91 +144,67 @@ local function label(ip)
end
local function precedence(ip)
- if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then
+ if ip == loopback then
return 50;
- elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then
+ elseif match(ip, sixtofour, 16) then
return 30;
- elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then
+ elseif match(ip, teredo, 32) then
return 5;
- elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then
+ elseif match(ip, uniquelocal, 7) then
return 3;
- elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then
+ elseif match(ip, sitelocal, 10) then
return 1;
- elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then
+ elseif match(ip, sixbone, 16) then
return 1;
- elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then
+ elseif match(ip, defaultunicast, 96) then
return 1;
- elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then
+ elseif match(ip, ipv6mapped, 96) then
return 35;
else
return 40;
end
end
-local function toV4mapped(ip)
- local fields = {};
- local ret = "::ffff:";
- ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end);
- ret = ret .. ("%02x"):format(fields[1]);
- ret = ret .. ("%02x"):format(fields[2]);
- ret = ret .. ":"
- ret = ret .. ("%02x"):format(fields[3]);
- ret = ret .. ("%02x"):format(fields[4]);
- return new_ip(ret, "IPv6");
-end
-
function ip_methods:toV4mapped()
if self.proto ~= "IPv4" then return nil, "No IPv4 address" end
- local value = toV4mapped(self.addr);
- self.toV4mapped = value;
+ local value = new_ip("::ffff:" .. self.normal);
return value;
end
function ip_methods:label()
- local value;
if self.proto == "IPv4" then
- value = label(self.toV4mapped);
+ return label(self.toV4mapped);
else
- value = label(self);
+ return label(self);
end
- self.label = value;
- return value;
end
function ip_methods:precedence()
- local value;
if self.proto == "IPv4" then
- value = precedence(self.toV4mapped);
+ return precedence(self.toV4mapped);
else
- value = precedence(self);
+ return precedence(self);
end
- self.precedence = value;
- return value;
end
function ip_methods:scope()
- local value;
if self.proto == "IPv4" then
- value = v4scope(self.addr);
+ return v4scope(self);
else
- value = v6scope(self.addr);
+ return v6scope(self);
end
- self.scope = value;
- return value;
end
+local rfc1918_8 = new_ip("10.0.0.0");
+local rfc1918_12 = new_ip("172.16.0.0");
+local rfc1918_16 = new_ip("192.168.0.0");
+local rfc6598 = new_ip("100.64.0.0");
+
function ip_methods:private()
local private = self.scope ~= 0xE;
if not private and self.proto == "IPv4" then
- local ip = self.addr;
- local fields = {};
- ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end);
- if fields[1] == 127 or fields[1] == 10 or (fields[1] == 192 and fields[2] == 168)
- or (fields[1] == 172 and (fields[2] >= 16 or fields[2] <= 32)) then
- private = true;
- end
+ return match(self, rfc1918_8, 8) or match(self, rfc1918_12, 12) or match(self, rfc1918_16) or match(self, rfc6598, 10);
end
- self.private = private;
return private;
end
@@ -231,15 +218,26 @@ local function parse_cidr(cidr)
return new_ip(cidr), bits;
end
-local function match(ipA, ipB, bits)
- local common_bits = commonPrefixLength(ipA, ipB);
- if bits and ipB.proto == "IPv4" then
- common_bits = common_bits - 96; -- v6 mapped addresses always share these bits
+function match(ipA, ipB, bits)
+ if not bits or bits >= 128 or ipB.proto == "IPv4" and bits >= 32 then
+ return ipA == ipB;
+ elseif bits < 1 then
+ return true;
+ end
+ if ipA.proto ~= ipB.proto then
+ if ipA.proto == "IPv4" then
+ ipA = ipA.toV4mapped;
+ elseif ipB.proto == "IPv4" then
+ ipB = ipB.toV4mapped;
+ bits = bits + (128 - 32);
+ end
end
- return common_bits >= (bits or 128);
+ return ipA.bits:sub(1, bits) == ipB.bits:sub(1, bits);
end
-return {new_ip = new_ip,
+return {
+ new_ip = new_ip,
commonPrefixLength = commonPrefixLength,
parse_cidr = parse_cidr,
- match=match};
+ match = match,
+};
diff --git a/util/iterators.lua b/util/iterators.lua
index bd150ff2..a152e7be 100644
--- a/util/iterators.lua
+++ b/util/iterators.lua
@@ -12,8 +12,8 @@ local it = {};
local t_insert = table.insert;
local select, next = select, next;
-local unpack = table.unpack or unpack; --luacheck: ignore 113
-local pack = table.pack or function (...) return { n = select("#", ...), ... }; end
+local unpack = table.unpack or unpack; --luacheck: ignore 113 143
+local pack = table.pack or function (...) return { n = select("#", ...), ... }; end -- luacheck: ignore 143
-- Reverse an iterator
function it.reverse(f, s, var)
diff --git a/util/json.lua b/util/json.lua
index cba54e8e..c88d4c09 100644
--- a/util/json.lua
+++ b/util/json.lua
@@ -27,9 +27,6 @@ module.null = null;
local escapes = {
["\""] = "\\\"", ["\\"] = "\\\\", ["\b"] = "\\b",
["\f"] = "\\f", ["\n"] = "\\n", ["\r"] = "\\r", ["\t"] = "\\t"};
-local unescapes = {
- ["\""] = "\"", ["\\"] = "\\", ["/"] = "/",
- b = "\b", f = "\f", n = "\n", r = "\r", t = "\t"};
for i=0,31 do
local ch = s_char(i);
if not escapes[ch] then escapes[ch] = ("\\u%.4X"):format(i); end
@@ -249,7 +246,7 @@ local function _readarray(json, index)
end
end
local _unescape_error;
-local function _unescape_surrogate_func(x)
+local function _unescape_surrogate_func(x) -- luacheck: ignore
local lead, trail = tonumber(x:sub(3, 6), 16), tonumber(x:sub(9, 12), 16);
local codepoint = lead * 0x400 + trail - 0x35FDC00;
local a = codepoint % 64;
diff --git a/util/multitable.lua b/util/multitable.lua
index e4321d3d..b790dd7f 100644
--- a/util/multitable.lua
+++ b/util/multitable.lua
@@ -9,7 +9,7 @@
local select = select;
local t_insert = table.insert;
local pairs, next, type = pairs, next, type;
-local unpack = table.unpack or unpack; --luacheck: ignore 113
+local unpack = table.unpack or unpack; --luacheck: ignore 113 143
local _ENV = nil;
@@ -132,7 +132,7 @@ local function iter(self, ...)
local maxdepth = select("#", ...);
local stack = { self.data };
local keys = { };
- local function it(self)
+ local function it(self) -- luacheck: ignore 432/self
local depth = #stack;
local key = next(stack[depth], keys[depth]);
if key == nil then -- Go up the stack
diff --git a/util/openssl.lua b/util/openssl.lua
index 703c6d15..32b5aea7 100644
--- a/util/openssl.lua
+++ b/util/openssl.lua
@@ -114,7 +114,7 @@ function ssl_config:add_xmppAddr(host)
s_format("%s;%s", oid_xmppaddr, utf8string(host)));
end
-function ssl_config:from_prosody(hosts, config, certhosts)
+function ssl_config:from_prosody(hosts, config, certhosts) -- luacheck: ignore 431/config
-- TODO Decide if this should go elsewhere
local found_matching_hosts = false;
for i = 1, #certhosts do
diff --git a/util/pluginloader.lua b/util/pluginloader.lua
index 004855f0..9ab8f245 100644
--- a/util/pluginloader.lua
+++ b/util/pluginloader.lua
@@ -5,6 +5,7 @@
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
+-- luacheck: ignore 113/CFG_PLUGINDIR
local dir_sep, path_sep = package.config:match("^(%S+)%s(%S+)");
local plugin_dir = {};
diff --git a/util/pubsub.lua b/util/pubsub.lua
index 1db917d8..0370bae2 100644
--- a/util/pubsub.lua
+++ b/util/pubsub.lua
@@ -1,32 +1,69 @@
local events = require "util.events";
local cache = require "util.cache";
-local service = {};
-local service_mt = { __index = service };
+local service_mt = {};
-local default_config = { __index = {
- itemstore = function (config) return cache.new(tonumber(config["pubsub#max_items"])) end;
+local default_config = {
+ itemstore = function (config, _) return cache.new(config["max_items"]) end;
broadcaster = function () end;
get_affiliation = function () end;
capabilities = {};
-} };
-local default_node_config = { __index = {
- ["pubsub#max_items"] = "20";
-} };
+};
+local default_config_mt = { __index = default_config };
+
+local default_node_config = {
+ ["persist_items"] = false;
+ ["max_items"] = 20;
+};
+local default_node_config_mt = { __index = default_node_config };
+
+-- Storage helper functions
+
+local function load_node_from_store(nodestore, node_name)
+ local node = nodestore:get(node_name);
+ node.config = setmetatable(node.config or {}, default_node_config_mt);
+ return node;
+end
+
+local function save_node_to_store(nodestore, node)
+ return nodestore:set(node.name, {
+ name = node.name;
+ config = node.config;
+ subscribers = node.subscribers;
+ affiliations = node.affiliations;
+ });
+end
+-- Create and return a new service object
local function new(config)
config = config or {};
- return setmetatable({
- config = setmetatable(config, default_config);
- node_defaults = setmetatable(config.node_defaults or {}, default_node_config);
+
+ local service = setmetatable({
+ config = setmetatable(config, default_config_mt);
+ node_defaults = setmetatable(config.node_defaults or {}, default_node_config_mt);
affiliations = {};
subscriptions = {};
nodes = {};
data = {};
events = events.new();
}, service_mt);
+
+ -- Load nodes from storage, if we have a store and it supports iterating over stored items
+ if config.nodestore and config.nodestore.users then
+ for node_name in config.nodestore:users() do
+ service.nodes[node_name] = load_node_from_store(config.nodestore, node_name);
+ service.data[node_name] = config.itemstore(service.nodes[node_name].config, node_name);
+ end
+ end
+
+ return service;
end
+--- Service methods
+
+local service = {};
+service_mt.__index = service;
+
function service:jids_equal(jid1, jid2)
local normalize = self.config.normalize_jid;
return normalize(jid1) == normalize(jid2);
@@ -176,18 +213,6 @@ function service:remove_subscription(node, actor, jid)
return true;
end
-function service:remove_all_subscriptions(actor, jid)
- local normal_jid = self.config.normalize_jid(jid);
- local subs = self.subscriptions[normal_jid]
- subs = subs and subs[jid];
- if subs then
- for node in pairs(subs) do
- self:remove_subscription(node, true, jid);
- end
- end
- return true;
-end
-
function service:get_subscription(node, actor, jid)
-- Access checking
local cap;
@@ -223,13 +248,24 @@ function service:create(node, actor, options)
config = setmetatable(options or {}, {__index=self.node_defaults});
affiliations = {};
};
- self.data[node] = self.config.itemstore(self.nodes[node].config);
+
+ if self.config.nodestore then
+ local ok, err = save_node_to_store(self.config.nodestore, self.nodes[node]);
+ if not ok then
+ self.nodes[node] = nil;
+ return ok, err;
+ end
+ end
+
+ self.data[node] = self.config.itemstore(self.nodes[node].config, node);
self.events.fire_event("node-created", { node = node, actor = actor });
local ok, err = self:set_affiliation(node, true, actor, "owner");
if not ok then
self.nodes[node] = nil;
self.data[node] = nil;
+ return ok, err;
end
+
return ok, err;
end
@@ -244,6 +280,9 @@ function service:delete(node, actor)
return false, "item-not-found";
end
self.nodes[node] = nil;
+ if self.data[node] and self.data[node].clear then
+ self.data[node]:clear();
+ end
self.data[node] = nil;
self.events.fire_event("node-deleted", { node = node, actor = actor });
self.config.broadcaster("delete", node, node_obj.subscribers);
@@ -272,6 +311,7 @@ function service:publish(node, actor, id, item)
if not ok then
return nil, "internal-server-error";
end
+ if type(ok) == "string" then id = ok; end
self.events.fire_event("item-published", { node = node, actor = actor, id = id, item = item });
self.config.broadcaster("items", node, node_obj.subscribers, item, actor);
return true;
@@ -308,7 +348,11 @@ function service:purge(node, actor, notify)
if not node_obj then
return false, "item-not-found";
end
- self.data[node] = self.config.itemstore(self.nodes[node].config);
+ if self.data[node] and self.data[node].clear then
+ self.data[node]:clear()
+ else
+ self.data[node] = self.config.itemstore(self.nodes[node].config, node);
+ end
self.events.fire_event("node-purged", { node = node, actor = actor });
if notify then
self.config.broadcaster("purge", node, node_obj.subscribers);
@@ -327,7 +371,11 @@ function service:get_items(node, actor, id)
return false, "item-not-found";
end
if id then -- Restrict results to a single specific item
- return true, { id, [id] = self.data[node]:get(id) };
+ local with_id = self.data[node]:get(id);
+ if not with_id then
+ return true, { };
+ end
+ return true, { id, [id] = with_id };
else
local data = {}
for key, value in self.data[node]:items() do
@@ -338,6 +386,15 @@ function service:get_items(node, actor, id)
end
end
+function service:get_last_item(node, actor)
+ -- Access checking
+ if not self:may(node, actor, "get_items") then
+ return false, "forbidden";
+ end
+ --
+ return true, self.data[node]:tail();
+end
+
function service:get_nodes(actor)
-- Access checking
if not self:may(nil, actor, "get_nodes") then
@@ -421,14 +478,14 @@ function service:set_node_config(node, actor, new_config)
return false, "item-not-found";
end
- for k,v in pairs(new_config) do
- node_obj.config[k] = v;
- end
- local new_data = self.config.itemstore(self.nodes[node].config);
- for key, value in self.data[node]:items() do
- new_data:set(key, value);
+ if new_config["persist_items"] ~= node_obj.config["persist_items"] then
+ self.data[node] = self.config.itemstore(self.nodes[node].config, node);
+ elseif new_config["max_items"] ~= node_obj.config["max_items"] then
+ self.data[node]:resize(new_config["max_items"]);
end
- self.data[node] = new_data;
+
+ node_obj.config = setmetatable(new_config, {__index=self.node_defaults});
+
return true;
end
diff --git a/util/random.lua b/util/random.lua
index b2d0000d..d8a84514 100644
--- a/util/random.lua
+++ b/util/random.lua
@@ -11,9 +11,6 @@ if ok then return crand; end
local urandom, urandom_err = io.open("/dev/urandom", "r");
-local function seed()
-end
-
local function bytes(n)
return urandom:read(n);
end
@@ -25,7 +22,6 @@ if not urandom then
end
return {
- seed = seed;
bytes = bytes;
_source = "/dev/urandom";
};
diff --git a/util/sasl.lua b/util/sasl.lua
index 5845f34a..3c5b8be0 100644
--- a/util/sasl.lua
+++ b/util/sasl.lua
@@ -42,7 +42,7 @@ Example:
local method = {};
method.__index = method;
-local mechanisms = {};
+local registered_mechanisms = {};
local backend_mechanism = {};
local mechanism_channelbindings = {};
@@ -52,7 +52,7 @@ local function registerMechanism(name, backends, f, cb_backends)
assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table.");
assert(type(f) == "function", "Parameter f MUST be a function.");
if cb_backends then assert(type(cb_backends) == "table"); end
- mechanisms[name] = f
+ registered_mechanisms[name] = f
if cb_backends then
mechanism_channelbindings[name] = {};
for _, cb_name in ipairs(cb_backends) do
@@ -70,7 +70,7 @@ local function new(realm, profile)
local mechanisms = profile.mechanisms;
if not mechanisms then
mechanisms = {};
- for backend, f in pairs(profile) do
+ for backend in pairs(profile) do
if backend_mechanism[backend] then
for _, mechanism in ipairs(backend_mechanism[backend]) do
mechanisms[mechanism] = true;
@@ -128,7 +128,7 @@ end
-- feed new messages to process into the library
function method:process(message)
--if message == "" or message == nil then return "failure", "malformed-request" end
- return mechanisms[self.selected](self, message);
+ return registered_mechanisms[self.selected](self, message);
end
-- load the mechanisms
diff --git a/util/sasl/anonymous.lua b/util/sasl/anonymous.lua
index 6201db32..acbb35a9 100644
--- a/util/sasl/anonymous.lua
+++ b/util/sasl/anonymous.lua
@@ -28,7 +28,7 @@ anonymous:
end
]]
-local function anonymous(self, message)
+local function anonymous(self, message) -- luacheck: ignore 212/message
local username;
repeat
username = generate_uuid();
diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua
index 4e20dbb9..0163de5d 100644
--- a/util/sasl/scram.lua
+++ b/util/sasl/scram.lua
@@ -46,7 +46,18 @@ Supported Channel Binding Backends
local default_i = 4096
-local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;};
+local xor_map = {
+ 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,1,0,3,2,5,4,7,6,9,8,11,10,
+ 13,12,15,14,2,3,0,1,6,7,4,5,10,11,8,9,14,15,12,13,3,2,1,0,7,6,5,
+ 4,11,10,9,8,15,14,13,12,4,5,6,7,0,1,2,3,12,13,14,15,8,9,10,11,5,
+ 4,7,6,1,0,3,2,13,12,15,14,9,8,11,10,6,7,4,5,2,3,0,1,14,15,12,13,
+ 10,11,8,9,7,6,5,4,3,2,1,0,15,14,13,12,11,10,9,8,8,9,10,11,12,13,
+ 14,15,0,1,2,3,4,5,6,7,9,8,11,10,13,12,15,14,1,0,3,2,5,4,7,6,10,
+ 11,8,9,14,15,12,13,2,3,0,1,6,7,4,5,11,10,9,8,15,14,13,12,3,2,1,
+ 0,7,6,5,4,12,13,14,15,8,9,10,11,4,5,6,7,0,1,2,3,13,12,15,14,9,8,
+ 11,10,5,4,7,6,1,0,3,2,14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1,15,
+ 14,13,12,11,10,9,8,7,6,5,4,3,2,1,0,
+};
local result = {};
local function binaryXOR( a, b )
@@ -237,10 +248,14 @@ end
local function init(registerMechanism)
local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
- registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash));
+ registerMechanism("SCRAM-"..hash_name,
+ {"plain", "scram_"..(hashprep(hash_name))},
+ scram_gen(hash_name:lower(), hash, hmac_hash));
-- register channel binding equivalent
- registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
+ registerMechanism("SCRAM-"..hash_name.."-PLUS",
+ {"plain", "scram_"..(hashprep(hash_name))},
+ scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
end
registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
diff --git a/util/set.lua b/util/set.lua
index c136a522..8630638e 100644
--- a/util/set.lua
+++ b/util/set.lua
@@ -12,7 +12,7 @@ local t_concat = table.concat;
local _ENV = nil;
-local set_mt = {};
+local set_mt = { __name = "set" };
function set_mt.__call(set, _, k)
return next(set._items, k);
end
diff --git a/util/sql.lua b/util/sql.lua
index d964025e..9648101a 100644
--- a/util/sql.lua
+++ b/util/sql.lua
@@ -1,11 +1,10 @@
local setmetatable, getmetatable = setmetatable, getmetatable;
-local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113
-local tonumber, tostring = tonumber, tostring;
+local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113 143
+local tostring = tostring;
local type = type;
local assert, pcall, xpcall, debug_traceback = assert, pcall, xpcall, debug.traceback;
local t_concat = table.concat;
-local s_char = string.char;
local log = require "util.logger".init("sql");
local DBI = require "DBI";
@@ -58,9 +57,6 @@ table_mt.__index = {};
function table_mt.__index:create(engine)
return engine:_create_table(self);
end
-function table_mt:__call(...)
- -- TODO
-end
function column_mt:__tostring()
return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
end
@@ -71,31 +67,6 @@ function index_mt:__tostring()
-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
end
-local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
-local function parse_url(url)
- local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)");
- assert(scheme, "Invalid URL format");
- local username, password, host, port;
- local authpart, hostpart = secondpart:match("([^@]+)@([^@+])");
- if not authpart then hostpart = secondpart; end
- if authpart then
- username, password = authpart:match("([^:]*):(.*)");
- username = username or authpart;
- password = password and urldecode(password);
- end
- if hostpart then
- host, port = hostpart:match("([^:]*):(.*)");
- host = host or hostpart;
- port = port and assert(tonumber(port), "Invalid URL format");
- end
- return {
- scheme = scheme:lower();
- username = username; password = password;
- host = host; port = port;
- database = #database > 0 and database or nil;
- };
-end
-
local engine = {};
function engine:connect()
if self.conn then return true; end
@@ -123,7 +94,7 @@ function engine:connect()
end
return true;
end
-function engine:onconnect()
+function engine:onconnect() -- luacheck: ignore 212/self
-- Override from create_engine()
end
@@ -148,6 +119,7 @@ function engine:execute(sql, ...)
prepared[sql] = stmt;
end
+ -- luacheck: ignore 411/success
local success, err = stmt:execute(...);
if not success then return success, err; end
return stmt;
@@ -161,14 +133,14 @@ local result_mt = { __index = {
local function debugquery(where, sql, ...)
local i = 0; local a = {...}
sql = sql:gsub("\n?\t+", " ");
- log("debug", "[%s] %s", where, sql:gsub("%?", function ()
+ log("debug", "[%s] %s", where, (sql:gsub("%?", function ()
i = i + 1;
local v = a[i];
if type(v) == "string" then
v = ("'%s'"):format(v:gsub("'", "''"));
end
return tostring(v);
- end));
+ end)));
end
function engine:execute_query(sql, ...)
@@ -335,7 +307,12 @@ function engine:set_encoding() -- to UTF-8
local charset = "utf8";
if driver == "MySQL" then
self:transaction(function()
- for row in self:select"SELECT \"CHARACTER_SET_NAME\" FROM \"information_schema\".\"CHARACTER_SETS\" WHERE \"CHARACTER_SET_NAME\" LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;" do
+ for row in self:select[[
+ SELECT "CHARACTER_SET_NAME"
+ FROM "information_schema"."CHARACTER_SETS"
+ WHERE "CHARACTER_SET_NAME" LIKE 'utf8%'
+ ORDER BY MAXLEN DESC LIMIT 1;
+ ]] do
charset = row and row[1] or charset;
end
end);
@@ -379,7 +356,7 @@ local function db2uri(params)
};
end
-local function create_engine(self, params, onconnect)
+local function create_engine(_, params, onconnect)
return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt);
end
diff --git a/util/stanza.lua b/util/stanza.lua
index 2191fa8e..42b6abc3 100644
--- a/util/stanza.lua
+++ b/util/stanza.lua
@@ -38,11 +38,11 @@ local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas";
local _ENV = nil;
-local stanza_mt = { __type = "stanza" };
+local stanza_mt = { __name = "stanza" };
stanza_mt.__index = stanza_mt;
-local function new_stanza(name, attr)
- local stanza = { name = name, attr = attr or {}, tags = {} };
+local function new_stanza(name, attr, namespaces)
+ local stanza = { name = name, attr = attr or {}, namespaces = namespaces, tags = {} };
return setmetatable(stanza, stanza_mt);
end
@@ -58,8 +58,8 @@ function stanza_mt:body(text, attr)
return self:tag("body", attr):text(text);
end
-function stanza_mt:tag(name, attrs)
- local s = new_stanza(name, attrs);
+function stanza_mt:tag(name, attr, namespaces)
+ local s = new_stanza(name, attr, namespaces);
local last_add = self.last_add;
if not last_add then last_add = {}; self.last_add = last_add; end
(last_add[#last_add] or self):add_direct_child(s);
@@ -337,7 +337,12 @@ end
local function clone(stanza)
local attr, tags = {}, {};
for k,v in pairs(stanza.attr) do attr[k] = v; end
- local new = { name = stanza.name, attr = attr, tags = tags };
+ local old_namespaces, namespaces = stanza.namespaces;
+ if old_namespaces then
+ namespaces = {};
+ for k,v in pairs(old_namespaces) do namespaces[k] = v; end
+ end
+ local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
for i=1,#stanza do
local child = stanza[i];
if child.name then
@@ -362,7 +367,13 @@ local function iq(attr)
end
local function reply(orig)
- return new_stanza(orig.name, orig.attr and { to = orig.attr.from, from = orig.attr.to, id = orig.attr.id, type = ((orig.name == "iq" and "result") or orig.attr.type) });
+ return new_stanza(orig.name,
+ orig.attr and {
+ to = orig.attr.from,
+ from = orig.attr.to,
+ id = orig.attr.id,
+ type = ((orig.name == "iq" and "result") or orig.attr.type)
+ });
end
local xmpp_stanzas_attr = { xmlns = xmlns_stanzas };
diff --git a/util/template.lua b/util/template.lua
index 04ebb93d..bc56020c 100644
--- a/util/template.lua
+++ b/util/template.lua
@@ -4,7 +4,7 @@ local setmetatable = setmetatable;
local pairs = pairs;
local ipairs = ipairs;
local error = error;
-local loadstring = loadstring;
+local envload = require "util.envload".envload;
local debug = debug;
local t_remove = table.remove;
local parse_xml = require "util.xml".parse;
@@ -72,7 +72,7 @@ local function create_cloner(stanza, chunkname)
src = src.."local _"..i.."="..lookup[i]..";";
end
src = src.."return "..name..";end";
- local f,err = loadstring(src, chunkname);
+ local f,err = envload(src, chunkname);
if not f then error(err); end
return f(setmetatable, stanza_mt);
end
diff --git a/util/timer.lua b/util/timer.lua
index 7e2e9414..c7996bfa 100644
--- a/util/timer.lua
+++ b/util/timer.lua
@@ -6,78 +6,104 @@
-- COPYING file in the source package for more information.
--
+local indexedbheap = require "util.indexedbheap";
+local log = require "util.logger".init("timer");
local server = require "net.server";
-local math_min = math.min
-local math_huge = math.huge
local get_time = require "util.time".now
-local t_insert = table.insert;
-local pairs = pairs;
local type = type;
-
-local data = {};
-local new_data = {};
+local debug_traceback = debug.traceback;
+local tostring = tostring;
+local xpcall = xpcall;
local _ENV = nil;
-local _add_task;
-if not server.event then
- function _add_task(delay, callback)
- local current_time = get_time();
- delay = delay + current_time;
- if delay >= current_time then
- t_insert(new_data, {delay, callback});
- else
- local r = callback(current_time);
- if r and type(r) == "number" then
- return _add_task(r, callback);
- end
+local _add_task = server.add_task;
+
+local _server_timer;
+local _active_timers = 0;
+local h = indexedbheap.create();
+local params = {};
+local next_time = nil;
+local _id, _callback, _now, _param;
+local function _call() return _callback(_now, _id, _param); end
+local function _traceback_handler(err) log("error", "Traceback[timer]: %s", debug_traceback(tostring(err), 2)); end
+local function _on_timer(now)
+ local peek;
+ while true do
+ peek = h:peek();
+ if peek == nil or peek > now then break; end
+ local _;
+ _, _callback, _id = h:pop();
+ _now = now;
+ _param = params[_id];
+ params[_id] = nil;
+ --item(now, id, _param); -- FIXME pcall
+ local success, err = xpcall(_call, _traceback_handler);
+ if success and type(err) == "number" then
+ h:insert(_callback, err + now, _id); -- re-add
+ params[_id] = _param;
end
end
- server._addtimer(function()
- local current_time = get_time();
- if #new_data > 0 then
- for _, d in pairs(new_data) do
- t_insert(data, d);
- end
- new_data = {};
- end
+ if peek ~= nil and _active_timers > 1 and peek == next_time then
+ -- Another instance of _on_timer already set next_time to the same value,
+ -- so it should be safe to not renew this timer event
+ peek = nil;
+ else
+ next_time = peek;
+ end
- local next_time = math_huge;
- for i, d in pairs(data) do
- local t, callback = d[1], d[2];
- if t <= current_time then
- data[i] = nil;
- local r = callback(current_time);
- if type(r) == "number" then
- _add_task(r, callback);
- next_time = math_min(next_time, r);
- end
- else
- next_time = math_min(next_time, t - current_time);
- end
- end
- return next_time;
- end);
-else
- local event = server.event;
- local event_base = server.event_base;
- local EVENT_LEAVE = (event.core and event.core.LEAVE) or -1;
+ if peek then
+ -- peek is the time of the next event
+ return peek - now;
+ end
+ _active_timers = _active_timers - 1;
+end
+local function add_task(delay, callback, param)
+ local current_time = get_time();
+ local event_time = current_time + delay;
- function _add_task(delay, callback)
- local event_handle;
- event_handle = event_base:addevent(nil, 0, function ()
- local ret = callback(get_time());
- if ret then
- return 0, ret;
- elseif event_handle then
- return EVENT_LEAVE;
- end
+ local id = h:insert(callback, event_time);
+ params[id] = param;
+ if next_time == nil or event_time < next_time then
+ next_time = event_time;
+ if _server_timer then
+ _server_timer:close();
+ _server_timer = nil;
+ else
+ _active_timers = _active_timers + 1;
+ end
+ _server_timer = _add_task(next_time - current_time, _on_timer);
+ end
+ return id;
+end
+local function stop(id)
+ params[id] = nil;
+ local result, item, result_sync = h:remove(id);
+ local peek = h:peek();
+ if peek ~= next_time and _server_timer then
+ next_time = peek;
+ _server_timer:close();
+ if next_time ~= nil then
+ _server_timer = _add_task(next_time - get_time(), _on_timer);
end
- , delay);
end
+ return result, item, result_sync;
+end
+local function reschedule(id, delay)
+ local current_time = get_time();
+ local event_time = current_time + delay;
+ h:reprioritize(id, delay);
+ if next_time == nil or event_time < next_time then
+ next_time = event_time;
+ _add_task(next_time - current_time, _on_timer);
+ end
+ return id;
end
return {
- add_task = _add_task;
+ add_task = add_task;
+ stop = stop;
+ reschedule = reschedule;
};
+
diff --git a/util/vcard.lua b/util/vcard.lua
new file mode 100644
index 00000000..51758c41
--- /dev/null
+++ b/util/vcard.lua
@@ -0,0 +1,572 @@
+-- Copyright (C) 2011-2014 Kim Alvefur
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+-- TODO
+-- Fix folding.
+
+local st = require "util.stanza";
+local t_insert, t_concat = table.insert, table.concat;
+local type = type;
+local pairs, ipairs = pairs, ipairs;
+
+local from_text, to_text, from_xep54, to_xep54;
+
+local line_sep = "\n";
+
+local vCard_dtd; -- See end of file
+local vCard4_dtd;
+
+local function vCard_esc(s)
+ return s:gsub("[,:;\\]", "\\%1"):gsub("\n","\\n");
+end
+
+local function vCard_unesc(s)
+ return s:gsub("\\?[\\nt:;,]", {
+ ["\\\\"] = "\\",
+ ["\\n"] = "\n",
+ ["\\r"] = "\r",
+ ["\\t"] = "\t",
+ ["\\:"] = ":", -- FIXME Shouldn't need to espace : in values, just params
+ ["\\;"] = ";",
+ ["\\,"] = ",",
+ [":"] = "\29",
+ [";"] = "\30",
+ [","] = "\31",
+ });
+end
+
+local function item_to_xep54(item)
+ local t = st.stanza(item.name, { xmlns = "vcard-temp" });
+
+ local prop_def = vCard_dtd[item.name];
+ if prop_def == "text" then
+ t:text(item[1]);
+ elseif type(prop_def) == "table" then
+ if prop_def.types and item.TYPE then
+ if type(item.TYPE) == "table" then
+ for _,v in pairs(prop_def.types) do
+ for _,typ in pairs(item.TYPE) do
+ if typ:upper() == v then
+ t:tag(v):up();
+ break;
+ end
+ end
+ end
+ else
+ t:tag(item.TYPE:upper()):up();
+ end
+ end
+
+ if prop_def.props then
+ for _,v in pairs(prop_def.props) do
+ if item[v] then
+ t:tag(v):up();
+ end
+ end
+ end
+
+ if prop_def.value then
+ t:tag(prop_def.value):text(item[1]):up();
+ elseif prop_def.values then
+ local prop_def_values = prop_def.values;
+ local repeat_last = prop_def_values.behaviour == "repeat-last" and prop_def_values[#prop_def_values];
+ for i=1,#item do
+ t:tag(prop_def.values[i] or repeat_last):text(item[i]):up();
+ end
+ end
+ end
+
+ return t;
+end
+
+local function vcard_to_xep54(vCard)
+ local t = st.stanza("vCard", { xmlns = "vcard-temp" });
+ for i=1,#vCard do
+ t:add_child(item_to_xep54(vCard[i]));
+ end
+ return t;
+end
+
+function to_xep54(vCards)
+ if not vCards[1] or vCards[1].name then
+ return vcard_to_xep54(vCards)
+ else
+ local t = st.stanza("xCard", { xmlns = "vcard-temp" });
+ for i=1,#vCards do
+ t:add_child(vcard_to_xep54(vCards[i]));
+ end
+ return t;
+ end
+end
+
+function from_text(data)
+ data = data -- unfold and remove empty lines
+ :gsub("\r\n","\n")
+ :gsub("\n ", "")
+ :gsub("\n\n+","\n");
+ local vCards = {};
+ local current;
+ for line in data:gmatch("[^\n]+") do
+ line = vCard_unesc(line);
+ local name, params, value = line:match("^([-%a]+)(\30?[^\29]*)\29(.*)$");
+ value = value:gsub("\29",":");
+ if #params > 0 then
+ local _params = {};
+ for k,isval,v in params:gmatch("\30([^=]+)(=?)([^\30]*)") do
+ k = k:upper();
+ local _vt = {};
+ for _p in v:gmatch("[^\31]+") do
+ _vt[#_vt+1]=_p
+ _vt[_p]=true;
+ end
+ if isval == "=" then
+ _params[k]=_vt;
+ else
+ _params[k]=true;
+ end
+ end
+ params = _params;
+ end
+ if name == "BEGIN" and value == "VCARD" then
+ current = {};
+ vCards[#vCards+1] = current;
+ elseif name == "END" and value == "VCARD" then
+ current = nil;
+ elseif current and vCard_dtd[name] then
+ local dtd = vCard_dtd[name];
+ local item = { name = name };
+ t_insert(current, item);
+ local up = current;
+ current = item;
+ if dtd.types then
+ for _, t in ipairs(dtd.types) do
+ t = t:lower();
+ if ( params.TYPE and params.TYPE[t] == true)
+ or params[t] == true then
+ current.TYPE=t;
+ end
+ end
+ end
+ if dtd.props then
+ for _, p in ipairs(dtd.props) do
+ if params[p] then
+ if params[p] == true then
+ current[p]=true;
+ else
+ for _, prop in ipairs(params[p]) do
+ current[p]=prop;
+ end
+ end
+ end
+ end
+ end
+ if dtd == "text" or dtd.value then
+ t_insert(current, value);
+ elseif dtd.values then
+ for p in ("\30"..value):gmatch("\30([^\30]*)") do
+ t_insert(current, p);
+ end
+ end
+ current = up;
+ end
+ end
+ return vCards;
+end
+
+local function item_to_text(item)
+ local value = {};
+ for i=1,#item do
+ value[i] = vCard_esc(item[i]);
+ end
+ value = t_concat(value, ";");
+
+ local params = "";
+ for k,v in pairs(item) do
+ if type(k) == "string" and k ~= "name" then
+ params = params .. (";%s=%s"):format(k, type(v) == "table" and t_concat(v,",") or v);
+ end
+ end
+
+ return ("%s%s:%s"):format(item.name, params, value)
+end
+
+local function vcard_to_text(vcard)
+ local t={};
+ t_insert(t, "BEGIN:VCARD")
+ for i=1,#vcard do
+ t_insert(t, item_to_text(vcard[i]));
+ end
+ t_insert(t, "END:VCARD")
+ return t_concat(t, line_sep);
+end
+
+function to_text(vCards)
+ if vCards[1] and vCards[1].name then
+ return vcard_to_text(vCards)
+ else
+ local t = {};
+ for i=1,#vCards do
+ t[i]=vcard_to_text(vCards[i]);
+ end
+ return t_concat(t, line_sep);
+ end
+end
+
+local function from_xep54_item(item)
+ local prop_name = item.name;
+ local prop_def = vCard_dtd[prop_name];
+
+ local prop = { name = prop_name };
+
+ if prop_def == "text" then
+ prop[1] = item:get_text();
+ elseif type(prop_def) == "table" then
+ if prop_def.value then --single item
+ prop[1] = item:get_child_text(prop_def.value) or "";
+ elseif prop_def.values then --array
+ local value_names = prop_def.values;
+ if value_names.behaviour == "repeat-last" then
+ for i=1,#item.tags do
+ t_insert(prop, item.tags[i]:get_text() or "");
+ end
+ else
+ for i=1,#value_names do
+ t_insert(prop, item:get_child_text(value_names[i]) or "");
+ end
+ end
+ elseif prop_def.names then
+ local names = prop_def.names;
+ for i=1,#names do
+ if item:get_child(names[i]) then
+ prop[1] = names[i];
+ break;
+ end
+ end
+ end
+
+ if prop_def.props_verbatim then
+ for k,v in pairs(prop_def.props_verbatim) do
+ prop[k] = v;
+ end
+ end
+
+ if prop_def.types then
+ local types = prop_def.types;
+ prop.TYPE = {};
+ for i=1,#types do
+ if item:get_child(types[i]) then
+ t_insert(prop.TYPE, types[i]:lower());
+ end
+ end
+ if #prop.TYPE == 0 then
+ prop.TYPE = nil;
+ end
+ end
+
+ -- A key-value pair, within a key-value pair?
+ if prop_def.props then
+ local params = prop_def.props;
+ for i=1,#params do
+ local name = params[i]
+ local data = item:get_child_text(name);
+ if data then
+ prop[name] = prop[name] or {};
+ t_insert(prop[name], data);
+ end
+ end
+ end
+ else
+ return nil
+ end
+
+ return prop;
+end
+
+local function from_xep54_vCard(vCard)
+ local tags = vCard.tags;
+ local t = {};
+ for i=1,#tags do
+ t_insert(t, from_xep54_item(tags[i]));
+ end
+ return t
+end
+
+function from_xep54(vCard)
+ if vCard.attr.xmlns ~= "vcard-temp" then
+ return nil, "wrong-xmlns";
+ end
+ if vCard.name == "xCard" then -- A collection of vCards
+ local t = {};
+ local vCards = vCard.tags;
+ for i=1,#vCards do
+ t[i] = from_xep54_vCard(vCards[i]);
+ end
+ return t
+ elseif vCard.name == "vCard" then -- A single vCard
+ return from_xep54_vCard(vCard)
+ end
+end
+
+local vcard4 = { }
+
+function vcard4:text(node, params, value) -- luacheck: ignore 212/params
+ self:tag(node:lower())
+ -- FIXME params
+ if type(value) == "string" then
+ self:tag("text"):text(value):up()
+ elseif vcard4[node] then
+ vcard4[node](value);
+ end
+ self:up();
+end
+
+function vcard4.N(value)
+ for i, k in ipairs(vCard_dtd.N.values) do
+ value:tag(k):text(value[i]):up();
+ end
+end
+
+local xmlns_vcard4 = "urn:ietf:params:xml:ns:vcard-4.0"
+
+local function item_to_vcard4(item)
+ local typ = item.name:lower();
+ local t = st.stanza(typ, { xmlns = xmlns_vcard4 });
+
+ local prop_def = vCard4_dtd[typ];
+ if prop_def == "text" then
+ t:tag("text"):text(item[1]):up();
+ elseif prop_def == "uri" then
+ if item.ENCODING and item.ENCODING[1] == 'b' then
+ t:tag("uri"):text("data:;base64,"):text(item[1]):up();
+ else
+ t:tag("uri"):text(item[1]):up();
+ end
+ elseif type(prop_def) == "table" then
+ if prop_def.values then
+ for i, v in ipairs(prop_def.values) do
+ t:tag(v:lower()):text(item[i] or ""):up();
+ end
+ else
+ t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"})
+ end
+ else
+ t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"})
+ end
+ return t;
+end
+
+local function vcard_to_vcard4xml(vCard)
+ local t = st.stanza("vcard", { xmlns = xmlns_vcard4 });
+ for i=1,#vCard do
+ t:add_child(item_to_vcard4(vCard[i]));
+ end
+ return t;
+end
+
+local function vcards_to_vcard4xml(vCards)
+ if not vCards[1] or vCards[1].name then
+ return vcard_to_vcard4xml(vCards)
+ else
+ local t = st.stanza("vcards", { xmlns = xmlns_vcard4 });
+ for i=1,#vCards do
+ t:add_child(vcard_to_vcard4xml(vCards[i]));
+ end
+ return t;
+ end
+end
+
+-- This was adapted from http://xmpp.org/extensions/xep-0054.html#dtd
+vCard_dtd = {
+ VERSION = "text", --MUST be 3.0, so parsing is redundant
+ FN = "text",
+ N = {
+ values = {
+ "FAMILY",
+ "GIVEN",
+ "MIDDLE",
+ "PREFIX",
+ "SUFFIX",
+ },
+ },
+ NICKNAME = "text",
+ PHOTO = {
+ props_verbatim = { ENCODING = { "b" } },
+ props = { "TYPE" },
+ value = "BINVAL", --{ "EXTVAL", },
+ },
+ BDAY = "text",
+ ADR = {
+ types = {
+ "HOME",
+ "WORK",
+ "POSTAL",
+ "PARCEL",
+ "DOM",
+ "INTL",
+ "PREF",
+ },
+ values = {
+ "POBOX",
+ "EXTADD",
+ "STREET",
+ "LOCALITY",
+ "REGION",
+ "PCODE",
+ "CTRY",
+ }
+ },
+ LABEL = {
+ types = {
+ "HOME",
+ "WORK",
+ "POSTAL",
+ "PARCEL",
+ "DOM",
+ "INTL",
+ "PREF",
+ },
+ value = "LINE",
+ },
+ TEL = {
+ types = {
+ "HOME",
+ "WORK",
+ "VOICE",
+ "FAX",
+ "PAGER",
+ "MSG",
+ "CELL",
+ "VIDEO",
+ "BBS",
+ "MODEM",
+ "ISDN",
+ "PCS",
+ "PREF",
+ },
+ value = "NUMBER",
+ },
+ EMAIL = {
+ types = {
+ "HOME",
+ "WORK",
+ "INTERNET",
+ "PREF",
+ "X400",
+ },
+ value = "USERID",
+ },
+ JABBERID = "text",
+ MAILER = "text",
+ TZ = "text",
+ GEO = {
+ values = {
+ "LAT",
+ "LON",
+ },
+ },
+ TITLE = "text",
+ ROLE = "text",
+ LOGO = "copy of PHOTO",
+ AGENT = "text",
+ ORG = {
+ values = {
+ behaviour = "repeat-last",
+ "ORGNAME",
+ "ORGUNIT",
+ }
+ },
+ CATEGORIES = {
+ values = "KEYWORD",
+ },
+ NOTE = "text",
+ PRODID = "text",
+ REV = "text",
+ SORTSTRING = "text",
+ SOUND = "copy of PHOTO",
+ UID = "text",
+ URL = "text",
+ CLASS = {
+ names = { -- The item.name is the value if it's one of these.
+ "PUBLIC",
+ "PRIVATE",
+ "CONFIDENTIAL",
+ },
+ },
+ KEY = {
+ props = { "TYPE" },
+ value = "CRED",
+ },
+ DESC = "text",
+};
+vCard_dtd.LOGO = vCard_dtd.PHOTO;
+vCard_dtd.SOUND = vCard_dtd.PHOTO;
+
+vCard4_dtd = {
+ source = "uri",
+ kind = "text",
+ xml = "text",
+ fn = "text",
+ n = {
+ values = {
+ "family",
+ "given",
+ "middle",
+ "prefix",
+ "suffix",
+ },
+ },
+ nickname = "text",
+ photo = "uri",
+ bday = "date-and-or-time",
+ anniversary = "date-and-or-time",
+ gender = "text",
+ adr = {
+ values = {
+ "pobox",
+ "ext",
+ "street",
+ "locality",
+ "region",
+ "code",
+ "country",
+ }
+ },
+ tel = "text",
+ email = "text",
+ impp = "uri",
+ lang = "language-tag",
+ tz = "text",
+ geo = "uri",
+ title = "text",
+ role = "text",
+ logo = "uri",
+ org = "text",
+ member = "uri",
+ related = "uri",
+ categories = "text",
+ note = "text",
+ prodid = "text",
+ rev = "timestamp",
+ sound = "uri",
+ uid = "uri",
+ clientpidmap = "number, uuid",
+ url = "uri",
+ version = "text",
+ key = "uri",
+ fburl = "uri",
+ caladruri = "uri",
+ caluri = "uri",
+};
+
+return {
+ from_text = from_text;
+ to_text = to_text;
+
+ from_xep54 = from_xep54;
+ to_xep54 = to_xep54;
+
+ to_vcard4 = vcards_to_vcard4xml;
+};
diff --git a/util/xml.lua b/util/xml.lua
index 733d821a..ec06fb01 100644
--- a/util/xml.lua
+++ b/util/xml.lua
@@ -1,6 +1,8 @@
local st = require "util.stanza";
local lxp = require "lxp";
+local t_insert = table.insert;
+local t_remove = table.remove;
local _ENV = nil;
@@ -14,6 +16,21 @@ local parse_xml = (function()
--luacheck: ignore 212/self
local handler = {};
local stanza = st.stanza("root");
+ local namespaces = {};
+ local prefixes = {};
+ function handler:StartNamespaceDecl(prefix, url)
+ if prefix ~= nil then
+ t_insert(namespaces, url);
+ t_insert(prefixes, prefix);
+ end
+ end
+ function handler:EndNamespaceDecl(prefix)
+ if prefix ~= nil then
+ -- we depend on each StartNamespaceDecl having a paired EndNamespaceDecl
+ t_remove(namespaces);
+ t_remove(prefixes);
+ end
+ end
function handler:StartElement(tagname, attr)
local curr_ns,name = tagname:match(ns_pattern);
if name == "" then
@@ -34,7 +51,11 @@ local parse_xml = (function()
end
end
end
- stanza:tag(name, attr);
+ local n = {}
+ for i=1,#namespaces do
+ n[prefixes[i]] = namespaces[i];
+ end
+ stanza:tag(name, attr, n);
end
function handler:CharacterData(data)
stanza:text(data);
diff --git a/util/xmppstream.lua b/util/xmppstream.lua
index 7be63285..53cb98ae 100644
--- a/util/xmppstream.lua
+++ b/util/xmppstream.lua
@@ -47,7 +47,10 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
local cb_streamopened = stream_callbacks.streamopened;
local cb_streamclosed = stream_callbacks.streamclosed;
- local cb_error = stream_callbacks.error or function(session, e, stanza) error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); end;
+ local cb_error = stream_callbacks.error or
+ function(_, e, stanza)
+ error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2);
+ end;
local cb_handlestanza = stream_callbacks.handlestanza;
cb_handleprogress = cb_handleprogress or dummy_cb;
@@ -128,6 +131,9 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
end
if lxp_supports_xmldecl then
function xml_handlers:XmlDecl(version, encoding, standalone)
+ session.xml_version = version;
+ session.xml_encoding = encoding;
+ session.xml_standalone = standalone;
if lxp_supports_bytecount then
cb_handleprogress(self:getcurrentbytecount());
end
@@ -214,7 +220,7 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
stack = {};
end
- local function set_session(stream, new_session)
+ local function set_session(stream, new_session) -- luacheck: ignore 212/stream
session = new_session;
end
@@ -238,7 +244,7 @@ local function new(session, stream_callbacks, stanza_size_limit)
local parser = new_parser(handlers, ns_separator, false);
local parse = parser.parse;
- function session.open_stream(session, from, to)
+ function session.open_stream(session, from, to) -- luacheck: ignore 432/session
local send = session.sends2s or session.send;
local attr = {
@@ -264,7 +270,7 @@ local function new(session, stream_callbacks, stanza_size_limit)
n_outstanding_bytes = 0;
meta.reset();
end,
- feed = function (self, data)
+ feed = function (self, data) -- luacheck: ignore 212/self
if lxp_supports_bytecount then
n_outstanding_bytes = n_outstanding_bytes + #data;
end