aboutsummaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/async.lua16
-rw-r--r--util/bit53.lua7
-rw-r--r--util/bitcompat.lua32
-rw-r--r--util/datamanager.lua2
-rw-r--r--util/dependencies.lua4
-rw-r--r--util/error.lua53
-rw-r--r--util/format.lua27
-rw-r--r--util/hashring.lua88
-rw-r--r--util/hmac.lua9
-rw-r--r--util/http.lua22
-rw-r--r--util/import.lua2
-rw-r--r--util/iterators.lua6
-rw-r--r--util/jid.lua12
-rw-r--r--util/multitable.lua2
-rw-r--r--util/paths.lua16
-rw-r--r--util/pluginloader.lua3
-rw-r--r--util/promise.lua3
-rw-r--r--util/prosodyctl.lua47
-rw-r--r--util/pubsub.lua12
-rw-r--r--util/queue.lua12
-rw-r--r--util/sasl/scram.lua48
-rw-r--r--util/serialization.lua27
-rw-r--r--util/session.lua7
-rw-r--r--util/sql.lua14
-rw-r--r--util/stanza.lua76
-rw-r--r--util/startup.lua35
-rw-r--r--util/statistics.lua4
-rw-r--r--util/x509.lua58
-rw-r--r--util/xmppstream.lua6
29 files changed, 513 insertions, 137 deletions
diff --git a/util/async.lua b/util/async.lua
index 20397785..d338071f 100644
--- a/util/async.lua
+++ b/util/async.lua
@@ -246,9 +246,25 @@ local function ready()
return pcall(checkthread);
end
+local function wait(promise)
+ local async_wait, async_done = waiter();
+ local ret, err = nil, nil;
+ promise:next(
+ function (r) ret = r; end,
+ function (e) err = e; end)
+ :finally(async_done);
+ async_wait();
+ if ret then
+ return ret;
+ else
+ return nil, err;
+ end
+end
+
return {
ready = ready;
waiter = waiter;
guarder = guarder;
runner = runner;
+ wait = wait;
};
diff --git a/util/bit53.lua b/util/bit53.lua
new file mode 100644
index 00000000..4b5c2e9c
--- /dev/null
+++ b/util/bit53.lua
@@ -0,0 +1,7 @@
+-- Only the operators needed by net.websocket.frames are provided at this point
+return {
+ band = function (a, b) return a & b end;
+ bor = function (a, b) return a | b end;
+ bxor = function (a, b) return a ~ b end;
+};
+
diff --git a/util/bitcompat.lua b/util/bitcompat.lua
new file mode 100644
index 00000000..454181af
--- /dev/null
+++ b/util/bitcompat.lua
@@ -0,0 +1,32 @@
+-- Compatibility layer for bitwise operations
+
+-- First try the bit32 lib
+-- Lua 5.3 has it with compat enabled
+-- Lua 5.2 has it by default
+if _G.bit32 then
+ return _G.bit32;
+else
+ -- Lua 5.1 may have it as a standalone module that can be installed
+ local ok, bitop = pcall(require, "bit32")
+ if ok then
+ return bitop;
+ end
+end
+
+do
+ -- Lua 5.3 and 5.4 would be able to use native infix operators
+ local ok, bitop = pcall(require, "util.bit53")
+ if ok then
+ return bitop;
+ end
+end
+
+do
+ -- Lastly, try the LuaJIT bitop library
+ local ok, bitop = pcall(require, "bit")
+ if ok then
+ return bitop;
+ end
+end
+
+error "No bit module found. See https://prosody.im/doc/depends#bitop";
diff --git a/util/datamanager.lua b/util/datamanager.lua
index cf96887b..b52c77fa 100644
--- a/util/datamanager.lua
+++ b/util/datamanager.lua
@@ -24,7 +24,7 @@ local t_concat = table.concat;
local envloadfile = require"util.envload".envloadfile;
local serialize = require "util.serialization".serialize;
local lfs = require "lfs";
--- Extract directory seperator from package.config (an undocumented string that comes with lua)
+-- Extract directory separator from package.config (an undocumented string that comes with lua)
local path_separator = assert ( package.config:match ( "^([^\n]+)" ) , "package.config not in standard form" )
local prosody = prosody;
diff --git a/util/dependencies.lua b/util/dependencies.lua
index 7c7b938e..22b66d7c 100644
--- a/util/dependencies.lua
+++ b/util/dependencies.lua
@@ -90,7 +90,7 @@ local function check_dependencies()
}, "SSL/TLS support will not be available");
end
- local bit = _G.bit32 or softreq"bit";
+ local bit = softreq"util.bitcompat";
if not bit then
missingdep("lua-bitops", {
@@ -140,7 +140,7 @@ local function check_dependencies()
end
local function log_warnings()
- if _VERSION > "Lua 5.2" then
+ if _VERSION > "Lua 5.3" then
prosody.log("warn", "Support for %s is experimental, please report any issues", _VERSION);
end
local ssl = softreq"ssl";
diff --git a/util/error.lua b/util/error.lua
new file mode 100644
index 00000000..9ebfa6ab
--- /dev/null
+++ b/util/error.lua
@@ -0,0 +1,53 @@
+local error_mt = { __name = "error" };
+
+function error_mt:__tostring()
+ return ("error<%s:%s:%s>"):format(self.type, self.condition, self.text or "");
+end
+
+local function is_err(e)
+ return getmetatable(e) == error_mt;
+end
+
+local function new(e, context, registry)
+ local template = (registry and registry[e]) or e or {};
+ return setmetatable({
+ type = template.type or "cancel";
+ condition = template.condition or "undefined-condition";
+ text = template.text;
+ code = template.code or 500;
+
+ context = context or template.context or { _error_id = e };
+ }, error_mt);
+end
+
+local function coerce(ok, err, ...)
+ if ok or is_err(err) then
+ return ok, err, ...;
+ end
+
+ local new_err = setmetatable({
+ native = err;
+
+ type = "cancel";
+ condition = "undefined-condition";
+ }, error_mt);
+ return ok, new_err, ...;
+end
+
+local function from_stanza(stanza, context)
+ local error_type, condition, text = stanza:get_error();
+ return setmetatable({
+ type = error_type or "cancel";
+ condition = condition or "undefined-condition";
+ text = text;
+
+ context = context or { stanza = stanza };
+ }, error_mt);
+end
+
+return {
+ new = new;
+ coerce = coerce;
+ is_err = is_err;
+ from_stanza = from_stanza;
+}
diff --git a/util/format.lua b/util/format.lua
index c5e513fa..1ce670f3 100644
--- a/util/format.lua
+++ b/util/format.lua
@@ -3,12 +3,20 @@
--
local tostring = tostring;
-local select = select;
local unpack = table.unpack or unpack; -- luacheck: ignore 113/unpack
+local pack = require "util.table".pack; -- TODO table.pack in 5.2+
local type = type;
+local dump = require "util.serialization".new("debug");
+local num_type = math.type or function (n)
+ return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float";
+end
+
+-- In Lua 5.3+ these formats throw an error if given a float
+local expects_integer = { c = true, d = true, i = true, o = true, u = true, X = true, x = true, };
local function format(formatstring, ...)
- local args, args_length = { ... }, select('#', ...);
+ local args = pack(...);
+ local args_length = args.n;
-- format specifier spec:
-- 1. Start: '%%'
@@ -28,17 +36,22 @@ local function format(formatstring, ...)
if spec ~= "%%" then
i = i + 1;
local arg = args[i];
- if arg == nil then -- special handling for nil
- arg = "<nil>"
- args[i] = "<nil>";
- end
local option = spec:sub(-1);
- if option == "q" or option == "s" then -- arg should be string
+ if arg == nil then
+ args[i] = "nil";
+ spec = "<%s>";
+ elseif option == "q" then
+ args[i] = dump(arg);
+ spec = "%s";
+ elseif option == "s" then
args[i] = tostring(arg);
elseif type(arg) ~= "number" then -- arg isn't number as expected?
args[i] = tostring(arg);
spec = "[%s]";
+ elseif expects_integer[option] and num_type(arg) ~= "integer" then
+ args[i] = tostring(arg);
+ spec = "[%s]";
end
end
return spec;
diff --git a/util/hashring.lua b/util/hashring.lua
new file mode 100644
index 00000000..322bc005
--- /dev/null
+++ b/util/hashring.lua
@@ -0,0 +1,88 @@
+local function generate_ring(nodes, num_replicas, hash)
+ local new_ring = {};
+ for _, node_name in ipairs(nodes) do
+ for replica = 1, num_replicas do
+ local replica_hash = hash(node_name..":"..replica);
+ new_ring[replica_hash] = node_name;
+ table.insert(new_ring, replica_hash);
+ end
+ end
+ table.sort(new_ring);
+ return new_ring;
+end
+
+local hashring_methods = {};
+local hashring_mt = {
+ __index = function (self, k)
+ -- Automatically build self.ring if it's missing
+ if k == "ring" then
+ local ring = generate_ring(self.nodes, self.num_replicas, self.hash);
+ rawset(self, "ring", ring);
+ return ring;
+ end
+ return rawget(hashring_methods, k);
+ end
+};
+
+local function new(num_replicas, hash_function)
+ return setmetatable({ nodes = {}, num_replicas = num_replicas, hash = hash_function }, hashring_mt);
+end;
+
+function hashring_methods:add_node(name)
+ self.ring = nil;
+ self.nodes[name] = true;
+ table.insert(self.nodes, name);
+ return true;
+end
+
+function hashring_methods:add_nodes(nodes)
+ self.ring = nil;
+ for _, node_name in ipairs(nodes) do
+ if not self.nodes[node_name] then
+ self.nodes[node_name] = true;
+ table.insert(self.nodes, node_name);
+ end
+ end
+ return true;
+end
+
+function hashring_methods:remove_node(node_name)
+ self.ring = nil;
+ if self.nodes[node_name] then
+ for i, stored_node_name in ipairs(self.nodes) do
+ if node_name == stored_node_name then
+ self.nodes[node_name] = nil;
+ table.remove(self.nodes, i);
+ return true;
+ end
+ end
+ end
+ return false;
+end
+
+function hashring_methods:remove_nodes(nodes)
+ self.ring = nil;
+ for _, node_name in ipairs(nodes) do
+ self:remove_node(node_name);
+ end
+end
+
+function hashring_methods:clone()
+ local clone_hashring = new(self.num_replicas, self.hash);
+ clone_hashring:add_nodes(self.nodes);
+ return clone_hashring;
+end
+
+function hashring_methods:get_node(key)
+ local key_hash = self.hash(key);
+ for _, replica_hash in ipairs(self.ring) do
+ if key_hash < replica_hash then
+ return self.ring[replica_hash];
+ end
+ end
+ return self.ring[self.ring[1]];
+end
+
+return {
+ new = new;
+}
diff --git a/util/hmac.lua b/util/hmac.lua
index 2c4cc6ef..4cad17cc 100644
--- a/util/hmac.lua
+++ b/util/hmac.lua
@@ -10,6 +10,9 @@
local hashes = require "util.hashes"
-return { md5 = hashes.hmac_md5,
- sha1 = hashes.hmac_sha1,
- sha256 = hashes.hmac_sha256 };
+return {
+ md5 = hashes.hmac_md5,
+ sha1 = hashes.hmac_sha1,
+ sha256 = hashes.hmac_sha256,
+ sha512 = hashes.hmac_sha512,
+};
diff --git a/util/http.lua b/util/http.lua
index cfb89193..3852f91c 100644
--- a/util/http.lua
+++ b/util/http.lua
@@ -6,24 +6,26 @@
--
local format, char = string.format, string.char;
-local pairs, ipairs, tonumber = pairs, ipairs, tonumber;
+local pairs, ipairs = pairs, ipairs;
local t_insert, t_concat = table.insert, table.concat;
+local url_codes = {};
+for i = 0, 255 do
+ local c = char(i);
+ local u = format("%%%02x", i);
+ url_codes[c] = u;
+ url_codes[u] = c;
+ url_codes[u:upper()] = c;
+end
local function urlencode(s)
- return s and (s:gsub("[^a-zA-Z0-9.~_-]", function (c) return format("%%%02x", c:byte()); end));
+ return s and (s:gsub("[^a-zA-Z0-9.~_-]", url_codes));
end
local function urldecode(s)
- return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end));
+ return s and (s:gsub("%%%x%x", url_codes));
end
local function _formencodepart(s)
- return s and (s:gsub("%W", function (c)
- if c ~= " " then
- return format("%%%02x", c:byte());
- else
- return "+";
- end
- end));
+ return s and (urlencode(s):gsub("%%20", "+"));
end
local function formencode(form)
diff --git a/util/import.lua b/util/import.lua
index 8ecfe43c..1007bc0a 100644
--- a/util/import.lua
+++ b/util/import.lua
@@ -8,7 +8,7 @@
-local unpack = table.unpack or unpack; --luacheck: ignore 113 143
+local unpack = table.unpack or unpack; --luacheck: ignore 113
local t_insert = table.insert;
function _G.import(module, ...)
local m = package.loaded[module] or require(module);
diff --git a/util/iterators.lua b/util/iterators.lua
index 302cca36..c03c2fd6 100644
--- a/util/iterators.lua
+++ b/util/iterators.lua
@@ -11,9 +11,9 @@
local it = {};
local t_insert = table.insert;
-local select, next = select, next;
-local unpack = table.unpack or unpack; --luacheck: ignore 113 143
-local pack = table.pack or function (...) return { n = select("#", ...), ... }; end -- luacheck: ignore 143
+local next = next;
+local unpack = table.unpack or unpack; --luacheck: ignore 113
+local pack = table.pack or require "util.table".pack;
local type = type;
local table, setmetatable = table, setmetatable;
diff --git a/util/jid.lua b/util/jid.lua
index ec31f180..1ddf33d4 100644
--- a/util/jid.lua
+++ b/util/jid.lua
@@ -45,20 +45,20 @@ local function bare(jid)
return host;
end
-local function prepped_split(jid)
+local function prepped_split(jid, strict)
local node, host, resource = split(jid);
if host and host ~= "." then
if sub(host, -1, -1) == "." then -- Strip empty root label
host = sub(host, 1, -2);
end
- host = nameprep(host);
+ host = nameprep(host, strict);
if not host then return; end
if node then
- node = nodeprep(node);
+ node = nodeprep(node, strict);
if not node then return; end
end
if resource then
- resource = resourceprep(resource);
+ resource = resourceprep(resource, strict);
if not resource then return; end
end
return node, host, resource;
@@ -77,8 +77,8 @@ local function join(node, host, resource)
return host;
end
-local function prep(jid)
- local node, host, resource = prepped_split(jid);
+local function prep(jid, strict)
+ local node, host, resource = prepped_split(jid, strict);
return join(node, host, resource);
end
diff --git a/util/multitable.lua b/util/multitable.lua
index 8d32ed8a..4f2cd972 100644
--- a/util/multitable.lua
+++ b/util/multitable.lua
@@ -9,7 +9,7 @@
local select = select;
local t_insert = table.insert;
local pairs, next, type = pairs, next, type;
-local unpack = table.unpack or unpack; --luacheck: ignore 113 143
+local unpack = table.unpack or unpack; --luacheck: ignore 113
local _ENV = nil;
-- luacheck: std none
diff --git a/util/paths.lua b/util/paths.lua
index 89f4cad9..036f315a 100644
--- a/util/paths.lua
+++ b/util/paths.lua
@@ -41,4 +41,20 @@ function path_util.join(...)
return t_concat({...}, path_sep);
end
+function path_util.complement_lua_path(installer_plugin_path)
+ -- Checking for duplicates
+ -- The commands using luarocks need the path to the directory that has the /share and /lib folders.
+ local lua_version = _VERSION:match(" (.+)$");
+ local lua_path_sep = package.config:sub(3,3);
+ local dir_sep = package.config:sub(1,1);
+ local sub_path = dir_sep.."lua"..dir_sep..lua_version..dir_sep;
+ if not string.find(package.path, installer_plugin_path, 1, true) then
+ package.path = package.path..lua_path_sep..installer_plugin_path..dir_sep.."share"..sub_path.."?.lua";
+ package.path = package.path..lua_path_sep..installer_plugin_path..dir_sep.."share"..sub_path.."?"..dir_sep.."init.lua";
+ end
+ if not string.find(package.path, installer_plugin_path, 1, true) then
+ package.cpath = package.cpath..lua_path_sep..installer_plugin_path..dir_sep.."lib"..sub_path.."?.so";
+ end
+end
+
return path_util;
diff --git a/util/pluginloader.lua b/util/pluginloader.lua
index 9ab8f245..af0428c4 100644
--- a/util/pluginloader.lua
+++ b/util/pluginloader.lua
@@ -36,12 +36,13 @@ end
local function load_resource(plugin, resource)
resource = resource or "mod_"..plugin..".lua";
-
+ local lua_version = _VERSION:match(" (.+)$");
local names = {
"mod_"..plugin..dir_sep..plugin..dir_sep..resource; -- mod_hello/hello/mod_hello.lua
"mod_"..plugin..dir_sep..resource; -- mod_hello/mod_hello.lua
plugin..dir_sep..resource; -- hello/mod_hello.lua
resource; -- mod_hello.lua
+ "share"..dir_sep.."lua"..dir_sep..lua_version..dir_sep.."mod_"..plugin..dir_sep..resource;
};
return load_file(names);
diff --git a/util/promise.lua b/util/promise.lua
index 07c9c4dc..0b182b54 100644
--- a/util/promise.lua
+++ b/util/promise.lua
@@ -49,6 +49,9 @@ local function promise_settle(promise, new_state, new_next, cbs, value)
for _, cb in ipairs(cbs) do
cb(value);
end
+ -- No need to keep references to callbacks
+ promise._pending_on_fulfilled = nil;
+ promise._pending_on_rejected = nil;
return true;
end
diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua
index 5f0c4d12..586802d3 100644
--- a/util/prosodyctl.lua
+++ b/util/prosodyctl.lua
@@ -39,6 +39,16 @@ local function show_usage(usage, desc)
end
end
+local function show_module_configuration_help(mod_name)
+ print("Done.")
+ print("If you installed a prosody plugin, don't forget to add its name under the 'modules_enabled' section inside your configuration file.")
+ print("Depending on the module, there might be further configuration steps required.")
+ print("")
+ print("More info about: ")
+ print(" modules_enabled: https://prosody.im/doc/modules_enabled")
+ print(" "..mod_name..": https://modules.prosody.im/"..mod_name..".html")
+end
+
local function getchar(n)
local stty_ret = os.execute("stty raw -echo 2>/dev/null");
local ok, char;
@@ -124,7 +134,7 @@ end
-- Server control
local function adduser(params)
- local user, host, password = nodeprep(params.user), nameprep(params.host), params.password;
+ local user, host, password = nodeprep(params.user, true), nameprep(params.host), params.password;
if not user then
return false, "invalid-username";
elseif not host then
@@ -229,7 +239,8 @@ local function isrunning()
return true, signal.kill(pid, 0) == 0;
end
-local function start(source_dir)
+local function start(source_dir, lua)
+ lua = lua and lua .. " " or "";
local ok, ret = isrunning();
if not ok then
return ok, ret;
@@ -238,9 +249,9 @@ local function start(source_dir)
return false, "already-running";
end
if not source_dir then
- os.execute("./prosody");
+ os.execute(lua .. "./prosody");
else
- os.execute(source_dir.."/../../bin/prosody");
+ os.execute(lua .. source_dir.."/../../bin/prosody");
end
return true;
end
@@ -277,10 +288,36 @@ local function reload()
return true;
end
+local function get_path_custom_plugins(plugin_paths)
+ -- I'm considering that the custom plugins' path is the first one at prosody.paths.plugins
+ -- luacheck: ignore 512
+ for path in plugin_paths:gmatch("[^;]+") do
+ return path;
+ end
+end
+
+local function call_luarocks(mod, operation)
+ local dir = get_path_custom_plugins(prosody.paths.plugins);
+ if operation == "install" then
+ show_message("Installing %s at %s", mod, dir);
+ elseif operation == "remove" then
+ show_message("Removing %s from %s", mod, dir);
+ end
+ if operation == "list" then
+ os.execute("luarocks list --tree='"..dir.."'")
+ else
+ os.execute("luarocks --tree='"..dir.."' --server='http://localhost/' "..operation.." "..mod);
+ end
+ if operation == "install" then
+ show_module_configuration_help(mod);
+ end
+end
+
return {
show_message = show_message;
show_warning = show_message;
show_usage = show_usage;
+ show_module_configuration_help = show_module_configuration_help;
getchar = getchar;
getline = getline;
getpass = getpass;
@@ -296,4 +333,6 @@ return {
start = start;
stop = stop;
reload = reload;
+ get_path_custom_plugins = get_path_custom_plugins;
+ call_luarocks = call_luarocks;
};
diff --git a/util/pubsub.lua b/util/pubsub.lua
index e5e0cb7c..8a07c669 100644
--- a/util/pubsub.lua
+++ b/util/pubsub.lua
@@ -1,5 +1,6 @@
local events = require "util.events";
local cache = require "util.cache";
+local errors = require "util.error";
local service_mt = {};
@@ -510,7 +511,7 @@ local function check_preconditions(node_config, required_config)
end
for config_field, value in pairs(required_config) do
if node_config[config_field] ~= value then
- return false;
+ return false, config_field;
end
end
return true;
@@ -546,8 +547,13 @@ function service:publish(node, actor, id, item, requested_config) --> ok, err
node_obj = self.nodes[node];
elseif requested_config and not requested_config._defaults_only then
-- Check that node has the requested config before we publish
- if not check_preconditions(node_obj.config, requested_config) then
- return false, "precondition-not-met";
+ local ok, field = check_preconditions(node_obj.config, requested_config);
+ if not ok then
+ local err = errors.new({
+ type = "cancel", condition = "conflict", text = "Field does not match: "..field;
+ });
+ err.pubsub_condition = "precondition-not-met";
+ return false, err;
end
end
if not self.config.itemcheck(item) then
diff --git a/util/queue.lua b/util/queue.lua
index 728e905f..66ed098b 100644
--- a/util/queue.lua
+++ b/util/queue.lua
@@ -52,18 +52,20 @@ local function new(size, allow_wrapping)
return t[tail];
end;
items = function (self)
- --luacheck: ignore 431/t
- return function (t, pos)
- if pos >= t:count() then
+ return function (_, pos)
+ if pos >= items then
return nil;
end
local read_pos = tail + pos;
- if read_pos > t.size then
+ if read_pos > self.size then
read_pos = (read_pos%size);
end
- return pos+1, t._items[read_pos];
+ return pos+1, t[read_pos];
end, self, 0;
end;
+ consume = function (self)
+ return self.pop, self;
+ end;
};
end
diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua
index 043f328b..1d1590e8 100644
--- a/util/sasl/scram.lua
+++ b/util/sasl/scram.lua
@@ -14,9 +14,7 @@
local s_match = string.match;
local type = type
local base64 = require "util.encodings".base64;
-local hmac_sha1 = require "util.hashes".hmac_sha1;
-local sha1 = require "util.hashes".sha1;
-local Hi = require "util.hashes".scram_Hi_sha1;
+local hashes = require "util.hashes";
local generate_uuid = require "util.uuid".generate;
local saslprep = require "util.encodings".stringprep.saslprep;
local nodeprep = require "util.encodings".stringprep.nodeprep;
@@ -99,20 +97,22 @@ local function hashprep(hashname)
return hashname:lower():gsub("-", "_");
end
-local function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
- if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
- return false, "inappropriate argument types"
- end
- if iteration_count < 4096 then
- log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
+local function get_scram_hasher(H, HMAC, Hi)
+ return function (password, salt, iteration_count)
+ if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
+ return false, "inappropriate argument types"
+ end
+ if iteration_count < 4096 then
+ log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
+ end
+ local salted_password = Hi(password, salt, iteration_count);
+ local stored_key = H(HMAC(salted_password, "Client Key"))
+ local server_key = HMAC(salted_password, "Server Key");
+ return true, stored_key, server_key
end
- local salted_password = Hi(password, salt, iteration_count);
- local stored_key = sha1(hmac_sha1(salted_password, "Client Key"))
- local server_key = hmac_sha1(salted_password, "Server Key");
- return true, stored_key, server_key
end
-local function scram_gen(hash_name, H_f, HMAC_f)
+local function scram_gen(hash_name, H_f, HMAC_f, get_auth_db)
local profile_name = "scram_" .. hashprep(hash_name);
local function scram_hash(self, message)
local support_channel_binding = false;
@@ -177,7 +177,7 @@ local function scram_gen(hash_name, H_f, HMAC_f)
iteration_count = default_i;
local succ;
- succ, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count);
+ succ, stored_key, server_key = get_auth_db(password, salt, iteration_count);
if not succ then
log("error", "Generating authentication database failed. Reason: %s", stored_key);
return "failure", "temporary-auth-failure";
@@ -190,7 +190,7 @@ local function scram_gen(hash_name, H_f, HMAC_f)
end
local nonce = clientnonce .. generate_uuid();
- local server_first_message = "r="..nonce..",s="..base64.encode(salt)..",i="..iteration_count;
+ local server_first_message = ("r=%s,s=%s,i=%d"):format(nonce, base64.encode(salt), iteration_count);
self.state = {
gs2_header = gs2_header;
gs2_cbind_name = gs2_cbind_name;
@@ -247,22 +247,28 @@ local function scram_gen(hash_name, H_f, HMAC_f)
return scram_hash;
end
+local auth_db_getters = {}
local function init(registerMechanism)
- local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
+ local function registerSCRAMMechanism(hash_name, hash, hmac_hash, pbkdf2)
+ local get_auth_db = get_scram_hasher(hash, hmac_hash, pbkdf2);
+ auth_db_getters[hash_name] = get_auth_db;
registerMechanism("SCRAM-"..hash_name,
{"plain", "scram_"..(hashprep(hash_name))},
- scram_gen(hash_name:lower(), hash, hmac_hash));
+ scram_gen(hash_name:lower(), hash, hmac_hash, get_auth_db));
-- register channel binding equivalent
registerMechanism("SCRAM-"..hash_name.."-PLUS",
{"plain", "scram_"..(hashprep(hash_name))},
- scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
+ scram_gen(hash_name:lower(), hash, hmac_hash, get_auth_db), {"tls-unique"});
end
- registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
+ registerSCRAMMechanism("SHA-1", hashes.sha1, hashes.hmac_sha1, hashes.pbkdf2_hmac_sha1);
+ registerSCRAMMechanism("SHA-256", hashes.sha256, hashes.hmac_sha256, hashes.pbkdf2_hmac_sha256);
end
return {
- getAuthenticationDatabaseSHA1 = getAuthenticationDatabaseSHA1;
+ get_hash = get_scram_hasher;
+ hashers = auth_db_getters;
+ getAuthenticationDatabaseSHA1 = get_scram_hasher(hashes.sha1, hashes.hmac_sha1, hashes.pbkdf2_hmac_sha1); -- COMPAT
init = init;
}
diff --git a/util/serialization.lua b/util/serialization.lua
index 5121a9f9..d70e92ba 100644
--- a/util/serialization.lua
+++ b/util/serialization.lua
@@ -16,22 +16,18 @@ local s_char = string.char;
local s_match = string.match;
local t_concat = table.concat;
+local to_hex = require "util.hex".to;
+
local pcall = pcall;
local envload = require"util.envload".envload;
local pos_inf, neg_inf = math.huge, -math.huge;
--- luacheck: ignore 143/math
local m_type = math.type or function (n)
return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float";
end;
-local char_to_hex = {};
-for i = 0,255 do
- char_to_hex[s_char(i)] = s_format("%02x", i);
-end
-
-local function to_hex(s)
- return (s_gsub(s, ".", char_to_hex));
+local function rawpairs(t)
+ return next, t, nil;
end
local function fatal_error(obj, why)
@@ -123,6 +119,7 @@ local function new(opt)
local freeze = opt.freeze;
local maxdepth = opt.maxdepth or 127;
local multirefs = opt.multiref;
+ local table_pairs = opt.table_iterator or rawpairs;
-- serialize one table, recursively
-- t - table being serialized
@@ -164,7 +161,9 @@ local function new(opt)
local indent = s_rep(indentwith, d);
local numkey = 1;
local ktyp, vtyp;
- for k,v in next,t do
+ local had_items = false;
+ for k,v in table_pairs(t) do
+ had_items = true;
o[l], l = itemstart, l + 1;
o[l], l = indent, l + 1;
ktyp, vtyp = type(k), type(v);
@@ -195,14 +194,10 @@ local function new(opt)
else
o[l], l = ser(v), l + 1;
end
- -- last item?
- if next(t, k) ~= nil then
- o[l], l = itemsep, l + 1;
- else
- o[l], l = itemlast, l + 1;
- end
+ o[l], l = itemsep, l + 1;
end
- if next(t) ~= nil then
+ if had_items then
+ o[l - 1] = itemlast;
o[l], l = s_rep(indentwith, d-1), l + 1;
end
o[l], l = tend, l +1;
diff --git a/util/session.lua b/util/session.lua
index b2a726ce..25b22faf 100644
--- a/util/session.lua
+++ b/util/session.lua
@@ -4,12 +4,13 @@ local logger = require "util.logger";
local function new_session(typ)
local session = {
type = typ .. "_unauthed";
+ base_type = typ;
};
return session;
end
local function set_id(session)
- local id = session.type .. tostring(session):match("%x+$"):lower();
+ local id = session.base_type .. tostring(session):match("%x+$"):lower();
session.id = id;
return session;
end
@@ -30,7 +31,7 @@ local function set_send(session)
local conn = session.conn;
if not conn then
function session.send(data)
- session.log("debug", "Discarding data sent to unconnected session: %s", tostring(data));
+ session.log("debug", "Discarding data sent to unconnected session: %s", data);
return false;
end
return session;
@@ -46,7 +47,7 @@ local function set_send(session)
if t then
local ret, err = w(conn, t);
if not ret then
- session.log("debug", "Error writing to connection: %s", tostring(err));
+ session.log("debug", "Error writing to connection: %s", err);
return false, err;
end
end
diff --git a/util/sql.lua b/util/sql.lua
index 00c7b57f..86740b1c 100644
--- a/util/sql.lua
+++ b/util/sql.lua
@@ -201,31 +201,31 @@ function engine:_transaction(func, ...)
if not ok then return ok, err; end
end
--assert(not self.__transaction, "Recursive transactions not allowed");
- log("debug", "SQL transaction begin [%s]", tostring(func));
+ log("debug", "SQL transaction begin [%s]", func);
self.__transaction = true;
local success, a, b, c = xpcall(func, handleerr, ...);
self.__transaction = nil;
if success then
- log("debug", "SQL transaction success [%s]", tostring(func));
+ log("debug", "SQL transaction success [%s]", func);
local ok, err = self.conn:commit();
-- LuaDBI doesn't actually return an error message here, just a boolean
if not ok then return ok, err or "commit failed"; end
return success, a, b, c;
else
- log("debug", "SQL transaction failure [%s]: %s", tostring(func), a.err);
+ log("debug", "SQL transaction failure [%s]: %s", func, a.err);
if self.conn then self.conn:rollback(); end
return success, a.err;
end
end
function engine:transaction(...)
- local ok, ret = self:_transaction(...);
+ local ok, ret, b, c = self:_transaction(...);
if not ok then
local conn = self.conn;
if not conn or not conn:ping() then
log("debug", "Database connection was closed. Will reconnect and retry.");
self.conn = nil;
- log("debug", "Retrying SQL transaction [%s]", tostring((...)));
- ok, ret = self:_transaction(...);
+ log("debug", "Retrying SQL transaction [%s]", (...));
+ ok, ret, b, c = self:_transaction(...);
log("debug", "SQL transaction retry %s", ok and "succeeded" or "failed");
else
log("debug", "SQL connection is up, so not retrying");
@@ -234,7 +234,7 @@ function engine:transaction(...)
log("error", "Error in SQL transaction: %s", ret);
end
end
- return ok, ret;
+ return ok, ret, b, c;
end
function engine:_create_index(index)
local sql = "CREATE INDEX \""..index.name.."\" ON \""..index.table.."\" (";
diff --git a/util/stanza.lua b/util/stanza.lua
index a90d56b3..55c38c73 100644
--- a/util/stanza.lua
+++ b/util/stanza.lua
@@ -98,7 +98,7 @@ function stanza_mt:query(xmlns)
end
function stanza_mt:body(text, attr)
- return self:tag("body", attr):text(text);
+ return self:text_tag("body", text, attr);
end
function stanza_mt:text_tag(name, text, attr, namespaces)
@@ -270,6 +270,34 @@ function stanza_mt:find(path)
until not self
end
+local function _clone(stanza, only_top)
+ local attr, tags = {}, {};
+ for k,v in pairs(stanza.attr) do attr[k] = v; end
+ local old_namespaces, namespaces = stanza.namespaces;
+ if old_namespaces then
+ namespaces = {};
+ for k,v in pairs(old_namespaces) do namespaces[k] = v; end
+ end
+ local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
+ if not only_top then
+ for i=1,#stanza do
+ local child = stanza[i];
+ if child.name then
+ child = _clone(child);
+ t_insert(tags, child);
+ end
+ t_insert(new, child);
+ end
+ end
+ return setmetatable(new, stanza_mt);
+end
+
+local function clone(stanza, only_top)
+ if not is_stanza(stanza) then
+ error("bad argument to clone: expected stanza, got "..type(stanza));
+ end
+ return _clone(stanza, only_top);
+end
local escape_table = { ["'"] = "&apos;", ["\""] = "&quot;", ["<"] = "&lt;", [">"] = "&gt;", ["&"] = "&amp;" };
local function xml_escape(str) return (s_gsub(str, "['&<>\"]", escape_table)); end
@@ -310,11 +338,8 @@ function stanza_mt.__tostring(t)
end
function stanza_mt.top_tag(t)
- local attr_string = "";
- if t.attr then
- for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(" %s='%s'", k, xml_escape(tostring(v))); end end
- end
- return s_format("<%s%s>", t.name, attr_string);
+ local top_tag_clone = clone(t, true);
+ return tostring(top_tag_clone):sub(1,-3)..">";
end
function stanza_mt.get_text(t)
@@ -388,44 +413,23 @@ local function deserialize(serialized)
end
end
-local function _clone(stanza)
- local attr, tags = {}, {};
- for k,v in pairs(stanza.attr) do attr[k] = v; end
- local old_namespaces, namespaces = stanza.namespaces;
- if old_namespaces then
- namespaces = {};
- for k,v in pairs(old_namespaces) do namespaces[k] = v; end
- end
- local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
- for i=1,#stanza do
- local child = stanza[i];
- if child.name then
- child = _clone(child);
- t_insert(tags, child);
- end
- t_insert(new, child);
- end
- return setmetatable(new, stanza_mt);
-end
-
-local function clone(stanza)
- if not is_stanza(stanza) then
- error("bad argument to clone: expected stanza, got "..type(stanza));
- end
- return _clone(stanza);
-end
-
local function message(attr, body)
if not body then
return new_stanza("message", attr);
else
- return new_stanza("message", attr):tag("body"):text(body):up();
+ return new_stanza("message", attr):text_tag("body", body);
end
end
local function iq(attr)
- if not (attr and attr.id) then
+ if not attr then
+ error("iq stanzas require id and type attributes");
+ end
+ if not attr.id then
error("iq stanzas require an id attribute");
end
+ if not attr.type then
+ error("iq stanzas require a type attribute");
+ end
return new_stanza("iq", attr);
end
@@ -445,7 +449,7 @@ local function error_reply(orig, error_type, condition, error_message)
t.attr.type = "error";
t:tag("error", {type = error_type}) --COMPAT: Some day xmlns:stanzas goes here
:tag(condition, xmpp_stanzas_attr):up();
- if error_message then t:tag("text", xmpp_stanzas_attr):text(error_message):up(); end
+ if error_message then t:text_tag("text", error_message, xmpp_stanzas_attr); end
return t; -- stanza ready for adding app-specific errors
end
diff --git a/util/startup.lua b/util/startup.lua
index e88ed709..8e6d89e6 100644
--- a/util/startup.lua
+++ b/util/startup.lua
@@ -7,6 +7,7 @@ local logger = require "util.logger";
local log = logger.init("startup");
local config = require "core.configmanager";
+local config_warnings;
local dependencies = require "util.dependencies";
@@ -65,6 +66,8 @@ function startup.read_config()
print("**************************");
print("");
os.exit(1);
+ elseif err and #err > 0 then
+ config_warnings = err;
end
prosody.config_loaded = true;
end
@@ -97,8 +100,13 @@ function startup.init_logging()
end);
end
-function startup.log_dependency_warnings()
+function startup.log_startup_warnings()
dependencies.log_warnings();
+ if config_warnings then
+ for _, warning in ipairs(config_warnings) do
+ log("warn", "Configuration warning: %s", warning);
+ end
+ end
end
function startup.sanity_check()
@@ -220,8 +228,8 @@ end
function startup.setup_plugindir()
local custom_plugin_paths = config.get("*", "plugin_paths");
+ local path_sep = package.config:sub(3,3);
if custom_plugin_paths then
- local path_sep = package.config:sub(3,3);
-- path1;path2;path3;defaultpath...
-- luacheck: ignore 111
CFG_PLUGINDIR = table.concat(custom_plugin_paths, path_sep)..path_sep..(CFG_PLUGINDIR or "plugins");
@@ -229,6 +237,19 @@ function startup.setup_plugindir()
end
end
+function startup.setup_plugin_install_path()
+ local installer_plugin_path = config.get("*", "installer_plugin_path") or "custom_plugins";
+ local path_sep = package.config:sub(3,3);
+ -- TODO Figure out what this should be relative to, because CWD could be anywhere
+ installer_plugin_path = config.resolve_relative_path(require "lfs".currentdir(), installer_plugin_path);
+ -- TODO Can probably move directory creation to the install command
+ require "lfs".mkdir(installer_plugin_path);
+ require"util.paths".complement_lua_path(installer_plugin_path);
+ -- luacheck: ignore 111
+ CFG_PLUGINDIR = installer_plugin_path..path_sep..(CFG_PLUGINDIR or "plugins");
+ prosody.paths.plugins = CFG_PLUGINDIR;
+end
+
function startup.chdir()
if prosody.installed then
local lfs = require "lfs";
@@ -250,9 +271,9 @@ function startup.add_global_prosody_functions()
local ok, level, err = config.load(prosody.config_file);
if not ok then
if level == "parser" then
- log("error", "There was an error parsing the configuration file: %s", tostring(err));
+ log("error", "There was an error parsing the configuration file: %s", err);
elseif level == "file" then
- log("error", "Couldn't read the config file when trying to reload: %s", tostring(err));
+ log("error", "Couldn't read the config file when trying to reload: %s", err);
end
else
prosody.events.fire_event("config-reloaded", {
@@ -520,12 +541,13 @@ function startup.prosodyctl()
startup.force_console_logging();
startup.init_logging();
startup.setup_plugindir();
+ -- startup.setup_plugin_install_path();
startup.setup_datadir();
startup.chdir();
startup.read_version();
startup.switch_user();
startup.check_dependencies();
- startup.log_dependency_warnings();
+ startup.log_startup_warnings();
startup.check_unwriteable();
startup.load_libraries();
startup.init_http_client();
@@ -545,12 +567,13 @@ function startup.prosody()
startup.init_logging();
startup.load_libraries();
startup.setup_plugindir();
+ -- startup.setup_plugin_install_path();
startup.setup_datadir();
startup.chdir();
startup.add_global_prosody_functions();
startup.read_version();
startup.log_greeting();
- startup.log_dependency_warnings();
+ startup.log_startup_warnings();
startup.load_secondary_libraries();
startup.init_http_client();
startup.init_data_store();
diff --git a/util/statistics.lua b/util/statistics.lua
index 39954652..0ec88e21 100644
--- a/util/statistics.lua
+++ b/util/statistics.lua
@@ -57,12 +57,14 @@ local function new_registry(config)
end;
end;
rate = function (name)
- local since, n = time(), 0;
+ local since, n, total = time(), 0, 0;
registry[name..":rate"] = function ()
+ total = total + n;
local t = time();
local stats = {
rate = n/(t-since);
count = n;
+ total = total;
};
since, n = t, 0;
return "rate", stats.rate, stats;
diff --git a/util/x509.lua b/util/x509.lua
index 15cc4d3c..fe6e4b79 100644
--- a/util/x509.lua
+++ b/util/x509.lua
@@ -20,9 +20,12 @@
local nameprep = require "util.encodings".stringprep.nameprep;
local idna_to_ascii = require "util.encodings".idna.to_ascii;
+local idna_to_unicode = require "util.encodings".idna.to_unicode;
local base64 = require "util.encodings".base64;
local log = require "util.logger".init("x509");
+local mt = require "util.multitable";
local s_format = string.format;
+local ipairs = ipairs;
local _ENV = nil;
-- luacheck: std none
@@ -216,6 +219,60 @@ local function verify_identity(host, service, cert)
return false
end
+-- TODO Support other SANs
+local function get_identities(cert) --> map of names to sets of services
+ if cert.setencode then
+ cert:setencode("utf8");
+ end
+
+ local names = mt.new();
+
+ local ext = cert:extensions();
+ local sans = ext[oid_subjectaltname];
+ if sans then
+ if sans["dNSName"] then -- Valid for any service
+ for _, name in ipairs(sans["dNSName"]) do
+ name = idna_to_unicode(nameprep(name));
+ if name then
+ names:set(name, "*", true);
+ end
+ end
+ end
+ if sans[oid_xmppaddr] then
+ for _, name in ipairs(sans[oid_xmppaddr]) do
+ name = nameprep(name);
+ if name then
+ names:set(name, "xmpp-client", true);
+ names:set(name, "xmpp-server", true);
+ end
+ end
+ end
+ if sans[oid_dnssrv] then
+ for _, srvname in ipairs(sans[oid_dnssrv]) do
+ local srv, name = srvname:match("^_([^.]+)%.(.*)");
+ if srv then
+ name = nameprep(name);
+ if name then
+ names:set(name, srv, true);
+ end
+ end
+ end
+ end
+ end
+
+ local subject = cert:subject();
+ for i = 1, #subject do
+ local dn = subject[i];
+ if dn.oid == oid_commonname then
+ local name = nameprep(dn.value);
+ if name and idna_to_ascii(name) then
+ names:set("*", name, true);
+ end
+ end
+ end
+ return names.data;
+end
+
local pat = "%-%-%-%-%-BEGIN ([A-Z ]+)%-%-%-%-%-\r?\n"..
"([0-9A-Za-z+/=\r\n]*)\r?\n%-%-%-%-%-END %1%-%-%-%-%-";
@@ -237,6 +294,7 @@ end
return {
verify_identity = verify_identity;
+ get_identities = get_identities;
pem2der = pem2der;
der2pem = der2pem;
};
diff --git a/util/xmppstream.lua b/util/xmppstream.lua
index 58cbd18e..6aa1def3 100644
--- a/util/xmppstream.lua
+++ b/util/xmppstream.lua
@@ -64,6 +64,8 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
local stream_default_ns = stream_callbacks.default_ns;
+ local stream_lang = "en";
+
local stack = {};
local chardata, stanza = {};
local stanza_size = 0;
@@ -101,6 +103,7 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
if session.notopen then
if tagname == stream_tag then
non_streamns_depth = 0;
+ stream_lang = attr["xml:lang"] or stream_lang;
if cb_streamopened then
if lxp_supports_bytecount then
cb_handleprogress(stanza_size);
@@ -178,6 +181,9 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
cb_handleprogress(stanza_size);
end
stanza_size = 0;
+ if stanza.attr["xml:lang"] == nil then
+ stanza.attr["xml:lang"] = stream_lang;
+ end
if tagname ~= stream_error_tag then
cb_handlestanza(session, stanza);
else