aboutsummaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/adhoc.lua2
-rw-r--r--util/array.lua2
-rw-r--r--util/async.lua233
-rw-r--r--util/cache.lua26
-rw-r--r--util/caps.lua1
-rw-r--r--util/dataforms.lua1
-rw-r--r--util/datamanager.lua10
-rw-r--r--util/datetime.lua1
-rw-r--r--util/debug.lua21
-rw-r--r--util/dependencies.lua10
-rw-r--r--util/envload.lua2
-rw-r--r--util/events.lua14
-rw-r--r--util/filters.lua1
-rw-r--r--util/format.lua20
-rw-r--r--util/import.lua4
-rw-r--r--util/indexedbheap.lua157
-rw-r--r--util/ip.lua250
-rw-r--r--util/iterators.lua4
-rw-r--r--util/jid.lua1
-rw-r--r--util/json.lua10
-rw-r--r--util/logger.lua14
-rw-r--r--util/multitable.lua5
-rw-r--r--util/openssl.lua2
-rw-r--r--util/pluginloader.lua1
-rw-r--r--util/prosodyctl.lua13
-rw-r--r--util/pubsub.lua127
-rw-r--r--util/random.lua4
-rw-r--r--util/sasl.lua9
-rw-r--r--util/sasl/anonymous.lua3
-rw-r--r--util/sasl/digest-md5.lua1
-rw-r--r--util/sasl/external.lua1
-rw-r--r--util/sasl/plain.lua1
-rw-r--r--util/sasl/scram.lua22
-rw-r--r--util/sasl_cyrus.lua1
-rw-r--r--util/serialization.lua1
-rw-r--r--util/set.lua3
-rw-r--r--util/sql.lua50
-rw-r--r--util/sslconfig.lua1
-rw-r--r--util/stanza.lua79
-rw-r--r--util/startup.lua543
-rw-r--r--util/template.lua5
-rw-r--r--util/termcolours.lua1
-rw-r--r--util/throttle.lua1
-rw-r--r--util/timer.lua151
-rw-r--r--util/vcard.lua572
-rw-r--r--util/watchdog.lua1
-rw-r--r--util/x509.lua1
-rw-r--r--util/xml.lua24
-rw-r--r--util/xmppstream.lua15
49 files changed, 2084 insertions, 338 deletions
diff --git a/util/adhoc.lua b/util/adhoc.lua
index 17c9eee5..d81b8242 100644
--- a/util/adhoc.lua
+++ b/util/adhoc.lua
@@ -1,3 +1,5 @@
+-- luacheck: ignore 212/self
+
local function new_simple_form(form, result_handler)
return function(self, data, state)
if state then
diff --git a/util/array.lua b/util/array.lua
index 150b4355..1a8ffec7 100644
--- a/util/array.lua
+++ b/util/array.lua
@@ -19,7 +19,7 @@ local type = type;
local array = {};
local array_base = {};
local array_methods = {};
-local array_mt = { __index = array_methods, __tostring = function (self) return "{"..self:concat(", ").."}"; end };
+local array_mt = { __index = array_methods, __name = "array", __tostring = function (self) return "{"..self:concat(", ").."}"; end };
local function new_array(self, t, _s, _var)
if type(t) == "function" then -- Assume iterator
diff --git a/util/async.lua b/util/async.lua
new file mode 100644
index 00000000..012cfd87
--- /dev/null
+++ b/util/async.lua
@@ -0,0 +1,233 @@
+local log = require "util.logger".init("util.async");
+local new_id = require "util.id".short;
+
+local function checkthread()
+ local thread, main = coroutine.running();
+ if not thread or main then
+ error("Not running in an async context, see https://prosody.im/doc/developers/util/async");
+ end
+ return thread;
+end
+
+local function runner_from_thread(thread)
+ local level = 0;
+ -- Find the 'level' of the top-most function (0 == current level, 1 == caller, ...)
+ while debug.getinfo(thread, level, "") do level = level + 1; end
+ local name, runner = debug.getlocal(thread, level-1, 1);
+ if name ~= "self" or type(runner) ~= "table" or runner.thread ~= thread then
+ return nil;
+ end
+ return runner;
+end
+
+local function call_watcher(runner, watcher_name, ...)
+ local watcher = runner.watchers[watcher_name];
+ if not watcher then
+ return false;
+ end
+ runner:log("debug", "Calling '%s' watcher", watcher_name);
+ local ok, err = pcall(watcher, runner, ...); -- COMPAT: Switch to xpcall after Lua 5.1
+ if not ok then
+ runner:log("error", "Error in '%s' watcher: %s", watcher_name, err);
+ return nil, err;
+ end
+ return true;
+end
+
+local function runner_continue(thread)
+ -- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure)
+ if coroutine.status(thread) ~= "suspended" then -- This should suffice
+ log("error", "unexpected async state: thread not suspended");
+ return false;
+ end
+ local ok, state, runner = coroutine.resume(thread);
+ if not ok then
+ local err = state;
+ -- Running the coroutine failed, which means we have to find the runner manually,
+ -- in order to inform the error handler
+ runner = runner_from_thread(thread);
+ if not runner then
+ log("error", "unexpected async state: unable to locate runner during error handling");
+ return false;
+ end
+ call_watcher(runner, "error", debug.traceback(thread, err));
+ runner.state, runner.thread = "ready", nil;
+ return runner:run();
+ elseif state == "ready" then
+ -- If state is 'ready', it is our responsibility to update runner.state from 'waiting'.
+ -- We also have to :run(), because the queue might have further items that will not be
+ -- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer).
+ runner.state = "ready";
+ runner:run();
+ end
+ return true;
+end
+
+local function waiter(num)
+ local thread = checkthread();
+ num = num or 1;
+ local waiting;
+ return function ()
+ if num == 0 then return; end -- already done
+ waiting = true;
+ coroutine.yield("wait");
+ end, function ()
+ num = num - 1;
+ if num == 0 and waiting then
+ runner_continue(thread);
+ elseif num < 0 then
+ error("done() called too many times");
+ end
+ end;
+end
+
+local function guarder()
+ local guards = {};
+ local default_id = {};
+ return function (id, func)
+ id = id or default_id;
+ local thread = checkthread();
+ local guard = guards[id];
+ if not guard then
+ guard = {};
+ guards[id] = guard;
+ log("debug", "New guard!");
+ else
+ table.insert(guard, thread);
+ log("debug", "Guarded. %d threads waiting.", #guard)
+ coroutine.yield("wait");
+ end
+ local function exit()
+ local next_waiting = table.remove(guard, 1);
+ if next_waiting then
+ log("debug", "guard: Executing next waiting thread (%d left)", #guard)
+ runner_continue(next_waiting);
+ else
+ log("debug", "Guard off duty.")
+ guards[id] = nil;
+ end
+ end
+ if func then
+ func();
+ exit();
+ return;
+ end
+ return exit;
+ end;
+end
+
+local runner_mt = {};
+runner_mt.__index = runner_mt;
+
+local function runner_create_thread(func, self)
+ local thread = coroutine.create(function (self) -- luacheck: ignore 432/self
+ while true do
+ func(coroutine.yield("ready", self));
+ end
+ end);
+ assert(coroutine.resume(thread, self)); -- Start it up, it will return instantly to wait for the first input
+ return thread;
+end
+
+local function default_error_watcher(runner, err)
+ runner:log("error", "Encountered error: %s", err);
+ error(err);
+end
+local function default_func(f) f(); end
+local function runner(func, watchers, data)
+ return setmetatable({ func = func or default_func, thread = false, state = "ready", notified_state = "ready",
+ queue = {}, watchers = watchers or { error = default_error_watcher }, data = data, id = new_id() }
+ , runner_mt);
+end
+
+-- Add a task item for the runner to process
+function runner_mt:run(input)
+ if input ~= nil then
+ table.insert(self.queue, input);
+ self:log("debug", "queued new work item, %d items queued", #self.queue);
+ end
+ if self.state ~= "ready" then
+ -- The runner is busy. Indicate that the task item has been
+ -- queued, and return information about the current runner state
+ return true, self.state, #self.queue;
+ end
+
+ local q, thread = self.queue, self.thread;
+ if not thread or coroutine.status(thread) == "dead" then
+ self:log("debug", "creating new coroutine");
+ -- Create a new coroutine for this runner
+ thread = runner_create_thread(self.func, self);
+ self.thread = thread;
+ end
+
+ -- Process task item(s) while the queue is not empty, and we're not blocked
+ local n, state, err = #q, self.state, nil;
+ self.state = "running";
+ self:log("debug", "running main loop");
+ while n > 0 and state == "ready" and not err do
+ local consumed;
+ -- Loop through queue items, and attempt to run them
+ for i = 1,n do
+ local queued_input = q[i];
+ local ok, new_state = coroutine.resume(thread, queued_input);
+ if not ok then
+ -- There was an error running the coroutine, save the error, mark runner as ready to begin again
+ consumed, state, err = i, "ready", debug.traceback(thread, new_state);
+ self.thread = nil;
+ break;
+ elseif new_state == "wait" then
+ -- Runner is blocked on waiting for a task item to complete
+ consumed, state = i, "waiting";
+ break;
+ end
+ end
+ -- Loop ended - either queue empty because all tasks passed without blocking (consumed == nil)
+ -- or runner is blocked/errored, and consumed will contain the number of tasks processed so far
+ if not consumed then consumed = n; end
+ -- Remove consumed items from the queue array
+ if q[n+1] ~= nil then
+ n = #q;
+ end
+ for i = 1, n do
+ q[i] = q[consumed+i];
+ end
+ n = #q;
+ end
+ -- Runner processed all items it can, so save current runner state
+ self.state = state;
+ if err or state ~= self.notified_state then
+ self:log("debug", "changed state from %s to %s", self.notified_state, err and ("error ("..state..")") or state);
+ if err then
+ state = "error"
+ else
+ self.notified_state = state;
+ end
+ local handler = self.watchers[state];
+ if handler then handler(self, err); end
+ end
+ if n > 0 then
+ return self:run();
+ end
+ return true, state, n;
+end
+
+-- Add a task item to the queue without invoking the runner, even if it is idle
+function runner_mt:enqueue(input)
+ table.insert(self.queue, input);
+ self:log("debug", "queued new work item, %d items queued", #self.queue);
+end
+
+function runner_mt:log(level, fmt, ...)
+ return log(level, "[runner %s] "..fmt, self.id, ...);
+end
+
+local function ready()
+ return pcall(checkthread);
+end
+
+return {
+ ready = ready;
+ waiter = waiter;
+ guarder = guarder;
+ runner = runner;
+};
diff --git a/util/cache.lua b/util/cache.lua
index 9c141bb6..a5fd5e6d 100644
--- a/util/cache.lua
+++ b/util/cache.lua
@@ -116,6 +116,25 @@ function cache_methods:tail()
return tail.key, tail.value;
end
+function cache_methods:resize(new_size)
+ new_size = assert(tonumber(new_size), "cache size must be a number");
+ new_size = math.floor(new_size);
+ assert(new_size > 0, "cache size must be greater than zero");
+ local on_evict = self._on_evict;
+ while self._count > new_size do
+ local tail = self._tail;
+ local evicted_key, evicted_value = tail.key, tail.value;
+ if on_evict ~= nil and (on_evict == false or on_evict(evicted_key, evicted_value) == false) then
+ -- Cache is full, and we're not allowed to evict
+ return false;
+ end
+ _remove(self, tail);
+ self._data[evicted_key] = nil;
+ end
+ self.size = new_size;
+ return true;
+end
+
function cache_methods:table()
--luacheck: ignore 212/t
if not self.proxy_table then
@@ -139,6 +158,13 @@ function cache_methods:table()
return self.proxy_table;
end
+function cache_methods:clear()
+ self._data = {};
+ self._count = 0;
+ self._head = nil;
+ self._tail = nil;
+end
+
local function new(size, on_evict)
size = assert(tonumber(size), "cache size must be a number");
size = math.floor(size);
diff --git a/util/caps.lua b/util/caps.lua
index cd5ff9c0..de492edb 100644
--- a/util/caps.lua
+++ b/util/caps.lua
@@ -13,6 +13,7 @@ local t_insert, t_sort, t_concat = table.insert, table.sort, table.concat;
local ipairs = ipairs;
local _ENV = nil;
+-- luacheck: std none
local function calculate_hash(disco_info)
local identities, features, extensions = {}, {}, {};
diff --git a/util/dataforms.lua b/util/dataforms.lua
index 469ce976..9f11bed2 100644
--- a/util/dataforms.lua
+++ b/util/dataforms.lua
@@ -14,6 +14,7 @@ local st = require "util.stanza";
local jid_prep = require "util.jid".prep;
local _ENV = nil;
+-- luacheck: std none
local xmlns_forms = 'jabber:x:data';
diff --git a/util/datamanager.lua b/util/datamanager.lua
index bd8fb7bb..cf96887b 100644
--- a/util/datamanager.lua
+++ b/util/datamanager.lua
@@ -40,9 +40,10 @@ pcall(function()
end);
local _ENV = nil;
+-- luacheck: std none
---- utils -----
-local encode, decode;
+local encode, decode, store_encode;
do
local urlcodes = setmetatable({}, { __index = function (t, k) t[k] = char(tonumber(k, 16)); return t[k]; end });
@@ -53,6 +54,12 @@ do
encode = function (s)
return s and (s:gsub("%W", function (c) return format("%%%02x", c:byte()); end));
end
+
+ -- Special encode function for store names, which historically were unencoded.
+ -- All currently known stores use a-z and underscore, so this one preserves underscores.
+ store_encode = function (s)
+ return s and (s:gsub("[^_%w]", function (c) return format("%%%02x", c:byte()); end));
+ end
end
if not atomic_append then
@@ -119,6 +126,7 @@ local function getpath(username, host, datastore, ext, create)
ext = ext or "dat";
host = (host and encode(host)) or "_global";
username = username and encode(username);
+ datastore = store_encode(datastore);
if username then
if create then mkdir(mkdir(mkdir(data_path).."/"..host).."/"..datastore); end
return format("%s/%s/%s/%s.%s", data_path, host, datastore, username, ext);
diff --git a/util/datetime.lua b/util/datetime.lua
index abb4e867..06be9fc2 100644
--- a/util/datetime.lua
+++ b/util/datetime.lua
@@ -15,6 +15,7 @@ local os_difftime = os.difftime;
local tonumber = tonumber;
local _ENV = nil;
+-- luacheck: std none
local function date(t)
return os_date("!%Y-%m-%d", t);
diff --git a/util/debug.lua b/util/debug.lua
index 00f476d0..9a28395a 100644
--- a/util/debug.lua
+++ b/util/debug.lua
@@ -47,6 +47,7 @@ local function get_upvalues_table(func)
for upvalue_num = 1, math.huge do
local name, value = debug.getupvalue(func, upvalue_num);
if not name then break; end
+ if name == "" then name = ("[%d]"):format(upvalue_num); end
table.insert(upvalues, { name = name, value = value });
end
end
@@ -112,7 +113,9 @@ end
local function build_source_boundary_marker(last_source_desc)
local padding = string.rep("-", math.floor(((optimal_line_length - 6) - #last_source_desc)/2));
- return getstring(styles.boundary_padding, "v"..padding).." "..getstring(styles.filename, last_source_desc).." "..getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v "));
+ return getstring(styles.boundary_padding, "v"..padding).." "..
+ getstring(styles.filename, last_source_desc).." "..
+ getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v "));
end
local function _traceback(thread, message, level)
@@ -142,9 +145,9 @@ local function _traceback(thread, message, level)
local last_source_desc;
local lines = {};
- for nlevel, level in ipairs(levels) do
- local info = level.info;
- local line = "...";
+ for nlevel, current_level in ipairs(levels) do
+ local info = current_level.info;
+ local line;
local func_type = info.namewhat.." ";
local source_desc = (info.short_src == "[C]" and "C code") or info.short_src or "Unknown";
if func_type == " " then func_type = ""; end;
@@ -160,7 +163,9 @@ local function _traceback(thread, message, level)
if func_type == "global " or func_type == "local " then
func_type = func_type.."function ";
end
- line = "[Lua] "..getstring(styles.location, info.short_src.." line "..info.currentline).." in "..func_type..getstring(styles.funcname, name).." (defined on line "..info.linedefined..")";
+ line = "[Lua] "..getstring(styles.location, info.short_src.." line "..
+ info.currentline).." in "..func_type..getstring(styles.funcname, name)..
+ " (defined on line "..info.linedefined..")";
end
if source_desc ~= last_source_desc then -- Venturing into a new source, add marker for previous
last_source_desc = source_desc;
@@ -169,13 +174,13 @@ local function _traceback(thread, message, level)
nlevel = nlevel-1;
table.insert(lines, "\t"..(nlevel==0 and ">" or " ")..getstring(styles.level_num, "("..nlevel..") ")..line);
local npadding = (" "):rep(#tostring(nlevel));
- if level.locals then
- local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding);
+ if current_level.locals then
+ local locals_str = string_from_var_table(current_level.locals, optimal_line_length, "\t "..npadding);
if locals_str then
table.insert(lines, "\t "..npadding.."Locals: "..locals_str);
end
end
- local upvalues_str = string_from_var_table(level.upvalues, optimal_line_length, "\t "..npadding);
+ local upvalues_str = string_from_var_table(current_level.upvalues, optimal_line_length, "\t "..npadding);
if upvalues_str then
table.insert(lines, "\t "..npadding.."Upvals: "..upvalues_str);
end
diff --git a/util/dependencies.lua b/util/dependencies.lua
index de840241..9b0afd77 100644
--- a/util/dependencies.lua
+++ b/util/dependencies.lua
@@ -28,7 +28,7 @@ local function missingdep(name, sources, msg)
end
print("");
print(msg or (name.." is required for Prosody to run, so we will now exit."));
- print("More help can be found on our website, at http://prosody.im/doc/depends");
+ print("More help can be found on our website, at https://prosody.im/doc/depends");
print("**************************");
print("");
end
@@ -40,7 +40,7 @@ end
package.preload["util.ztact"] = function ()
if not package.loaded["core.loggingmanager"] then
error("util.ztact has been removed from Prosody and you need to fix your config "
- .."file. More information can be found at http://prosody.im/doc/packagers#ztact", 0);
+ .."file. More information can be found at https://prosody.im/doc/packagers#ztact", 0);
else
error("module 'util.ztact' has been deprecated in Prosody 0.8.");
end
@@ -156,7 +156,7 @@ local function log_warnings()
if ssl then
local major, minor, veryminor, patched = ssl._VERSION:match("(%d+)%.(%d+)%.?(%d*)(M?)");
if not major or ((tonumber(major) == 0 and (tonumber(minor) or 0) <= 3 and (tonumber(veryminor) or 0) <= 2) and patched ~= "M") then
- prosody.log("error", "This version of LuaSec contains a known bug that causes disconnects, see http://prosody.im/doc/depends");
+ prosody.log("error", "This version of LuaSec contains a known bug that causes disconnects, see https://prosody.im/doc/depends");
end
end
local lxp = softreq"lxp";
@@ -165,7 +165,7 @@ local function log_warnings()
prosody.log("error", "The version of LuaExpat on your system leaves Prosody "
.."vulnerable to denial-of-service attacks. You should upgrade to "
.."LuaExpat 1.3.0 or higher as soon as possible. See "
- .."http://prosody.im/doc/depends#luaexpat for more information.");
+ .."https://prosody.im/doc/depends#luaexpat for more information.");
end
if not lxp.new({}).getcurrentbytecount then
prosody.log("error", "The version of LuaExpat on your system does not support "
@@ -173,7 +173,7 @@ local function log_warnings()
.."networks (e.g. the internet) vulnerable to denial-of-service "
.."attacks. You should upgrade to LuaExpat 1.3.0 or higher as "
.."soon as possible. See "
- .."http://prosody.im/doc/depends#luaexpat for more information.");
+ .."https://prosody.im/doc/depends#luaexpat for more information.");
end
end
end
diff --git a/util/envload.lua b/util/envload.lua
index 926f20c0..6182a1f9 100644
--- a/util/envload.lua
+++ b/util/envload.lua
@@ -4,7 +4,7 @@
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
--- luacheck: ignore 113/setfenv
+-- luacheck: ignore 113/setfenv 113/loadstring
local load, loadstring, setfenv = load, loadstring, setfenv;
local io_open = io.open;
diff --git a/util/events.lua b/util/events.lua
index e2943e44..78700d65 100644
--- a/util/events.lua
+++ b/util/events.lua
@@ -15,6 +15,7 @@ local setmetatable = setmetatable;
local next = next;
local _ENV = nil;
+-- luacheck: std none
local function new()
-- Map event name to ordered list of handlers (lazily built): handlers[event_name] = array_of_handler_functions
@@ -26,7 +27,7 @@ local function new()
-- Event map: event_map[handler_function] = priority_number
local event_map = {};
-- Called on-demand to build handlers entries
- local function _rebuild_index(handlers, event)
+ local function _rebuild_index(self, event)
local _handlers = event_map[event];
if not _handlers or next(_handlers) == nil then return; end
local index = {};
@@ -34,7 +35,7 @@ local function new()
t_insert(index, handler);
end
t_sort(index, function(a, b) return _handlers[a] > _handlers[b]; end);
- handlers[event] = index;
+ self[event] = index;
return index;
end;
setmetatable(handlers, { __index = _rebuild_index });
@@ -61,13 +62,13 @@ local function new()
local function get_handlers(event)
return handlers[event];
end;
- local function add_handlers(handlers)
- for event, handler in pairs(handlers) do
+ local function add_handlers(self)
+ for event, handler in pairs(self) do
add_handler(event, handler);
end
end;
- local function remove_handlers(handlers)
- for event, handler in pairs(handlers) do
+ local function remove_handlers(self)
+ for event, handler in pairs(self) do
remove_handler(event, handler);
end
end;
@@ -81,6 +82,7 @@ local function new()
end
end;
local function fire_event(event_name, event_data)
+ -- luacheck: ignore 432/event_name 432/event_data
local w = wrappers[event_name] or global_wrappers;
if w then
local curr_wrapper = #w;
diff --git a/util/filters.lua b/util/filters.lua
index f405c0bd..f30dfd9c 100644
--- a/util/filters.lua
+++ b/util/filters.lua
@@ -9,6 +9,7 @@
local t_insert, t_remove = table.insert, table.remove;
local _ENV = nil;
+-- luacheck: std none
local new_filter_hooks = {};
diff --git a/util/format.lua b/util/format.lua
index 5f2b12be..c5e513fa 100644
--- a/util/format.lua
+++ b/util/format.lua
@@ -4,11 +4,10 @@
local tostring = tostring;
local select = select;
-local assert = assert;
-local unpack = unpack;
+local unpack = table.unpack or unpack; -- luacheck: ignore 113/unpack
local type = type;
-local function format(format, ...)
+local function format(formatstring, ...)
local args, args_length = { ... }, select('#', ...);
-- format specifier spec:
@@ -25,7 +24,7 @@ local function format(format, ...)
-- process each format specifier
local i = 0;
- format = format:gsub("%%[^cdiouxXaAeEfgGqs%%]*[cdiouxXaAeEfgGqs%%]", function(spec)
+ formatstring = formatstring:gsub("%%[^cdiouxXaAeEfgGqs%%]*[cdiouxXaAeEfgGqs%%]", function(spec)
if spec ~= "%%" then
i = i + 1;
local arg = args[i];
@@ -54,21 +53,12 @@ local function format(format, ...)
else
args[i] = tostring(arg);
end
- format = format .. " [%s]"
+ formatstring = formatstring .. " [%s]"
end
- return format:format(unpack(args));
-end
-
-local function test()
- assert(format("%s", "hello") == "hello");
- assert(format("%s") == "<nil>");
- assert(format("%s", true) == "true");
- assert(format("%d", true) == "[true]");
- assert(format("%%", true) == "% [true]");
+ return formatstring:format(unpack(args));
end
return {
format = format;
- test = test;
};
diff --git a/util/import.lua b/util/import.lua
index c2b9dce1..8ecfe43c 100644
--- a/util/import.lua
+++ b/util/import.lua
@@ -8,9 +8,9 @@
-local unpack = table.unpack or unpack; --luacheck: ignore 113
+local unpack = table.unpack or unpack; --luacheck: ignore 113 143
local t_insert = table.insert;
-function import(module, ...)
+function _G.import(module, ...)
local m = package.loaded[module] or require(module);
if type(m) == "table" and ... then
local ret = {};
diff --git a/util/indexedbheap.lua b/util/indexedbheap.lua
new file mode 100644
index 00000000..7f193d54
--- /dev/null
+++ b/util/indexedbheap.lua
@@ -0,0 +1,157 @@
+
+local setmetatable = setmetatable;
+local math_floor = math.floor;
+local t_remove = table.remove;
+
+local function _heap_insert(self, item, sync, item2, index)
+ local pos = #self + 1;
+ while true do
+ local half_pos = math_floor(pos / 2);
+ if half_pos == 0 or item > self[half_pos] then break; end
+ self[pos] = self[half_pos];
+ sync[pos] = sync[half_pos];
+ index[sync[pos]] = pos;
+ pos = half_pos;
+ end
+ self[pos] = item;
+ sync[pos] = item2;
+ index[item2] = pos;
+end
+
+local function _percolate_up(self, k, sync, index)
+ local tmp = self[k];
+ local tmp_sync = sync[k];
+ while k ~= 1 do
+ local parent = math_floor(k/2);
+ if tmp < self[parent] then break; end
+ self[k] = self[parent];
+ sync[k] = sync[parent];
+ index[sync[k]] = k;
+ k = parent;
+ end
+ self[k] = tmp;
+ sync[k] = tmp_sync;
+ index[tmp_sync] = k;
+ return k;
+end
+
+local function _percolate_down(self, k, sync, index)
+ local tmp = self[k];
+ local tmp_sync = sync[k];
+ local size = #self;
+ local child = 2*k;
+ while 2*k <= size do
+ if child ~= size and self[child] > self[child + 1] then
+ child = child + 1;
+ end
+ if tmp > self[child] then
+ self[k] = self[child];
+ sync[k] = sync[child];
+ index[sync[k]] = k;
+ else
+ break;
+ end
+
+ k = child;
+ child = 2*k;
+ end
+ self[k] = tmp;
+ sync[k] = tmp_sync;
+ index[tmp_sync] = k;
+ return k;
+end
+
+local function _heap_pop(self, sync, index)
+ local size = #self;
+ if size == 0 then return nil; end
+
+ local result = self[1];
+ local result_sync = sync[1];
+ index[result_sync] = nil;
+ if size == 1 then
+ self[1] = nil;
+ sync[1] = nil;
+ return result, result_sync;
+ end
+ self[1] = t_remove(self);
+ sync[1] = t_remove(sync);
+ index[sync[1]] = 1;
+
+ _percolate_down(self, 1, sync, index);
+
+ return result, result_sync;
+end
+
+local indexed_heap = {};
+
+function indexed_heap:insert(item, priority, id)
+ if id == nil then
+ id = self.current_id;
+ self.current_id = id + 1;
+ end
+ self.items[id] = item;
+ _heap_insert(self.priorities, priority, self.ids, id, self.index);
+ return id;
+end
+function indexed_heap:pop()
+ local priority, id = _heap_pop(self.priorities, self.ids, self.index);
+ if id then
+ local item = self.items[id];
+ self.items[id] = nil;
+ return priority, item, id;
+ end
+end
+function indexed_heap:peek()
+ return self.priorities[1];
+end
+function indexed_heap:reprioritize(id, priority)
+ local k = self.index[id];
+ if k == nil then return; end
+ self.priorities[k] = priority;
+
+ k = _percolate_up(self.priorities, k, self.ids, self.index);
+ _percolate_down(self.priorities, k, self.ids, self.index);
+end
+function indexed_heap:remove_index(k)
+ local result = self.priorities[k];
+ if result == nil then return; end
+
+ local result_sync = self.ids[k];
+ local item = self.items[result_sync];
+ local size = #self.priorities;
+
+ self.priorities[k] = self.priorities[size];
+ self.ids[k] = self.ids[size];
+ self.index[self.ids[k]] = k;
+
+ t_remove(self.priorities);
+ t_remove(self.ids);
+
+ self.index[result_sync] = nil;
+ self.items[result_sync] = nil;
+
+ if size > k then
+ k = _percolate_up(self.priorities, k, self.ids, self.index);
+ _percolate_down(self.priorities, k, self.ids, self.index);
+ end
+
+ return result, item, result_sync;
+end
+function indexed_heap:remove(id)
+ return self:remove_index(self.index[id]);
+end
+
+local mt = { __index = indexed_heap };
+
+local _M = {
+ create = function()
+ return setmetatable({
+ ids = {}; -- heap of ids, sync'd with priorities
+ items = {}; -- map id->items
+ priorities = {}; -- heap of priorities
+ index = {}; -- map of id->index of id in ids
+ current_id = 1.5
+ }, mt);
+ end
+};
+return _M;
diff --git a/util/ip.lua b/util/ip.lua
index 81a98ef7..0ec9e297 100644
--- a/util/ip.lua
+++ b/util/ip.lua
@@ -5,69 +5,76 @@
-- COPYING file in the source package for more information.
--
+local net = require "util.net";
+local hex = require "util.hex";
+
local ip_methods = {};
-local ip_mt = { __index = function (ip, key) return (ip_methods[key])(ip); end,
- __tostring = function (ip) return ip.addr; end,
- __eq = function (ipA, ipB) return ipA.addr == ipB.addr; end};
-local hex2bits = { ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111", ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011", ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111" };
+
+local ip_mt = {
+ __index = function (ip, key)
+ local method = ip_methods[key];
+ if not method then return nil; end
+ local ret = method(ip);
+ ip[key] = ret;
+ return ret;
+ end,
+ __tostring = function (ip) return ip.addr; end,
+ __eq = function (ipA, ipB) return ipA.packed == ipB.packed; end
+};
+
+local hex2bits = {
+ ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011",
+ ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111",
+ ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011",
+ ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111",
+};
local function new_ip(ipStr, proto)
- if not proto then
- local sep = ipStr:match("^%x+(.)");
- if sep == ":" or (not(sep) and ipStr:sub(1,1) == ":") then
- proto = "IPv6"
- elseif sep == "." then
- proto = "IPv4"
- end
- if not proto then
- return nil, "invalid address";
- end
- elseif proto ~= "IPv4" and proto ~= "IPv6" then
- return nil, "invalid protocol";
- end
local zone;
- if proto == "IPv6" and ipStr:find('%', 1, true) then
+ if (not proto or proto == "IPv6") and ipStr:find('%', 1, true) then
ipStr, zone = ipStr:match("^(.-)%%(.*)");
end
- if proto == "IPv6" and ipStr:find('.', 1, true) then
- local changed;
- ipStr, changed = ipStr:gsub(":(%d+)%.(%d+)%.(%d+)%.(%d+)$", function(a,b,c,d)
- return (":%04X:%04X"):format(a*256+b,c*256+d);
- end);
- if changed ~= 1 then return nil, "invalid-address"; end
+
+ local packed, err = net.pton(ipStr);
+ if not packed then return packed, err end
+ if proto == "IPv6" and #packed ~= 16 then
+ return nil, "invalid-ipv6";
+ elseif proto == "IPv4" and #packed ~= 4 then
+ return nil, "invalid-ipv4";
+ elseif not proto then
+ if #packed == 16 then
+ proto = "IPv6";
+ elseif #packed == 4 then
+ proto = "IPv4";
+ else
+ return nil, "unknown protocol";
+ end
+ elseif proto ~= "IPv6" and proto ~= "IPv4" then
+ return nil, "invalid protocol";
end
- return setmetatable({ addr = ipStr, proto = proto, zone = zone }, ip_mt);
+ return setmetatable({ addr = ipStr, packed = packed, proto = proto, zone = zone }, ip_mt);
+end
+
+function ip_methods:normal()
+ return net.ntop(self.packed);
end
-local function toBits(ip)
- local result = "";
- local fields = {};
+function ip_methods.bits(ip)
+ return hex.to(ip.packed):upper():gsub(".", hex2bits);
+end
+
+function ip_methods.bits_full(ip)
if ip.proto == "IPv4" then
ip = ip.toV4mapped;
end
- ip = (ip.addr):upper();
- ip:gsub("([^:]*):?", function (c) fields[#fields + 1] = c end);
- if not ip:match(":$") then fields[#fields] = nil; end
- for i, field in ipairs(fields) do
- if field:len() == 0 and i ~= 1 and i ~= #fields then
- for _ = 1, 16 * (9 - #fields) do
- result = result .. "0";
- end
- else
- for _ = 1, 4 - field:len() do
- result = result .. "0000";
- end
- for j = 1, field:len() do
- result = result .. hex2bits[field:sub(j, j)];
- end
- end
- end
- return result;
+ return ip.bits;
end
+local match;
+
local function commonPrefixLength(ipA, ipB)
- ipA, ipB = toBits(ipA), toBits(ipB);
+ ipA, ipB = ipA.bits_full, ipB.bits_full;
for i = 1, 128 do
if ipA:sub(i,i) ~= ipB:sub(i,i) then
return i-1;
@@ -76,56 +83,60 @@ local function commonPrefixLength(ipA, ipB)
return 128;
end
+-- Instantiate once
+local loopback = new_ip("::1");
+local loopback4 = new_ip("127.0.0.0");
+local sixtofour = new_ip("2002::");
+local teredo = new_ip("2001::");
+local linklocal = new_ip("fe80::");
+local linklocal4 = new_ip("169.254.0.0");
+local uniquelocal = new_ip("fc00::");
+local sitelocal = new_ip("fec0::");
+local sixbone = new_ip("3ffe::");
+local defaultunicast = new_ip("::");
+local multicast = new_ip("ff00::");
+local ipv6mapped = new_ip("::ffff:0:0");
+
local function v4scope(ip)
- local fields = {};
- ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end);
- -- Loopback:
- if fields[1] == 127 then
+ if match(ip, loopback4, 8) then
return 0x2;
- -- Link-local unicast:
- elseif fields[1] == 169 and fields[2] == 254 then
+ elseif match(ip, linklocal4) then
return 0x2;
- -- Global unicast:
- else
+ else -- Global unicast
return 0xE;
end
end
local function v6scope(ip)
- -- Loopback:
- if ip:match("^[0:]*1$") then
+ if ip == loopback then
return 0x2;
- -- Link-local unicast:
- elseif ip:match("^[Ff][Ee][89ABab]") then
+ elseif match(ip, linklocal, 10) then
return 0x2;
- -- Site-local unicast:
- elseif ip:match("^[Ff][Ee][CcDdEeFf]") then
+ elseif match(ip, sitelocal, 10) then
return 0x5;
- -- Multicast:
- elseif ip:match("^[Ff][Ff]") then
- return tonumber("0x"..ip:sub(4,4));
- -- Global unicast:
- else
+ elseif match(ip, multicast, 10) then
+ return ip.packed:byte(2) % 0x10;
+ else -- Global unicast
return 0xE;
end
end
local function label(ip)
- if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then
+ if ip == loopback then
return 0;
- elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then
+ elseif match(ip, sixtofour, 16) then
return 2;
- elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then
+ elseif match(ip, teredo, 32) then
return 5;
- elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then
+ elseif match(ip, uniquelocal, 7) then
return 13;
- elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then
+ elseif match(ip, sitelocal, 10) then
return 11;
- elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then
+ elseif match(ip, sixbone, 16) then
return 12;
- elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then
+ elseif match(ip, defaultunicast, 96) then
return 3;
- elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then
+ elseif match(ip, ipv6mapped, 96) then
return 4;
else
return 1;
@@ -133,91 +144,67 @@ local function label(ip)
end
local function precedence(ip)
- if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then
+ if ip == loopback then
return 50;
- elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then
+ elseif match(ip, sixtofour, 16) then
return 30;
- elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then
+ elseif match(ip, teredo, 32) then
return 5;
- elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then
+ elseif match(ip, uniquelocal, 7) then
return 3;
- elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then
+ elseif match(ip, sitelocal, 10) then
return 1;
- elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then
+ elseif match(ip, sixbone, 16) then
return 1;
- elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then
+ elseif match(ip, defaultunicast, 96) then
return 1;
- elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then
+ elseif match(ip, ipv6mapped, 96) then
return 35;
else
return 40;
end
end
-local function toV4mapped(ip)
- local fields = {};
- local ret = "::ffff:";
- ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end);
- ret = ret .. ("%02x"):format(fields[1]);
- ret = ret .. ("%02x"):format(fields[2]);
- ret = ret .. ":"
- ret = ret .. ("%02x"):format(fields[3]);
- ret = ret .. ("%02x"):format(fields[4]);
- return new_ip(ret, "IPv6");
-end
-
function ip_methods:toV4mapped()
if self.proto ~= "IPv4" then return nil, "No IPv4 address" end
- local value = toV4mapped(self.addr);
- self.toV4mapped = value;
+ local value = new_ip("::ffff:" .. self.normal);
return value;
end
function ip_methods:label()
- local value;
if self.proto == "IPv4" then
- value = label(self.toV4mapped);
+ return label(self.toV4mapped);
else
- value = label(self);
+ return label(self);
end
- self.label = value;
- return value;
end
function ip_methods:precedence()
- local value;
if self.proto == "IPv4" then
- value = precedence(self.toV4mapped);
+ return precedence(self.toV4mapped);
else
- value = precedence(self);
+ return precedence(self);
end
- self.precedence = value;
- return value;
end
function ip_methods:scope()
- local value;
if self.proto == "IPv4" then
- value = v4scope(self.addr);
+ return v4scope(self);
else
- value = v6scope(self.addr);
+ return v6scope(self);
end
- self.scope = value;
- return value;
end
+local rfc1918_8 = new_ip("10.0.0.0");
+local rfc1918_12 = new_ip("172.16.0.0");
+local rfc1918_16 = new_ip("192.168.0.0");
+local rfc6598 = new_ip("100.64.0.0");
+
function ip_methods:private()
local private = self.scope ~= 0xE;
if not private and self.proto == "IPv4" then
- local ip = self.addr;
- local fields = {};
- ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end);
- if fields[1] == 127 or fields[1] == 10 or (fields[1] == 192 and fields[2] == 168)
- or (fields[1] == 172 and (fields[2] >= 16 or fields[2] <= 32)) then
- private = true;
- end
+ return match(self, rfc1918_8, 8) or match(self, rfc1918_12, 12) or match(self, rfc1918_16) or match(self, rfc6598, 10);
end
- self.private = private;
return private;
end
@@ -231,15 +218,26 @@ local function parse_cidr(cidr)
return new_ip(cidr), bits;
end
-local function match(ipA, ipB, bits)
- local common_bits = commonPrefixLength(ipA, ipB);
- if bits and ipB.proto == "IPv4" then
- common_bits = common_bits - 96; -- v6 mapped addresses always share these bits
+function match(ipA, ipB, bits)
+ if not bits or bits >= 128 or ipB.proto == "IPv4" and bits >= 32 then
+ return ipA == ipB;
+ elseif bits < 1 then
+ return true;
+ end
+ if ipA.proto ~= ipB.proto then
+ if ipA.proto == "IPv4" then
+ ipA = ipA.toV4mapped;
+ elseif ipB.proto == "IPv4" then
+ ipB = ipB.toV4mapped;
+ bits = bits + (128 - 32);
+ end
end
- return common_bits >= (bits or 128);
+ return ipA.bits:sub(1, bits) == ipB.bits:sub(1, bits);
end
-return {new_ip = new_ip,
+return {
+ new_ip = new_ip,
commonPrefixLength = commonPrefixLength,
parse_cidr = parse_cidr,
- match=match};
+ match = match,
+};
diff --git a/util/iterators.lua b/util/iterators.lua
index bd150ff2..a152e7be 100644
--- a/util/iterators.lua
+++ b/util/iterators.lua
@@ -12,8 +12,8 @@ local it = {};
local t_insert = table.insert;
local select, next = select, next;
-local unpack = table.unpack or unpack; --luacheck: ignore 113
-local pack = table.pack or function (...) return { n = select("#", ...), ... }; end
+local unpack = table.unpack or unpack; --luacheck: ignore 113 143
+local pack = table.pack or function (...) return { n = select("#", ...), ... }; end -- luacheck: ignore 143
-- Reverse an iterator
function it.reverse(f, s, var)
diff --git a/util/jid.lua b/util/jid.lua
index f402b7f4..37c48193 100644
--- a/util/jid.lua
+++ b/util/jid.lua
@@ -25,6 +25,7 @@ local unescapes = {};
for k,v in pairs(escapes) do unescapes[v] = k; end
local _ENV = nil;
+-- luacheck: std none
local function split(jid)
if not jid then return; end
diff --git a/util/json.lua b/util/json.lua
index cba54e8e..05af709a 100644
--- a/util/json.lua
+++ b/util/json.lua
@@ -27,9 +27,6 @@ module.null = null;
local escapes = {
["\""] = "\\\"", ["\\"] = "\\\\", ["\b"] = "\\b",
["\f"] = "\\f", ["\n"] = "\\n", ["\r"] = "\\r", ["\t"] = "\\t"};
-local unescapes = {
- ["\""] = "\"", ["\\"] = "\\", ["/"] = "/",
- b = "\b", f = "\f", n = "\n", r = "\r", t = "\t"};
for i=0,31 do
local ch = s_char(i);
if not escapes[ch] then escapes[ch] = ("\\u%.4X"):format(i); end
@@ -263,8 +260,9 @@ end
local function _unescape_func(x)
x = x:match("%x%x%x%x", 3);
if x then
- --if x >= 0xD800 and x <= 0xDFFF then _unescape_error = true; end -- bad surrogate pair
- return codepoint_to_utf8(tonumber(x, 16));
+ local codepoint = tonumber(x, 16)
+ if codepoint >= 0xD800 and codepoint <= 0xDFFF then _unescape_error = true; end -- bad surrogate pair
+ return codepoint_to_utf8(codepoint);
end
_unescape_error = true;
end
@@ -276,7 +274,7 @@ function _readstring(json, index)
--if s:find("[%z-\31]") then return nil, "control char in string"; end
-- FIXME handle control characters
_unescape_error = nil;
- --s = s:gsub("\\u[dD][89abAB]%x%x\\u[dD][cdefCDEF]%x%x", _unescape_surrogate_func);
+ s = s:gsub("\\u[dD][89abAB]%x%x\\u[dD][cdefCDEF]%x%x", _unescape_surrogate_func);
-- FIXME handle escapes beyond BMP
s = s:gsub("\\u.?.?.?.?", _unescape_func);
if _unescape_error then return nil, "invalid escape"; end
diff --git a/util/logger.lua b/util/logger.lua
index e72b29bc..20a5cef2 100644
--- a/util/logger.lua
+++ b/util/logger.lua
@@ -8,8 +8,11 @@
-- luacheck: ignore 213/level
local pairs = pairs;
+local ipairs = ipairs;
+local require = require;
local _ENV = nil;
+-- luacheck: std none
local level_sinks = {};
@@ -67,10 +70,21 @@ local function add_level_sink(level, sink_function)
end
end
+local function add_simple_sink(simple_sink_function, levels)
+ local format = require "util.format".format;
+ local function sink_function(name, level, msg, ...)
+ return simple_sink_function(name, level, format(msg, ...));
+ end
+ for _, level in ipairs(levels or {"debug", "info", "warn", "error"}) do
+ add_level_sink(level, sink_function);
+ end
+end
+
return {
init = init;
make_logger = make_logger;
reset = reset;
add_level_sink = add_level_sink;
+ add_simple_sink = add_simple_sink;
new = make_logger;
};
diff --git a/util/multitable.lua b/util/multitable.lua
index e4321d3d..8d32ed8a 100644
--- a/util/multitable.lua
+++ b/util/multitable.lua
@@ -9,9 +9,10 @@
local select = select;
local t_insert = table.insert;
local pairs, next, type = pairs, next, type;
-local unpack = table.unpack or unpack; --luacheck: ignore 113
+local unpack = table.unpack or unpack; --luacheck: ignore 113 143
local _ENV = nil;
+-- luacheck: std none
local function get(self, ...)
local t = self.data;
@@ -132,7 +133,7 @@ local function iter(self, ...)
local maxdepth = select("#", ...);
local stack = { self.data };
local keys = { };
- local function it(self)
+ local function it(self) -- luacheck: ignore 432/self
local depth = #stack;
local key = next(stack[depth], keys[depth]);
if key == nil then -- Go up the stack
diff --git a/util/openssl.lua b/util/openssl.lua
index 703c6d15..32b5aea7 100644
--- a/util/openssl.lua
+++ b/util/openssl.lua
@@ -114,7 +114,7 @@ function ssl_config:add_xmppAddr(host)
s_format("%s;%s", oid_xmppaddr, utf8string(host)));
end
-function ssl_config:from_prosody(hosts, config, certhosts)
+function ssl_config:from_prosody(hosts, config, certhosts) -- luacheck: ignore 431/config
-- TODO Decide if this should go elsewhere
local found_matching_hosts = false;
for i = 1, #certhosts do
diff --git a/util/pluginloader.lua b/util/pluginloader.lua
index 004855f0..9ab8f245 100644
--- a/util/pluginloader.lua
+++ b/util/pluginloader.lua
@@ -5,6 +5,7 @@
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
+-- luacheck: ignore 113/CFG_PLUGINDIR
local dir_sep, path_sep = package.config:match("^(%S+)%s(%S+)");
local plugin_dir = {};
diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua
index 8ae051ae..eee09762 100644
--- a/util/prosodyctl.lua
+++ b/util/prosodyctl.lua
@@ -24,8 +24,6 @@ local io, os = io, os;
local print = print;
local tonumber = tonumber;
-local CFG_SOURCEDIR = _G.CFG_SOURCEDIR;
-
local _G = _G;
local prosody = prosody;
@@ -66,7 +64,10 @@ local function getline()
end
local function getpass()
- local stty_ret = os.execute("stty -echo 2>/dev/null");
+ local stty_ret, _, status_code = os.execute("stty -echo 2>/dev/null");
+ if status_code then -- COMPAT w/ Lua 5.1
+ stty_ret = status_code;
+ end
if stty_ret ~= 0 then
io.write("\027[08m"); -- ANSI 'hidden' text attribute
end
@@ -228,7 +229,7 @@ local function isrunning()
return true, signal.kill(pid, 0) == 0;
end
-local function start()
+local function start(source_dir)
local ok, ret = isrunning();
if not ok then
return ok, ret;
@@ -236,10 +237,10 @@ local function start()
if ret then
return false, "already-running";
end
- if not CFG_SOURCEDIR then
+ if not source_dir then
os.execute("./prosody");
else
- os.execute(CFG_SOURCEDIR.."/../../bin/prosody");
+ os.execute(source_dir.."/../../bin/prosody");
end
return true;
end
diff --git a/util/pubsub.lua b/util/pubsub.lua
index 1db917d8..3a00aae5 100644
--- a/util/pubsub.lua
+++ b/util/pubsub.lua
@@ -1,32 +1,70 @@
local events = require "util.events";
local cache = require "util.cache";
-local service = {};
-local service_mt = { __index = service };
+local service_mt = {};
-local default_config = { __index = {
- itemstore = function (config) return cache.new(tonumber(config["pubsub#max_items"])) end;
+local default_config = {
+ itemstore = function (config, _) return cache.new(config["max_items"]) end;
broadcaster = function () end;
+ itemcheck = function () return true; end;
get_affiliation = function () end;
capabilities = {};
-} };
-local default_node_config = { __index = {
- ["pubsub#max_items"] = "20";
-} };
+};
+local default_config_mt = { __index = default_config };
+
+local default_node_config = {
+ ["persist_items"] = false;
+ ["max_items"] = 20;
+};
+local default_node_config_mt = { __index = default_node_config };
+
+-- Storage helper functions
+
+local function load_node_from_store(nodestore, node_name)
+ local node = nodestore:get(node_name);
+ node.config = setmetatable(node.config or {}, default_node_config_mt);
+ return node;
+end
+
+local function save_node_to_store(nodestore, node)
+ return nodestore:set(node.name, {
+ name = node.name;
+ config = node.config;
+ subscribers = node.subscribers;
+ affiliations = node.affiliations;
+ });
+end
+-- Create and return a new service object
local function new(config)
config = config or {};
- return setmetatable({
- config = setmetatable(config, default_config);
- node_defaults = setmetatable(config.node_defaults or {}, default_node_config);
+
+ local service = setmetatable({
+ config = setmetatable(config, default_config_mt);
+ node_defaults = setmetatable(config.node_defaults or {}, default_node_config_mt);
affiliations = {};
subscriptions = {};
nodes = {};
data = {};
events = events.new();
}, service_mt);
+
+ -- Load nodes from storage, if we have a store and it supports iterating over stored items
+ if config.nodestore and config.nodestore.users then
+ for node_name in config.nodestore:users() do
+ service.nodes[node_name] = load_node_from_store(config.nodestore, node_name);
+ service.data[node_name] = config.itemstore(service.nodes[node_name].config, node_name);
+ end
+ end
+
+ return service;
end
+--- Service methods
+
+local service = {};
+service_mt.__index = service;
+
function service:jids_equal(jid1, jid2)
local normalize = self.config.normalize_jid;
return normalize(jid1) == normalize(jid2);
@@ -176,18 +214,6 @@ function service:remove_subscription(node, actor, jid)
return true;
end
-function service:remove_all_subscriptions(actor, jid)
- local normal_jid = self.config.normalize_jid(jid);
- local subs = self.subscriptions[normal_jid]
- subs = subs and subs[jid];
- if subs then
- for node in pairs(subs) do
- self:remove_subscription(node, true, jid);
- end
- end
- return true;
-end
-
function service:get_subscription(node, actor, jid)
-- Access checking
local cap;
@@ -223,13 +249,24 @@ function service:create(node, actor, options)
config = setmetatable(options or {}, {__index=self.node_defaults});
affiliations = {};
};
- self.data[node] = self.config.itemstore(self.nodes[node].config);
+
+ if self.config.nodestore then
+ local ok, err = save_node_to_store(self.config.nodestore, self.nodes[node]);
+ if not ok then
+ self.nodes[node] = nil;
+ return ok, err;
+ end
+ end
+
+ self.data[node] = self.config.itemstore(self.nodes[node].config, node);
self.events.fire_event("node-created", { node = node, actor = actor });
local ok, err = self:set_affiliation(node, true, actor, "owner");
if not ok then
self.nodes[node] = nil;
self.data[node] = nil;
+ return ok, err;
end
+
return ok, err;
end
@@ -244,6 +281,9 @@ function service:delete(node, actor)
return false, "item-not-found";
end
self.nodes[node] = nil;
+ if self.data[node] and self.data[node].clear then
+ self.data[node]:clear();
+ end
self.data[node] = nil;
self.events.fire_event("node-deleted", { node = node, actor = actor });
self.config.broadcaster("delete", node, node_obj.subscribers);
@@ -267,11 +307,15 @@ function service:publish(node, actor, id, item)
end
node_obj = self.nodes[node];
end
+ if not self.config.itemcheck(item) then
+ return nil, "internal-server-error";
+ end
local node_data = self.data[node];
local ok = node_data:set(id, item);
if not ok then
return nil, "internal-server-error";
end
+ if type(ok) == "string" then id = ok; end
self.events.fire_event("item-published", { node = node, actor = actor, id = id, item = item });
self.config.broadcaster("items", node, node_obj.subscribers, item, actor);
return true;
@@ -308,7 +352,11 @@ function service:purge(node, actor, notify)
if not node_obj then
return false, "item-not-found";
end
- self.data[node] = self.config.itemstore(self.nodes[node].config);
+ if self.data[node] and self.data[node].clear then
+ self.data[node]:clear()
+ else
+ self.data[node] = self.config.itemstore(self.nodes[node].config, node);
+ end
self.events.fire_event("node-purged", { node = node, actor = actor });
if notify then
self.config.broadcaster("purge", node, node_obj.subscribers);
@@ -327,7 +375,11 @@ function service:get_items(node, actor, id)
return false, "item-not-found";
end
if id then -- Restrict results to a single specific item
- return true, { id, [id] = self.data[node]:get(id) };
+ local with_id = self.data[node]:get(id);
+ if not with_id then
+ return true, { };
+ end
+ return true, { id, [id] = with_id };
else
local data = {}
for key, value in self.data[node]:items() do
@@ -338,6 +390,15 @@ function service:get_items(node, actor, id)
end
end
+function service:get_last_item(node, actor)
+ -- Access checking
+ if not self:may(node, actor, "get_items") then
+ return false, "forbidden";
+ end
+ --
+ return true, self.data[node]:tail();
+end
+
function service:get_nodes(actor)
-- Access checking
if not self:may(nil, actor, "get_nodes") then
@@ -421,14 +482,14 @@ function service:set_node_config(node, actor, new_config)
return false, "item-not-found";
end
- for k,v in pairs(new_config) do
- node_obj.config[k] = v;
- end
- local new_data = self.config.itemstore(self.nodes[node].config);
- for key, value in self.data[node]:items() do
- new_data:set(key, value);
+ if new_config["persist_items"] ~= node_obj.config["persist_items"] then
+ self.data[node] = self.config.itemstore(self.nodes[node].config, node);
+ elseif new_config["max_items"] ~= node_obj.config["max_items"] then
+ self.data[node]:resize(new_config["max_items"]);
end
- self.data[node] = new_data;
+
+ node_obj.config = setmetatable(new_config, {__index=self.node_defaults});
+
return true;
end
diff --git a/util/random.lua b/util/random.lua
index b2d0000d..d8a84514 100644
--- a/util/random.lua
+++ b/util/random.lua
@@ -11,9 +11,6 @@ if ok then return crand; end
local urandom, urandom_err = io.open("/dev/urandom", "r");
-local function seed()
-end
-
local function bytes(n)
return urandom:read(n);
end
@@ -25,7 +22,6 @@ if not urandom then
end
return {
- seed = seed;
bytes = bytes;
_source = "/dev/urandom";
};
diff --git a/util/sasl.lua b/util/sasl.lua
index 5845f34a..50851405 100644
--- a/util/sasl.lua
+++ b/util/sasl.lua
@@ -20,6 +20,7 @@ local assert = assert;
local require = require;
local _ENV = nil;
+-- luacheck: std none
--[[
Authentication Backend Prototypes:
@@ -42,7 +43,7 @@ Example:
local method = {};
method.__index = method;
-local mechanisms = {};
+local registered_mechanisms = {};
local backend_mechanism = {};
local mechanism_channelbindings = {};
@@ -52,7 +53,7 @@ local function registerMechanism(name, backends, f, cb_backends)
assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table.");
assert(type(f) == "function", "Parameter f MUST be a function.");
if cb_backends then assert(type(cb_backends) == "table"); end
- mechanisms[name] = f
+ registered_mechanisms[name] = f
if cb_backends then
mechanism_channelbindings[name] = {};
for _, cb_name in ipairs(cb_backends) do
@@ -70,7 +71,7 @@ local function new(realm, profile)
local mechanisms = profile.mechanisms;
if not mechanisms then
mechanisms = {};
- for backend, f in pairs(profile) do
+ for backend in pairs(profile) do
if backend_mechanism[backend] then
for _, mechanism in ipairs(backend_mechanism[backend]) do
mechanisms[mechanism] = true;
@@ -128,7 +129,7 @@ end
-- feed new messages to process into the library
function method:process(message)
--if message == "" or message == nil then return "failure", "malformed-request" end
- return mechanisms[self.selected](self, message);
+ return registered_mechanisms[self.selected](self, message);
end
-- load the mechanisms
diff --git a/util/sasl/anonymous.lua b/util/sasl/anonymous.lua
index 6201db32..88e13b40 100644
--- a/util/sasl/anonymous.lua
+++ b/util/sasl/anonymous.lua
@@ -15,6 +15,7 @@
local generate_uuid = require "util.uuid".generate;
local _ENV = nil;
+-- luacheck: std none
--=========================
--SASL ANONYMOUS according to RFC 4505
@@ -28,7 +29,7 @@ anonymous:
end
]]
-local function anonymous(self, message)
+local function anonymous(self, message) -- luacheck: ignore 212/message
local username;
repeat
username = generate_uuid();
diff --git a/util/sasl/digest-md5.lua b/util/sasl/digest-md5.lua
index 695dd2a3..7542a037 100644
--- a/util/sasl/digest-md5.lua
+++ b/util/sasl/digest-md5.lua
@@ -26,6 +26,7 @@ local generate_uuid = require "util.uuid".generate;
local nodeprep = require "util.encodings".stringprep.nodeprep;
local _ENV = nil;
+-- luacheck: std none
--=========================
--SASL DIGEST-MD5 according to RFC 2831
diff --git a/util/sasl/external.lua b/util/sasl/external.lua
index 5ba90190..ce50743e 100644
--- a/util/sasl/external.lua
+++ b/util/sasl/external.lua
@@ -1,6 +1,7 @@
local saslprep = require "util.encodings".stringprep.saslprep;
local _ENV = nil;
+-- luacheck: std none
local function external(self, message)
message = saslprep(message);
diff --git a/util/sasl/plain.lua b/util/sasl/plain.lua
index cd59b1ac..00c6bd20 100644
--- a/util/sasl/plain.lua
+++ b/util/sasl/plain.lua
@@ -17,6 +17,7 @@ local nodeprep = require "util.encodings".stringprep.nodeprep;
local log = require "util.logger".init("sasl");
local _ENV = nil;
+-- luacheck: std none
-- ================================
-- SASL PLAIN according to RFC 4616
diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua
index 4e20dbb9..15c15d68 100644
--- a/util/sasl/scram.lua
+++ b/util/sasl/scram.lua
@@ -26,6 +26,7 @@ local char = string.char;
local byte = string.byte;
local _ENV = nil;
+-- luacheck: std none
--=========================
--SASL SCRAM-SHA-1 according to RFC 5802
@@ -46,7 +47,18 @@ Supported Channel Binding Backends
local default_i = 4096
-local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;};
+local xor_map = {
+ 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,1,0,3,2,5,4,7,6,9,8,11,10,
+ 13,12,15,14,2,3,0,1,6,7,4,5,10,11,8,9,14,15,12,13,3,2,1,0,7,6,5,
+ 4,11,10,9,8,15,14,13,12,4,5,6,7,0,1,2,3,12,13,14,15,8,9,10,11,5,
+ 4,7,6,1,0,3,2,13,12,15,14,9,8,11,10,6,7,4,5,2,3,0,1,14,15,12,13,
+ 10,11,8,9,7,6,5,4,3,2,1,0,15,14,13,12,11,10,9,8,8,9,10,11,12,13,
+ 14,15,0,1,2,3,4,5,6,7,9,8,11,10,13,12,15,14,1,0,3,2,5,4,7,6,10,
+ 11,8,9,14,15,12,13,2,3,0,1,6,7,4,5,11,10,9,8,15,14,13,12,3,2,1,
+ 0,7,6,5,4,12,13,14,15,8,9,10,11,4,5,6,7,0,1,2,3,13,12,15,14,9,8,
+ 11,10,5,4,7,6,1,0,3,2,14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1,15,
+ 14,13,12,11,10,9,8,7,6,5,4,3,2,1,0,
+};
local result = {};
local function binaryXOR( a, b )
@@ -237,10 +249,14 @@ end
local function init(registerMechanism)
local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
- registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash));
+ registerMechanism("SCRAM-"..hash_name,
+ {"plain", "scram_"..(hashprep(hash_name))},
+ scram_gen(hash_name:lower(), hash, hmac_hash));
-- register channel binding equivalent
- registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
+ registerMechanism("SCRAM-"..hash_name.."-PLUS",
+ {"plain", "scram_"..(hashprep(hash_name))},
+ scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
end
registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
diff --git a/util/sasl_cyrus.lua b/util/sasl_cyrus.lua
index 4e9a4af5..a6bd0628 100644
--- a/util/sasl_cyrus.lua
+++ b/util/sasl_cyrus.lua
@@ -61,6 +61,7 @@ local sasl_errstring = {
setmetatable(sasl_errstring, { __index = function() return "undefined error!" end });
local _ENV = nil;
+-- luacheck: std none
local method = {};
method.__index = method;
diff --git a/util/serialization.lua b/util/serialization.lua
index 206f5fbb..54c8110f 100644
--- a/util/serialization.lua
+++ b/util/serialization.lua
@@ -21,6 +21,7 @@ local log = require "util.logger".init("serialization");
local envload = require"util.envload".envload;
local _ENV = nil;
+-- luacheck: std none
local indent = function(i)
return string_rep("\t", i);
diff --git a/util/set.lua b/util/set.lua
index c136a522..a4f20138 100644
--- a/util/set.lua
+++ b/util/set.lua
@@ -11,8 +11,9 @@ local ipairs, pairs, setmetatable, next, tostring =
local t_concat = table.concat;
local _ENV = nil;
+-- luacheck: std none
-local set_mt = {};
+local set_mt = { __name = "set" };
function set_mt.__call(set, _, k)
return next(set._items, k);
end
diff --git a/util/sql.lua b/util/sql.lua
index d964025e..67a5d683 100644
--- a/util/sql.lua
+++ b/util/sql.lua
@@ -1,11 +1,10 @@
local setmetatable, getmetatable = setmetatable, getmetatable;
-local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113
-local tonumber, tostring = tonumber, tostring;
+local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113 143
+local tostring = tostring;
local type = type;
local assert, pcall, xpcall, debug_traceback = assert, pcall, xpcall, debug.traceback;
local t_concat = table.concat;
-local s_char = string.char;
local log = require "util.logger".init("sql");
local DBI = require "DBI";
@@ -15,6 +14,7 @@ DBI.Drivers();
local build_url = require "socket.url".build;
local _ENV = nil;
+-- luacheck: std none
local column_mt = {};
local table_mt = {};
@@ -58,9 +58,6 @@ table_mt.__index = {};
function table_mt.__index:create(engine)
return engine:_create_table(self);
end
-function table_mt:__call(...)
- -- TODO
-end
function column_mt:__tostring()
return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
end
@@ -71,31 +68,6 @@ function index_mt:__tostring()
-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
end
-local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
-local function parse_url(url)
- local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)");
- assert(scheme, "Invalid URL format");
- local username, password, host, port;
- local authpart, hostpart = secondpart:match("([^@]+)@([^@+])");
- if not authpart then hostpart = secondpart; end
- if authpart then
- username, password = authpart:match("([^:]*):(.*)");
- username = username or authpart;
- password = password and urldecode(password);
- end
- if hostpart then
- host, port = hostpart:match("([^:]*):(.*)");
- host = host or hostpart;
- port = port and assert(tonumber(port), "Invalid URL format");
- end
- return {
- scheme = scheme:lower();
- username = username; password = password;
- host = host; port = port;
- database = #database > 0 and database or nil;
- };
-end
-
local engine = {};
function engine:connect()
if self.conn then return true; end
@@ -123,7 +95,7 @@ function engine:connect()
end
return true;
end
-function engine:onconnect()
+function engine:onconnect() -- luacheck: ignore 212/self
-- Override from create_engine()
end
@@ -148,6 +120,7 @@ function engine:execute(sql, ...)
prepared[sql] = stmt;
end
+ -- luacheck: ignore 411/success
local success, err = stmt:execute(...);
if not success then return success, err; end
return stmt;
@@ -161,14 +134,14 @@ local result_mt = { __index = {
local function debugquery(where, sql, ...)
local i = 0; local a = {...}
sql = sql:gsub("\n?\t+", " ");
- log("debug", "[%s] %s", where, sql:gsub("%?", function ()
+ log("debug", "[%s] %s", where, (sql:gsub("%?", function ()
i = i + 1;
local v = a[i];
if type(v) == "string" then
v = ("'%s'"):format(v:gsub("'", "''"));
end
return tostring(v);
- end));
+ end)));
end
function engine:execute_query(sql, ...)
@@ -335,7 +308,12 @@ function engine:set_encoding() -- to UTF-8
local charset = "utf8";
if driver == "MySQL" then
self:transaction(function()
- for row in self:select"SELECT \"CHARACTER_SET_NAME\" FROM \"information_schema\".\"CHARACTER_SETS\" WHERE \"CHARACTER_SET_NAME\" LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;" do
+ for row in self:select[[
+ SELECT "CHARACTER_SET_NAME"
+ FROM "information_schema"."CHARACTER_SETS"
+ WHERE "CHARACTER_SET_NAME" LIKE 'utf8%'
+ ORDER BY MAXLEN DESC LIMIT 1;
+ ]] do
charset = row and row[1] or charset;
end
end);
@@ -379,7 +357,7 @@ local function db2uri(params)
};
end
-local function create_engine(self, params, onconnect)
+local function create_engine(_, params, onconnect)
return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt);
end
diff --git a/util/sslconfig.lua b/util/sslconfig.lua
index 4c4e1d48..5c685f7d 100644
--- a/util/sslconfig.lua
+++ b/util/sslconfig.lua
@@ -8,6 +8,7 @@ local t_insert = table.insert;
local setmetatable = setmetatable;
local _ENV = nil;
+-- luacheck: std none
local handlers = { };
local finalisers = { };
diff --git a/util/stanza.lua b/util/stanza.lua
index 2191fa8e..4f91d5e9 100644
--- a/util/stanza.lua
+++ b/util/stanza.lua
@@ -7,6 +7,7 @@
--
+local error = error;
local t_insert = table.insert;
local t_remove = table.remove;
local t_concat = table.concat;
@@ -23,6 +24,8 @@ local s_sub = string.sub;
local s_find = string.find;
local os = os;
+local valid_utf8 = require "util.encodings".utf8.valid;
+
local do_pretty_printing = not os.getenv("WINDIR");
local getstyle, getstring;
if do_pretty_printing then
@@ -37,12 +40,52 @@ end
local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas";
local _ENV = nil;
+-- luacheck: std none
-local stanza_mt = { __type = "stanza" };
+local stanza_mt = { __name = "stanza" };
stanza_mt.__index = stanza_mt;
-local function new_stanza(name, attr)
- local stanza = { name = name, attr = attr or {}, tags = {} };
+local function check_name(name, name_type)
+ if type(name) ~= "string" then
+ error("invalid "..name_type.." name: expected string, got "..type(name));
+ elseif #name == 0 then
+ error("invalid "..name_type.." name: empty string");
+ elseif s_find(name, "[<>& '\"]") then
+ error("invalid "..name_type.." name: contains invalid characters");
+ elseif not valid_utf8(name) then
+ error("invalid "..name_type.." name: contains invalid utf8");
+ end
+end
+
+local function check_text(text, text_type)
+ if type(text) ~= "string" then
+ error("invalid "..text_type.." value: expected string, got "..type(text));
+ elseif not valid_utf8(text) then
+ error("invalid "..text_type.." value: contains invalid utf8");
+ end
+end
+
+local function check_attr(attr)
+ if attr ~= nil then
+ if type(attr) ~= "table" then
+ error("invalid attributes, expected table got "..type(attr));
+ end
+ for k, v in pairs(attr) do
+ check_name(k, "attribute");
+ check_text(v, "attribute");
+ if type(v) ~= "string" then
+ error("invalid attribute value for '"..k.."': expected string, got "..type(v));
+ elseif not valid_utf8(v) then
+ error("invalid attribute value for '"..k.."': contains invalid utf8");
+ end
+ end
+ end
+end
+
+local function new_stanza(name, attr, namespaces)
+ check_name(name, "tag");
+ check_attr(attr);
+ local stanza = { name = name, attr = attr or {}, namespaces = namespaces, tags = {} };
return setmetatable(stanza, stanza_mt);
end
@@ -58,8 +101,12 @@ function stanza_mt:body(text, attr)
return self:tag("body", attr):text(text);
end
-function stanza_mt:tag(name, attrs)
- local s = new_stanza(name, attrs);
+function stanza_mt:text_tag(name, text, attr, namespaces)
+ return self:tag(name, attr, namespaces):text(text):up();
+end
+
+function stanza_mt:tag(name, attr, namespaces)
+ local s = new_stanza(name, attr, namespaces);
local last_add = self.last_add;
if not last_add then last_add = {}; self.last_add = last_add; end
(last_add[#last_add] or self):add_direct_child(s);
@@ -68,8 +115,11 @@ function stanza_mt:tag(name, attrs)
end
function stanza_mt:text(text)
- local last_add = self.last_add;
- (last_add and last_add[#last_add] or self):add_direct_child(text);
+ if text ~= nil and text ~= "" then
+ check_text(text, "text");
+ local last_add = self.last_add;
+ (last_add and last_add[#last_add] or self):add_direct_child(text);
+ end
return self;
end
@@ -337,7 +387,12 @@ end
local function clone(stanza)
local attr, tags = {}, {};
for k,v in pairs(stanza.attr) do attr[k] = v; end
- local new = { name = stanza.name, attr = attr, tags = tags };
+ local old_namespaces, namespaces = stanza.namespaces;
+ if old_namespaces then
+ namespaces = {};
+ for k,v in pairs(old_namespaces) do namespaces[k] = v; end
+ end
+ local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
for i=1,#stanza do
local child = stanza[i];
if child.name then
@@ -362,7 +417,13 @@ local function iq(attr)
end
local function reply(orig)
- return new_stanza(orig.name, orig.attr and { to = orig.attr.from, from = orig.attr.to, id = orig.attr.id, type = ((orig.name == "iq" and "result") or orig.attr.type) });
+ return new_stanza(orig.name,
+ orig.attr and {
+ to = orig.attr.from,
+ from = orig.attr.to,
+ id = orig.attr.id,
+ type = ((orig.name == "iq" and "result") or orig.attr.type)
+ });
end
local xmpp_stanzas_attr = { xmlns = xmlns_stanzas };
diff --git a/util/startup.lua b/util/startup.lua
new file mode 100644
index 00000000..69511bf9
--- /dev/null
+++ b/util/startup.lua
@@ -0,0 +1,543 @@
+-- Ignore the CFG_* variables
+-- luacheck: ignore 113/CFG_CONFIGDIR 113/CFG_SOURCEDIR 113/CFG_DATADIR 113/CFG_PLUGINDIR
+local startup = {};
+
+local prosody = { events = require "util.events".new() };
+
+local config = require "core.configmanager";
+
+local dependencies = require "util.dependencies";
+
+local original_logging_config;
+
+function startup.read_config()
+ local filenames = {};
+
+ local filename;
+ if arg[1] == "--config" and arg[2] then
+ table.insert(filenames, arg[2]);
+ if CFG_CONFIGDIR then
+ table.insert(filenames, CFG_CONFIGDIR.."/"..arg[2]);
+ end
+ table.remove(arg, 1); table.remove(arg, 1);
+ elseif os.getenv("PROSODY_CONFIG") then -- Passed by prosodyctl
+ table.insert(filenames, os.getenv("PROSODY_CONFIG"));
+ else
+ table.insert(filenames, (CFG_CONFIGDIR or ".").."/prosody.cfg.lua");
+ end
+ for _,_filename in ipairs(filenames) do
+ filename = _filename;
+ local file = io.open(filename);
+ if file then
+ file:close();
+ prosody.config_file = filename;
+ CFG_CONFIGDIR = filename:match("^(.*)[\\/][^\\/]*$"); -- luacheck: ignore 111
+ break;
+ end
+ end
+ prosody.config_file = filename
+ local ok, level, err = config.load(filename);
+ if not ok then
+ print("\n");
+ print("**************************");
+ if level == "parser" then
+ print("A problem occured while reading the config file "..filename);
+ print("");
+ local err_line, err_message = tostring(err):match("%[string .-%]:(%d*): (.*)");
+ if err:match("chunk has too many syntax levels$") then
+ print("An Include statement in a config file is including an already-included");
+ print("file and causing an infinite loop. An Include statement in a config file is...");
+ else
+ print("Error"..(err_line and (" on line "..err_line) or "")..": "..(err_message or tostring(err)));
+ end
+ print("");
+ elseif level == "file" then
+ print("Prosody was unable to find the configuration file.");
+ print("We looked for: "..filename);
+ print("A sample config file is included in the Prosody download called prosody.cfg.lua.dist");
+ print("Copy or rename it to prosody.cfg.lua and edit as necessary.");
+ end
+ print("More help on configuring Prosody can be found at https://prosody.im/doc/configure");
+ print("Good luck!");
+ print("**************************");
+ print("");
+ os.exit(1);
+ end
+end
+
+function startup.check_dependencies()
+ if not dependencies.check_dependencies() then
+ os.exit(1);
+ end
+end
+
+-- luacheck: globals socket server
+
+function startup.load_libraries()
+ -- Load socket framework
+ -- luacheck: ignore 111/server 111/socket
+ socket = require "socket";
+ server = require "net.server"
+end
+
+-- The global log() gets defined by loggingmanager
+-- luacheck: ignore 113/log
+
+function startup.init_logging()
+ -- Initialize logging
+ require "core.loggingmanager"
+end
+
+function startup.log_dependency_warnings()
+ dependencies.log_warnings();
+end
+
+function startup.sanity_check()
+ for host, host_config in pairs(config.getconfig()) do
+ if host ~= "*"
+ and host_config.enabled ~= false
+ and not host_config.component_module then
+ return;
+ end
+ end
+ log("error", "No enabled VirtualHost entries found in the config file.");
+ log("error", "At least one active host is required for Prosody to function. Exiting...");
+ os.exit(1);
+end
+
+function startup.sandbox_require()
+ -- Replace require() with one that doesn't pollute _G, required
+ -- for neat sandboxing of modules
+ -- luacheck: ignore 113/getfenv 111/require
+ local _realG = _G;
+ local _real_require = require;
+ local getfenv = getfenv or function (f)
+ -- FIXME: This is a hack to replace getfenv() in Lua 5.2
+ local name, env = debug.getupvalue(debug.getinfo(f or 1).func, 1);
+ if name == "_ENV" then
+ return env;
+ end
+ end
+ function require(...) -- luacheck: ignore 121
+ local curr_env = getfenv(2);
+ local curr_env_mt = getmetatable(curr_env);
+ local _realG_mt = getmetatable(_realG);
+ if curr_env_mt and curr_env_mt.__index and not curr_env_mt.__newindex and _realG_mt then
+ local old_newindex, old_index;
+ old_newindex, _realG_mt.__newindex = _realG_mt.__newindex, curr_env;
+ old_index, _realG_mt.__index = _realG_mt.__index, function (_G, k) -- luacheck: ignore 212/_G
+ return rawget(curr_env, k);
+ end;
+ local ret = _real_require(...);
+ _realG_mt.__newindex = old_newindex;
+ _realG_mt.__index = old_index;
+ return ret;
+ end
+ return _real_require(...);
+ end
+end
+
+function startup.set_function_metatable()
+ local mt = {};
+ function mt.__index(f, upvalue)
+ local i, name, value = 0;
+ repeat
+ i = i + 1;
+ name, value = debug.getupvalue(f, i);
+ until name == upvalue or name == nil;
+ return value;
+ end
+ function mt.__newindex(f, upvalue, value)
+ local i, name = 0;
+ repeat
+ i = i + 1;
+ name = debug.getupvalue(f, i);
+ until name == upvalue or name == nil;
+ if name then
+ debug.setupvalue(f, i, value);
+ end
+ end
+ function mt.__tostring(f)
+ local info = debug.getinfo(f);
+ return ("function(%s:%d)"):format(info.short_src:match("[^\\/]*$"), info.linedefined);
+ end
+ debug.setmetatable(function() end, mt);
+end
+
+function startup.detect_platform()
+ prosody.platform = "unknown";
+ if os.getenv("WINDIR") then
+ prosody.platform = "windows";
+ elseif package.config:sub(1,1) == "/" then
+ prosody.platform = "posix";
+ end
+end
+
+function startup.detect_installed()
+ prosody.installed = nil;
+ if CFG_SOURCEDIR and (prosody.platform == "windows" or CFG_SOURCEDIR:match("^/")) then
+ prosody.installed = true;
+ end
+end
+
+function startup.init_global_state()
+ -- luacheck: ignore 121
+ prosody.bare_sessions = {};
+ prosody.full_sessions = {};
+ prosody.hosts = {};
+
+ -- COMPAT: These globals are deprecated
+ -- luacheck: ignore 111/bare_sessions 111/full_sessions 111/hosts
+ bare_sessions = prosody.bare_sessions;
+ full_sessions = prosody.full_sessions;
+ hosts = prosody.hosts;
+
+ prosody.paths = { source = CFG_SOURCEDIR, config = CFG_CONFIGDIR or ".",
+ plugins = CFG_PLUGINDIR or "plugins", data = "data" };
+
+ prosody.arg = _G.arg;
+
+ startup.detect_platform();
+ startup.detect_installed();
+ _G.prosody = prosody;
+end
+
+function startup.setup_datadir()
+ prosody.paths.data = config.get("*", "data_path") or CFG_DATADIR or "data";
+end
+
+function startup.setup_plugindir()
+ local custom_plugin_paths = config.get("*", "plugin_paths");
+ if custom_plugin_paths then
+ local path_sep = package.config:sub(3,3);
+ -- path1;path2;path3;defaultpath...
+ -- luacheck: ignore 111
+ CFG_PLUGINDIR = table.concat(custom_plugin_paths, path_sep)..path_sep..(CFG_PLUGINDIR or "plugins");
+ end
+ prosody.paths.plugins = CFG_PLUGINDIR;
+end
+
+function startup.chdir()
+ if prosody.installed then
+ -- Change working directory to data path.
+ require "lfs".chdir(prosody.paths.data);
+ end
+end
+
+function startup.add_global_prosody_functions()
+ -- Function to reload the config file
+ function prosody.reload_config()
+ log("info", "Reloading configuration file");
+ prosody.events.fire_event("reloading-config");
+ local ok, level, err = config.load(prosody.config_file);
+ if not ok then
+ if level == "parser" then
+ log("error", "There was an error parsing the configuration file: %s", tostring(err));
+ elseif level == "file" then
+ log("error", "Couldn't read the config file when trying to reload: %s", tostring(err));
+ end
+ else
+ prosody.events.fire_event("config-reloaded", {
+ filename = prosody.config_file,
+ config = config.getconfig(),
+ });
+ end
+ return ok, (err and tostring(level)..": "..tostring(err)) or nil;
+ end
+
+ -- Function to reopen logfiles
+ function prosody.reopen_logfiles()
+ log("info", "Re-opening log files");
+ prosody.events.fire_event("reopen-log-files");
+ end
+
+ -- Function to initiate prosody shutdown
+ function prosody.shutdown(reason, code)
+ log("info", "Shutting down: %s", reason or "unknown reason");
+ prosody.shutdown_reason = reason;
+ prosody.shutdown_code = code;
+ prosody.events.fire_event("server-stopping", {
+ reason = reason;
+ code = code;
+ });
+ server.setquitting(true);
+ end
+end
+
+function startup.load_secondary_libraries()
+ --- Load and initialise core modules
+ require "util.import"
+ require "util.xmppstream"
+ require "core.stanza_router"
+ require "core.statsmanager"
+ require "core.hostmanager"
+ require "core.portmanager"
+ require "core.modulemanager"
+ require "core.usermanager"
+ require "core.rostermanager"
+ require "core.sessionmanager"
+ package.loaded['core.componentmanager'] = setmetatable({},{__index=function()
+ log("warn", "componentmanager is deprecated: %s", debug.traceback():match("\n[^\n]*\n[ \t]*([^\n]*)"));
+ return function() end
+ end});
+
+ require "util.array"
+ require "util.datetime"
+ require "util.iterators"
+ require "util.timer"
+ require "util.helpers"
+
+ pcall(require, "util.signal") -- Not on Windows
+
+ -- Commented to protect us from
+ -- the second kind of people
+ --[[
+ pcall(require, "remdebug.engine");
+ if remdebug then remdebug.engine.start() end
+ ]]
+
+ require "util.stanza"
+ require "util.jid"
+end
+
+function startup.init_http_client()
+ local http = require "net.http"
+ local config_ssl = config.get("*", "ssl") or {}
+ local https_client = config.get("*", "client_https_ssl")
+ http.default.options.sslctx = require "core.certmanager".create_context("client_https port 0", "client",
+ { capath = config_ssl.capath, cafile = config_ssl.cafile, verify = "peer", }, https_client);
+end
+
+function startup.init_data_store()
+ require "core.storagemanager";
+end
+
+function startup.prepare_to_start()
+ log("info", "Prosody is using the %s backend for connection handling", server.get_backend());
+ -- Signal to modules that we are ready to start
+ prosody.events.fire_event("server-starting");
+ prosody.start_time = os.time();
+end
+
+function startup.init_global_protection()
+ -- Catch global accesses
+ -- luacheck: ignore 212/t
+ local locked_globals_mt = {
+ __index = function (t, k) log("warn", "%s", debug.traceback("Attempt to read a non-existent global '"..tostring(k).."'", 2)); end;
+ __newindex = function (t, k, v) error("Attempt to set a global: "..tostring(k).." = "..tostring(v), 2); end;
+ };
+
+ function prosody.unlock_globals()
+ setmetatable(_G, nil);
+ end
+
+ function prosody.lock_globals()
+ setmetatable(_G, locked_globals_mt);
+ end
+
+ -- And lock now...
+ prosody.lock_globals();
+end
+
+function startup.read_version()
+ -- Try to determine version
+ local version_file = io.open((CFG_SOURCEDIR or ".").."/prosody.version");
+ prosody.version = "unknown";
+ if version_file then
+ prosody.version = version_file:read("*a"):gsub("%s*$", "");
+ version_file:close();
+ if #prosody.version == 12 and prosody.version:match("^[a-f0-9]+$") then
+ prosody.version = "hg:"..prosody.version;
+ end
+ else
+ local hg = require"util.mercurial";
+ local hgid = hg.check_id(CFG_SOURCEDIR or ".");
+ if hgid then prosody.version = "hg:" .. hgid; end
+ end
+end
+
+function startup.log_greeting()
+ log("info", "Hello and welcome to Prosody version %s", prosody.version);
+end
+
+function startup.notify_started()
+ prosody.events.fire_event("server-started");
+end
+
+-- Override logging config (used by prosodyctl)
+function startup.force_console_logging()
+ original_logging_config = config.get("*", "log");
+ config.set("*", "log", { { levels = { min="info" }, to = "console" } });
+end
+
+function startup.switch_user()
+ -- Switch away from root and into the prosody user --
+ -- NOTE: This function is only used by prosodyctl.
+ -- The prosody process is built with the assumption that
+ -- it is already started as the appropriate user.
+
+ local want_pposix_version = "0.4.0";
+ local have_pposix, pposix = pcall(require, "util.pposix");
+
+ if have_pposix and pposix then
+ if pposix._VERSION ~= want_pposix_version then
+ print(string.format("Unknown version (%s) of binary pposix module, expected %s",
+ tostring(pposix._VERSION), want_pposix_version));
+ os.exit(1);
+ end
+ prosody.current_uid = pposix.getuid();
+ local arg_root = arg[1] == "--root";
+ if arg_root then table.remove(arg, 1); end
+ if prosody.current_uid == 0 and config.get("*", "run_as_root") ~= true and not arg_root then
+ -- We haz root!
+ local desired_user = config.get("*", "prosody_user") or "prosody";
+ local desired_group = config.get("*", "prosody_group") or desired_user;
+ local ok, err = pposix.setgid(desired_group);
+ if ok then
+ ok, err = pposix.initgroups(desired_user);
+ end
+ if ok then
+ ok, err = pposix.setuid(desired_user);
+ if ok then
+ -- Yay!
+ prosody.switched_user = true;
+ end
+ end
+ if not prosody.switched_user then
+ -- Boo!
+ print("Warning: Couldn't switch to Prosody user/group '"..tostring(desired_user).."'/'"..tostring(desired_group).."': "..tostring(err));
+ else
+ -- Make sure the Prosody user can read the config
+ local conf, err, errno = io.open(prosody.config_file);
+ if conf then
+ conf:close();
+ else
+ print("The config file is not readable by the '"..desired_user.."' user.");
+ print("Prosody will not be able to read it.");
+ print("Error was "..err);
+ os.exit(1);
+ end
+ end
+ end
+
+ -- Set our umask to protect data files
+ pposix.umask(config.get("*", "umask") or "027");
+ pposix.setenv("HOME", prosody.paths.data);
+ pposix.setenv("PROSODY_CONFIG", prosody.config_file);
+ else
+ print("Error: Unable to load pposix module. Check that Prosody is installed correctly.")
+ print("For more help send the below error to us through https://prosody.im/discuss");
+ print(tostring(pposix))
+ os.exit(1);
+ end
+end
+
+function startup.check_unwriteable()
+ local function test_writeable(filename)
+ local f, err = io.open(filename, "a");
+ if not f then
+ return false, err;
+ end
+ f:close();
+ return true;
+ end
+
+ local unwriteable_files = {};
+ if type(original_logging_config) == "string" and original_logging_config:sub(1,1) ~= "*" then
+ local ok, err = test_writeable(original_logging_config);
+ if not ok then
+ table.insert(unwriteable_files, err);
+ end
+ elseif type(original_logging_config) == "table" then
+ for _, rule in ipairs(original_logging_config) do
+ if rule.filename then
+ local ok, err = test_writeable(rule.filename);
+ if not ok then
+ table.insert(unwriteable_files, err);
+ end
+ end
+ end
+ end
+
+ if #unwriteable_files > 0 then
+ print("One of more of the Prosody log files are not");
+ print("writeable, please correct the errors and try");
+ print("starting prosodyctl again.");
+ print("");
+ for _, err in ipairs(unwriteable_files) do
+ print(err);
+ end
+ print("");
+ os.exit(1);
+ end
+end
+
+function startup.make_host(hostname)
+ return {
+ type = "local",
+ events = prosody.events,
+ modules = {},
+ sessions = {},
+ users = require "core.usermanager".new_null_provider(hostname)
+ };
+end
+
+function startup.make_dummy_hosts()
+ -- When running under prosodyctl, we don't want to
+ -- fully initialize the server, so we populate prosody.hosts
+ -- with just enough things for most code to work correctly
+ -- luacheck: ignore 122/hosts
+ prosody.core_post_stanza = function () end; -- TODO: mod_router!
+
+ for hostname in pairs(config.getconfig()) do
+ hosts[hostname] = startup.make_host(hostname);
+ end
+end
+
+-- prosodyctl only
+function startup.prosodyctl()
+ startup.init_global_state();
+ startup.read_config();
+ startup.setup_plugindir();
+ startup.setup_datadir();
+ startup.chdir();
+ startup.read_version();
+ startup.switch_user();
+ startup.check_dependencies();
+ startup.force_console_logging();
+ startup.init_logging();
+ startup.log_dependency_warnings();
+ startup.check_unwriteable();
+ startup.load_libraries();
+ startup.init_global_protection();
+ startup.init_http_client();
+ startup.make_dummy_hosts();
+end
+
+function startup.prosody()
+ -- These actions are in a strict order, as many depend on
+ -- previous steps to have already been performed
+ startup.init_global_state();
+ startup.read_config();
+ startup.sanity_check();
+ startup.sandbox_require();
+ startup.set_function_metatable();
+ startup.check_dependencies();
+ startup.load_libraries();
+ startup.setup_plugindir();
+ startup.setup_datadir();
+ startup.init_logging();
+ startup.chdir();
+ startup.add_global_prosody_functions();
+ startup.read_version();
+ startup.log_greeting();
+ startup.log_dependency_warnings();
+ startup.load_secondary_libraries();
+ startup.init_http_client();
+ startup.init_data_store();
+ startup.init_global_protection();
+ startup.prepare_to_start();
+ startup.notify_started();
+end
+
+return startup;
diff --git a/util/template.lua b/util/template.lua
index 04ebb93d..c11037c5 100644
--- a/util/template.lua
+++ b/util/template.lua
@@ -4,12 +4,13 @@ local setmetatable = setmetatable;
local pairs = pairs;
local ipairs = ipairs;
local error = error;
-local loadstring = loadstring;
+local envload = require "util.envload".envload;
local debug = debug;
local t_remove = table.remove;
local parse_xml = require "util.xml".parse;
local _ENV = nil;
+-- luacheck: std none
local function trim_xml(stanza)
for i=#stanza,1,-1 do
@@ -72,7 +73,7 @@ local function create_cloner(stanza, chunkname)
src = src.."local _"..i.."="..lookup[i]..";";
end
src = src.."return "..name..";end";
- local f,err = loadstring(src, chunkname);
+ local f,err = envload(src, chunkname);
if not f then error(err); end
return f(setmetatable, stanza_mt);
end
diff --git a/util/termcolours.lua b/util/termcolours.lua
index 23c9156b..829d84af 100644
--- a/util/termcolours.lua
+++ b/util/termcolours.lua
@@ -26,6 +26,7 @@ end
local orig_color = windows and windows.get_consolecolor and windows.get_consolecolor();
local _ENV = nil;
+-- luacheck: std none
local stylemap = {
reset = 0; bright = 1, dim = 2, underscore = 4, blink = 5, reverse = 7, hidden = 8;
diff --git a/util/throttle.lua b/util/throttle.lua
index 1012f78a..d2036e9e 100644
--- a/util/throttle.lua
+++ b/util/throttle.lua
@@ -3,6 +3,7 @@ local gettime = require "util.time".now
local setmetatable = setmetatable;
local _ENV = nil;
+-- luacheck: std none
local throttle = {};
local throttle_mt = { __index = throttle };
diff --git a/util/timer.lua b/util/timer.lua
index 7e2e9414..70678ab2 100644
--- a/util/timer.lua
+++ b/util/timer.lua
@@ -6,78 +6,113 @@
-- COPYING file in the source package for more information.
--
+local indexedbheap = require "util.indexedbheap";
+local log = require "util.logger".init("timer");
local server = require "net.server";
-local math_min = math.min
-local math_huge = math.huge
local get_time = require "util.time".now
-local t_insert = table.insert;
-local pairs = pairs;
+local async = require "util.async";
local type = type;
-
-local data = {};
-local new_data = {};
+local debug_traceback = debug.traceback;
+local tostring = tostring;
+local xpcall = xpcall;
local _ENV = nil;
+-- luacheck: std none
-local _add_task;
-if not server.event then
- function _add_task(delay, callback)
- local current_time = get_time();
- delay = delay + current_time;
- if delay >= current_time then
- t_insert(new_data, {delay, callback});
- else
- local r = callback(current_time);
- if r and type(r) == "number" then
- return _add_task(r, callback);
- end
+local _add_task = server.add_task;
+
+local _server_timer;
+local _active_timers = 0;
+local h = indexedbheap.create();
+local params = {};
+local next_time = nil;
+local _id, _callback, _now, _param;
+local function _call() return _callback(_now, _id, _param); end
+local function _traceback_handler(err) log("error", "Traceback[timer]: %s", debug_traceback(tostring(err), 2)); end
+local function _on_timer(now)
+ local peek;
+ while true do
+ peek = h:peek();
+ if peek == nil or peek > now then break; end
+ local _;
+ _, _callback, _id = h:pop();
+ _now = now;
+ _param = params[_id];
+ params[_id] = nil;
+ --item(now, id, _param); -- FIXME pcall
+ local success, err = xpcall(_call, _traceback_handler);
+ if success and type(err) == "number" then
+ h:insert(_callback, err + now, _id); -- re-add
+ params[_id] = _param;
end
end
- server._addtimer(function()
- local current_time = get_time();
- if #new_data > 0 then
- for _, d in pairs(new_data) do
- t_insert(data, d);
- end
- new_data = {};
- end
+ if peek ~= nil and _active_timers > 1 and peek == next_time then
+ -- Another instance of _on_timer already set next_time to the same value,
+ -- so it should be safe to not renew this timer event
+ peek = nil;
+ else
+ next_time = peek;
+ end
- local next_time = math_huge;
- for i, d in pairs(data) do
- local t, callback = d[1], d[2];
- if t <= current_time then
- data[i] = nil;
- local r = callback(current_time);
- if type(r) == "number" then
- _add_task(r, callback);
- next_time = math_min(next_time, r);
- end
- else
- next_time = math_min(next_time, t - current_time);
- end
- end
- return next_time;
- end);
-else
- local event = server.event;
- local event_base = server.event_base;
- local EVENT_LEAVE = (event.core and event.core.LEAVE) or -1;
+ if peek then
+ -- peek is the time of the next event
+ return peek - now;
+ end
+ _active_timers = _active_timers - 1;
+end
+local function add_task(delay, callback, param)
+ local current_time = get_time();
+ local event_time = current_time + delay;
- function _add_task(delay, callback)
- local event_handle;
- event_handle = event_base:addevent(nil, 0, function ()
- local ret = callback(get_time());
- if ret then
- return 0, ret;
- elseif event_handle then
- return EVENT_LEAVE;
- end
+ local id = h:insert(callback, event_time);
+ params[id] = param;
+ if next_time == nil or event_time < next_time then
+ next_time = event_time;
+ if _server_timer then
+ _server_timer:close();
+ _server_timer = nil;
+ else
+ _active_timers = _active_timers + 1;
end
- , delay);
+ _server_timer = _add_task(next_time - current_time, _on_timer);
end
+ return id;
+end
+local function stop(id)
+ params[id] = nil;
+ local result, item, result_sync = h:remove(id);
+ local peek = h:peek();
+ if peek ~= next_time and _server_timer then
+ next_time = peek;
+ _server_timer:close();
+ if next_time ~= nil then
+ _server_timer = _add_task(next_time - get_time(), _on_timer);
+ end
+ end
+ return result, item, result_sync;
+end
+local function reschedule(id, delay)
+ local current_time = get_time();
+ local event_time = current_time + delay;
+ h:reprioritize(id, delay);
+ if next_time == nil or event_time < next_time then
+ next_time = event_time;
+ _add_task(next_time - current_time, _on_timer);
+ end
+ return id;
+end
+
+local function sleep(s)
+ local wait, done = async.waiter();
+ add_task(s, done);
+ wait();
end
return {
- add_task = _add_task;
+ add_task = add_task;
+ stop = stop;
+ reschedule = reschedule;
+ sleep = sleep;
};
+
diff --git a/util/vcard.lua b/util/vcard.lua
new file mode 100644
index 00000000..51758c41
--- /dev/null
+++ b/util/vcard.lua
@@ -0,0 +1,572 @@
+-- Copyright (C) 2011-2014 Kim Alvefur
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+-- TODO
+-- Fix folding.
+
+local st = require "util.stanza";
+local t_insert, t_concat = table.insert, table.concat;
+local type = type;
+local pairs, ipairs = pairs, ipairs;
+
+local from_text, to_text, from_xep54, to_xep54;
+
+local line_sep = "\n";
+
+local vCard_dtd; -- See end of file
+local vCard4_dtd;
+
+local function vCard_esc(s)
+ return s:gsub("[,:;\\]", "\\%1"):gsub("\n","\\n");
+end
+
+local function vCard_unesc(s)
+ return s:gsub("\\?[\\nt:;,]", {
+ ["\\\\"] = "\\",
+ ["\\n"] = "\n",
+ ["\\r"] = "\r",
+ ["\\t"] = "\t",
+ ["\\:"] = ":", -- FIXME Shouldn't need to espace : in values, just params
+ ["\\;"] = ";",
+ ["\\,"] = ",",
+ [":"] = "\29",
+ [";"] = "\30",
+ [","] = "\31",
+ });
+end
+
+local function item_to_xep54(item)
+ local t = st.stanza(item.name, { xmlns = "vcard-temp" });
+
+ local prop_def = vCard_dtd[item.name];
+ if prop_def == "text" then
+ t:text(item[1]);
+ elseif type(prop_def) == "table" then
+ if prop_def.types and item.TYPE then
+ if type(item.TYPE) == "table" then
+ for _,v in pairs(prop_def.types) do
+ for _,typ in pairs(item.TYPE) do
+ if typ:upper() == v then
+ t:tag(v):up();
+ break;
+ end
+ end
+ end
+ else
+ t:tag(item.TYPE:upper()):up();
+ end
+ end
+
+ if prop_def.props then
+ for _,v in pairs(prop_def.props) do
+ if item[v] then
+ t:tag(v):up();
+ end
+ end
+ end
+
+ if prop_def.value then
+ t:tag(prop_def.value):text(item[1]):up();
+ elseif prop_def.values then
+ local prop_def_values = prop_def.values;
+ local repeat_last = prop_def_values.behaviour == "repeat-last" and prop_def_values[#prop_def_values];
+ for i=1,#item do
+ t:tag(prop_def.values[i] or repeat_last):text(item[i]):up();
+ end
+ end
+ end
+
+ return t;
+end
+
+local function vcard_to_xep54(vCard)
+ local t = st.stanza("vCard", { xmlns = "vcard-temp" });
+ for i=1,#vCard do
+ t:add_child(item_to_xep54(vCard[i]));
+ end
+ return t;
+end
+
+function to_xep54(vCards)
+ if not vCards[1] or vCards[1].name then
+ return vcard_to_xep54(vCards)
+ else
+ local t = st.stanza("xCard", { xmlns = "vcard-temp" });
+ for i=1,#vCards do
+ t:add_child(vcard_to_xep54(vCards[i]));
+ end
+ return t;
+ end
+end
+
+function from_text(data)
+ data = data -- unfold and remove empty lines
+ :gsub("\r\n","\n")
+ :gsub("\n ", "")
+ :gsub("\n\n+","\n");
+ local vCards = {};
+ local current;
+ for line in data:gmatch("[^\n]+") do
+ line = vCard_unesc(line);
+ local name, params, value = line:match("^([-%a]+)(\30?[^\29]*)\29(.*)$");
+ value = value:gsub("\29",":");
+ if #params > 0 then
+ local _params = {};
+ for k,isval,v in params:gmatch("\30([^=]+)(=?)([^\30]*)") do
+ k = k:upper();
+ local _vt = {};
+ for _p in v:gmatch("[^\31]+") do
+ _vt[#_vt+1]=_p
+ _vt[_p]=true;
+ end
+ if isval == "=" then
+ _params[k]=_vt;
+ else
+ _params[k]=true;
+ end
+ end
+ params = _params;
+ end
+ if name == "BEGIN" and value == "VCARD" then
+ current = {};
+ vCards[#vCards+1] = current;
+ elseif name == "END" and value == "VCARD" then
+ current = nil;
+ elseif current and vCard_dtd[name] then
+ local dtd = vCard_dtd[name];
+ local item = { name = name };
+ t_insert(current, item);
+ local up = current;
+ current = item;
+ if dtd.types then
+ for _, t in ipairs(dtd.types) do
+ t = t:lower();
+ if ( params.TYPE and params.TYPE[t] == true)
+ or params[t] == true then
+ current.TYPE=t;
+ end
+ end
+ end
+ if dtd.props then
+ for _, p in ipairs(dtd.props) do
+ if params[p] then
+ if params[p] == true then
+ current[p]=true;
+ else
+ for _, prop in ipairs(params[p]) do
+ current[p]=prop;
+ end
+ end
+ end
+ end
+ end
+ if dtd == "text" or dtd.value then
+ t_insert(current, value);
+ elseif dtd.values then
+ for p in ("\30"..value):gmatch("\30([^\30]*)") do
+ t_insert(current, p);
+ end
+ end
+ current = up;
+ end
+ end
+ return vCards;
+end
+
+local function item_to_text(item)
+ local value = {};
+ for i=1,#item do
+ value[i] = vCard_esc(item[i]);
+ end
+ value = t_concat(value, ";");
+
+ local params = "";
+ for k,v in pairs(item) do
+ if type(k) == "string" and k ~= "name" then
+ params = params .. (";%s=%s"):format(k, type(v) == "table" and t_concat(v,",") or v);
+ end
+ end
+
+ return ("%s%s:%s"):format(item.name, params, value)
+end
+
+local function vcard_to_text(vcard)
+ local t={};
+ t_insert(t, "BEGIN:VCARD")
+ for i=1,#vcard do
+ t_insert(t, item_to_text(vcard[i]));
+ end
+ t_insert(t, "END:VCARD")
+ return t_concat(t, line_sep);
+end
+
+function to_text(vCards)
+ if vCards[1] and vCards[1].name then
+ return vcard_to_text(vCards)
+ else
+ local t = {};
+ for i=1,#vCards do
+ t[i]=vcard_to_text(vCards[i]);
+ end
+ return t_concat(t, line_sep);
+ end
+end
+
+local function from_xep54_item(item)
+ local prop_name = item.name;
+ local prop_def = vCard_dtd[prop_name];
+
+ local prop = { name = prop_name };
+
+ if prop_def == "text" then
+ prop[1] = item:get_text();
+ elseif type(prop_def) == "table" then
+ if prop_def.value then --single item
+ prop[1] = item:get_child_text(prop_def.value) or "";
+ elseif prop_def.values then --array
+ local value_names = prop_def.values;
+ if value_names.behaviour == "repeat-last" then
+ for i=1,#item.tags do
+ t_insert(prop, item.tags[i]:get_text() or "");
+ end
+ else
+ for i=1,#value_names do
+ t_insert(prop, item:get_child_text(value_names[i]) or "");
+ end
+ end
+ elseif prop_def.names then
+ local names = prop_def.names;
+ for i=1,#names do
+ if item:get_child(names[i]) then
+ prop[1] = names[i];
+ break;
+ end
+ end
+ end
+
+ if prop_def.props_verbatim then
+ for k,v in pairs(prop_def.props_verbatim) do
+ prop[k] = v;
+ end
+ end
+
+ if prop_def.types then
+ local types = prop_def.types;
+ prop.TYPE = {};
+ for i=1,#types do
+ if item:get_child(types[i]) then
+ t_insert(prop.TYPE, types[i]:lower());
+ end
+ end
+ if #prop.TYPE == 0 then
+ prop.TYPE = nil;
+ end
+ end
+
+ -- A key-value pair, within a key-value pair?
+ if prop_def.props then
+ local params = prop_def.props;
+ for i=1,#params do
+ local name = params[i]
+ local data = item:get_child_text(name);
+ if data then
+ prop[name] = prop[name] or {};
+ t_insert(prop[name], data);
+ end
+ end
+ end
+ else
+ return nil
+ end
+
+ return prop;
+end
+
+local function from_xep54_vCard(vCard)
+ local tags = vCard.tags;
+ local t = {};
+ for i=1,#tags do
+ t_insert(t, from_xep54_item(tags[i]));
+ end
+ return t
+end
+
+function from_xep54(vCard)
+ if vCard.attr.xmlns ~= "vcard-temp" then
+ return nil, "wrong-xmlns";
+ end
+ if vCard.name == "xCard" then -- A collection of vCards
+ local t = {};
+ local vCards = vCard.tags;
+ for i=1,#vCards do
+ t[i] = from_xep54_vCard(vCards[i]);
+ end
+ return t
+ elseif vCard.name == "vCard" then -- A single vCard
+ return from_xep54_vCard(vCard)
+ end
+end
+
+local vcard4 = { }
+
+function vcard4:text(node, params, value) -- luacheck: ignore 212/params
+ self:tag(node:lower())
+ -- FIXME params
+ if type(value) == "string" then
+ self:tag("text"):text(value):up()
+ elseif vcard4[node] then
+ vcard4[node](value);
+ end
+ self:up();
+end
+
+function vcard4.N(value)
+ for i, k in ipairs(vCard_dtd.N.values) do
+ value:tag(k):text(value[i]):up();
+ end
+end
+
+local xmlns_vcard4 = "urn:ietf:params:xml:ns:vcard-4.0"
+
+local function item_to_vcard4(item)
+ local typ = item.name:lower();
+ local t = st.stanza(typ, { xmlns = xmlns_vcard4 });
+
+ local prop_def = vCard4_dtd[typ];
+ if prop_def == "text" then
+ t:tag("text"):text(item[1]):up();
+ elseif prop_def == "uri" then
+ if item.ENCODING and item.ENCODING[1] == 'b' then
+ t:tag("uri"):text("data:;base64,"):text(item[1]):up();
+ else
+ t:tag("uri"):text(item[1]):up();
+ end
+ elseif type(prop_def) == "table" then
+ if prop_def.values then
+ for i, v in ipairs(prop_def.values) do
+ t:tag(v:lower()):text(item[i] or ""):up();
+ end
+ else
+ t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"})
+ end
+ else
+ t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"})
+ end
+ return t;
+end
+
+local function vcard_to_vcard4xml(vCard)
+ local t = st.stanza("vcard", { xmlns = xmlns_vcard4 });
+ for i=1,#vCard do
+ t:add_child(item_to_vcard4(vCard[i]));
+ end
+ return t;
+end
+
+local function vcards_to_vcard4xml(vCards)
+ if not vCards[1] or vCards[1].name then
+ return vcard_to_vcard4xml(vCards)
+ else
+ local t = st.stanza("vcards", { xmlns = xmlns_vcard4 });
+ for i=1,#vCards do
+ t:add_child(vcard_to_vcard4xml(vCards[i]));
+ end
+ return t;
+ end
+end
+
+-- This was adapted from http://xmpp.org/extensions/xep-0054.html#dtd
+vCard_dtd = {
+ VERSION = "text", --MUST be 3.0, so parsing is redundant
+ FN = "text",
+ N = {
+ values = {
+ "FAMILY",
+ "GIVEN",
+ "MIDDLE",
+ "PREFIX",
+ "SUFFIX",
+ },
+ },
+ NICKNAME = "text",
+ PHOTO = {
+ props_verbatim = { ENCODING = { "b" } },
+ props = { "TYPE" },
+ value = "BINVAL", --{ "EXTVAL", },
+ },
+ BDAY = "text",
+ ADR = {
+ types = {
+ "HOME",
+ "WORK",
+ "POSTAL",
+ "PARCEL",
+ "DOM",
+ "INTL",
+ "PREF",
+ },
+ values = {
+ "POBOX",
+ "EXTADD",
+ "STREET",
+ "LOCALITY",
+ "REGION",
+ "PCODE",
+ "CTRY",
+ }
+ },
+ LABEL = {
+ types = {
+ "HOME",
+ "WORK",
+ "POSTAL",
+ "PARCEL",
+ "DOM",
+ "INTL",
+ "PREF",
+ },
+ value = "LINE",
+ },
+ TEL = {
+ types = {
+ "HOME",
+ "WORK",
+ "VOICE",
+ "FAX",
+ "PAGER",
+ "MSG",
+ "CELL",
+ "VIDEO",
+ "BBS",
+ "MODEM",
+ "ISDN",
+ "PCS",
+ "PREF",
+ },
+ value = "NUMBER",
+ },
+ EMAIL = {
+ types = {
+ "HOME",
+ "WORK",
+ "INTERNET",
+ "PREF",
+ "X400",
+ },
+ value = "USERID",
+ },
+ JABBERID = "text",
+ MAILER = "text",
+ TZ = "text",
+ GEO = {
+ values = {
+ "LAT",
+ "LON",
+ },
+ },
+ TITLE = "text",
+ ROLE = "text",
+ LOGO = "copy of PHOTO",
+ AGENT = "text",
+ ORG = {
+ values = {
+ behaviour = "repeat-last",
+ "ORGNAME",
+ "ORGUNIT",
+ }
+ },
+ CATEGORIES = {
+ values = "KEYWORD",
+ },
+ NOTE = "text",
+ PRODID = "text",
+ REV = "text",
+ SORTSTRING = "text",
+ SOUND = "copy of PHOTO",
+ UID = "text",
+ URL = "text",
+ CLASS = {
+ names = { -- The item.name is the value if it's one of these.
+ "PUBLIC",
+ "PRIVATE",
+ "CONFIDENTIAL",
+ },
+ },
+ KEY = {
+ props = { "TYPE" },
+ value = "CRED",
+ },
+ DESC = "text",
+};
+vCard_dtd.LOGO = vCard_dtd.PHOTO;
+vCard_dtd.SOUND = vCard_dtd.PHOTO;
+
+vCard4_dtd = {
+ source = "uri",
+ kind = "text",
+ xml = "text",
+ fn = "text",
+ n = {
+ values = {
+ "family",
+ "given",
+ "middle",
+ "prefix",
+ "suffix",
+ },
+ },
+ nickname = "text",
+ photo = "uri",
+ bday = "date-and-or-time",
+ anniversary = "date-and-or-time",
+ gender = "text",
+ adr = {
+ values = {
+ "pobox",
+ "ext",
+ "street",
+ "locality",
+ "region",
+ "code",
+ "country",
+ }
+ },
+ tel = "text",
+ email = "text",
+ impp = "uri",
+ lang = "language-tag",
+ tz = "text",
+ geo = "uri",
+ title = "text",
+ role = "text",
+ logo = "uri",
+ org = "text",
+ member = "uri",
+ related = "uri",
+ categories = "text",
+ note = "text",
+ prodid = "text",
+ rev = "timestamp",
+ sound = "uri",
+ uid = "uri",
+ clientpidmap = "number, uuid",
+ url = "uri",
+ version = "text",
+ key = "uri",
+ fburl = "uri",
+ caladruri = "uri",
+ caluri = "uri",
+};
+
+return {
+ from_text = from_text;
+ to_text = to_text;
+
+ from_xep54 = from_xep54;
+ to_xep54 = to_xep54;
+
+ to_vcard4 = vcards_to_vcard4xml;
+};
diff --git a/util/watchdog.lua b/util/watchdog.lua
index aa8c6486..516e60e4 100644
--- a/util/watchdog.lua
+++ b/util/watchdog.lua
@@ -3,6 +3,7 @@ local setmetatable = setmetatable;
local os_time = os.time;
local _ENV = nil;
+-- luacheck: std none
local watchdog_methods = {};
local watchdog_mt = { __index = watchdog_methods };
diff --git a/util/x509.lua b/util/x509.lua
index f228b201..15cc4d3c 100644
--- a/util/x509.lua
+++ b/util/x509.lua
@@ -25,6 +25,7 @@ local log = require "util.logger".init("x509");
local s_format = string.format;
local _ENV = nil;
+-- luacheck: std none
local oid_commonname = "2.5.4.3"; -- [LDAP] 2.3
local oid_subjectaltname = "2.5.29.17"; -- [PKIX] 4.2.1.6
diff --git a/util/xml.lua b/util/xml.lua
index 733d821a..dac3f6fe 100644
--- a/util/xml.lua
+++ b/util/xml.lua
@@ -1,8 +1,11 @@
local st = require "util.stanza";
local lxp = require "lxp";
+local t_insert = table.insert;
+local t_remove = table.remove;
local _ENV = nil;
+-- luacheck: std none
local parse_xml = (function()
local ns_prefixes = {
@@ -14,6 +17,21 @@ local parse_xml = (function()
--luacheck: ignore 212/self
local handler = {};
local stanza = st.stanza("root");
+ local namespaces = {};
+ local prefixes = {};
+ function handler:StartNamespaceDecl(prefix, url)
+ if prefix ~= nil then
+ t_insert(namespaces, url);
+ t_insert(prefixes, prefix);
+ end
+ end
+ function handler:EndNamespaceDecl(prefix)
+ if prefix ~= nil then
+ -- we depend on each StartNamespaceDecl having a paired EndNamespaceDecl
+ t_remove(namespaces);
+ t_remove(prefixes);
+ end
+ end
function handler:StartElement(tagname, attr)
local curr_ns,name = tagname:match(ns_pattern);
if name == "" then
@@ -34,7 +52,11 @@ local parse_xml = (function()
end
end
end
- stanza:tag(name, attr);
+ local n = {}
+ for i=1,#namespaces do
+ n[prefixes[i]] = namespaces[i];
+ end
+ stanza:tag(name, attr, n);
end
function handler:CharacterData(data)
stanza:text(data);
diff --git a/util/xmppstream.lua b/util/xmppstream.lua
index 7be63285..8c7851a5 100644
--- a/util/xmppstream.lua
+++ b/util/xmppstream.lua
@@ -25,6 +25,7 @@ local lxp_supports_bytecount = not not lxp.new({}).getcurrentbytecount;
local default_stanza_size_limit = 1024*1024*10; -- 10MB
local _ENV = nil;
+-- luacheck: std none
local new_parser = lxp.new;
@@ -47,7 +48,10 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
local cb_streamopened = stream_callbacks.streamopened;
local cb_streamclosed = stream_callbacks.streamclosed;
- local cb_error = stream_callbacks.error or function(session, e, stanza) error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); end;
+ local cb_error = stream_callbacks.error or
+ function(_, e, stanza)
+ error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2);
+ end;
local cb_handlestanza = stream_callbacks.handlestanza;
cb_handleprogress = cb_handleprogress or dummy_cb;
@@ -128,6 +132,9 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
end
if lxp_supports_xmldecl then
function xml_handlers:XmlDecl(version, encoding, standalone)
+ session.xml_version = version;
+ session.xml_encoding = encoding;
+ session.xml_standalone = standalone;
if lxp_supports_bytecount then
cb_handleprogress(self:getcurrentbytecount());
end
@@ -214,7 +221,7 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
stack = {};
end
- local function set_session(stream, new_session)
+ local function set_session(stream, new_session) -- luacheck: ignore 212/stream
session = new_session;
end
@@ -238,7 +245,7 @@ local function new(session, stream_callbacks, stanza_size_limit)
local parser = new_parser(handlers, ns_separator, false);
local parse = parser.parse;
- function session.open_stream(session, from, to)
+ function session.open_stream(session, from, to) -- luacheck: ignore 432/session
local send = session.sends2s or session.send;
local attr = {
@@ -264,7 +271,7 @@ local function new(session, stream_callbacks, stanza_size_limit)
n_outstanding_bytes = 0;
meta.reset();
end,
- feed = function (self, data)
+ feed = function (self, data) -- luacheck: ignore 212/self
if lxp_supports_bytecount then
n_outstanding_bytes = n_outstanding_bytes + #data;
end