aboutsummaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/adhoc.lua2
-rw-r--r--util/array.lua2
-rw-r--r--util/async.lua261
-rw-r--r--util/cache.lua26
-rw-r--r--util/caps.lua1
-rw-r--r--util/dataforms.lua40
-rw-r--r--util/datamanager.lua10
-rw-r--r--util/datetime.lua1
-rw-r--r--util/debug.lua21
-rw-r--r--util/dependencies.lua10
-rw-r--r--util/envload.lua2
-rw-r--r--util/events.lua14
-rw-r--r--util/filters.lua1
-rw-r--r--util/format.lua20
-rw-r--r--util/import.lua4
-rw-r--r--util/indexedbheap.lua157
-rw-r--r--util/ip.lua250
-rw-r--r--util/iterators.lua50
-rw-r--r--util/jid.lua1
-rw-r--r--util/json.lua10
-rw-r--r--util/logger.lua14
-rw-r--r--util/multitable.lua5
-rw-r--r--util/openssl.lua2
-rw-r--r--util/pluginloader.lua1
-rw-r--r--util/presence.lua1
-rw-r--r--util/prosodyctl.lua17
-rw-r--r--util/pubsub.lua252
-rw-r--r--util/random.lua4
-rw-r--r--util/sasl.lua9
-rw-r--r--util/sasl/anonymous.lua7
-rw-r--r--util/sasl/digest-md5.lua1
-rw-r--r--util/sasl/external.lua1
-rw-r--r--util/sasl/plain.lua1
-rw-r--r--util/sasl/scram.lua24
-rw-r--r--util/sasl_cyrus.lua1
-rw-r--r--util/serialization.lua295
-rw-r--r--util/set.lua3
-rw-r--r--util/sql.lua50
-rw-r--r--util/sslconfig.lua1
-rw-r--r--util/stanza.lua85
-rw-r--r--util/startup.lua551
-rw-r--r--util/template.lua5
-rw-r--r--util/termcolours.lua1
-rw-r--r--util/throttle.lua1
-rw-r--r--util/timer.lua144
-rw-r--r--util/vcard.lua574
-rw-r--r--util/watchdog.lua1
-rw-r--r--util/x509.lua1
-rw-r--r--util/xml.lua24
-rw-r--r--util/xmppstream.lua32
50 files changed, 2546 insertions, 445 deletions
diff --git a/util/adhoc.lua b/util/adhoc.lua
index 17c9eee5..d81b8242 100644
--- a/util/adhoc.lua
+++ b/util/adhoc.lua
@@ -1,3 +1,5 @@
+-- luacheck: ignore 212/self
+
local function new_simple_form(form, result_handler)
return function(self, data, state)
if state then
diff --git a/util/array.lua b/util/array.lua
index 150b4355..1a8ffec7 100644
--- a/util/array.lua
+++ b/util/array.lua
@@ -19,7 +19,7 @@ local type = type;
local array = {};
local array_base = {};
local array_methods = {};
-local array_mt = { __index = array_methods, __tostring = function (self) return "{"..self:concat(", ").."}"; end };
+local array_mt = { __index = array_methods, __name = "array", __tostring = function (self) return "{"..self:concat(", ").."}"; end };
local function new_array(self, t, _s, _var)
if type(t) == "function" then -- Assume iterator
diff --git a/util/async.lua b/util/async.lua
new file mode 100644
index 00000000..d1d78fa5
--- /dev/null
+++ b/util/async.lua
@@ -0,0 +1,261 @@
+local logger = require "util.logger";
+local log = logger.init("util.async");
+local timer = require "util.timer";
+local new_id = require "util.id".short;
+
+local function checkthread()
+ local thread, main = coroutine.running();
+ if not thread or main then
+ error("Not running in an async context, see https://prosody.im/doc/developers/util/async");
+ end
+ return thread;
+end
+
+local function runner_from_thread(thread)
+ local level = 0;
+ -- Find the 'level' of the top-most function (0 == current level, 1 == caller, ...)
+ while debug.getinfo(thread, level, "") do level = level + 1; end
+ local name, runner = debug.getlocal(thread, level-1, 1);
+ if name ~= "self" or type(runner) ~= "table" or runner.thread ~= thread then
+ return nil;
+ end
+ return runner;
+end
+
+local function call_watcher(runner, watcher_name, ...)
+ local watcher = runner.watchers[watcher_name];
+ if not watcher then
+ return false;
+ end
+ runner:log("debug", "Calling '%s' watcher", watcher_name);
+ local ok, err = pcall(watcher, runner, ...); -- COMPAT: Switch to xpcall after Lua 5.1
+ if not ok then
+ runner:log("error", "Error in '%s' watcher: %s", watcher_name, err);
+ return nil, err;
+ end
+ return true;
+end
+
+local function runner_continue(thread)
+ -- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure)
+ if coroutine.status(thread) ~= "suspended" then -- This should suffice
+ log("error", "unexpected async state: thread not suspended");
+ return false;
+ end
+ local ok, state, runner = coroutine.resume(thread);
+ if not ok then
+ local err = state;
+ -- Running the coroutine failed, which means we have to find the runner manually,
+ -- in order to inform the error handler
+ runner = runner_from_thread(thread);
+ if not runner then
+ log("error", "unexpected async state: unable to locate runner during error handling");
+ return false;
+ end
+ call_watcher(runner, "error", debug.traceback(thread, err));
+ runner.state, runner.thread = "ready", nil;
+ return runner:run();
+ elseif state == "ready" then
+ -- If state is 'ready', it is our responsibility to update runner.state from 'waiting'.
+ -- We also have to :run(), because the queue might have further items that will not be
+ -- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer).
+ runner.state = "ready";
+ runner:run();
+ end
+ return true;
+end
+
+local function waiter(num)
+ local thread = checkthread();
+ num = num or 1;
+ local waiting;
+ return function ()
+ if num == 0 then return; end -- already done
+ waiting = true;
+ coroutine.yield("wait");
+ end, function ()
+ num = num - 1;
+ if num == 0 and waiting then
+ runner_continue(thread);
+ elseif num < 0 then
+ error("done() called too many times");
+ end
+ end;
+end
+
+local function guarder()
+ local guards = {};
+ local default_id = {};
+ return function (id, func)
+ id = id or default_id;
+ local thread = checkthread();
+ local guard = guards[id];
+ if not guard then
+ guard = {};
+ guards[id] = guard;
+ log("debug", "New guard!");
+ else
+ table.insert(guard, thread);
+ log("debug", "Guarded. %d threads waiting.", #guard)
+ coroutine.yield("wait");
+ end
+ local function exit()
+ local next_waiting = table.remove(guard, 1);
+ if next_waiting then
+ log("debug", "guard: Executing next waiting thread (%d left)", #guard)
+ runner_continue(next_waiting);
+ else
+ log("debug", "Guard off duty.")
+ guards[id] = nil;
+ end
+ end
+ if func then
+ func();
+ exit();
+ return;
+ end
+ return exit;
+ end;
+end
+
+local runner_mt = {};
+runner_mt.__index = runner_mt;
+
+local function runner_create_thread(func, self)
+ local thread = coroutine.create(function (self) -- luacheck: ignore 432/self
+ while true do
+ func(coroutine.yield("ready", self));
+ end
+ end);
+ debug.sethook(thread, debug.gethook());
+ assert(coroutine.resume(thread, self)); -- Start it up, it will return instantly to wait for the first input
+ return thread;
+end
+
+local function default_error_watcher(runner, err)
+ runner:log("error", "Encountered error: %s", err);
+ error(err);
+end
+local function default_func(f) f(); end
+local function runner(func, watchers, data)
+ local id = new_id();
+ local _log = logger.init("runner" .. id);
+ return setmetatable({ func = func or default_func, thread = false, state = "ready", notified_state = "ready",
+ queue = {}, watchers = watchers or { error = default_error_watcher }, data = data, id = id, _log = _log; }
+ , runner_mt);
+end
+
+-- Add a task item for the runner to process
+function runner_mt:run(input)
+ if input ~= nil then
+ table.insert(self.queue, input);
+ --self:log("debug", "queued new work item, %d items queued", #self.queue);
+ end
+ if self.state ~= "ready" then
+ -- The runner is busy. Indicate that the task item has been
+ -- queued, and return information about the current runner state
+ return true, self.state, #self.queue;
+ end
+
+ local q, thread = self.queue, self.thread;
+ if not thread or coroutine.status(thread) == "dead" then
+ self:log("debug", "creating new coroutine");
+ -- Create a new coroutine for this runner
+ thread = runner_create_thread(self.func, self);
+ self.thread = thread;
+ end
+
+ -- Process task item(s) while the queue is not empty, and we're not blocked
+ local n, state, err = #q, self.state, nil;
+ self.state = "running";
+ --self:log("debug", "running main loop");
+ while n > 0 and state == "ready" and not err do
+ local consumed;
+ -- Loop through queue items, and attempt to run them
+ for i = 1,n do
+ local queued_input = q[i];
+ local ok, new_state = coroutine.resume(thread, queued_input);
+ if not ok then
+ -- There was an error running the coroutine, save the error, mark runner as ready to begin again
+ consumed, state, err = i, "ready", debug.traceback(thread, new_state);
+ self.thread = nil;
+ break;
+ elseif new_state == "wait" then
+ -- Runner is blocked on waiting for a task item to complete
+ consumed, state = i, "waiting";
+ break;
+ end
+ end
+ -- Loop ended - either queue empty because all tasks passed without blocking (consumed == nil)
+ -- or runner is blocked/errored, and consumed will contain the number of tasks processed so far
+ if not consumed then consumed = n; end
+ -- Remove consumed items from the queue array
+ if q[n+1] ~= nil then
+ n = #q;
+ end
+ for i = 1, n do
+ q[i] = q[consumed+i];
+ end
+ n = #q;
+ end
+ -- Runner processed all items it can, so save current runner state
+ self.state = state;
+ if err or state ~= self.notified_state then
+ self:log("debug", "changed state from %s to %s", self.notified_state, err and ("error ("..state..")") or state);
+ if err then
+ state = "error"
+ else
+ self.notified_state = state;
+ end
+ local handler = self.watchers[state];
+ if handler then handler(self, err); end
+ end
+ if n > 0 then
+ return self:run();
+ end
+ return true, state, n;
+end
+
+-- Add a task item to the queue without invoking the runner, even if it is idle
+function runner_mt:enqueue(input)
+ table.insert(self.queue, input);
+ self:log("debug", "queued new work item, %d items queued", #self.queue);
+ return self;
+end
+
+function runner_mt:log(level, fmt, ...)
+ return self._log(level, fmt, ...);
+end
+
+function runner_mt:onready(f)
+ self.watchers.ready = f;
+ return self;
+end
+
+function runner_mt:onwaiting(f)
+ self.watchers.waiting = f;
+ return self;
+end
+
+function runner_mt:onerror(f)
+ self.watchers.error = f;
+ return self;
+end
+
+local function ready()
+ return pcall(checkthread);
+end
+
+local function sleep(s)
+ local wait, done = waiter();
+ timer.add_task(s, done);
+ wait();
+end
+
+return {
+ ready = ready;
+ waiter = waiter;
+ guarder = guarder;
+ runner = runner;
+ sleep = sleep;
+};
diff --git a/util/cache.lua b/util/cache.lua
index 9c141bb6..a5fd5e6d 100644
--- a/util/cache.lua
+++ b/util/cache.lua
@@ -116,6 +116,25 @@ function cache_methods:tail()
return tail.key, tail.value;
end
+function cache_methods:resize(new_size)
+ new_size = assert(tonumber(new_size), "cache size must be a number");
+ new_size = math.floor(new_size);
+ assert(new_size > 0, "cache size must be greater than zero");
+ local on_evict = self._on_evict;
+ while self._count > new_size do
+ local tail = self._tail;
+ local evicted_key, evicted_value = tail.key, tail.value;
+ if on_evict ~= nil and (on_evict == false or on_evict(evicted_key, evicted_value) == false) then
+ -- Cache is full, and we're not allowed to evict
+ return false;
+ end
+ _remove(self, tail);
+ self._data[evicted_key] = nil;
+ end
+ self.size = new_size;
+ return true;
+end
+
function cache_methods:table()
--luacheck: ignore 212/t
if not self.proxy_table then
@@ -139,6 +158,13 @@ function cache_methods:table()
return self.proxy_table;
end
+function cache_methods:clear()
+ self._data = {};
+ self._count = 0;
+ self._head = nil;
+ self._tail = nil;
+end
+
local function new(size, on_evict)
size = assert(tonumber(size), "cache size must be a number");
size = math.floor(size);
diff --git a/util/caps.lua b/util/caps.lua
index cd5ff9c0..de492edb 100644
--- a/util/caps.lua
+++ b/util/caps.lua
@@ -13,6 +13,7 @@ local t_insert, t_sort, t_concat = table.insert, table.sort, table.concat;
local ipairs = ipairs;
local _ENV = nil;
+-- luacheck: std none
local function calculate_hash(disco_info)
local identities, features, extensions = {}, {}, {};
diff --git a/util/dataforms.lua b/util/dataforms.lua
index 469ce976..6152c9f3 100644
--- a/util/dataforms.lua
+++ b/util/dataforms.lua
@@ -8,12 +8,13 @@
local setmetatable = setmetatable;
local ipairs = ipairs;
-local tostring, type, next = tostring, type, next;
+local type, next = type, next;
local t_concat = table.concat;
local st = require "util.stanza";
local jid_prep = require "util.jid".prep;
local _ENV = nil;
+-- luacheck: std none
local xmlns_forms = 'jabber:x:data';
@@ -37,9 +38,18 @@ function form_t.form(layout, data, formtype)
-- Add field tag
form:tag("field", { type = field_type, var = field.name, label = field.label });
- local value = (data and data[field.name]) or field.value;
+ if field.desc then
+ form:text_tag("desc", field.desc);
+ end
+
+ local value;
+ if data and data[field.name] ~= nil then
+ value = data[field.name];
+ else
+ value = field.value;
+ end
- if value then
+ if value ~= nil then
-- Add value, depending on type
if field_type == "hidden" then
if type(value) == "table" then
@@ -48,7 +58,7 @@ function form_t.form(layout, data, formtype)
:add_child(value)
:up();
else
- form:tag("value"):text(tostring(value)):up();
+ form:tag("value"):text(value):up();
end
elseif field_type == "boolean" then
form:tag("value"):text((value and "1") or "0"):up();
@@ -78,7 +88,7 @@ function form_t.form(layout, data, formtype)
has_default = true;
end
else
- form:tag("option", { label= val }):tag("value"):text(tostring(val)):up():up();
+ form:tag("option", { label= val }):tag("value"):text(val):up():up();
end
end
end
@@ -94,7 +104,7 @@ function form_t.form(layout, data, formtype)
form:tag("value"):text(val.value):up();
end
else
- form:tag("option", { label= val }):tag("value"):text(tostring(val)):up():up();
+ form:tag("option", { label= val }):tag("value"):text(val):up():up();
end
end
end
@@ -145,7 +155,7 @@ function form_t.data(layout, stanza)
if field.required then
errors[field.name] = "Required value missing";
end
- else
+ elseif field.name then
present[field.name] = true;
local reader = field_readers[field.type];
if reader then
@@ -248,8 +258,24 @@ field_readers["hidden"] =
return field_tag:get_child_text("value");
end
+
+local function get_form_type(form)
+ if not st.is_stanza(form) then
+ return nil, "not a stanza object";
+ elseif form.attr.xmlns ~= "jabber:x:data" or form.name ~= "x" then
+ return nil, "not a dataform element";
+ end
+ for field in form:childtags("field") do
+ if field.attr.var == "FORM_TYPE" then
+ return field:get_child_text("value");
+ end
+ end
+ return "";
+end
+
return {
new = new;
+ get_type = get_form_type;
};
diff --git a/util/datamanager.lua b/util/datamanager.lua
index bd8fb7bb..cf96887b 100644
--- a/util/datamanager.lua
+++ b/util/datamanager.lua
@@ -40,9 +40,10 @@ pcall(function()
end);
local _ENV = nil;
+-- luacheck: std none
---- utils -----
-local encode, decode;
+local encode, decode, store_encode;
do
local urlcodes = setmetatable({}, { __index = function (t, k) t[k] = char(tonumber(k, 16)); return t[k]; end });
@@ -53,6 +54,12 @@ do
encode = function (s)
return s and (s:gsub("%W", function (c) return format("%%%02x", c:byte()); end));
end
+
+ -- Special encode function for store names, which historically were unencoded.
+ -- All currently known stores use a-z and underscore, so this one preserves underscores.
+ store_encode = function (s)
+ return s and (s:gsub("[^_%w]", function (c) return format("%%%02x", c:byte()); end));
+ end
end
if not atomic_append then
@@ -119,6 +126,7 @@ local function getpath(username, host, datastore, ext, create)
ext = ext or "dat";
host = (host and encode(host)) or "_global";
username = username and encode(username);
+ datastore = store_encode(datastore);
if username then
if create then mkdir(mkdir(mkdir(data_path).."/"..host).."/"..datastore); end
return format("%s/%s/%s/%s.%s", data_path, host, datastore, username, ext);
diff --git a/util/datetime.lua b/util/datetime.lua
index abb4e867..06be9fc2 100644
--- a/util/datetime.lua
+++ b/util/datetime.lua
@@ -15,6 +15,7 @@ local os_difftime = os.difftime;
local tonumber = tonumber;
local _ENV = nil;
+-- luacheck: std none
local function date(t)
return os_date("!%Y-%m-%d", t);
diff --git a/util/debug.lua b/util/debug.lua
index 00f476d0..9a28395a 100644
--- a/util/debug.lua
+++ b/util/debug.lua
@@ -47,6 +47,7 @@ local function get_upvalues_table(func)
for upvalue_num = 1, math.huge do
local name, value = debug.getupvalue(func, upvalue_num);
if not name then break; end
+ if name == "" then name = ("[%d]"):format(upvalue_num); end
table.insert(upvalues, { name = name, value = value });
end
end
@@ -112,7 +113,9 @@ end
local function build_source_boundary_marker(last_source_desc)
local padding = string.rep("-", math.floor(((optimal_line_length - 6) - #last_source_desc)/2));
- return getstring(styles.boundary_padding, "v"..padding).." "..getstring(styles.filename, last_source_desc).." "..getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v "));
+ return getstring(styles.boundary_padding, "v"..padding).." "..
+ getstring(styles.filename, last_source_desc).." "..
+ getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v "));
end
local function _traceback(thread, message, level)
@@ -142,9 +145,9 @@ local function _traceback(thread, message, level)
local last_source_desc;
local lines = {};
- for nlevel, level in ipairs(levels) do
- local info = level.info;
- local line = "...";
+ for nlevel, current_level in ipairs(levels) do
+ local info = current_level.info;
+ local line;
local func_type = info.namewhat.." ";
local source_desc = (info.short_src == "[C]" and "C code") or info.short_src or "Unknown";
if func_type == " " then func_type = ""; end;
@@ -160,7 +163,9 @@ local function _traceback(thread, message, level)
if func_type == "global " or func_type == "local " then
func_type = func_type.."function ";
end
- line = "[Lua] "..getstring(styles.location, info.short_src.." line "..info.currentline).." in "..func_type..getstring(styles.funcname, name).." (defined on line "..info.linedefined..")";
+ line = "[Lua] "..getstring(styles.location, info.short_src.." line "..
+ info.currentline).." in "..func_type..getstring(styles.funcname, name)..
+ " (defined on line "..info.linedefined..")";
end
if source_desc ~= last_source_desc then -- Venturing into a new source, add marker for previous
last_source_desc = source_desc;
@@ -169,13 +174,13 @@ local function _traceback(thread, message, level)
nlevel = nlevel-1;
table.insert(lines, "\t"..(nlevel==0 and ">" or " ")..getstring(styles.level_num, "("..nlevel..") ")..line);
local npadding = (" "):rep(#tostring(nlevel));
- if level.locals then
- local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding);
+ if current_level.locals then
+ local locals_str = string_from_var_table(current_level.locals, optimal_line_length, "\t "..npadding);
if locals_str then
table.insert(lines, "\t "..npadding.."Locals: "..locals_str);
end
end
- local upvalues_str = string_from_var_table(level.upvalues, optimal_line_length, "\t "..npadding);
+ local upvalues_str = string_from_var_table(current_level.upvalues, optimal_line_length, "\t "..npadding);
if upvalues_str then
table.insert(lines, "\t "..npadding.."Upvals: "..upvalues_str);
end
diff --git a/util/dependencies.lua b/util/dependencies.lua
index de840241..9b0afd77 100644
--- a/util/dependencies.lua
+++ b/util/dependencies.lua
@@ -28,7 +28,7 @@ local function missingdep(name, sources, msg)
end
print("");
print(msg or (name.." is required for Prosody to run, so we will now exit."));
- print("More help can be found on our website, at http://prosody.im/doc/depends");
+ print("More help can be found on our website, at https://prosody.im/doc/depends");
print("**************************");
print("");
end
@@ -40,7 +40,7 @@ end
package.preload["util.ztact"] = function ()
if not package.loaded["core.loggingmanager"] then
error("util.ztact has been removed from Prosody and you need to fix your config "
- .."file. More information can be found at http://prosody.im/doc/packagers#ztact", 0);
+ .."file. More information can be found at https://prosody.im/doc/packagers#ztact", 0);
else
error("module 'util.ztact' has been deprecated in Prosody 0.8.");
end
@@ -156,7 +156,7 @@ local function log_warnings()
if ssl then
local major, minor, veryminor, patched = ssl._VERSION:match("(%d+)%.(%d+)%.?(%d*)(M?)");
if not major or ((tonumber(major) == 0 and (tonumber(minor) or 0) <= 3 and (tonumber(veryminor) or 0) <= 2) and patched ~= "M") then
- prosody.log("error", "This version of LuaSec contains a known bug that causes disconnects, see http://prosody.im/doc/depends");
+ prosody.log("error", "This version of LuaSec contains a known bug that causes disconnects, see https://prosody.im/doc/depends");
end
end
local lxp = softreq"lxp";
@@ -165,7 +165,7 @@ local function log_warnings()
prosody.log("error", "The version of LuaExpat on your system leaves Prosody "
.."vulnerable to denial-of-service attacks. You should upgrade to "
.."LuaExpat 1.3.0 or higher as soon as possible. See "
- .."http://prosody.im/doc/depends#luaexpat for more information.");
+ .."https://prosody.im/doc/depends#luaexpat for more information.");
end
if not lxp.new({}).getcurrentbytecount then
prosody.log("error", "The version of LuaExpat on your system does not support "
@@ -173,7 +173,7 @@ local function log_warnings()
.."networks (e.g. the internet) vulnerable to denial-of-service "
.."attacks. You should upgrade to LuaExpat 1.3.0 or higher as "
.."soon as possible. See "
- .."http://prosody.im/doc/depends#luaexpat for more information.");
+ .."https://prosody.im/doc/depends#luaexpat for more information.");
end
end
end
diff --git a/util/envload.lua b/util/envload.lua
index 926f20c0..6182a1f9 100644
--- a/util/envload.lua
+++ b/util/envload.lua
@@ -4,7 +4,7 @@
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
--- luacheck: ignore 113/setfenv
+-- luacheck: ignore 113/setfenv 113/loadstring
local load, loadstring, setfenv = load, loadstring, setfenv;
local io_open = io.open;
diff --git a/util/events.lua b/util/events.lua
index 6e13619c..0bf0ddcb 100644
--- a/util/events.lua
+++ b/util/events.lua
@@ -15,6 +15,7 @@ local setmetatable = setmetatable;
local next = next;
local _ENV = nil;
+-- luacheck: std none
local function new()
-- Map event name to ordered list of handlers (lazily built): handlers[event_name] = array_of_handler_functions
@@ -26,7 +27,7 @@ local function new()
-- Event map: event_map[handler_function] = priority_number
local event_map = {};
-- Called on-demand to build handlers entries
- local function _rebuild_index(handlers, event)
+ local function _rebuild_index(self, event)
local _handlers = event_map[event];
if not _handlers or next(_handlers) == nil then return; end
local index = {};
@@ -34,7 +35,7 @@ local function new()
t_insert(index, handler);
end
t_sort(index, function(a, b) return _handlers[a] > _handlers[b]; end);
- handlers[event] = index;
+ self[event] = index;
return index;
end;
setmetatable(handlers, { __index = _rebuild_index });
@@ -61,13 +62,13 @@ local function new()
local function get_handlers(event)
return handlers[event];
end;
- local function add_handlers(handlers)
- for event, handler in pairs(handlers) do
+ local function add_handlers(self)
+ for event, handler in pairs(self) do
add_handler(event, handler);
end
end;
- local function remove_handlers(handlers)
- for event, handler in pairs(handlers) do
+ local function remove_handlers(self)
+ for event, handler in pairs(self) do
remove_handler(event, handler);
end
end;
@@ -81,6 +82,7 @@ local function new()
end
end;
local function fire_event(event_name, event_data)
+ -- luacheck: ignore 432/event_name 432/event_data
local w = wrappers[event_name] or global_wrappers;
if w then
local curr_wrapper = #w;
diff --git a/util/filters.lua b/util/filters.lua
index f405c0bd..f30dfd9c 100644
--- a/util/filters.lua
+++ b/util/filters.lua
@@ -9,6 +9,7 @@
local t_insert, t_remove = table.insert, table.remove;
local _ENV = nil;
+-- luacheck: std none
local new_filter_hooks = {};
diff --git a/util/format.lua b/util/format.lua
index 5f2b12be..c5e513fa 100644
--- a/util/format.lua
+++ b/util/format.lua
@@ -4,11 +4,10 @@
local tostring = tostring;
local select = select;
-local assert = assert;
-local unpack = unpack;
+local unpack = table.unpack or unpack; -- luacheck: ignore 113/unpack
local type = type;
-local function format(format, ...)
+local function format(formatstring, ...)
local args, args_length = { ... }, select('#', ...);
-- format specifier spec:
@@ -25,7 +24,7 @@ local function format(format, ...)
-- process each format specifier
local i = 0;
- format = format:gsub("%%[^cdiouxXaAeEfgGqs%%]*[cdiouxXaAeEfgGqs%%]", function(spec)
+ formatstring = formatstring:gsub("%%[^cdiouxXaAeEfgGqs%%]*[cdiouxXaAeEfgGqs%%]", function(spec)
if spec ~= "%%" then
i = i + 1;
local arg = args[i];
@@ -54,21 +53,12 @@ local function format(format, ...)
else
args[i] = tostring(arg);
end
- format = format .. " [%s]"
+ formatstring = formatstring .. " [%s]"
end
- return format:format(unpack(args));
-end
-
-local function test()
- assert(format("%s", "hello") == "hello");
- assert(format("%s") == "<nil>");
- assert(format("%s", true) == "true");
- assert(format("%d", true) == "[true]");
- assert(format("%%", true) == "% [true]");
+ return formatstring:format(unpack(args));
end
return {
format = format;
- test = test;
};
diff --git a/util/import.lua b/util/import.lua
index c2b9dce1..8ecfe43c 100644
--- a/util/import.lua
+++ b/util/import.lua
@@ -8,9 +8,9 @@
-local unpack = table.unpack or unpack; --luacheck: ignore 113
+local unpack = table.unpack or unpack; --luacheck: ignore 113 143
local t_insert = table.insert;
-function import(module, ...)
+function _G.import(module, ...)
local m = package.loaded[module] or require(module);
if type(m) == "table" and ... then
local ret = {};
diff --git a/util/indexedbheap.lua b/util/indexedbheap.lua
new file mode 100644
index 00000000..7f193d54
--- /dev/null
+++ b/util/indexedbheap.lua
@@ -0,0 +1,157 @@
+
+local setmetatable = setmetatable;
+local math_floor = math.floor;
+local t_remove = table.remove;
+
+local function _heap_insert(self, item, sync, item2, index)
+ local pos = #self + 1;
+ while true do
+ local half_pos = math_floor(pos / 2);
+ if half_pos == 0 or item > self[half_pos] then break; end
+ self[pos] = self[half_pos];
+ sync[pos] = sync[half_pos];
+ index[sync[pos]] = pos;
+ pos = half_pos;
+ end
+ self[pos] = item;
+ sync[pos] = item2;
+ index[item2] = pos;
+end
+
+local function _percolate_up(self, k, sync, index)
+ local tmp = self[k];
+ local tmp_sync = sync[k];
+ while k ~= 1 do
+ local parent = math_floor(k/2);
+ if tmp < self[parent] then break; end
+ self[k] = self[parent];
+ sync[k] = sync[parent];
+ index[sync[k]] = k;
+ k = parent;
+ end
+ self[k] = tmp;
+ sync[k] = tmp_sync;
+ index[tmp_sync] = k;
+ return k;
+end
+
+local function _percolate_down(self, k, sync, index)
+ local tmp = self[k];
+ local tmp_sync = sync[k];
+ local size = #self;
+ local child = 2*k;
+ while 2*k <= size do
+ if child ~= size and self[child] > self[child + 1] then
+ child = child + 1;
+ end
+ if tmp > self[child] then
+ self[k] = self[child];
+ sync[k] = sync[child];
+ index[sync[k]] = k;
+ else
+ break;
+ end
+
+ k = child;
+ child = 2*k;
+ end
+ self[k] = tmp;
+ sync[k] = tmp_sync;
+ index[tmp_sync] = k;
+ return k;
+end
+
+local function _heap_pop(self, sync, index)
+ local size = #self;
+ if size == 0 then return nil; end
+
+ local result = self[1];
+ local result_sync = sync[1];
+ index[result_sync] = nil;
+ if size == 1 then
+ self[1] = nil;
+ sync[1] = nil;
+ return result, result_sync;
+ end
+ self[1] = t_remove(self);
+ sync[1] = t_remove(sync);
+ index[sync[1]] = 1;
+
+ _percolate_down(self, 1, sync, index);
+
+ return result, result_sync;
+end
+
+local indexed_heap = {};
+
+function indexed_heap:insert(item, priority, id)
+ if id == nil then
+ id = self.current_id;
+ self.current_id = id + 1;
+ end
+ self.items[id] = item;
+ _heap_insert(self.priorities, priority, self.ids, id, self.index);
+ return id;
+end
+function indexed_heap:pop()
+ local priority, id = _heap_pop(self.priorities, self.ids, self.index);
+ if id then
+ local item = self.items[id];
+ self.items[id] = nil;
+ return priority, item, id;
+ end
+end
+function indexed_heap:peek()
+ return self.priorities[1];
+end
+function indexed_heap:reprioritize(id, priority)
+ local k = self.index[id];
+ if k == nil then return; end
+ self.priorities[k] = priority;
+
+ k = _percolate_up(self.priorities, k, self.ids, self.index);
+ _percolate_down(self.priorities, k, self.ids, self.index);
+end
+function indexed_heap:remove_index(k)
+ local result = self.priorities[k];
+ if result == nil then return; end
+
+ local result_sync = self.ids[k];
+ local item = self.items[result_sync];
+ local size = #self.priorities;
+
+ self.priorities[k] = self.priorities[size];
+ self.ids[k] = self.ids[size];
+ self.index[self.ids[k]] = k;
+
+ t_remove(self.priorities);
+ t_remove(self.ids);
+
+ self.index[result_sync] = nil;
+ self.items[result_sync] = nil;
+
+ if size > k then
+ k = _percolate_up(self.priorities, k, self.ids, self.index);
+ _percolate_down(self.priorities, k, self.ids, self.index);
+ end
+
+ return result, item, result_sync;
+end
+function indexed_heap:remove(id)
+ return self:remove_index(self.index[id]);
+end
+
+local mt = { __index = indexed_heap };
+
+local _M = {
+ create = function()
+ return setmetatable({
+ ids = {}; -- heap of ids, sync'd with priorities
+ items = {}; -- map id->items
+ priorities = {}; -- heap of priorities
+ index = {}; -- map of id->index of id in ids
+ current_id = 1.5
+ }, mt);
+ end
+};
+return _M;
diff --git a/util/ip.lua b/util/ip.lua
index 81a98ef7..0ec9e297 100644
--- a/util/ip.lua
+++ b/util/ip.lua
@@ -5,69 +5,76 @@
-- COPYING file in the source package for more information.
--
+local net = require "util.net";
+local hex = require "util.hex";
+
local ip_methods = {};
-local ip_mt = { __index = function (ip, key) return (ip_methods[key])(ip); end,
- __tostring = function (ip) return ip.addr; end,
- __eq = function (ipA, ipB) return ipA.addr == ipB.addr; end};
-local hex2bits = { ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111", ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011", ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111" };
+
+local ip_mt = {
+ __index = function (ip, key)
+ local method = ip_methods[key];
+ if not method then return nil; end
+ local ret = method(ip);
+ ip[key] = ret;
+ return ret;
+ end,
+ __tostring = function (ip) return ip.addr; end,
+ __eq = function (ipA, ipB) return ipA.packed == ipB.packed; end
+};
+
+local hex2bits = {
+ ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011",
+ ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111",
+ ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011",
+ ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111",
+};
local function new_ip(ipStr, proto)
- if not proto then
- local sep = ipStr:match("^%x+(.)");
- if sep == ":" or (not(sep) and ipStr:sub(1,1) == ":") then
- proto = "IPv6"
- elseif sep == "." then
- proto = "IPv4"
- end
- if not proto then
- return nil, "invalid address";
- end
- elseif proto ~= "IPv4" and proto ~= "IPv6" then
- return nil, "invalid protocol";
- end
local zone;
- if proto == "IPv6" and ipStr:find('%', 1, true) then
+ if (not proto or proto == "IPv6") and ipStr:find('%', 1, true) then
ipStr, zone = ipStr:match("^(.-)%%(.*)");
end
- if proto == "IPv6" and ipStr:find('.', 1, true) then
- local changed;
- ipStr, changed = ipStr:gsub(":(%d+)%.(%d+)%.(%d+)%.(%d+)$", function(a,b,c,d)
- return (":%04X:%04X"):format(a*256+b,c*256+d);
- end);
- if changed ~= 1 then return nil, "invalid-address"; end
+
+ local packed, err = net.pton(ipStr);
+ if not packed then return packed, err end
+ if proto == "IPv6" and #packed ~= 16 then
+ return nil, "invalid-ipv6";
+ elseif proto == "IPv4" and #packed ~= 4 then
+ return nil, "invalid-ipv4";
+ elseif not proto then
+ if #packed == 16 then
+ proto = "IPv6";
+ elseif #packed == 4 then
+ proto = "IPv4";
+ else
+ return nil, "unknown protocol";
+ end
+ elseif proto ~= "IPv6" and proto ~= "IPv4" then
+ return nil, "invalid protocol";
end
- return setmetatable({ addr = ipStr, proto = proto, zone = zone }, ip_mt);
+ return setmetatable({ addr = ipStr, packed = packed, proto = proto, zone = zone }, ip_mt);
+end
+
+function ip_methods:normal()
+ return net.ntop(self.packed);
end
-local function toBits(ip)
- local result = "";
- local fields = {};
+function ip_methods.bits(ip)
+ return hex.to(ip.packed):upper():gsub(".", hex2bits);
+end
+
+function ip_methods.bits_full(ip)
if ip.proto == "IPv4" then
ip = ip.toV4mapped;
end
- ip = (ip.addr):upper();
- ip:gsub("([^:]*):?", function (c) fields[#fields + 1] = c end);
- if not ip:match(":$") then fields[#fields] = nil; end
- for i, field in ipairs(fields) do
- if field:len() == 0 and i ~= 1 and i ~= #fields then
- for _ = 1, 16 * (9 - #fields) do
- result = result .. "0";
- end
- else
- for _ = 1, 4 - field:len() do
- result = result .. "0000";
- end
- for j = 1, field:len() do
- result = result .. hex2bits[field:sub(j, j)];
- end
- end
- end
- return result;
+ return ip.bits;
end
+local match;
+
local function commonPrefixLength(ipA, ipB)
- ipA, ipB = toBits(ipA), toBits(ipB);
+ ipA, ipB = ipA.bits_full, ipB.bits_full;
for i = 1, 128 do
if ipA:sub(i,i) ~= ipB:sub(i,i) then
return i-1;
@@ -76,56 +83,60 @@ local function commonPrefixLength(ipA, ipB)
return 128;
end
+-- Instantiate once
+local loopback = new_ip("::1");
+local loopback4 = new_ip("127.0.0.0");
+local sixtofour = new_ip("2002::");
+local teredo = new_ip("2001::");
+local linklocal = new_ip("fe80::");
+local linklocal4 = new_ip("169.254.0.0");
+local uniquelocal = new_ip("fc00::");
+local sitelocal = new_ip("fec0::");
+local sixbone = new_ip("3ffe::");
+local defaultunicast = new_ip("::");
+local multicast = new_ip("ff00::");
+local ipv6mapped = new_ip("::ffff:0:0");
+
local function v4scope(ip)
- local fields = {};
- ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end);
- -- Loopback:
- if fields[1] == 127 then
+ if match(ip, loopback4, 8) then
return 0x2;
- -- Link-local unicast:
- elseif fields[1] == 169 and fields[2] == 254 then
+ elseif match(ip, linklocal4) then
return 0x2;
- -- Global unicast:
- else
+ else -- Global unicast
return 0xE;
end
end
local function v6scope(ip)
- -- Loopback:
- if ip:match("^[0:]*1$") then
+ if ip == loopback then
return 0x2;
- -- Link-local unicast:
- elseif ip:match("^[Ff][Ee][89ABab]") then
+ elseif match(ip, linklocal, 10) then
return 0x2;
- -- Site-local unicast:
- elseif ip:match("^[Ff][Ee][CcDdEeFf]") then
+ elseif match(ip, sitelocal, 10) then
return 0x5;
- -- Multicast:
- elseif ip:match("^[Ff][Ff]") then
- return tonumber("0x"..ip:sub(4,4));
- -- Global unicast:
- else
+ elseif match(ip, multicast, 10) then
+ return ip.packed:byte(2) % 0x10;
+ else -- Global unicast
return 0xE;
end
end
local function label(ip)
- if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then
+ if ip == loopback then
return 0;
- elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then
+ elseif match(ip, sixtofour, 16) then
return 2;
- elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then
+ elseif match(ip, teredo, 32) then
return 5;
- elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then
+ elseif match(ip, uniquelocal, 7) then
return 13;
- elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then
+ elseif match(ip, sitelocal, 10) then
return 11;
- elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then
+ elseif match(ip, sixbone, 16) then
return 12;
- elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then
+ elseif match(ip, defaultunicast, 96) then
return 3;
- elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then
+ elseif match(ip, ipv6mapped, 96) then
return 4;
else
return 1;
@@ -133,91 +144,67 @@ local function label(ip)
end
local function precedence(ip)
- if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then
+ if ip == loopback then
return 50;
- elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then
+ elseif match(ip, sixtofour, 16) then
return 30;
- elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then
+ elseif match(ip, teredo, 32) then
return 5;
- elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then
+ elseif match(ip, uniquelocal, 7) then
return 3;
- elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then
+ elseif match(ip, sitelocal, 10) then
return 1;
- elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then
+ elseif match(ip, sixbone, 16) then
return 1;
- elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then
+ elseif match(ip, defaultunicast, 96) then
return 1;
- elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then
+ elseif match(ip, ipv6mapped, 96) then
return 35;
else
return 40;
end
end
-local function toV4mapped(ip)
- local fields = {};
- local ret = "::ffff:";
- ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end);
- ret = ret .. ("%02x"):format(fields[1]);
- ret = ret .. ("%02x"):format(fields[2]);
- ret = ret .. ":"
- ret = ret .. ("%02x"):format(fields[3]);
- ret = ret .. ("%02x"):format(fields[4]);
- return new_ip(ret, "IPv6");
-end
-
function ip_methods:toV4mapped()
if self.proto ~= "IPv4" then return nil, "No IPv4 address" end
- local value = toV4mapped(self.addr);
- self.toV4mapped = value;
+ local value = new_ip("::ffff:" .. self.normal);
return value;
end
function ip_methods:label()
- local value;
if self.proto == "IPv4" then
- value = label(self.toV4mapped);
+ return label(self.toV4mapped);
else
- value = label(self);
+ return label(self);
end
- self.label = value;
- return value;
end
function ip_methods:precedence()
- local value;
if self.proto == "IPv4" then
- value = precedence(self.toV4mapped);
+ return precedence(self.toV4mapped);
else
- value = precedence(self);
+ return precedence(self);
end
- self.precedence = value;
- return value;
end
function ip_methods:scope()
- local value;
if self.proto == "IPv4" then
- value = v4scope(self.addr);
+ return v4scope(self);
else
- value = v6scope(self.addr);
+ return v6scope(self);
end
- self.scope = value;
- return value;
end
+local rfc1918_8 = new_ip("10.0.0.0");
+local rfc1918_12 = new_ip("172.16.0.0");
+local rfc1918_16 = new_ip("192.168.0.0");
+local rfc6598 = new_ip("100.64.0.0");
+
function ip_methods:private()
local private = self.scope ~= 0xE;
if not private and self.proto == "IPv4" then
- local ip = self.addr;
- local fields = {};
- ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end);
- if fields[1] == 127 or fields[1] == 10 or (fields[1] == 192 and fields[2] == 168)
- or (fields[1] == 172 and (fields[2] >= 16 or fields[2] <= 32)) then
- private = true;
- end
+ return match(self, rfc1918_8, 8) or match(self, rfc1918_12, 12) or match(self, rfc1918_16) or match(self, rfc6598, 10);
end
- self.private = private;
return private;
end
@@ -231,15 +218,26 @@ local function parse_cidr(cidr)
return new_ip(cidr), bits;
end
-local function match(ipA, ipB, bits)
- local common_bits = commonPrefixLength(ipA, ipB);
- if bits and ipB.proto == "IPv4" then
- common_bits = common_bits - 96; -- v6 mapped addresses always share these bits
+function match(ipA, ipB, bits)
+ if not bits or bits >= 128 or ipB.proto == "IPv4" and bits >= 32 then
+ return ipA == ipB;
+ elseif bits < 1 then
+ return true;
+ end
+ if ipA.proto ~= ipB.proto then
+ if ipA.proto == "IPv4" then
+ ipA = ipA.toV4mapped;
+ elseif ipB.proto == "IPv4" then
+ ipB = ipB.toV4mapped;
+ bits = bits + (128 - 32);
+ end
end
- return common_bits >= (bits or 128);
+ return ipA.bits:sub(1, bits) == ipB.bits:sub(1, bits);
end
-return {new_ip = new_ip,
+return {
+ new_ip = new_ip,
commonPrefixLength = commonPrefixLength,
parse_cidr = parse_cidr,
- match=match};
+ match = match,
+};
diff --git a/util/iterators.lua b/util/iterators.lua
index bd150ff2..5d16d8c1 100644
--- a/util/iterators.lua
+++ b/util/iterators.lua
@@ -12,8 +12,13 @@ local it = {};
local t_insert = table.insert;
local select, next = select, next;
-local unpack = table.unpack or unpack; --luacheck: ignore 113
-local pack = table.pack or function (...) return { n = select("#", ...), ... }; end
+local unpack = table.unpack or unpack; --luacheck: ignore 113 143
+local pack = table.pack or function (...) return { n = select("#", ...), ... }; end -- luacheck: ignore 143
+local type = type;
+local table, setmetatable = table, setmetatable;
+
+local _ENV = nil;
+--luacheck: std none
-- Reverse an iterator
function it.reverse(f, s, var)
@@ -184,4 +189,45 @@ function it.to_table(f, s, var)
return t;
end
+local function _join_iter(j_s, j_var)
+ local iterators, current_idx = j_s[1], j_s[2];
+ local f, s, var = unpack(iterators[current_idx], 1, 3);
+ if j_var ~= nil then
+ var = j_var;
+ end
+ local ret = pack(f(s, var));
+ local var1 = ret[1];
+ if var1 == nil then
+ -- End of this iterator, advance to next
+ if current_idx == #iterators then
+ -- No more iterators, return nil
+ return;
+ end
+ j_s[2] = current_idx + 1;
+ return _join_iter(j_s);
+ end
+ return unpack(ret, 1, ret.n);
+end
+local join_methods = {};
+local join_mt = {
+ __index = join_methods;
+ __call = function (t, s, var) --luacheck: ignore 212/t
+ return _join_iter(s, var);
+ end;
+};
+
+function join_methods:append(f, s, var)
+ table.insert(self, { f, s, var });
+ return self, { self, 1 };
+end
+
+function join_methods:prepend(f, s, var)
+ table.insert(self, { f, s, var }, 1);
+ return self, { self, 1 };
+end
+
+function it.join(f, s, var)
+ return setmetatable({ {f, s, var} }, join_mt);
+end
+
return it;
diff --git a/util/jid.lua b/util/jid.lua
index f402b7f4..37c48193 100644
--- a/util/jid.lua
+++ b/util/jid.lua
@@ -25,6 +25,7 @@ local unescapes = {};
for k,v in pairs(escapes) do unescapes[v] = k; end
local _ENV = nil;
+-- luacheck: std none
local function split(jid)
if not jid then return; end
diff --git a/util/json.lua b/util/json.lua
index cba54e8e..05af709a 100644
--- a/util/json.lua
+++ b/util/json.lua
@@ -27,9 +27,6 @@ module.null = null;
local escapes = {
["\""] = "\\\"", ["\\"] = "\\\\", ["\b"] = "\\b",
["\f"] = "\\f", ["\n"] = "\\n", ["\r"] = "\\r", ["\t"] = "\\t"};
-local unescapes = {
- ["\""] = "\"", ["\\"] = "\\", ["/"] = "/",
- b = "\b", f = "\f", n = "\n", r = "\r", t = "\t"};
for i=0,31 do
local ch = s_char(i);
if not escapes[ch] then escapes[ch] = ("\\u%.4X"):format(i); end
@@ -263,8 +260,9 @@ end
local function _unescape_func(x)
x = x:match("%x%x%x%x", 3);
if x then
- --if x >= 0xD800 and x <= 0xDFFF then _unescape_error = true; end -- bad surrogate pair
- return codepoint_to_utf8(tonumber(x, 16));
+ local codepoint = tonumber(x, 16)
+ if codepoint >= 0xD800 and codepoint <= 0xDFFF then _unescape_error = true; end -- bad surrogate pair
+ return codepoint_to_utf8(codepoint);
end
_unescape_error = true;
end
@@ -276,7 +274,7 @@ function _readstring(json, index)
--if s:find("[%z-\31]") then return nil, "control char in string"; end
-- FIXME handle control characters
_unescape_error = nil;
- --s = s:gsub("\\u[dD][89abAB]%x%x\\u[dD][cdefCDEF]%x%x", _unescape_surrogate_func);
+ s = s:gsub("\\u[dD][89abAB]%x%x\\u[dD][cdefCDEF]%x%x", _unescape_surrogate_func);
-- FIXME handle escapes beyond BMP
s = s:gsub("\\u.?.?.?.?", _unescape_func);
if _unescape_error then return nil, "invalid escape"; end
diff --git a/util/logger.lua b/util/logger.lua
index e72b29bc..20a5cef2 100644
--- a/util/logger.lua
+++ b/util/logger.lua
@@ -8,8 +8,11 @@
-- luacheck: ignore 213/level
local pairs = pairs;
+local ipairs = ipairs;
+local require = require;
local _ENV = nil;
+-- luacheck: std none
local level_sinks = {};
@@ -67,10 +70,21 @@ local function add_level_sink(level, sink_function)
end
end
+local function add_simple_sink(simple_sink_function, levels)
+ local format = require "util.format".format;
+ local function sink_function(name, level, msg, ...)
+ return simple_sink_function(name, level, format(msg, ...));
+ end
+ for _, level in ipairs(levels or {"debug", "info", "warn", "error"}) do
+ add_level_sink(level, sink_function);
+ end
+end
+
return {
init = init;
make_logger = make_logger;
reset = reset;
add_level_sink = add_level_sink;
+ add_simple_sink = add_simple_sink;
new = make_logger;
};
diff --git a/util/multitable.lua b/util/multitable.lua
index e4321d3d..8d32ed8a 100644
--- a/util/multitable.lua
+++ b/util/multitable.lua
@@ -9,9 +9,10 @@
local select = select;
local t_insert = table.insert;
local pairs, next, type = pairs, next, type;
-local unpack = table.unpack or unpack; --luacheck: ignore 113
+local unpack = table.unpack or unpack; --luacheck: ignore 113 143
local _ENV = nil;
+-- luacheck: std none
local function get(self, ...)
local t = self.data;
@@ -132,7 +133,7 @@ local function iter(self, ...)
local maxdepth = select("#", ...);
local stack = { self.data };
local keys = { };
- local function it(self)
+ local function it(self) -- luacheck: ignore 432/self
local depth = #stack;
local key = next(stack[depth], keys[depth]);
if key == nil then -- Go up the stack
diff --git a/util/openssl.lua b/util/openssl.lua
index 703c6d15..32b5aea7 100644
--- a/util/openssl.lua
+++ b/util/openssl.lua
@@ -114,7 +114,7 @@ function ssl_config:add_xmppAddr(host)
s_format("%s;%s", oid_xmppaddr, utf8string(host)));
end
-function ssl_config:from_prosody(hosts, config, certhosts)
+function ssl_config:from_prosody(hosts, config, certhosts) -- luacheck: ignore 431/config
-- TODO Decide if this should go elsewhere
local found_matching_hosts = false;
for i = 1, #certhosts do
diff --git a/util/pluginloader.lua b/util/pluginloader.lua
index 004855f0..9ab8f245 100644
--- a/util/pluginloader.lua
+++ b/util/pluginloader.lua
@@ -5,6 +5,7 @@
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
+-- luacheck: ignore 113/CFG_PLUGINDIR
local dir_sep, path_sep = package.config:match("^(%S+)%s(%S+)");
local plugin_dir = {};
diff --git a/util/presence.lua b/util/presence.lua
index f6370354..8d1ae2d9 100644
--- a/util/presence.lua
+++ b/util/presence.lua
@@ -13,7 +13,6 @@ local function select_top_resources(user)
local recipients = {};
for _, session in pairs(user.sessions) do -- find resource with greatest priority
if session.presence then
- -- TODO check active privacy list for session
local p = session.priority;
if p > priority then
priority = p;
diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua
index 8ae051ae..5f0c4d12 100644
--- a/util/prosodyctl.lua
+++ b/util/prosodyctl.lua
@@ -24,8 +24,6 @@ local io, os = io, os;
local print = print;
local tonumber = tonumber;
-local CFG_SOURCEDIR = _G.CFG_SOURCEDIR;
-
local _G = _G;
local prosody = prosody;
@@ -66,7 +64,10 @@ local function getline()
end
local function getpass()
- local stty_ret = os.execute("stty -echo 2>/dev/null");
+ local stty_ret, _, status_code = os.execute("stty -echo 2>/dev/null");
+ if status_code then -- COMPAT w/ Lua 5.1
+ stty_ret = status_code;
+ end
if stty_ret ~= 0 then
io.write("\027[08m"); -- ANSI 'hidden' text attribute
end
@@ -189,8 +190,8 @@ local function getpid()
pidfile = config.resolve_relative_path(prosody.paths.data, pidfile);
- local modules_enabled = set.new(config.get("*", "modules_disabled"));
- if prosody.platform ~= "posix" or modules_enabled:contains("posix") then
+ local modules_disabled = set.new(config.get("*", "modules_disabled"));
+ if prosody.platform ~= "posix" or modules_disabled:contains("posix") then
return false, "no-posix";
end
@@ -228,7 +229,7 @@ local function isrunning()
return true, signal.kill(pid, 0) == 0;
end
-local function start()
+local function start(source_dir)
local ok, ret = isrunning();
if not ok then
return ok, ret;
@@ -236,10 +237,10 @@ local function start()
if ret then
return false, "already-running";
end
- if not CFG_SOURCEDIR then
+ if not source_dir then
os.execute("./prosody");
else
- os.execute(CFG_SOURCEDIR.."/../../bin/prosody");
+ os.execute(source_dir.."/../../bin/prosody");
end
return true;
end
diff --git a/util/pubsub.lua b/util/pubsub.lua
index 1db917d8..b7f89844 100644
--- a/util/pubsub.lua
+++ b/util/pubsub.lua
@@ -1,32 +1,75 @@
local events = require "util.events";
local cache = require "util.cache";
-local service = {};
-local service_mt = { __index = service };
+local service_mt = {};
-local default_config = { __index = {
- itemstore = function (config) return cache.new(tonumber(config["pubsub#max_items"])) end;
+local default_config = {
+ itemstore = function (config, _) return cache.new(config["max_items"]) end;
broadcaster = function () end;
+ itemcheck = function () return true; end;
get_affiliation = function () end;
+ normalize_jid = function (jid) return jid; end;
capabilities = {};
-} };
-local default_node_config = { __index = {
- ["pubsub#max_items"] = "20";
-} };
+};
+local default_config_mt = { __index = default_config };
+
+local default_node_config = {
+ ["persist_items"] = false;
+ ["max_items"] = 20;
+};
+local default_node_config_mt = { __index = default_node_config };
+
+-- Storage helper functions
+
+local function load_node_from_store(service, node_name)
+ local node = service.config.nodestore:get(node_name);
+ node.config = setmetatable(node.config or {}, {__index=service.node_defaults});
+ return node;
+end
+
+local function save_node_to_store(service, node)
+ return service.config.nodestore:set(node.name, {
+ name = node.name;
+ config = node.config;
+ subscribers = node.subscribers;
+ affiliations = node.affiliations;
+ });
+end
+
+local function delete_node_in_store(service, node_name)
+ return service.config.nodestore:set(node_name, nil);
+end
+-- Create and return a new service object
local function new(config)
config = config or {};
- return setmetatable({
- config = setmetatable(config, default_config);
- node_defaults = setmetatable(config.node_defaults or {}, default_node_config);
+
+ local service = setmetatable({
+ config = setmetatable(config, default_config_mt);
+ node_defaults = setmetatable(config.node_defaults or {}, default_node_config_mt);
affiliations = {};
subscriptions = {};
nodes = {};
data = {};
events = events.new();
}, service_mt);
+
+ -- Load nodes from storage, if we have a store and it supports iterating over stored items
+ if config.nodestore and config.nodestore.users then
+ for node_name in config.nodestore:users() do
+ service.nodes[node_name] = load_node_from_store(service, node_name);
+ service.data[node_name] = config.itemstore(service.nodes[node_name].config, node_name);
+ end
+ end
+
+ return service;
end
+--- Service methods
+
+local service = {};
+service_mt.__index = service;
+
function service:jids_equal(jid1, jid2)
local normalize = self.config.normalize_jid;
return normalize(jid1) == normalize(jid2);
@@ -36,7 +79,8 @@ function service:may(node, actor, action)
if actor == true then return true; end
local node_obj = self.nodes[node];
- local node_aff = node_obj and node_obj.affiliations[actor];
+ local node_aff = node_obj and (node_obj.affiliations[actor]
+ or node_obj.affiliations[self.config.normalize_jid(actor)]);
local service_aff = self.affiliations[actor]
or self.config.get_affiliation(actor, node, action)
or "none";
@@ -76,7 +120,18 @@ function service:set_affiliation(node, actor, jid, affiliation)
if not node_obj then
return false, "item-not-found";
end
+ jid = self.config.normalize_jid(jid);
+ local old_affiliation = node_obj.affiliations[jid];
node_obj.affiliations[jid] = affiliation;
+
+ if self.config.nodestore then
+ local ok, err = save_node_to_store(self, node_obj);
+ if not ok then
+ node_obj.affiliations[jid] = old_affiliation;
+ return ok, "internal-server-error";
+ end
+ end
+
local _, jid_sub = self:get_subscription(node, true, jid);
if not jid_sub and not self:may(node, jid, "be_unsubscribed") then
local ok, err = self:add_subscription(node, true, jid);
@@ -119,6 +174,7 @@ function service:add_subscription(node, actor, jid, options)
node_obj = self.nodes[node];
end
end
+ local old_subscription = node_obj.subscribers[jid];
node_obj.subscribers[jid] = options or true;
local normal_jid = self.config.normalize_jid(jid);
local subs = self.subscriptions[normal_jid];
@@ -131,6 +187,16 @@ function service:add_subscription(node, actor, jid, options)
else
self.subscriptions[normal_jid] = { [jid] = { [node] = true } };
end
+
+ if self.config.nodestore then
+ local ok, err = save_node_to_store(self, node_obj);
+ if not ok then
+ node_obj.subscribers[jid] = old_subscription;
+ self.subscriptions[normal_jid][jid][node] = old_subscription and true or nil;
+ return ok, "internal-server-error";
+ end
+ end
+
self.events.fire_event("subscription-added", { node = node, jid = jid, normalized_jid = normal_jid, options = options });
return true;
end
@@ -157,6 +223,7 @@ function service:remove_subscription(node, actor, jid)
if not node_obj.subscribers[jid] then
return false, "not-subscribed";
end
+ local old_subscription = node_obj.subscribers[jid];
node_obj.subscribers[jid] = nil;
local normal_jid = self.config.normalize_jid(jid);
local subs = self.subscriptions[normal_jid];
@@ -172,19 +239,17 @@ function service:remove_subscription(node, actor, jid)
self.subscriptions[normal_jid] = nil;
end
end
- self.events.fire_event("subscription-removed", { node = node, jid = jid, normalized_jid = normal_jid });
- return true;
-end
-function service:remove_all_subscriptions(actor, jid)
- local normal_jid = self.config.normalize_jid(jid);
- local subs = self.subscriptions[normal_jid]
- subs = subs and subs[jid];
- if subs then
- for node in pairs(subs) do
- self:remove_subscription(node, true, jid);
+ if self.config.nodestore then
+ local ok, err = save_node_to_store(self, node_obj);
+ if not ok then
+ node_obj.subscribers[jid] = old_subscription;
+ self.subscriptions[normal_jid][jid][node] = old_subscription and true or nil;
+ return ok, "internal-server-error";
end
end
+
+ self.events.fire_event("subscription-removed", { node = node, jid = jid, normalized_jid = normal_jid });
return true;
end
@@ -223,14 +288,27 @@ function service:create(node, actor, options)
config = setmetatable(options or {}, {__index=self.node_defaults});
affiliations = {};
};
- self.data[node] = self.config.itemstore(self.nodes[node].config);
+
+ if self.config.nodestore then
+ local ok, err = save_node_to_store(self, self.nodes[node]);
+ if not ok then
+ self.nodes[node] = nil;
+ return ok, "internal-server-error";
+ end
+ end
+
+ self.data[node] = self.config.itemstore(self.nodes[node].config, node);
self.events.fire_event("node-created", { node = node, actor = actor });
- local ok, err = self:set_affiliation(node, true, actor, "owner");
- if not ok then
- self.nodes[node] = nil;
- self.data[node] = nil;
+ if actor ~= true then
+ local ok, err = self:set_affiliation(node, true, actor, "owner");
+ if not ok then
+ self.nodes[node] = nil;
+ self.data[node] = nil;
+ return ok, err;
+ end
end
- return ok, err;
+
+ return true;
end
function service:delete(node, actor)
@@ -244,9 +322,21 @@ function service:delete(node, actor)
return false, "item-not-found";
end
self.nodes[node] = nil;
+ if self.data[node] and self.data[node].clear then
+ self.data[node]:clear();
+ end
self.data[node] = nil;
+
+ if self.config.nodestore then
+ local ok, err = delete_node_in_store(self, node);
+ if not ok then
+ self.nodes[node] = nil;
+ return ok, err;
+ end
+ end
+
self.events.fire_event("node-deleted", { node = node, actor = actor });
- self.config.broadcaster("delete", node, node_obj.subscribers);
+ self.config.broadcaster("delete", node, node_obj.subscribers, nil, actor, node_obj, self);
return true;
end
@@ -267,13 +357,17 @@ function service:publish(node, actor, id, item)
end
node_obj = self.nodes[node];
end
+ if not self.config.itemcheck(item) then
+ return nil, "internal-server-error";
+ end
local node_data = self.data[node];
local ok = node_data:set(id, item);
if not ok then
return nil, "internal-server-error";
end
+ if type(ok) == "string" then id = ok; end
self.events.fire_event("item-published", { node = node, actor = actor, id = id, item = item });
- self.config.broadcaster("items", node, node_obj.subscribers, item, actor);
+ self.config.broadcaster("items", node, node_obj.subscribers, item, actor, node_obj, self);
return true;
end
@@ -293,7 +387,7 @@ function service:retract(node, actor, id, retract)
end
self.events.fire_event("item-retracted", { node = node, actor = actor, id = id });
if retract then
- self.config.broadcaster("items", node, node_obj.subscribers, retract);
+ self.config.broadcaster("items", node, node_obj.subscribers, retract, actor, node_obj, self);
end
return true
end
@@ -308,10 +402,14 @@ function service:purge(node, actor, notify)
if not node_obj then
return false, "item-not-found";
end
- self.data[node] = self.config.itemstore(self.nodes[node].config);
+ if self.data[node] and self.data[node].clear then
+ self.data[node]:clear()
+ else
+ self.data[node] = self.config.itemstore(self.nodes[node].config, node);
+ end
self.events.fire_event("node-purged", { node = node, actor = actor });
if notify then
- self.config.broadcaster("purge", node, node_obj.subscribers);
+ self.config.broadcaster("purge", node, node_obj.subscribers, nil, actor, node_obj, self);
end
return true
end
@@ -327,7 +425,11 @@ function service:get_items(node, actor, id)
return false, "item-not-found";
end
if id then -- Restrict results to a single specific item
- return true, { id, [id] = self.data[node]:get(id) };
+ local with_id = self.data[node]:get(id);
+ if not with_id then
+ return true, { };
+ end
+ return true, { id, [id] = with_id };
else
local data = {}
for key, value in self.data[node]:items() do
@@ -338,6 +440,15 @@ function service:get_items(node, actor, id)
end
end
+function service:get_last_item(node, actor)
+ -- Access checking
+ if not self:may(node, actor, "get_items") then
+ return false, "forbidden";
+ end
+ --
+ return true, self.data[node]:tail();
+end
+
function service:get_nodes(actor)
-- Access checking
if not self:may(nil, actor, "get_nodes") then
@@ -347,6 +458,29 @@ function service:get_nodes(actor)
return true, self.nodes;
end
+local function flatten_subscriptions(ret, serv, subs, node, node_obj)
+ for subscribed_jid, subscribed_nodes in pairs(subs) do
+ if node then -- Return only subscriptions to this node
+ if subscribed_nodes[node] then
+ ret[#ret+1] = {
+ node = node;
+ jid = subscribed_jid;
+ subscription = node_obj.subscribers[subscribed_jid];
+ };
+ end
+ else -- Return subscriptions to all nodes
+ local nodes = serv.nodes;
+ for subscribed_node in pairs(subscribed_nodes) do
+ ret[#ret+1] = {
+ node = subscribed_node;
+ jid = subscribed_jid;
+ subscription = nodes[subscribed_node].subscribers[subscribed_jid];
+ };
+ end
+ end
+ end
+end
+
function service:get_subscriptions(node, actor, jid)
-- Access checking
local cap;
@@ -366,32 +500,19 @@ function service:get_subscriptions(node, actor, jid)
return false, "item-not-found";
end
end
+ local ret = {};
+ if jid == nil then
+ for _, subs in pairs(self.subscriptions) do
+ flatten_subscriptions(ret, self, subs, node, node_obj)
+ end
+ return true, ret;
+ end
local normal_jid = self.config.normalize_jid(jid);
local subs = self.subscriptions[normal_jid];
-- We return the subscription object from the node to save
-- a get_subscription() call for each node.
- local ret = {};
if subs then
- for subscribed_jid, subscribed_nodes in pairs(subs) do
- if node then -- Return only subscriptions to this node
- if subscribed_nodes[node] then
- ret[#ret+1] = {
- node = node;
- jid = subscribed_jid;
- subscription = node_obj.subscribers[subscribed_jid];
- };
- end
- else -- Return subscriptions to all nodes
- local nodes = self.nodes;
- for subscribed_node in pairs(subscribed_nodes) do
- ret[#ret+1] = {
- node = subscribed_node;
- jid = subscribed_jid;
- subscription = nodes[subscribed_node].subscribers[subscribed_jid];
- };
- end
- end
- end
+ flatten_subscriptions(ret, self, subs, node, node_obj)
end
return true, ret;
end
@@ -421,14 +542,23 @@ function service:set_node_config(node, actor, new_config)
return false, "item-not-found";
end
- for k,v in pairs(new_config) do
- node_obj.config[k] = v;
+ local old_config = node_obj.config;
+ node_obj.config = setmetatable(new_config, {__index=self.node_defaults});
+
+ if self.config.nodestore then
+ local ok, err = save_node_to_store(self, node_obj);
+ if not ok then
+ node_obj.config = old_config;
+ return ok, "internal-server-error";
+ end
end
- local new_data = self.config.itemstore(self.nodes[node].config);
- for key, value in self.data[node]:items() do
- new_data:set(key, value);
+
+ if old_config["persist_items"] ~= node_obj.config["persist_items"] then
+ self.data[node] = self.config.itemstore(self.nodes[node].config, node);
+ elseif old_config["max_items"] ~= node_obj.config["max_items"] then
+ self.data[node]:resize(self.nodes[node].config["max_items"]);
end
- self.data[node] = new_data;
+
return true;
end
diff --git a/util/random.lua b/util/random.lua
index b2d0000d..d8a84514 100644
--- a/util/random.lua
+++ b/util/random.lua
@@ -11,9 +11,6 @@ if ok then return crand; end
local urandom, urandom_err = io.open("/dev/urandom", "r");
-local function seed()
-end
-
local function bytes(n)
return urandom:read(n);
end
@@ -25,7 +22,6 @@ if not urandom then
end
return {
- seed = seed;
bytes = bytes;
_source = "/dev/urandom";
};
diff --git a/util/sasl.lua b/util/sasl.lua
index 5845f34a..50851405 100644
--- a/util/sasl.lua
+++ b/util/sasl.lua
@@ -20,6 +20,7 @@ local assert = assert;
local require = require;
local _ENV = nil;
+-- luacheck: std none
--[[
Authentication Backend Prototypes:
@@ -42,7 +43,7 @@ Example:
local method = {};
method.__index = method;
-local mechanisms = {};
+local registered_mechanisms = {};
local backend_mechanism = {};
local mechanism_channelbindings = {};
@@ -52,7 +53,7 @@ local function registerMechanism(name, backends, f, cb_backends)
assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table.");
assert(type(f) == "function", "Parameter f MUST be a function.");
if cb_backends then assert(type(cb_backends) == "table"); end
- mechanisms[name] = f
+ registered_mechanisms[name] = f
if cb_backends then
mechanism_channelbindings[name] = {};
for _, cb_name in ipairs(cb_backends) do
@@ -70,7 +71,7 @@ local function new(realm, profile)
local mechanisms = profile.mechanisms;
if not mechanisms then
mechanisms = {};
- for backend, f in pairs(profile) do
+ for backend in pairs(profile) do
if backend_mechanism[backend] then
for _, mechanism in ipairs(backend_mechanism[backend]) do
mechanisms[mechanism] = true;
@@ -128,7 +129,7 @@ end
-- feed new messages to process into the library
function method:process(message)
--if message == "" or message == nil then return "failure", "malformed-request" end
- return mechanisms[self.selected](self, message);
+ return registered_mechanisms[self.selected](self, message);
end
-- load the mechanisms
diff --git a/util/sasl/anonymous.lua b/util/sasl/anonymous.lua
index 6201db32..de98a5e2 100644
--- a/util/sasl/anonymous.lua
+++ b/util/sasl/anonymous.lua
@@ -12,9 +12,10 @@
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-local generate_uuid = require "util.uuid".generate;
+local generate_random_id = require "util.id".medium;
local _ENV = nil;
+-- luacheck: std none
--=========================
--SASL ANONYMOUS according to RFC 4505
@@ -28,10 +29,10 @@ anonymous:
end
]]
-local function anonymous(self, message)
+local function anonymous(self, message) -- luacheck: ignore 212/message
local username;
repeat
- username = generate_uuid();
+ username = generate_random_id():lower();
until self.profile.anonymous(self, username, self.realm);
self.username = username;
return "success"
diff --git a/util/sasl/digest-md5.lua b/util/sasl/digest-md5.lua
index 695dd2a3..7542a037 100644
--- a/util/sasl/digest-md5.lua
+++ b/util/sasl/digest-md5.lua
@@ -26,6 +26,7 @@ local generate_uuid = require "util.uuid".generate;
local nodeprep = require "util.encodings".stringprep.nodeprep;
local _ENV = nil;
+-- luacheck: std none
--=========================
--SASL DIGEST-MD5 according to RFC 2831
diff --git a/util/sasl/external.lua b/util/sasl/external.lua
index 5ba90190..ce50743e 100644
--- a/util/sasl/external.lua
+++ b/util/sasl/external.lua
@@ -1,6 +1,7 @@
local saslprep = require "util.encodings".stringprep.saslprep;
local _ENV = nil;
+-- luacheck: std none
local function external(self, message)
message = saslprep(message);
diff --git a/util/sasl/plain.lua b/util/sasl/plain.lua
index cd59b1ac..00c6bd20 100644
--- a/util/sasl/plain.lua
+++ b/util/sasl/plain.lua
@@ -17,6 +17,7 @@ local nodeprep = require "util.encodings".stringprep.nodeprep;
local log = require "util.logger".init("sasl");
local _ENV = nil;
+-- luacheck: std none
-- ================================
-- SASL PLAIN according to RFC 4616
diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua
index 4e20dbb9..043f328b 100644
--- a/util/sasl/scram.lua
+++ b/util/sasl/scram.lua
@@ -26,6 +26,7 @@ local char = string.char;
local byte = string.byte;
local _ENV = nil;
+-- luacheck: std none
--=========================
--SASL SCRAM-SHA-1 according to RFC 5802
@@ -46,7 +47,18 @@ Supported Channel Binding Backends
local default_i = 4096
-local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;};
+local xor_map = {
+ 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,1,0,3,2,5,4,7,6,9,8,11,10,
+ 13,12,15,14,2,3,0,1,6,7,4,5,10,11,8,9,14,15,12,13,3,2,1,0,7,6,5,
+ 4,11,10,9,8,15,14,13,12,4,5,6,7,0,1,2,3,12,13,14,15,8,9,10,11,5,
+ 4,7,6,1,0,3,2,13,12,15,14,9,8,11,10,6,7,4,5,2,3,0,1,14,15,12,13,
+ 10,11,8,9,7,6,5,4,3,2,1,0,15,14,13,12,11,10,9,8,8,9,10,11,12,13,
+ 14,15,0,1,2,3,4,5,6,7,9,8,11,10,13,12,15,14,1,0,3,2,5,4,7,6,10,
+ 11,8,9,14,15,12,13,2,3,0,1,6,7,4,5,11,10,9,8,15,14,13,12,3,2,1,
+ 0,7,6,5,4,12,13,14,15,8,9,10,11,4,5,6,7,0,1,2,3,13,12,15,14,9,8,
+ 11,10,5,4,7,6,1,0,3,2,14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1,15,
+ 14,13,12,11,10,9,8,7,6,5,4,3,2,1,0,
+};
local result = {};
local function binaryXOR( a, b )
@@ -148,7 +160,7 @@ local function scram_gen(hash_name, H_f, HMAC_f)
end
self.username = username;
- -- retreive credentials
+ -- retrieve credentials
local stored_key, server_key, salt, iteration_count;
if self.profile.plain then
local password, status = self.profile.plain(self, username, self.realm)
@@ -237,10 +249,14 @@ end
local function init(registerMechanism)
local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
- registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash));
+ registerMechanism("SCRAM-"..hash_name,
+ {"plain", "scram_"..(hashprep(hash_name))},
+ scram_gen(hash_name:lower(), hash, hmac_hash));
-- register channel binding equivalent
- registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
+ registerMechanism("SCRAM-"..hash_name.."-PLUS",
+ {"plain", "scram_"..(hashprep(hash_name))},
+ scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
end
registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
diff --git a/util/sasl_cyrus.lua b/util/sasl_cyrus.lua
index 4e9a4af5..a6bd0628 100644
--- a/util/sasl_cyrus.lua
+++ b/util/sasl_cyrus.lua
@@ -61,6 +61,7 @@ local sasl_errstring = {
setmetatable(sasl_errstring, { __index = function() return "undefined error!" end });
local _ENV = nil;
+-- luacheck: std none
local method = {};
method.__index = method;
diff --git a/util/serialization.lua b/util/serialization.lua
index 206f5fbb..960794f2 100644
--- a/util/serialization.lua
+++ b/util/serialization.lua
@@ -1,83 +1,262 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
+-- Copyright (C) 2018 Kim Alvefur
--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
-local string_rep = string.rep;
-local type = type;
-local tostring = tostring;
-local t_insert = table.insert;
+local getmetatable = getmetatable;
+local next, type = next, type;
+local s_format = string.format;
+local s_gsub = string.gsub;
+local s_rep = string.rep;
+local s_char = string.char;
+local s_match = string.match;
local t_concat = table.concat;
-local pairs = pairs;
-local next = next;
local pcall = pcall;
-
-local debug_traceback = debug.traceback;
-local log = require "util.logger".init("serialization");
local envload = require"util.envload".envload;
-local _ENV = nil;
+local pos_inf, neg_inf = math.huge, -math.huge;
+local m_log = math.log;
+local m_log10 = math.log10 or function (n)
+ return m_log(n, 10);
+end
+local m_floor = math.floor;
+-- luacheck: ignore 143/math
+local m_type = math.type or function (n)
+ return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float";
+end;
+
+local char_to_hex = {};
+for i = 0,255 do
+ char_to_hex[s_char(i)] = s_format("%02x", i);
+end
+
+local function to_hex(s)
+ return (s_gsub(s, ".", char_to_hex));
+end
+
+local function fatal_error(obj, why)
+ error("Can't serialize "..type(obj) .. (why and ": ".. why or ""));
+end
-local indent = function(i)
- return string_rep("\t", i);
+local function default_fallback(x, why)
+ return s_format("nil --[[%s: %s]]", type(x), why or "fail");
end
-local function basicSerialize (o)
- if type(o) == "number" or type(o) == "boolean" then
- -- no need to check for NaN, as that's not a valid table index
- if o == 1/0 then return "(1/0)";
- elseif o == -1/0 then return "(-1/0)";
- else return tostring(o); end
- else -- assume it is a string -- FIXME make sure it's a string. throw an error otherwise.
- return (("%q"):format(tostring(o)):gsub("\\\n", "\\n"));
+
+local string_escapes = {
+ ['\a'] = [[\a]]; ['\b'] = [[\b]];
+ ['\f'] = [[\f]]; ['\n'] = [[\n]];
+ ['\r'] = [[\r]]; ['\t'] = [[\t]];
+ ['\v'] = [[\v]]; ['\\'] = [[\\]];
+ ['\"'] = [[\"]]; ['\''] = [[\']];
+}
+
+for i = 0, 255 do
+ local c = s_char(i);
+ if not string_escapes[c] then
+ string_escapes[c] = s_format("\\%03d", i);
end
end
-local function _simplesave(o, ind, t, func)
- if type(o) == "number" then
- if o ~= o then func(t, "(0/0)");
- elseif o == 1/0 then func(t, "(1/0)");
- elseif o == -1/0 then func(t, "(-1/0)");
- else func(t, tostring(o)); end
- elseif type(o) == "string" then
- func(t, (("%q"):format(o):gsub("\\\n", "\\n")));
- elseif type(o) == "table" then
- if next(o) ~= nil then
- func(t, "{\n");
- for k,v in pairs(o) do
- func(t, indent(ind));
- func(t, "[");
- func(t, basicSerialize(k));
- func(t, "] = ");
- if ind == 0 then
- _simplesave(v, 0, t, func);
+
+local default_keywords = {
+ ["do"] = true; ["and"] = true; ["else"] = true; ["break"] = true;
+ ["if"] = true; ["end"] = true; ["goto"] = true; ["false"] = true;
+ ["in"] = true; ["for"] = true; ["then"] = true; ["local"] = true;
+ ["or"] = true; ["nil"] = true; ["true"] = true; ["until"] = true;
+ ["elseif"] = true; ["function"] = true; ["not"] = true;
+ ["repeat"] = true; ["return"] = true; ["while"] = true;
+};
+
+local function new(opt)
+ if type(opt) ~= "table" then
+ opt = { preset = opt };
+ end
+
+ local types = {
+ table = true;
+ string = true;
+ number = true;
+ boolean = true;
+ ["nil"] = true;
+ };
+
+ -- presets
+ if opt.preset == "debug" then
+ opt.preset = "oneline";
+ opt.freeze = true;
+ opt.fatal = false;
+ opt.fallback = default_fallback;
+ end
+ if opt.preset == "oneline" then
+ opt.indentwith = opt.indentwith or "";
+ opt.itemstart = opt.itemstart or " ";
+ opt.itemlast = opt.itemlast or "";
+ opt.tend = opt.tend or " }";
+ elseif opt.preset == "compact" then
+ opt.indentwith = opt.indentwith or "";
+ opt.itemstart = opt.itemstart or "";
+ opt.itemlast = opt.itemlast or "";
+ opt.equals = opt.equals or "=";
+ end
+
+ local fallback = opt.fatal and fatal_error or opt.fallback or default_fallback;
+
+ local function ser(v)
+ return (types[type(v)] or fallback)(v);
+ end
+
+ local keywords = opt.keywords or default_keywords;
+
+ -- indented
+ local indentwith = opt.indentwith or "\t";
+ local itemstart = opt.itemstart or "\n";
+ local itemsep = opt.itemsep or ";";
+ local itemlast = opt.itemlast or ";\n";
+ local tstart = opt.tstart or "{";
+ local tend = opt.tend or "}";
+ local kstart = opt.kstart or "[";
+ local kend = opt.kend or "]";
+ local equals = opt.equals or " = ";
+ local unquoted = opt.unquoted == nil and "^[%a_][%w_]*$" or opt.unquoted;
+ local hex = opt.hex;
+ local freeze = opt.freeze;
+ local precision = opt.precision or 10;
+
+ -- serialize one table, recursively
+ -- t - table being serialized
+ -- o - array where tokens are added, concatenate to get final result
+ -- - also used to detect cycles
+ -- l - position in o of where to insert next token
+ -- d - depth, used for indentation
+ local function serialize_table(t, o, l, d)
+ if o[t] or d > 127 then
+ o[l], l = fallback(t, "recursion"), l + 1;
+ return l;
+ end
+
+ o[t] = true;
+ if freeze then
+ -- opportunity to do pre-serialization
+ local mt = getmetatable(t);
+ local fr = (freeze ~= true and freeze[mt]);
+ local mf = mt and mt.__freeze;
+ local tag;
+ if type(fr) == "string" then
+ tag = fr;
+ fr = mf;
+ elseif mt then
+ tag = mt.__type;
+ end
+ if type(fr) == "function" then
+ t = fr(t);
+ if type(tag) == "string" then
+ o[l], l = tag, l + 1;
+ end
+ end
+ end
+ o[l], l = tstart, l + 1;
+ local indent = s_rep(indentwith, d);
+ local numkey = 1;
+ local ktyp, vtyp;
+ for k,v in next,t do
+ o[l], l = itemstart, l + 1;
+ o[l], l = indent, l + 1;
+ ktyp, vtyp = type(k), type(v);
+ if k == numkey then
+ -- next index in array part
+ -- assuming that these are found in order
+ numkey = numkey + 1;
+ elseif unquoted and ktyp == "string" and
+ not keywords[k] and s_match(k, unquoted) then
+ -- unquoted keys
+ o[l], l = k, l + 1;
+ o[l], l = equals, l + 1;
+ else
+ -- quoted keys
+ o[l], l = kstart, l + 1;
+ if ktyp == "table" then
+ l = serialize_table(k, o, l, d+1);
else
- _simplesave(v, ind+1, t, func);
+ o[l], l = ser(k), l + 1;
end
- func(t, ";\n");
+ -- =
+ o[l], o[l+1], l = kend, equals, l + 2;
end
- func(t, indent(ind-1));
- func(t, "}");
- else
- func(t, "{}");
+
+ -- the value
+ if vtyp == "table" then
+ l = serialize_table(v, o, l, d+1);
+ else
+ o[l], l = ser(v), l + 1;
+ end
+ -- last item?
+ if next(t, k) ~= nil then
+ o[l], l = itemsep, l + 1;
+ else
+ o[l], l = itemlast, l + 1;
+ end
+ end
+ if next(t) ~= nil then
+ o[l], l = s_rep(indentwith, d-1), l + 1;
+ end
+ o[l], l = tend, l +1;
+ return l;
+ end
+
+ function types.table(t)
+ local o = {};
+ serialize_table(t, o, 1, 1);
+ return t_concat(o);
+ end
+
+ local function serialize_string(s)
+ return '"' .. s_gsub(s, "[%z\1-\31\"\'\\\127-\255]", string_escapes) .. '"';
+ end
+
+ if hex then
+ function types.string(s)
+ local esc = serialize_string(s);
+ if #esc > (#s*2+2+#hex) then
+ return hex .. '"' .. to_hex(s) .. '"';
+ end
+ return esc;
end
- elseif type(o) == "boolean" then
- func(t, (o and "true" or "false"));
else
- log("error", "cannot serialize a %s: %s", type(o), debug_traceback())
- func(t, "nil");
+ types.string = serialize_string;
end
-end
-local function append(t, o)
- _simplesave(o, 1, t, t.write or t_insert);
- return t;
-end
+ function types.number(t)
+ if m_type(t) == "integer" then
+ return s_format("%d", t);
+ elseif t == pos_inf then
+ return "(1/0)";
+ elseif t == neg_inf then
+ return "(-1/0)";
+ elseif t ~= t then
+ return "(0/0)";
+ end
+ local log = m_floor(m_log10(t));
+ if log > precision then
+ return s_format("%.18e", t);
+ else
+ return s_format("%.18g", t);
+ end
+ end
+
+ -- Are these faster than tostring?
+ types["nil"] = function()
+ return "nil";
+ end
+
+ function types.boolean(t)
+ return t and "true" or "false";
+ end
-local function serialize(o)
- return t_concat(append({}, o));
+ return ser;
end
local function deserialize(str)
@@ -91,7 +270,9 @@ local function deserialize(str)
end
return {
- append = append;
- serialize = serialize;
+ new = new;
+ serialize = function (x, opt)
+ return new(opt)(x);
+ end;
deserialize = deserialize;
};
diff --git a/util/set.lua b/util/set.lua
index c136a522..a4f20138 100644
--- a/util/set.lua
+++ b/util/set.lua
@@ -11,8 +11,9 @@ local ipairs, pairs, setmetatable, next, tostring =
local t_concat = table.concat;
local _ENV = nil;
+-- luacheck: std none
-local set_mt = {};
+local set_mt = { __name = "set" };
function set_mt.__call(set, _, k)
return next(set._items, k);
end
diff --git a/util/sql.lua b/util/sql.lua
index d964025e..67a5d683 100644
--- a/util/sql.lua
+++ b/util/sql.lua
@@ -1,11 +1,10 @@
local setmetatable, getmetatable = setmetatable, getmetatable;
-local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113
-local tonumber, tostring = tonumber, tostring;
+local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113 143
+local tostring = tostring;
local type = type;
local assert, pcall, xpcall, debug_traceback = assert, pcall, xpcall, debug.traceback;
local t_concat = table.concat;
-local s_char = string.char;
local log = require "util.logger".init("sql");
local DBI = require "DBI";
@@ -15,6 +14,7 @@ DBI.Drivers();
local build_url = require "socket.url".build;
local _ENV = nil;
+-- luacheck: std none
local column_mt = {};
local table_mt = {};
@@ -58,9 +58,6 @@ table_mt.__index = {};
function table_mt.__index:create(engine)
return engine:_create_table(self);
end
-function table_mt:__call(...)
- -- TODO
-end
function column_mt:__tostring()
return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
end
@@ -71,31 +68,6 @@ function index_mt:__tostring()
-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
end
-local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
-local function parse_url(url)
- local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)");
- assert(scheme, "Invalid URL format");
- local username, password, host, port;
- local authpart, hostpart = secondpart:match("([^@]+)@([^@+])");
- if not authpart then hostpart = secondpart; end
- if authpart then
- username, password = authpart:match("([^:]*):(.*)");
- username = username or authpart;
- password = password and urldecode(password);
- end
- if hostpart then
- host, port = hostpart:match("([^:]*):(.*)");
- host = host or hostpart;
- port = port and assert(tonumber(port), "Invalid URL format");
- end
- return {
- scheme = scheme:lower();
- username = username; password = password;
- host = host; port = port;
- database = #database > 0 and database or nil;
- };
-end
-
local engine = {};
function engine:connect()
if self.conn then return true; end
@@ -123,7 +95,7 @@ function engine:connect()
end
return true;
end
-function engine:onconnect()
+function engine:onconnect() -- luacheck: ignore 212/self
-- Override from create_engine()
end
@@ -148,6 +120,7 @@ function engine:execute(sql, ...)
prepared[sql] = stmt;
end
+ -- luacheck: ignore 411/success
local success, err = stmt:execute(...);
if not success then return success, err; end
return stmt;
@@ -161,14 +134,14 @@ local result_mt = { __index = {
local function debugquery(where, sql, ...)
local i = 0; local a = {...}
sql = sql:gsub("\n?\t+", " ");
- log("debug", "[%s] %s", where, sql:gsub("%?", function ()
+ log("debug", "[%s] %s", where, (sql:gsub("%?", function ()
i = i + 1;
local v = a[i];
if type(v) == "string" then
v = ("'%s'"):format(v:gsub("'", "''"));
end
return tostring(v);
- end));
+ end)));
end
function engine:execute_query(sql, ...)
@@ -335,7 +308,12 @@ function engine:set_encoding() -- to UTF-8
local charset = "utf8";
if driver == "MySQL" then
self:transaction(function()
- for row in self:select"SELECT \"CHARACTER_SET_NAME\" FROM \"information_schema\".\"CHARACTER_SETS\" WHERE \"CHARACTER_SET_NAME\" LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;" do
+ for row in self:select[[
+ SELECT "CHARACTER_SET_NAME"
+ FROM "information_schema"."CHARACTER_SETS"
+ WHERE "CHARACTER_SET_NAME" LIKE 'utf8%'
+ ORDER BY MAXLEN DESC LIMIT 1;
+ ]] do
charset = row and row[1] or charset;
end
end);
@@ -379,7 +357,7 @@ local function db2uri(params)
};
end
-local function create_engine(self, params, onconnect)
+local function create_engine(_, params, onconnect)
return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt);
end
diff --git a/util/sslconfig.lua b/util/sslconfig.lua
index 4c4e1d48..5c685f7d 100644
--- a/util/sslconfig.lua
+++ b/util/sslconfig.lua
@@ -8,6 +8,7 @@ local t_insert = table.insert;
local setmetatable = setmetatable;
local _ENV = nil;
+-- luacheck: std none
local handlers = { };
local finalisers = { };
diff --git a/util/stanza.lua b/util/stanza.lua
index 07365144..11398179 100644
--- a/util/stanza.lua
+++ b/util/stanza.lua
@@ -7,6 +7,7 @@
--
+local error = error;
local t_insert = table.insert;
local t_remove = table.remove;
local t_concat = table.concat;
@@ -23,6 +24,8 @@ local s_sub = string.sub;
local s_find = string.find;
local os = os;
+local valid_utf8 = require "util.encodings".utf8.valid;
+
local do_pretty_printing = not os.getenv("WINDIR");
local getstyle, getstring;
if do_pretty_printing then
@@ -37,12 +40,52 @@ end
local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas";
local _ENV = nil;
+-- luacheck: std none
-local stanza_mt = { __type = "stanza" };
+local stanza_mt = { __name = "stanza" };
stanza_mt.__index = stanza_mt;
-local function new_stanza(name, attr)
- local stanza = { name = name, attr = attr or {}, tags = {} };
+local function check_name(name, name_type)
+ if type(name) ~= "string" then
+ error("invalid "..name_type.." name: expected string, got "..type(name));
+ elseif #name == 0 then
+ error("invalid "..name_type.." name: empty string");
+ elseif s_find(name, "[<>& '\"]") then
+ error("invalid "..name_type.." name: contains invalid characters");
+ elseif not valid_utf8(name) then
+ error("invalid "..name_type.." name: contains invalid utf8");
+ end
+end
+
+local function check_text(text, text_type)
+ if type(text) ~= "string" then
+ error("invalid "..text_type.." value: expected string, got "..type(text));
+ elseif not valid_utf8(text) then
+ error("invalid "..text_type.." value: contains invalid utf8");
+ end
+end
+
+local function check_attr(attr)
+ if attr ~= nil then
+ if type(attr) ~= "table" then
+ error("invalid attributes, expected table got "..type(attr));
+ end
+ for k, v in pairs(attr) do
+ check_name(k, "attribute");
+ check_text(v, "attribute");
+ if type(v) ~= "string" then
+ error("invalid attribute value for '"..k.."': expected string, got "..type(v));
+ elseif not valid_utf8(v) then
+ error("invalid attribute value for '"..k.."': contains invalid utf8");
+ end
+ end
+ end
+end
+
+local function new_stanza(name, attr, namespaces)
+ check_name(name, "tag");
+ check_attr(attr);
+ local stanza = { name = name, attr = attr or {}, namespaces = namespaces, tags = {} };
return setmetatable(stanza, stanza_mt);
end
@@ -58,8 +101,12 @@ function stanza_mt:body(text, attr)
return self:tag("body", attr):text(text);
end
-function stanza_mt:tag(name, attrs)
- local s = new_stanza(name, attrs);
+function stanza_mt:text_tag(name, text, attr, namespaces)
+ return self:tag(name, attr, namespaces):text(text):up();
+end
+
+function stanza_mt:tag(name, attr, namespaces)
+ local s = new_stanza(name, attr, namespaces);
local last_add = self.last_add;
if not last_add then last_add = {}; self.last_add = last_add; end
(last_add[#last_add] or self):add_direct_child(s);
@@ -68,8 +115,10 @@ function stanza_mt:tag(name, attrs)
end
function stanza_mt:text(text)
- local last_add = self.last_add;
- (last_add and last_add[#last_add] or self):add_direct_child(text);
+ if text ~= nil and text ~= "" then
+ local last_add = self.last_add;
+ (last_add and last_add[#last_add] or self):add_direct_child(text);
+ end
return self;
end
@@ -85,10 +134,13 @@ function stanza_mt:reset()
end
function stanza_mt:add_direct_child(child)
- if type(child) == "table" then
+ if is_stanza(child) then
t_insert(self.tags, child);
+ t_insert(self, child);
+ else
+ check_text(child, "text");
+ t_insert(self, child);
end
- t_insert(self, child);
end
function stanza_mt:add_child(child)
@@ -347,7 +399,12 @@ end
local function clone(stanza)
local attr, tags = {}, {};
for k,v in pairs(stanza.attr) do attr[k] = v; end
- local new = { name = stanza.name, attr = attr, tags = tags };
+ local old_namespaces, namespaces = stanza.namespaces;
+ if old_namespaces then
+ namespaces = {};
+ for k,v in pairs(old_namespaces) do namespaces[k] = v; end
+ end
+ local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
for i=1,#stanza do
local child = stanza[i];
if child.name then
@@ -372,7 +429,13 @@ local function iq(attr)
end
local function reply(orig)
- return new_stanza(orig.name, orig.attr and { to = orig.attr.from, from = orig.attr.to, id = orig.attr.id, type = ((orig.name == "iq" and "result") or orig.attr.type) });
+ return new_stanza(orig.name,
+ orig.attr and {
+ to = orig.attr.from,
+ from = orig.attr.to,
+ id = orig.attr.id,
+ type = ((orig.name == "iq" and "result") or orig.attr.type)
+ });
end
local xmpp_stanzas_attr = { xmlns = xmlns_stanzas };
diff --git a/util/startup.lua b/util/startup.lua
new file mode 100644
index 00000000..4ae45bdf
--- /dev/null
+++ b/util/startup.lua
@@ -0,0 +1,551 @@
+-- Ignore the CFG_* variables
+-- luacheck: ignore 113/CFG_CONFIGDIR 113/CFG_SOURCEDIR 113/CFG_DATADIR 113/CFG_PLUGINDIR
+local startup = {};
+
+local prosody = { events = require "util.events".new() };
+local logger = require "util.logger";
+local log = logger.init("startup");
+
+local config = require "core.configmanager";
+
+local dependencies = require "util.dependencies";
+
+local original_logging_config;
+
+function startup.read_config()
+ local filenames = {};
+
+ local filename;
+ if arg[1] == "--config" and arg[2] then
+ table.insert(filenames, arg[2]);
+ if CFG_CONFIGDIR then
+ table.insert(filenames, CFG_CONFIGDIR.."/"..arg[2]);
+ end
+ table.remove(arg, 1); table.remove(arg, 1);
+ elseif os.getenv("PROSODY_CONFIG") then -- Passed by prosodyctl
+ table.insert(filenames, os.getenv("PROSODY_CONFIG"));
+ else
+ table.insert(filenames, (CFG_CONFIGDIR or ".").."/prosody.cfg.lua");
+ end
+ for _,_filename in ipairs(filenames) do
+ filename = _filename;
+ local file = io.open(filename);
+ if file then
+ file:close();
+ prosody.config_file = filename;
+ CFG_CONFIGDIR = filename:match("^(.*)[\\/][^\\/]*$"); -- luacheck: ignore 111
+ break;
+ end
+ end
+ prosody.config_file = filename
+ local ok, level, err = config.load(filename);
+ if not ok then
+ print("\n");
+ print("**************************");
+ if level == "parser" then
+ print("A problem occurred while reading the config file "..filename);
+ print("");
+ local err_line, err_message = tostring(err):match("%[string .-%]:(%d*): (.*)");
+ if err:match("chunk has too many syntax levels$") then
+ print("An Include statement in a config file is including an already-included");
+ print("file and causing an infinite loop. An Include statement in a config file is...");
+ else
+ print("Error"..(err_line and (" on line "..err_line) or "")..": "..(err_message or tostring(err)));
+ end
+ print("");
+ elseif level == "file" then
+ print("Prosody was unable to find the configuration file.");
+ print("We looked for: "..filename);
+ print("A sample config file is included in the Prosody download called prosody.cfg.lua.dist");
+ print("Copy or rename it to prosody.cfg.lua and edit as necessary.");
+ end
+ print("More help on configuring Prosody can be found at https://prosody.im/doc/configure");
+ print("Good luck!");
+ print("**************************");
+ print("");
+ os.exit(1);
+ end
+end
+
+function startup.check_dependencies()
+ if not dependencies.check_dependencies() then
+ os.exit(1);
+ end
+end
+
+-- luacheck: globals socket server
+
+function startup.load_libraries()
+ -- Load socket framework
+ -- luacheck: ignore 111/server 111/socket
+ socket = require "socket";
+ server = require "net.server"
+end
+
+function startup.init_logging()
+ -- Initialize logging
+ local loggingmanager = require "core.loggingmanager"
+ loggingmanager.reload_logging();
+ prosody.events.add_handler("reopen-log-files", function ()
+ loggingmanager.reload_logging();
+ prosody.events.fire_event("logging-reloaded");
+ end);
+end
+
+function startup.log_dependency_warnings()
+ dependencies.log_warnings();
+end
+
+function startup.sanity_check()
+ for host, host_config in pairs(config.getconfig()) do
+ if host ~= "*"
+ and host_config.enabled ~= false
+ and not host_config.component_module then
+ return;
+ end
+ end
+ log("error", "No enabled VirtualHost entries found in the config file.");
+ log("error", "At least one active host is required for Prosody to function. Exiting...");
+ os.exit(1);
+end
+
+function startup.sandbox_require()
+ -- Replace require() with one that doesn't pollute _G, required
+ -- for neat sandboxing of modules
+ -- luacheck: ignore 113/getfenv 111/require
+ local _realG = _G;
+ local _real_require = require;
+ local getfenv = getfenv or function (f)
+ -- FIXME: This is a hack to replace getfenv() in Lua 5.2
+ local name, env = debug.getupvalue(debug.getinfo(f or 1).func, 1);
+ if name == "_ENV" then
+ return env;
+ end
+ end
+ function require(...) -- luacheck: ignore 121
+ local curr_env = getfenv(2);
+ local curr_env_mt = getmetatable(curr_env);
+ local _realG_mt = getmetatable(_realG);
+ if curr_env_mt and curr_env_mt.__index and not curr_env_mt.__newindex and _realG_mt then
+ local old_newindex, old_index;
+ old_newindex, _realG_mt.__newindex = _realG_mt.__newindex, curr_env;
+ old_index, _realG_mt.__index = _realG_mt.__index, function (_G, k) -- luacheck: ignore 212/_G
+ return rawget(curr_env, k);
+ end;
+ local ret = _real_require(...);
+ _realG_mt.__newindex = old_newindex;
+ _realG_mt.__index = old_index;
+ return ret;
+ end
+ return _real_require(...);
+ end
+end
+
+function startup.set_function_metatable()
+ local mt = {};
+ function mt.__index(f, upvalue)
+ local i, name, value = 0;
+ repeat
+ i = i + 1;
+ name, value = debug.getupvalue(f, i);
+ until name == upvalue or name == nil;
+ return value;
+ end
+ function mt.__newindex(f, upvalue, value)
+ local i, name = 0;
+ repeat
+ i = i + 1;
+ name = debug.getupvalue(f, i);
+ until name == upvalue or name == nil;
+ if name then
+ debug.setupvalue(f, i, value);
+ end
+ end
+ function mt.__tostring(f)
+ local info = debug.getinfo(f);
+ return ("function(%s:%d)"):format(info.short_src:match("[^\\/]*$"), info.linedefined);
+ end
+ debug.setmetatable(function() end, mt);
+end
+
+function startup.detect_platform()
+ prosody.platform = "unknown";
+ if os.getenv("WINDIR") then
+ prosody.platform = "windows";
+ elseif package.config:sub(1,1) == "/" then
+ prosody.platform = "posix";
+ end
+end
+
+function startup.detect_installed()
+ prosody.installed = nil;
+ if CFG_SOURCEDIR and (prosody.platform == "windows" or CFG_SOURCEDIR:match("^/")) then
+ prosody.installed = true;
+ end
+end
+
+function startup.init_global_state()
+ -- luacheck: ignore 121
+ prosody.bare_sessions = {};
+ prosody.full_sessions = {};
+ prosody.hosts = {};
+
+ -- COMPAT: These globals are deprecated
+ -- luacheck: ignore 111/bare_sessions 111/full_sessions 111/hosts
+ bare_sessions = prosody.bare_sessions;
+ full_sessions = prosody.full_sessions;
+ hosts = prosody.hosts;
+
+ prosody.paths = { source = CFG_SOURCEDIR, config = CFG_CONFIGDIR or ".",
+ plugins = CFG_PLUGINDIR or "plugins", data = "data" };
+
+ prosody.arg = _G.arg;
+
+ _G.log = logger.init("general");
+ prosody.log = logger.init("general");
+
+ startup.detect_platform();
+ startup.detect_installed();
+ _G.prosody = prosody;
+end
+
+function startup.setup_datadir()
+ prosody.paths.data = config.get("*", "data_path") or CFG_DATADIR or "data";
+end
+
+function startup.setup_plugindir()
+ local custom_plugin_paths = config.get("*", "plugin_paths");
+ if custom_plugin_paths then
+ local path_sep = package.config:sub(3,3);
+ -- path1;path2;path3;defaultpath...
+ -- luacheck: ignore 111
+ CFG_PLUGINDIR = table.concat(custom_plugin_paths, path_sep)..path_sep..(CFG_PLUGINDIR or "plugins");
+ prosody.paths.plugins = CFG_PLUGINDIR;
+ end
+end
+
+function startup.chdir()
+ if prosody.installed then
+ -- Change working directory to data path.
+ require "lfs".chdir(prosody.paths.data);
+ end
+end
+
+function startup.add_global_prosody_functions()
+ -- Function to reload the config file
+ function prosody.reload_config()
+ log("info", "Reloading configuration file");
+ prosody.events.fire_event("reloading-config");
+ local ok, level, err = config.load(prosody.config_file);
+ if not ok then
+ if level == "parser" then
+ log("error", "There was an error parsing the configuration file: %s", tostring(err));
+ elseif level == "file" then
+ log("error", "Couldn't read the config file when trying to reload: %s", tostring(err));
+ end
+ else
+ prosody.events.fire_event("config-reloaded", {
+ filename = prosody.config_file,
+ config = config.getconfig(),
+ });
+ end
+ return ok, (err and tostring(level)..": "..tostring(err)) or nil;
+ end
+
+ -- Function to reopen logfiles
+ function prosody.reopen_logfiles()
+ log("info", "Re-opening log files");
+ prosody.events.fire_event("reopen-log-files");
+ end
+
+ -- Function to initiate prosody shutdown
+ function prosody.shutdown(reason, code)
+ log("info", "Shutting down: %s", reason or "unknown reason");
+ prosody.shutdown_reason = reason;
+ prosody.shutdown_code = code;
+ prosody.events.fire_event("server-stopping", {
+ reason = reason;
+ code = code;
+ });
+ server.setquitting(true);
+ end
+end
+
+function startup.load_secondary_libraries()
+ --- Load and initialise core modules
+ require "util.import"
+ require "util.xmppstream"
+ require "core.stanza_router"
+ require "core.statsmanager"
+ require "core.hostmanager"
+ require "core.portmanager"
+ require "core.modulemanager"
+ require "core.usermanager"
+ require "core.rostermanager"
+ require "core.sessionmanager"
+ package.loaded['core.componentmanager'] = setmetatable({},{__index=function()
+ -- COMPAT which version?
+ log("warn", "componentmanager is deprecated: %s", debug.traceback():match("\n[^\n]*\n[ \t]*([^\n]*)"));
+ return function() end
+ end});
+
+ require "util.array"
+ require "util.datetime"
+ require "util.iterators"
+ require "util.timer"
+ require "util.helpers"
+
+ pcall(require, "util.signal") -- Not on Windows
+
+ -- Commented to protect us from
+ -- the second kind of people
+ --[[
+ pcall(require, "remdebug.engine");
+ if remdebug then remdebug.engine.start() end
+ ]]
+
+ require "util.stanza"
+ require "util.jid"
+end
+
+function startup.init_http_client()
+ local http = require "net.http"
+ local config_ssl = config.get("*", "ssl") or {}
+ local https_client = config.get("*", "client_https_ssl")
+ http.default.options.sslctx = require "core.certmanager".create_context("client_https port 0", "client",
+ { capath = config_ssl.capath, cafile = config_ssl.cafile, verify = "peer", }, https_client);
+end
+
+function startup.init_data_store()
+ require "core.storagemanager";
+end
+
+function startup.prepare_to_start()
+ log("info", "Prosody is using the %s backend for connection handling", server.get_backend());
+ -- Signal to modules that we are ready to start
+ prosody.events.fire_event("server-starting");
+ prosody.start_time = os.time();
+end
+
+function startup.init_global_protection()
+ -- Catch global accesses
+ -- luacheck: ignore 212/t
+ local locked_globals_mt = {
+ __index = function (t, k) log("warn", "%s", debug.traceback("Attempt to read a non-existent global '"..tostring(k).."'", 2)); end;
+ __newindex = function (t, k, v) error("Attempt to set a global: "..tostring(k).." = "..tostring(v), 2); end;
+ };
+
+ function prosody.unlock_globals()
+ setmetatable(_G, nil);
+ end
+
+ function prosody.lock_globals()
+ setmetatable(_G, locked_globals_mt);
+ end
+
+ -- And lock now...
+ prosody.lock_globals();
+end
+
+function startup.read_version()
+ -- Try to determine version
+ local version_file = io.open((CFG_SOURCEDIR or ".").."/prosody.version");
+ prosody.version = "unknown";
+ if version_file then
+ prosody.version = version_file:read("*a"):gsub("%s*$", "");
+ version_file:close();
+ if #prosody.version == 12 and prosody.version:match("^[a-f0-9]+$") then
+ prosody.version = "hg:"..prosody.version;
+ end
+ else
+ local hg = require"util.mercurial";
+ local hgid = hg.check_id(CFG_SOURCEDIR or ".");
+ if hgid then prosody.version = "hg:" .. hgid; end
+ end
+end
+
+function startup.log_greeting()
+ log("info", "Hello and welcome to Prosody version %s", prosody.version);
+end
+
+function startup.notify_started()
+ prosody.events.fire_event("server-started");
+end
+
+-- Override logging config (used by prosodyctl)
+function startup.force_console_logging()
+ original_logging_config = config.get("*", "log");
+ config.set("*", "log", { { levels = { min = os.getenv("PROSODYCTL_LOG_LEVEL") or "info" }, to = "console" } });
+end
+
+function startup.switch_user()
+ -- Switch away from root and into the prosody user --
+ -- NOTE: This function is only used by prosodyctl.
+ -- The prosody process is built with the assumption that
+ -- it is already started as the appropriate user.
+
+ local want_pposix_version = "0.4.0";
+ local have_pposix, pposix = pcall(require, "util.pposix");
+
+ if have_pposix and pposix then
+ if pposix._VERSION ~= want_pposix_version then
+ print(string.format("Unknown version (%s) of binary pposix module, expected %s",
+ tostring(pposix._VERSION), want_pposix_version));
+ os.exit(1);
+ end
+ prosody.current_uid = pposix.getuid();
+ local arg_root = arg[1] == "--root";
+ if arg_root then table.remove(arg, 1); end
+ if prosody.current_uid == 0 and config.get("*", "run_as_root") ~= true and not arg_root then
+ -- We haz root!
+ local desired_user = config.get("*", "prosody_user") or "prosody";
+ local desired_group = config.get("*", "prosody_group") or desired_user;
+ local ok, err = pposix.setgid(desired_group);
+ if ok then
+ ok, err = pposix.initgroups(desired_user);
+ end
+ if ok then
+ ok, err = pposix.setuid(desired_user);
+ if ok then
+ -- Yay!
+ prosody.switched_user = true;
+ end
+ end
+ if not prosody.switched_user then
+ -- Boo!
+ print("Warning: Couldn't switch to Prosody user/group '"..tostring(desired_user).."'/'"..tostring(desired_group).."': "..tostring(err));
+ else
+ -- Make sure the Prosody user can read the config
+ local conf, err, errno = io.open(prosody.config_file);
+ if conf then
+ conf:close();
+ else
+ print("The config file is not readable by the '"..desired_user.."' user.");
+ print("Prosody will not be able to read it.");
+ print("Error was "..err);
+ os.exit(1);
+ end
+ end
+ end
+
+ -- Set our umask to protect data files
+ pposix.umask(config.get("*", "umask") or "027");
+ pposix.setenv("HOME", prosody.paths.data);
+ pposix.setenv("PROSODY_CONFIG", prosody.config_file);
+ else
+ print("Error: Unable to load pposix module. Check that Prosody is installed correctly.")
+ print("For more help send the below error to us through https://prosody.im/discuss");
+ print(tostring(pposix))
+ os.exit(1);
+ end
+end
+
+function startup.check_unwriteable()
+ local function test_writeable(filename)
+ local f, err = io.open(filename, "a");
+ if not f then
+ return false, err;
+ end
+ f:close();
+ return true;
+ end
+
+ local unwriteable_files = {};
+ if type(original_logging_config) == "string" and original_logging_config:sub(1,1) ~= "*" then
+ local ok, err = test_writeable(original_logging_config);
+ if not ok then
+ table.insert(unwriteable_files, err);
+ end
+ elseif type(original_logging_config) == "table" then
+ for _, rule in ipairs(original_logging_config) do
+ if rule.filename then
+ local ok, err = test_writeable(rule.filename);
+ if not ok then
+ table.insert(unwriteable_files, err);
+ end
+ end
+ end
+ end
+
+ if #unwriteable_files > 0 then
+ print("One of more of the Prosody log files are not");
+ print("writeable, please correct the errors and try");
+ print("starting prosodyctl again.");
+ print("");
+ for _, err in ipairs(unwriteable_files) do
+ print(err);
+ end
+ print("");
+ os.exit(1);
+ end
+end
+
+function startup.make_host(hostname)
+ return {
+ type = "local",
+ events = prosody.events,
+ modules = {},
+ sessions = {},
+ users = require "core.usermanager".new_null_provider(hostname)
+ };
+end
+
+function startup.make_dummy_hosts()
+ -- When running under prosodyctl, we don't want to
+ -- fully initialize the server, so we populate prosody.hosts
+ -- with just enough things for most code to work correctly
+ -- luacheck: ignore 122/hosts
+ prosody.core_post_stanza = function () end; -- TODO: mod_router!
+
+ for hostname in pairs(config.getconfig()) do
+ prosody.hosts[hostname] = startup.make_host(hostname);
+ end
+end
+
+-- prosodyctl only
+function startup.prosodyctl()
+ startup.init_global_state();
+ startup.read_config();
+ startup.force_console_logging();
+ startup.init_logging();
+ startup.setup_plugindir();
+ startup.setup_datadir();
+ startup.chdir();
+ startup.read_version();
+ startup.switch_user();
+ startup.check_dependencies();
+ startup.log_dependency_warnings();
+ startup.check_unwriteable();
+ startup.load_libraries();
+ startup.init_http_client();
+ startup.make_dummy_hosts();
+end
+
+function startup.prosody()
+ -- These actions are in a strict order, as many depend on
+ -- previous steps to have already been performed
+ startup.init_global_state();
+ startup.read_config();
+ startup.init_logging();
+ startup.sanity_check();
+ startup.sandbox_require();
+ startup.set_function_metatable();
+ startup.check_dependencies();
+ startup.init_logging();
+ startup.load_libraries();
+ startup.setup_plugindir();
+ startup.setup_datadir();
+ startup.chdir();
+ startup.add_global_prosody_functions();
+ startup.read_version();
+ startup.log_greeting();
+ startup.log_dependency_warnings();
+ startup.load_secondary_libraries();
+ startup.init_http_client();
+ startup.init_data_store();
+ startup.init_global_protection();
+ startup.prepare_to_start();
+ startup.notify_started();
+end
+
+return startup;
diff --git a/util/template.lua b/util/template.lua
index 04ebb93d..c11037c5 100644
--- a/util/template.lua
+++ b/util/template.lua
@@ -4,12 +4,13 @@ local setmetatable = setmetatable;
local pairs = pairs;
local ipairs = ipairs;
local error = error;
-local loadstring = loadstring;
+local envload = require "util.envload".envload;
local debug = debug;
local t_remove = table.remove;
local parse_xml = require "util.xml".parse;
local _ENV = nil;
+-- luacheck: std none
local function trim_xml(stanza)
for i=#stanza,1,-1 do
@@ -72,7 +73,7 @@ local function create_cloner(stanza, chunkname)
src = src.."local _"..i.."="..lookup[i]..";";
end
src = src.."return "..name..";end";
- local f,err = loadstring(src, chunkname);
+ local f,err = envload(src, chunkname);
if not f then error(err); end
return f(setmetatable, stanza_mt);
end
diff --git a/util/termcolours.lua b/util/termcolours.lua
index 23c9156b..829d84af 100644
--- a/util/termcolours.lua
+++ b/util/termcolours.lua
@@ -26,6 +26,7 @@ end
local orig_color = windows and windows.get_consolecolor and windows.get_consolecolor();
local _ENV = nil;
+-- luacheck: std none
local stylemap = {
reset = 0; bright = 1, dim = 2, underscore = 4, blink = 5, reverse = 7, hidden = 8;
diff --git a/util/throttle.lua b/util/throttle.lua
index 1012f78a..d2036e9e 100644
--- a/util/throttle.lua
+++ b/util/throttle.lua
@@ -3,6 +3,7 @@ local gettime = require "util.time".now
local setmetatable = setmetatable;
local _ENV = nil;
+-- luacheck: std none
local throttle = {};
local throttle_mt = { __index = throttle };
diff --git a/util/timer.lua b/util/timer.lua
index 7e2e9414..22f547df 100644
--- a/util/timer.lua
+++ b/util/timer.lua
@@ -6,78 +6,106 @@
-- COPYING file in the source package for more information.
--
+local indexedbheap = require "util.indexedbheap";
+local log = require "util.logger".init("timer");
local server = require "net.server";
-local math_min = math.min
-local math_huge = math.huge
local get_time = require "util.time".now
-local t_insert = table.insert;
-local pairs = pairs;
local type = type;
-
-local data = {};
-local new_data = {};
+local debug_traceback = debug.traceback;
+local tostring = tostring;
+local xpcall = xpcall;
+local math_max = math.max;
local _ENV = nil;
+-- luacheck: std none
-local _add_task;
-if not server.event then
- function _add_task(delay, callback)
- local current_time = get_time();
- delay = delay + current_time;
- if delay >= current_time then
- t_insert(new_data, {delay, callback});
- else
- local r = callback(current_time);
- if r and type(r) == "number" then
- return _add_task(r, callback);
- end
+local _add_task = server.add_task;
+
+local _server_timer;
+local _active_timers = 0;
+local h = indexedbheap.create();
+local params = {};
+local next_time = nil;
+local _id, _callback, _now, _param;
+local function _call() return _callback(_now, _id, _param); end
+local function _traceback_handler(err) log("error", "Traceback[timer]: %s", debug_traceback(tostring(err), 2)); end
+local function _on_timer(now)
+ local peek;
+ while true do
+ peek = h:peek();
+ if peek == nil or peek > now then break; end
+ local _;
+ _, _callback, _id = h:pop();
+ _now = now;
+ _param = params[_id];
+ params[_id] = nil;
+ --item(now, id, _param); -- FIXME pcall
+ local success, err = xpcall(_call, _traceback_handler);
+ if success and type(err) == "number" then
+ h:insert(_callback, err + now, _id); -- re-add
+ params[_id] = _param;
end
end
- server._addtimer(function()
- local current_time = get_time();
- if #new_data > 0 then
- for _, d in pairs(new_data) do
- t_insert(data, d);
- end
- new_data = {};
- end
+ if peek ~= nil and _active_timers > 1 and peek == next_time then
+ -- Another instance of _on_timer already set next_time to the same value,
+ -- so it should be safe to not renew this timer event
+ peek = nil;
+ else
+ next_time = peek;
+ end
- local next_time = math_huge;
- for i, d in pairs(data) do
- local t, callback = d[1], d[2];
- if t <= current_time then
- data[i] = nil;
- local r = callback(current_time);
- if type(r) == "number" then
- _add_task(r, callback);
- next_time = math_min(next_time, r);
- end
- else
- next_time = math_min(next_time, t - current_time);
- end
- end
- return next_time;
- end);
-else
- local event = server.event;
- local event_base = server.event_base;
- local EVENT_LEAVE = (event.core and event.core.LEAVE) or -1;
+ if peek then
+ -- peek is the time of the next event
+ return peek - now;
+ end
+ _active_timers = _active_timers - 1;
+end
+local function add_task(delay, callback, param)
+ local current_time = get_time();
+ local event_time = current_time + delay;
- function _add_task(delay, callback)
- local event_handle;
- event_handle = event_base:addevent(nil, 0, function ()
- local ret = callback(get_time());
- if ret then
- return 0, ret;
- elseif event_handle then
- return EVENT_LEAVE;
- end
+ local id = h:insert(callback, event_time);
+ params[id] = param;
+ if next_time == nil or event_time < next_time then
+ next_time = event_time;
+ if _server_timer then
+ _server_timer:close();
+ _server_timer = nil;
+ else
+ _active_timers = _active_timers + 1;
+ end
+ _server_timer = _add_task(next_time - current_time, _on_timer);
+ end
+ return id;
+end
+local function stop(id)
+ params[id] = nil;
+ local result, item, result_sync = h:remove(id);
+ local peek = h:peek();
+ if peek ~= next_time and _server_timer then
+ next_time = peek;
+ _server_timer:close();
+ if next_time ~= nil then
+ _server_timer = _add_task(math_max(next_time - get_time(), 0), _on_timer);
end
- , delay);
end
+ return result, item, result_sync;
+end
+local function reschedule(id, delay)
+ local current_time = get_time();
+ local event_time = current_time + delay;
+ h:reprioritize(id, delay);
+ if next_time == nil or event_time < next_time then
+ next_time = event_time;
+ _add_task(next_time - current_time, _on_timer);
+ end
+ return id;
end
return {
- add_task = _add_task;
+ add_task = add_task;
+ stop = stop;
+ reschedule = reschedule;
};
+
diff --git a/util/vcard.lua b/util/vcard.lua
new file mode 100644
index 00000000..bb299fab
--- /dev/null
+++ b/util/vcard.lua
@@ -0,0 +1,574 @@
+-- Copyright (C) 2011-2014 Kim Alvefur
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+-- TODO
+-- Fix folding.
+
+local st = require "util.stanza";
+local t_insert, t_concat = table.insert, table.concat;
+local type = type;
+local pairs, ipairs = pairs, ipairs;
+
+local from_text, to_text, from_xep54, to_xep54;
+
+local line_sep = "\n";
+
+local vCard_dtd; -- See end of file
+local vCard4_dtd;
+
+local function vCard_esc(s)
+ return s:gsub("[,:;\\]", "\\%1"):gsub("\n","\\n");
+end
+
+local function vCard_unesc(s)
+ return s:gsub("\\?[\\nt:;,]", {
+ ["\\\\"] = "\\",
+ ["\\n"] = "\n",
+ ["\\r"] = "\r",
+ ["\\t"] = "\t",
+ ["\\:"] = ":", -- FIXME Shouldn't need to espace : in values, just params
+ ["\\;"] = ";",
+ ["\\,"] = ",",
+ [":"] = "\29",
+ [";"] = "\30",
+ [","] = "\31",
+ });
+end
+
+local function item_to_xep54(item)
+ local t = st.stanza(item.name, { xmlns = "vcard-temp" });
+
+ local prop_def = vCard_dtd[item.name];
+ if prop_def == "text" then
+ t:text(item[1]);
+ elseif type(prop_def) == "table" then
+ if prop_def.types and item.TYPE then
+ if type(item.TYPE) == "table" then
+ for _,v in pairs(prop_def.types) do
+ for _,typ in pairs(item.TYPE) do
+ if typ:upper() == v then
+ t:tag(v):up();
+ break;
+ end
+ end
+ end
+ else
+ t:tag(item.TYPE:upper()):up();
+ end
+ end
+
+ if prop_def.props then
+ for _,prop in pairs(prop_def.props) do
+ if item[prop] then
+ for _, v in ipairs(item[prop]) do
+ t:text_tag(prop, v);
+ end
+ end
+ end
+ end
+
+ if prop_def.value then
+ t:text_tag(prop_def.value, item[1]);
+ elseif prop_def.values then
+ local prop_def_values = prop_def.values;
+ local repeat_last = prop_def_values.behaviour == "repeat-last" and prop_def_values[#prop_def_values];
+ for i=1,#item do
+ t:text_tag(prop_def.values[i] or repeat_last, item[i]);
+ end
+ end
+ end
+
+ return t;
+end
+
+local function vcard_to_xep54(vCard)
+ local t = st.stanza("vCard", { xmlns = "vcard-temp" });
+ for i=1,#vCard do
+ t:add_child(item_to_xep54(vCard[i]));
+ end
+ return t;
+end
+
+function to_xep54(vCards)
+ if not vCards[1] or vCards[1].name then
+ return vcard_to_xep54(vCards)
+ else
+ local t = st.stanza("xCard", { xmlns = "vcard-temp" });
+ for i=1,#vCards do
+ t:add_child(vcard_to_xep54(vCards[i]));
+ end
+ return t;
+ end
+end
+
+function from_text(data)
+ data = data -- unfold and remove empty lines
+ :gsub("\r\n","\n")
+ :gsub("\n ", "")
+ :gsub("\n\n+","\n");
+ local vCards = {};
+ local current;
+ for line in data:gmatch("[^\n]+") do
+ line = vCard_unesc(line);
+ local name, params, value = line:match("^([-%a]+)(\30?[^\29]*)\29(.*)$");
+ value = value:gsub("\29",":");
+ if #params > 0 then
+ local _params = {};
+ for k,isval,v in params:gmatch("\30([^=]+)(=?)([^\30]*)") do
+ k = k:upper();
+ local _vt = {};
+ for _p in v:gmatch("[^\31]+") do
+ _vt[#_vt+1]=_p
+ _vt[_p]=true;
+ end
+ if isval == "=" then
+ _params[k]=_vt;
+ else
+ _params[k]=true;
+ end
+ end
+ params = _params;
+ end
+ if name == "BEGIN" and value == "VCARD" then
+ current = {};
+ vCards[#vCards+1] = current;
+ elseif name == "END" and value == "VCARD" then
+ current = nil;
+ elseif current and vCard_dtd[name] then
+ local dtd = vCard_dtd[name];
+ local item = { name = name };
+ t_insert(current, item);
+ local up = current;
+ current = item;
+ if dtd.types then
+ for _, t in ipairs(dtd.types) do
+ t = t:lower();
+ if ( params.TYPE and params.TYPE[t] == true)
+ or params[t] == true then
+ current.TYPE=t;
+ end
+ end
+ end
+ if dtd.props then
+ for _, p in ipairs(dtd.props) do
+ if params[p] then
+ if params[p] == true then
+ current[p]=true;
+ else
+ for _, prop in ipairs(params[p]) do
+ current[p]=prop;
+ end
+ end
+ end
+ end
+ end
+ if dtd == "text" or dtd.value then
+ t_insert(current, value);
+ elseif dtd.values then
+ for p in ("\30"..value):gmatch("\30([^\30]*)") do
+ t_insert(current, p);
+ end
+ end
+ current = up;
+ end
+ end
+ return vCards;
+end
+
+local function item_to_text(item)
+ local value = {};
+ for i=1,#item do
+ value[i] = vCard_esc(item[i]);
+ end
+ value = t_concat(value, ";");
+
+ local params = "";
+ for k,v in pairs(item) do
+ if type(k) == "string" and k ~= "name" then
+ params = params .. (";%s=%s"):format(k, type(v) == "table" and t_concat(v,",") or v);
+ end
+ end
+
+ return ("%s%s:%s"):format(item.name, params, value)
+end
+
+local function vcard_to_text(vcard)
+ local t={};
+ t_insert(t, "BEGIN:VCARD")
+ for i=1,#vcard do
+ t_insert(t, item_to_text(vcard[i]));
+ end
+ t_insert(t, "END:VCARD")
+ return t_concat(t, line_sep);
+end
+
+function to_text(vCards)
+ if vCards[1] and vCards[1].name then
+ return vcard_to_text(vCards)
+ else
+ local t = {};
+ for i=1,#vCards do
+ t[i]=vcard_to_text(vCards[i]);
+ end
+ return t_concat(t, line_sep);
+ end
+end
+
+local function from_xep54_item(item)
+ local prop_name = item.name;
+ local prop_def = vCard_dtd[prop_name];
+
+ local prop = { name = prop_name };
+
+ if prop_def == "text" then
+ prop[1] = item:get_text();
+ elseif type(prop_def) == "table" then
+ if prop_def.value then --single item
+ prop[1] = item:get_child_text(prop_def.value) or "";
+ elseif prop_def.values then --array
+ local value_names = prop_def.values;
+ if value_names.behaviour == "repeat-last" then
+ for i=1,#item.tags do
+ t_insert(prop, item.tags[i]:get_text() or "");
+ end
+ else
+ for i=1,#value_names do
+ t_insert(prop, item:get_child_text(value_names[i]) or "");
+ end
+ end
+ elseif prop_def.names then
+ local names = prop_def.names;
+ for i=1,#names do
+ if item:get_child(names[i]) then
+ prop[1] = names[i];
+ break;
+ end
+ end
+ end
+
+ if prop_def.props_verbatim then
+ for k,v in pairs(prop_def.props_verbatim) do
+ prop[k] = v;
+ end
+ end
+
+ if prop_def.types then
+ local types = prop_def.types;
+ prop.TYPE = {};
+ for i=1,#types do
+ if item:get_child(types[i]) then
+ t_insert(prop.TYPE, types[i]:lower());
+ end
+ end
+ if #prop.TYPE == 0 then
+ prop.TYPE = nil;
+ end
+ end
+
+ -- A key-value pair, within a key-value pair?
+ if prop_def.props then
+ local params = prop_def.props;
+ for i=1,#params do
+ local name = params[i]
+ local data = item:get_child_text(name);
+ if data then
+ prop[name] = prop[name] or {};
+ t_insert(prop[name], data);
+ end
+ end
+ end
+ else
+ return nil
+ end
+
+ return prop;
+end
+
+local function from_xep54_vCard(vCard)
+ local tags = vCard.tags;
+ local t = {};
+ for i=1,#tags do
+ t_insert(t, from_xep54_item(tags[i]));
+ end
+ return t
+end
+
+function from_xep54(vCard)
+ if vCard.attr.xmlns ~= "vcard-temp" then
+ return nil, "wrong-xmlns";
+ end
+ if vCard.name == "xCard" then -- A collection of vCards
+ local t = {};
+ local vCards = vCard.tags;
+ for i=1,#vCards do
+ t[i] = from_xep54_vCard(vCards[i]);
+ end
+ return t
+ elseif vCard.name == "vCard" then -- A single vCard
+ return from_xep54_vCard(vCard)
+ end
+end
+
+local vcard4 = { }
+
+function vcard4:text(node, params, value) -- luacheck: ignore 212/params
+ self:tag(node:lower())
+ -- FIXME params
+ if type(value) == "string" then
+ self:text_tag("text", value);
+ elseif vcard4[node] then
+ vcard4[node](value);
+ end
+ self:up();
+end
+
+function vcard4.N(value)
+ for i, k in ipairs(vCard_dtd.N.values) do
+ value:text_tag(k, value[i]);
+ end
+end
+
+local xmlns_vcard4 = "urn:ietf:params:xml:ns:vcard-4.0"
+
+local function item_to_vcard4(item)
+ local typ = item.name:lower();
+ local t = st.stanza(typ, { xmlns = xmlns_vcard4 });
+
+ local prop_def = vCard4_dtd[typ];
+ if prop_def == "text" then
+ t:text_tag("text", item[1]);
+ elseif prop_def == "uri" then
+ if item.ENCODING and item.ENCODING[1] == 'b' then
+ t:text_tag("uri", "data:;base64," .. item[1]);
+ else
+ t:text_tag("uri", item[1]);
+ end
+ elseif type(prop_def) == "table" then
+ if prop_def.values then
+ for i, v in ipairs(prop_def.values) do
+ t:text_tag(v:lower(), item[i]);
+ end
+ else
+ t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"})
+ end
+ else
+ t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"})
+ end
+ return t;
+end
+
+local function vcard_to_vcard4xml(vCard)
+ local t = st.stanza("vcard", { xmlns = xmlns_vcard4 });
+ for i=1,#vCard do
+ t:add_child(item_to_vcard4(vCard[i]));
+ end
+ return t;
+end
+
+local function vcards_to_vcard4xml(vCards)
+ if not vCards[1] or vCards[1].name then
+ return vcard_to_vcard4xml(vCards)
+ else
+ local t = st.stanza("vcards", { xmlns = xmlns_vcard4 });
+ for i=1,#vCards do
+ t:add_child(vcard_to_vcard4xml(vCards[i]));
+ end
+ return t;
+ end
+end
+
+-- This was adapted from http://xmpp.org/extensions/xep-0054.html#dtd
+vCard_dtd = {
+ VERSION = "text", --MUST be 3.0, so parsing is redundant
+ FN = "text",
+ N = {
+ values = {
+ "FAMILY",
+ "GIVEN",
+ "MIDDLE",
+ "PREFIX",
+ "SUFFIX",
+ },
+ },
+ NICKNAME = "text",
+ PHOTO = {
+ props_verbatim = { ENCODING = { "b" } },
+ props = { "TYPE" },
+ value = "BINVAL", --{ "EXTVAL", },
+ },
+ BDAY = "text",
+ ADR = {
+ types = {
+ "HOME",
+ "WORK",
+ "POSTAL",
+ "PARCEL",
+ "DOM",
+ "INTL",
+ "PREF",
+ },
+ values = {
+ "POBOX",
+ "EXTADD",
+ "STREET",
+ "LOCALITY",
+ "REGION",
+ "PCODE",
+ "CTRY",
+ }
+ },
+ LABEL = {
+ types = {
+ "HOME",
+ "WORK",
+ "POSTAL",
+ "PARCEL",
+ "DOM",
+ "INTL",
+ "PREF",
+ },
+ value = "LINE",
+ },
+ TEL = {
+ types = {
+ "HOME",
+ "WORK",
+ "VOICE",
+ "FAX",
+ "PAGER",
+ "MSG",
+ "CELL",
+ "VIDEO",
+ "BBS",
+ "MODEM",
+ "ISDN",
+ "PCS",
+ "PREF",
+ },
+ value = "NUMBER",
+ },
+ EMAIL = {
+ types = {
+ "HOME",
+ "WORK",
+ "INTERNET",
+ "PREF",
+ "X400",
+ },
+ value = "USERID",
+ },
+ JABBERID = "text",
+ MAILER = "text",
+ TZ = "text",
+ GEO = {
+ values = {
+ "LAT",
+ "LON",
+ },
+ },
+ TITLE = "text",
+ ROLE = "text",
+ LOGO = "copy of PHOTO",
+ AGENT = "text",
+ ORG = {
+ values = {
+ behaviour = "repeat-last",
+ "ORGNAME",
+ "ORGUNIT",
+ }
+ },
+ CATEGORIES = {
+ values = "KEYWORD",
+ },
+ NOTE = "text",
+ PRODID = "text",
+ REV = "text",
+ SORTSTRING = "text",
+ SOUND = "copy of PHOTO",
+ UID = "text",
+ URL = "text",
+ CLASS = {
+ names = { -- The item.name is the value if it's one of these.
+ "PUBLIC",
+ "PRIVATE",
+ "CONFIDENTIAL",
+ },
+ },
+ KEY = {
+ props = { "TYPE" },
+ value = "CRED",
+ },
+ DESC = "text",
+};
+vCard_dtd.LOGO = vCard_dtd.PHOTO;
+vCard_dtd.SOUND = vCard_dtd.PHOTO;
+
+vCard4_dtd = {
+ source = "uri",
+ kind = "text",
+ xml = "text",
+ fn = "text",
+ n = {
+ values = {
+ "family",
+ "given",
+ "middle",
+ "prefix",
+ "suffix",
+ },
+ },
+ nickname = "text",
+ photo = "uri",
+ bday = "date-and-or-time",
+ anniversary = "date-and-or-time",
+ gender = "text",
+ adr = {
+ values = {
+ "pobox",
+ "ext",
+ "street",
+ "locality",
+ "region",
+ "code",
+ "country",
+ }
+ },
+ tel = "text",
+ email = "text",
+ impp = "uri",
+ lang = "language-tag",
+ tz = "text",
+ geo = "uri",
+ title = "text",
+ role = "text",
+ logo = "uri",
+ org = "text",
+ member = "uri",
+ related = "uri",
+ categories = "text",
+ note = "text",
+ prodid = "text",
+ rev = "timestamp",
+ sound = "uri",
+ uid = "uri",
+ clientpidmap = "number, uuid",
+ url = "uri",
+ version = "text",
+ key = "uri",
+ fburl = "uri",
+ caladruri = "uri",
+ caluri = "uri",
+};
+
+return {
+ from_text = from_text;
+ to_text = to_text;
+
+ from_xep54 = from_xep54;
+ to_xep54 = to_xep54;
+
+ to_vcard4 = vcards_to_vcard4xml;
+};
diff --git a/util/watchdog.lua b/util/watchdog.lua
index aa8c6486..516e60e4 100644
--- a/util/watchdog.lua
+++ b/util/watchdog.lua
@@ -3,6 +3,7 @@ local setmetatable = setmetatable;
local os_time = os.time;
local _ENV = nil;
+-- luacheck: std none
local watchdog_methods = {};
local watchdog_mt = { __index = watchdog_methods };
diff --git a/util/x509.lua b/util/x509.lua
index f228b201..15cc4d3c 100644
--- a/util/x509.lua
+++ b/util/x509.lua
@@ -25,6 +25,7 @@ local log = require "util.logger".init("x509");
local s_format = string.format;
local _ENV = nil;
+-- luacheck: std none
local oid_commonname = "2.5.4.3"; -- [LDAP] 2.3
local oid_subjectaltname = "2.5.29.17"; -- [PKIX] 4.2.1.6
diff --git a/util/xml.lua b/util/xml.lua
index 733d821a..dac3f6fe 100644
--- a/util/xml.lua
+++ b/util/xml.lua
@@ -1,8 +1,11 @@
local st = require "util.stanza";
local lxp = require "lxp";
+local t_insert = table.insert;
+local t_remove = table.remove;
local _ENV = nil;
+-- luacheck: std none
local parse_xml = (function()
local ns_prefixes = {
@@ -14,6 +17,21 @@ local parse_xml = (function()
--luacheck: ignore 212/self
local handler = {};
local stanza = st.stanza("root");
+ local namespaces = {};
+ local prefixes = {};
+ function handler:StartNamespaceDecl(prefix, url)
+ if prefix ~= nil then
+ t_insert(namespaces, url);
+ t_insert(prefixes, prefix);
+ end
+ end
+ function handler:EndNamespaceDecl(prefix)
+ if prefix ~= nil then
+ -- we depend on each StartNamespaceDecl having a paired EndNamespaceDecl
+ t_remove(namespaces);
+ t_remove(prefixes);
+ end
+ end
function handler:StartElement(tagname, attr)
local curr_ns,name = tagname:match(ns_pattern);
if name == "" then
@@ -34,7 +52,11 @@ local parse_xml = (function()
end
end
end
- stanza:tag(name, attr);
+ local n = {}
+ for i=1,#namespaces do
+ n[prefixes[i]] = namespaces[i];
+ end
+ stanza:tag(name, attr, n);
end
function handler:CharacterData(data)
stanza:text(data);
diff --git a/util/xmppstream.lua b/util/xmppstream.lua
index 7be63285..f245afbf 100644
--- a/util/xmppstream.lua
+++ b/util/xmppstream.lua
@@ -25,6 +25,7 @@ local lxp_supports_bytecount = not not lxp.new({}).getcurrentbytecount;
local default_stanza_size_limit = 1024*1024*10; -- 10MB
local _ENV = nil;
+-- luacheck: std none
local new_parser = lxp.new;
@@ -47,7 +48,10 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
local cb_streamopened = stream_callbacks.streamopened;
local cb_streamclosed = stream_callbacks.streamclosed;
- local cb_error = stream_callbacks.error or function(session, e, stanza) error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); end;
+ local cb_error = stream_callbacks.error or
+ function(_, e, stanza)
+ error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2);
+ end;
local cb_handlestanza = stream_callbacks.handlestanza;
cb_handleprogress = cb_handleprogress or dummy_cb;
@@ -126,13 +130,7 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
t_insert(oldstanza.tags, stanza);
end
end
- if lxp_supports_xmldecl then
- function xml_handlers:XmlDecl(version, encoding, standalone)
- if lxp_supports_bytecount then
- cb_handleprogress(self:getcurrentbytecount());
- end
- end
- end
+
function xml_handlers:StartCdataSection()
if lxp_supports_bytecount then
if stanza then
@@ -203,6 +201,18 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
end
end
+ if lxp_supports_xmldecl then
+ function xml_handlers:XmlDecl(version, encoding, standalone)
+ if lxp_supports_bytecount then
+ cb_handleprogress(self:getcurrentbytecount());
+ end
+ if (encoding and encoding:lower() ~= "utf-8")
+ or (standalone == "no")
+ or (version and version ~= "1.0") then
+ return restricted_handler(self);
+ end
+ end
+ end
if lxp_supports_doctype then
xml_handlers.StartDoctypeDecl = restricted_handler;
end
@@ -214,7 +224,7 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
stack = {};
end
- local function set_session(stream, new_session)
+ local function set_session(stream, new_session) -- luacheck: ignore 212/stream
session = new_session;
end
@@ -238,7 +248,7 @@ local function new(session, stream_callbacks, stanza_size_limit)
local parser = new_parser(handlers, ns_separator, false);
local parse = parser.parse;
- function session.open_stream(session, from, to)
+ function session.open_stream(session, from, to) -- luacheck: ignore 432/session
local send = session.sends2s or session.send;
local attr = {
@@ -264,7 +274,7 @@ local function new(session, stream_callbacks, stanza_size_limit)
n_outstanding_bytes = 0;
meta.reset();
end,
- feed = function (self, data)
+ feed = function (self, data) -- luacheck: ignore 212/self
if lxp_supports_bytecount then
n_outstanding_bytes = n_outstanding_bytes + #data;
end