aboutsummaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/adhoc.lua4
-rw-r--r--util/argparse.lua58
-rw-r--r--util/array.lua5
-rw-r--r--util/async.lua16
-rw-r--r--util/bit53.lua7
-rw-r--r--util/bitcompat.lua32
-rw-r--r--util/datamanager.lua5
-rw-r--r--util/dependencies.lua23
-rw-r--r--util/error.lua60
-rw-r--r--util/format.lua27
-rw-r--r--util/hashring.lua88
-rw-r--r--util/hmac.lua9
-rw-r--r--util/http.lua22
-rw-r--r--util/import.lua2
-rw-r--r--util/ip.lua8
-rw-r--r--util/iterators.lua6
-rw-r--r--util/jid.lua12
-rw-r--r--util/jwt.lua50
-rw-r--r--util/mercurial.lua2
-rw-r--r--util/multitable.lua2
-rw-r--r--util/paths.lua16
-rw-r--r--util/pluginloader.lua3
-rw-r--r--util/promise.lua3
-rw-r--r--util/prosodyctl.lua51
-rw-r--r--util/pubsub.lua27
-rw-r--r--util/queue.lua12
-rw-r--r--util/sasl/scram.lua49
-rw-r--r--util/serialization.lua27
-rw-r--r--util/session.lua7
-rw-r--r--util/set.lua6
-rw-r--r--util/sql.lua17
-rw-r--r--util/stanza.lua96
-rw-r--r--util/startup.lua93
-rw-r--r--util/statistics.lua4
-rw-r--r--util/termcolours.lua2
-rw-r--r--util/x509.lua58
-rw-r--r--util/xmppstream.lua6
37 files changed, 703 insertions, 212 deletions
diff --git a/util/adhoc.lua b/util/adhoc.lua
index d81b8242..19c94c34 100644
--- a/util/adhoc.lua
+++ b/util/adhoc.lua
@@ -2,7 +2,7 @@
local function new_simple_form(form, result_handler)
return function(self, data, state)
- if state then
+ if state or data.form then
if data.action == "cancel" then
return { status = "canceled" };
end
@@ -16,7 +16,7 @@ end
local function new_initial_data_form(form, initial_data, result_handler)
return function(self, data, state)
- if state then
+ if state or data.form then
if data.action == "cancel" then
return { status = "canceled" };
end
diff --git a/util/argparse.lua b/util/argparse.lua
new file mode 100644
index 00000000..928fc3eb
--- /dev/null
+++ b/util/argparse.lua
@@ -0,0 +1,58 @@
+local function parse(arg, config)
+ local short_params = config and config.short_params or {};
+ local value_params = config and config.value_params or {};
+
+ local parsed_opts = {};
+
+ if #arg == 0 then
+ return parsed_opts;
+ end
+ while true do
+ local raw_param = arg[1];
+ if not raw_param then
+ break;
+ end
+
+ local prefix = raw_param:match("^%-%-?");
+ if not prefix then
+ break;
+ elseif prefix == "--" and raw_param == "--" then
+ table.remove(arg, 1);
+ break;
+ end
+ local param = table.remove(arg, 1):sub(#prefix+1);
+ if #param == 1 and short_params then
+ param = short_params[param];
+ end
+
+ if not param then
+ print("Unknown command-line option: "..tostring(param));
+ print("Perhaps you meant to use prosodyctl instead?");
+ os.exit(1);
+ end
+
+ local param_k, param_v;
+ if value_params[param] then
+ param_k, param_v = param, table.remove(arg, 1);
+ if not param_v then
+ print("Expected a value to follow command-line option: "..raw_param);
+ os.exit(1);
+ end
+ else
+ param_k, param_v = param:match("^([^=]+)=(.+)$");
+ if not param_k then
+ if param:match("^no%-") then
+ param_k, param_v = param:sub(4), false;
+ else
+ param_k, param_v = param, true;
+ end
+ end
+ end
+ parsed_opts[param_k] = param_v;
+ end
+ return parsed_opts;
+end
+
+return {
+ parse = parse;
+}
diff --git a/util/array.lua b/util/array.lua
index 0b60a4fd..32d2d6a5 100644
--- a/util/array.lua
+++ b/util/array.lua
@@ -10,6 +10,7 @@ local t_insert, t_sort, t_remove, t_concat
= table.insert, table.sort, table.remove, table.concat;
local setmetatable = setmetatable;
+local getmetatable = getmetatable;
local math_random = math.random;
local math_floor = math.floor;
local pairs, ipairs = pairs, ipairs;
@@ -40,6 +41,10 @@ function array_mt.__add(a1, a2)
end
function array_mt.__eq(a, b)
+ if getmetatable(a) ~= array_mt or getmetatable(b) ~= array_mt then
+ -- Lua 5.3+ calls this if both operands are tables, even if metatables differ
+ return false;
+ end
if #a == #b then
for i = 1, #a do
if a[i] ~= b[i] then
diff --git a/util/async.lua b/util/async.lua
index 20397785..d338071f 100644
--- a/util/async.lua
+++ b/util/async.lua
@@ -246,9 +246,25 @@ local function ready()
return pcall(checkthread);
end
+local function wait(promise)
+ local async_wait, async_done = waiter();
+ local ret, err = nil, nil;
+ promise:next(
+ function (r) ret = r; end,
+ function (e) err = e; end)
+ :finally(async_done);
+ async_wait();
+ if ret then
+ return ret;
+ else
+ return nil, err;
+ end
+end
+
return {
ready = ready;
waiter = waiter;
guarder = guarder;
runner = runner;
+ wait = wait;
};
diff --git a/util/bit53.lua b/util/bit53.lua
new file mode 100644
index 00000000..4b5c2e9c
--- /dev/null
+++ b/util/bit53.lua
@@ -0,0 +1,7 @@
+-- Only the operators needed by net.websocket.frames are provided at this point
+return {
+ band = function (a, b) return a & b end;
+ bor = function (a, b) return a | b end;
+ bxor = function (a, b) return a ~ b end;
+};
+
diff --git a/util/bitcompat.lua b/util/bitcompat.lua
new file mode 100644
index 00000000..454181af
--- /dev/null
+++ b/util/bitcompat.lua
@@ -0,0 +1,32 @@
+-- Compatibility layer for bitwise operations
+
+-- First try the bit32 lib
+-- Lua 5.3 has it with compat enabled
+-- Lua 5.2 has it by default
+if _G.bit32 then
+ return _G.bit32;
+else
+ -- Lua 5.1 may have it as a standalone module that can be installed
+ local ok, bitop = pcall(require, "bit32")
+ if ok then
+ return bitop;
+ end
+end
+
+do
+ -- Lua 5.3 and 5.4 would be able to use native infix operators
+ local ok, bitop = pcall(require, "util.bit53")
+ if ok then
+ return bitop;
+ end
+end
+
+do
+ -- Lastly, try the LuaJIT bitop library
+ local ok, bitop = pcall(require, "bit")
+ if ok then
+ return bitop;
+ end
+end
+
+error "No bit module found. See https://prosody.im/doc/depends#bitop";
diff --git a/util/datamanager.lua b/util/datamanager.lua
index 0d7060b7..26dede08 100644
--- a/util/datamanager.lua
+++ b/util/datamanager.lua
@@ -24,7 +24,7 @@ local t_concat = table.concat;
local envloadfile = require"util.envload".envloadfile;
local serialize = require "util.serialization".serialize;
local lfs = require "lfs";
--- Extract directory seperator from package.config (an undocumented string that comes with lua)
+-- Extract directory separator from package.config (an undocumented string that comes with lua)
local path_separator = assert ( package.config:match ( "^([^\n]+)" ) , "package.config not in standard form" )
local prosody = prosody;
@@ -157,7 +157,8 @@ end
local function atomic_store(filename, data)
local scratch = filename.."~";
- local f, ok, msg, errno;
+ local f, ok, msg, errno; -- luacheck: ignore errno
+ -- TODO return util.error with code=errno?
f, msg, errno = io_open(scratch, "w");
if not f then
diff --git a/util/dependencies.lua b/util/dependencies.lua
index 7c7b938e..ede8c6ac 100644
--- a/util/dependencies.lua
+++ b/util/dependencies.lua
@@ -13,7 +13,8 @@ if not softreq "luarocks.loader" then -- LuaRocks 2.x
softreq "luarocks.require"; -- LuaRocks <1.x
end
-local function missingdep(name, sources, msg)
+local function missingdep(name, sources, msg, err) -- luacheck: ignore err
+ -- TODO print something about the underlying error, useful for debugging
print("");
print("**************************");
print("Prosody was unable to find "..tostring(name));
@@ -44,25 +45,25 @@ local function check_dependencies()
local fatal;
- local lxp = softreq "lxp"
+ local lxp, err = softreq "lxp"
if not lxp then
missingdep("luaexpat", {
["Debian/Ubuntu"] = "sudo apt-get install lua-expat";
["luarocks"] = "luarocks install luaexpat";
["Source"] = "http://matthewwild.co.uk/projects/luaexpat/";
- });
+ }, nil, err);
fatal = true;
end
- local socket = softreq "socket"
+ local socket, err = softreq "socket"
if not socket then
missingdep("luasocket", {
["Debian/Ubuntu"] = "sudo apt-get install lua-socket";
["luarocks"] = "luarocks install luasocket";
["Source"] = "http://www.tecgraf.puc-rio.br/~diego/professional/luasocket/";
- });
+ }, nil, err);
fatal = true;
elseif not socket.tcp4 then
-- COMPAT LuaSocket before being IP-version agnostic
@@ -76,28 +77,28 @@ local function check_dependencies()
["luarocks"] = "luarocks install luafilesystem";
["Debian/Ubuntu"] = "sudo apt-get install lua-filesystem";
["Source"] = "http://www.keplerproject.org/luafilesystem/";
- });
+ }, nil, err);
fatal = true;
end
- local ssl = softreq "ssl"
+ local ssl, err = softreq "ssl"
if not ssl then
missingdep("LuaSec", {
["Debian/Ubuntu"] = "sudo apt-get install lua-sec";
["luarocks"] = "luarocks install luasec";
["Source"] = "https://github.com/brunoos/luasec";
- }, "SSL/TLS support will not be available");
+ }, "SSL/TLS support will not be available", err);
end
- local bit = _G.bit32 or softreq"bit";
+ local bit, err = softreq"util.bitcompat";
if not bit then
missingdep("lua-bitops", {
["Debian/Ubuntu"] = "sudo apt-get install lua-bitop";
["luarocks"] = "luarocks install luabitop";
["Source"] = "http://bitop.luajit.org/";
- }, "WebSocket support will not be available");
+ }, "WebSocket support will not be available", err);
end
local encodings, err = softreq "util.encodings"
@@ -140,7 +141,7 @@ local function check_dependencies()
end
local function log_warnings()
- if _VERSION > "Lua 5.2" then
+ if _VERSION > "Lua 5.3" then
prosody.log("warn", "Support for %s is experimental, please report any issues", _VERSION);
end
local ssl = softreq"ssl";
diff --git a/util/error.lua b/util/error.lua
new file mode 100644
index 00000000..ca960dd9
--- /dev/null
+++ b/util/error.lua
@@ -0,0 +1,60 @@
+local error_mt = { __name = "error" };
+
+function error_mt:__tostring()
+ return ("error<%s:%s:%s>"):format(self.type, self.condition, self.text or "");
+end
+
+local function is_err(e)
+ return getmetatable(e) == error_mt;
+end
+
+-- Do we want any more well-known fields?
+-- Or could we just copy all fields from `e`?
+-- Sometimes you want variable details in the `text`, how to handle that?
+-- Translations?
+-- Should the `type` be restricted to the stanza error types or free-form?
+-- What to set `type` to for stream errors or SASL errors? Those don't have a 'type' attr.
+
+local function new(e, context, registry)
+ local template = (registry and registry[e]) or e or {};
+ return setmetatable({
+ type = template.type or "cancel";
+ condition = template.condition or "undefined-condition";
+ text = template.text;
+ code = template.code;
+
+ context = context or template.context or { _error_id = e };
+ }, error_mt);
+end
+
+local function coerce(ok, err, ...)
+ if ok or is_err(err) then
+ return ok, err, ...;
+ end
+
+ local new_err = setmetatable({
+ native = err;
+
+ type = "cancel";
+ condition = "undefined-condition";
+ }, error_mt);
+ return ok, new_err, ...;
+end
+
+local function from_stanza(stanza, context)
+ local error_type, condition, text = stanza:get_error();
+ return setmetatable({
+ type = error_type or "cancel";
+ condition = condition or "undefined-condition";
+ text = text;
+
+ context = context or { stanza = stanza };
+ }, error_mt);
+end
+
+return {
+ new = new;
+ coerce = coerce;
+ is_err = is_err;
+ from_stanza = from_stanza;
+}
diff --git a/util/format.lua b/util/format.lua
index c5e513fa..1ce670f3 100644
--- a/util/format.lua
+++ b/util/format.lua
@@ -3,12 +3,20 @@
--
local tostring = tostring;
-local select = select;
local unpack = table.unpack or unpack; -- luacheck: ignore 113/unpack
+local pack = require "util.table".pack; -- TODO table.pack in 5.2+
local type = type;
+local dump = require "util.serialization".new("debug");
+local num_type = math.type or function (n)
+ return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float";
+end
+
+-- In Lua 5.3+ these formats throw an error if given a float
+local expects_integer = { c = true, d = true, i = true, o = true, u = true, X = true, x = true, };
local function format(formatstring, ...)
- local args, args_length = { ... }, select('#', ...);
+ local args = pack(...);
+ local args_length = args.n;
-- format specifier spec:
-- 1. Start: '%%'
@@ -28,17 +36,22 @@ local function format(formatstring, ...)
if spec ~= "%%" then
i = i + 1;
local arg = args[i];
- if arg == nil then -- special handling for nil
- arg = "<nil>"
- args[i] = "<nil>";
- end
local option = spec:sub(-1);
- if option == "q" or option == "s" then -- arg should be string
+ if arg == nil then
+ args[i] = "nil";
+ spec = "<%s>";
+ elseif option == "q" then
+ args[i] = dump(arg);
+ spec = "%s";
+ elseif option == "s" then
args[i] = tostring(arg);
elseif type(arg) ~= "number" then -- arg isn't number as expected?
args[i] = tostring(arg);
spec = "[%s]";
+ elseif expects_integer[option] and num_type(arg) ~= "integer" then
+ args[i] = tostring(arg);
+ spec = "[%s]";
end
end
return spec;
diff --git a/util/hashring.lua b/util/hashring.lua
new file mode 100644
index 00000000..322bc005
--- /dev/null
+++ b/util/hashring.lua
@@ -0,0 +1,88 @@
+local function generate_ring(nodes, num_replicas, hash)
+ local new_ring = {};
+ for _, node_name in ipairs(nodes) do
+ for replica = 1, num_replicas do
+ local replica_hash = hash(node_name..":"..replica);
+ new_ring[replica_hash] = node_name;
+ table.insert(new_ring, replica_hash);
+ end
+ end
+ table.sort(new_ring);
+ return new_ring;
+end
+
+local hashring_methods = {};
+local hashring_mt = {
+ __index = function (self, k)
+ -- Automatically build self.ring if it's missing
+ if k == "ring" then
+ local ring = generate_ring(self.nodes, self.num_replicas, self.hash);
+ rawset(self, "ring", ring);
+ return ring;
+ end
+ return rawget(hashring_methods, k);
+ end
+};
+
+local function new(num_replicas, hash_function)
+ return setmetatable({ nodes = {}, num_replicas = num_replicas, hash = hash_function }, hashring_mt);
+end;
+
+function hashring_methods:add_node(name)
+ self.ring = nil;
+ self.nodes[name] = true;
+ table.insert(self.nodes, name);
+ return true;
+end
+
+function hashring_methods:add_nodes(nodes)
+ self.ring = nil;
+ for _, node_name in ipairs(nodes) do
+ if not self.nodes[node_name] then
+ self.nodes[node_name] = true;
+ table.insert(self.nodes, node_name);
+ end
+ end
+ return true;
+end
+
+function hashring_methods:remove_node(node_name)
+ self.ring = nil;
+ if self.nodes[node_name] then
+ for i, stored_node_name in ipairs(self.nodes) do
+ if node_name == stored_node_name then
+ self.nodes[node_name] = nil;
+ table.remove(self.nodes, i);
+ return true;
+ end
+ end
+ end
+ return false;
+end
+
+function hashring_methods:remove_nodes(nodes)
+ self.ring = nil;
+ for _, node_name in ipairs(nodes) do
+ self:remove_node(node_name);
+ end
+end
+
+function hashring_methods:clone()
+ local clone_hashring = new(self.num_replicas, self.hash);
+ clone_hashring:add_nodes(self.nodes);
+ return clone_hashring;
+end
+
+function hashring_methods:get_node(key)
+ local key_hash = self.hash(key);
+ for _, replica_hash in ipairs(self.ring) do
+ if key_hash < replica_hash then
+ return self.ring[replica_hash];
+ end
+ end
+ return self.ring[self.ring[1]];
+end
+
+return {
+ new = new;
+}
diff --git a/util/hmac.lua b/util/hmac.lua
index 2c4cc6ef..4cad17cc 100644
--- a/util/hmac.lua
+++ b/util/hmac.lua
@@ -10,6 +10,9 @@
local hashes = require "util.hashes"
-return { md5 = hashes.hmac_md5,
- sha1 = hashes.hmac_sha1,
- sha256 = hashes.hmac_sha256 };
+return {
+ md5 = hashes.hmac_md5,
+ sha1 = hashes.hmac_sha1,
+ sha256 = hashes.hmac_sha256,
+ sha512 = hashes.hmac_sha512,
+};
diff --git a/util/http.lua b/util/http.lua
index cfb89193..3852f91c 100644
--- a/util/http.lua
+++ b/util/http.lua
@@ -6,24 +6,26 @@
--
local format, char = string.format, string.char;
-local pairs, ipairs, tonumber = pairs, ipairs, tonumber;
+local pairs, ipairs = pairs, ipairs;
local t_insert, t_concat = table.insert, table.concat;
+local url_codes = {};
+for i = 0, 255 do
+ local c = char(i);
+ local u = format("%%%02x", i);
+ url_codes[c] = u;
+ url_codes[u] = c;
+ url_codes[u:upper()] = c;
+end
local function urlencode(s)
- return s and (s:gsub("[^a-zA-Z0-9.~_-]", function (c) return format("%%%02x", c:byte()); end));
+ return s and (s:gsub("[^a-zA-Z0-9.~_-]", url_codes));
end
local function urldecode(s)
- return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end));
+ return s and (s:gsub("%%%x%x", url_codes));
end
local function _formencodepart(s)
- return s and (s:gsub("%W", function (c)
- if c ~= " " then
- return format("%%%02x", c:byte());
- else
- return "+";
- end
- end));
+ return s and (urlencode(s):gsub("%%20", "+"));
end
local function formencode(form)
diff --git a/util/import.lua b/util/import.lua
index 8ecfe43c..1007bc0a 100644
--- a/util/import.lua
+++ b/util/import.lua
@@ -8,7 +8,7 @@
-local unpack = table.unpack or unpack; --luacheck: ignore 113 143
+local unpack = table.unpack or unpack; --luacheck: ignore 113
local t_insert = table.insert;
function _G.import(module, ...)
local m = package.loaded[module] or require(module);
diff --git a/util/ip.lua b/util/ip.lua
index d1808225..d039e475 100644
--- a/util/ip.lua
+++ b/util/ip.lua
@@ -19,8 +19,14 @@ local ip_mt = {
return ret;
end,
__tostring = function (ip) return ip.addr; end,
- __eq = function (ipA, ipB) return ipA.packed == ipB.packed; end
};
+ip_mt.__eq = function (ipA, ipB)
+ if getmetatable(ipA) ~= ip_mt or getmetatable(ipB) ~= ip_mt then
+ -- Lua 5.3+ calls this if both operands are tables, even if metatables differ
+ return false;
+ end
+ return ipA.packed == ipB.packed;
+end
local hex2bits = {
["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011",
diff --git a/util/iterators.lua b/util/iterators.lua
index 302cca36..c03c2fd6 100644
--- a/util/iterators.lua
+++ b/util/iterators.lua
@@ -11,9 +11,9 @@
local it = {};
local t_insert = table.insert;
-local select, next = select, next;
-local unpack = table.unpack or unpack; --luacheck: ignore 113 143
-local pack = table.pack or function (...) return { n = select("#", ...), ... }; end -- luacheck: ignore 143
+local next = next;
+local unpack = table.unpack or unpack; --luacheck: ignore 113
+local pack = table.pack or require "util.table".pack;
local type = type;
local table, setmetatable = table, setmetatable;
diff --git a/util/jid.lua b/util/jid.lua
index ec31f180..1ddf33d4 100644
--- a/util/jid.lua
+++ b/util/jid.lua
@@ -45,20 +45,20 @@ local function bare(jid)
return host;
end
-local function prepped_split(jid)
+local function prepped_split(jid, strict)
local node, host, resource = split(jid);
if host and host ~= "." then
if sub(host, -1, -1) == "." then -- Strip empty root label
host = sub(host, 1, -2);
end
- host = nameprep(host);
+ host = nameprep(host, strict);
if not host then return; end
if node then
- node = nodeprep(node);
+ node = nodeprep(node, strict);
if not node then return; end
end
if resource then
- resource = resourceprep(resource);
+ resource = resourceprep(resource, strict);
if not resource then return; end
end
return node, host, resource;
@@ -77,8 +77,8 @@ local function join(node, host, resource)
return host;
end
-local function prep(jid)
- local node, host, resource = prepped_split(jid);
+local function prep(jid, strict)
+ local node, host, resource = prepped_split(jid, strict);
return join(node, host, resource);
end
diff --git a/util/jwt.lua b/util/jwt.lua
new file mode 100644
index 00000000..2b172d38
--- /dev/null
+++ b/util/jwt.lua
@@ -0,0 +1,50 @@
+local s_gsub = string.gsub;
+local json = require "util.json";
+local hashes = require "util.hashes";
+local base64_encode = require "util.encodings".base64.encode;
+local base64_decode = require "util.encodings".base64.decode;
+
+local b64url_rep = { ["+"] = "-", ["/"] = "_", ["="] = "", ["-"] = "+", ["_"] = "/" };
+local function b64url(data)
+ return (s_gsub(base64_encode(data), "[+/=]", b64url_rep));
+end
+local function unb64url(data)
+ return base64_decode(s_gsub(data, "[-_]", b64url_rep).."==");
+end
+
+local static_header = b64url('{"alg":"HS256","typ":"JWT"}') .. '.';
+
+local function sign(key, payload)
+ local encoded_payload = json.encode(payload);
+ local signed = static_header .. b64url(encoded_payload);
+ local signature = hashes.hmac_sha256(key, signed);
+ return signed .. "." .. b64url(signature);
+end
+
+local jwt_pattern = "^(([A-Za-z0-9-_]+)%.([A-Za-z0-9-_]+))%.([A-Za-z0-9-_]+)$"
+local function verify(key, blob)
+ local signed, bheader, bpayload, signature = string.match(blob, jwt_pattern);
+ if not signed then
+ return nil, "invalid-encoding";
+ end
+ local header = json.decode(unb64url(bheader));
+ if not header or type(header) ~= "table" then
+ return nil, "invalid-header";
+ elseif header.alg ~= "HS256" then
+ return nil, "unsupported-algorithm";
+ end
+ if b64url(hashes.hmac_sha256(key, signed)) ~= signature then
+ return false, "signature-mismatch";
+ end
+ local payload, err = json.decode(unb64url(bpayload));
+ if err ~= nil then
+ return nil, "json-decode-error";
+ end
+ return true, payload;
+end
+
+return {
+ sign = sign;
+ verify = verify;
+};
+
diff --git a/util/mercurial.lua b/util/mercurial.lua
index 3f75c4c1..0f2b1d04 100644
--- a/util/mercurial.lua
+++ b/util/mercurial.lua
@@ -19,7 +19,7 @@ function hg.check_id(path)
hg_changelog:close();
end
else
- local hg_archival,e = io.open(path.."/.hg_archival.txt");
+ local hg_archival,e = io.open(path.."/.hg_archival.txt"); -- luacheck: ignore 211/e
if hg_archival then
local repo = hg_archival:read("*l");
local node = hg_archival:read("*l");
diff --git a/util/multitable.lua b/util/multitable.lua
index 8d32ed8a..4f2cd972 100644
--- a/util/multitable.lua
+++ b/util/multitable.lua
@@ -9,7 +9,7 @@
local select = select;
local t_insert = table.insert;
local pairs, next, type = pairs, next, type;
-local unpack = table.unpack or unpack; --luacheck: ignore 113 143
+local unpack = table.unpack or unpack; --luacheck: ignore 113
local _ENV = nil;
-- luacheck: std none
diff --git a/util/paths.lua b/util/paths.lua
index 89f4cad9..036f315a 100644
--- a/util/paths.lua
+++ b/util/paths.lua
@@ -41,4 +41,20 @@ function path_util.join(...)
return t_concat({...}, path_sep);
end
+function path_util.complement_lua_path(installer_plugin_path)
+ -- Checking for duplicates
+ -- The commands using luarocks need the path to the directory that has the /share and /lib folders.
+ local lua_version = _VERSION:match(" (.+)$");
+ local lua_path_sep = package.config:sub(3,3);
+ local dir_sep = package.config:sub(1,1);
+ local sub_path = dir_sep.."lua"..dir_sep..lua_version..dir_sep;
+ if not string.find(package.path, installer_plugin_path, 1, true) then
+ package.path = package.path..lua_path_sep..installer_plugin_path..dir_sep.."share"..sub_path.."?.lua";
+ package.path = package.path..lua_path_sep..installer_plugin_path..dir_sep.."share"..sub_path.."?"..dir_sep.."init.lua";
+ end
+ if not string.find(package.path, installer_plugin_path, 1, true) then
+ package.cpath = package.cpath..lua_path_sep..installer_plugin_path..dir_sep.."lib"..sub_path.."?.so";
+ end
+end
+
return path_util;
diff --git a/util/pluginloader.lua b/util/pluginloader.lua
index 9ab8f245..af0428c4 100644
--- a/util/pluginloader.lua
+++ b/util/pluginloader.lua
@@ -36,12 +36,13 @@ end
local function load_resource(plugin, resource)
resource = resource or "mod_"..plugin..".lua";
-
+ local lua_version = _VERSION:match(" (.+)$");
local names = {
"mod_"..plugin..dir_sep..plugin..dir_sep..resource; -- mod_hello/hello/mod_hello.lua
"mod_"..plugin..dir_sep..resource; -- mod_hello/mod_hello.lua
plugin..dir_sep..resource; -- hello/mod_hello.lua
resource; -- mod_hello.lua
+ "share"..dir_sep.."lua"..dir_sep..lua_version..dir_sep.."mod_"..plugin..dir_sep..resource;
};
return load_file(names);
diff --git a/util/promise.lua b/util/promise.lua
index 07c9c4dc..0b182b54 100644
--- a/util/promise.lua
+++ b/util/promise.lua
@@ -49,6 +49,9 @@ local function promise_settle(promise, new_state, new_next, cbs, value)
for _, cb in ipairs(cbs) do
cb(value);
end
+ -- No need to keep references to callbacks
+ promise._pending_on_fulfilled = nil;
+ promise._pending_on_rejected = nil;
return true;
end
diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua
index 5f0c4d12..ea697ffc 100644
--- a/util/prosodyctl.lua
+++ b/util/prosodyctl.lua
@@ -39,6 +39,16 @@ local function show_usage(usage, desc)
end
end
+local function show_module_configuration_help(mod_name)
+ print("Done.")
+ print("If you installed a prosody plugin, don't forget to add its name under the 'modules_enabled' section inside your configuration file.")
+ print("Depending on the module, there might be further configuration steps required.")
+ print("")
+ print("More info about: ")
+ print(" modules_enabled: https://prosody.im/doc/modules_enabled")
+ print(" "..mod_name..": https://modules.prosody.im/"..mod_name..".html")
+end
+
local function getchar(n)
local stty_ret = os.execute("stty raw -echo 2>/dev/null");
local ok, char;
@@ -124,7 +134,7 @@ end
-- Server control
local function adduser(params)
- local user, host, password = nodeprep(params.user), nameprep(params.host), params.password;
+ local user, host, password = nodeprep(params.user, true), nameprep(params.host), params.password;
if not user then
return false, "invalid-username";
elseif not host then
@@ -200,7 +210,7 @@ local function getpid()
return false, "pidfile-read-failed", err;
end
- local locked, err = lfs.lock(file, "w");
+ local locked, err = lfs.lock(file, "w"); -- luacheck: ignore 211/err
if locked then
file:close();
return false, "pidfile-not-locked";
@@ -217,7 +227,7 @@ local function getpid()
end
local function isrunning()
- local ok, pid, err = getpid();
+ local ok, pid, err = getpid(); -- luacheck: ignore 211/err
if not ok then
if pid == "pidfile-read-failed" or pid == "pidfile-not-locked" then
-- Report as not running, since we can't open the pidfile
@@ -229,7 +239,8 @@ local function isrunning()
return true, signal.kill(pid, 0) == 0;
end
-local function start(source_dir)
+local function start(source_dir, lua)
+ lua = lua and lua .. " " or "";
local ok, ret = isrunning();
if not ok then
return ok, ret;
@@ -238,9 +249,9 @@ local function start(source_dir)
return false, "already-running";
end
if not source_dir then
- os.execute("./prosody");
+ os.execute(lua .. "./prosody -D");
else
- os.execute(source_dir.."/../../bin/prosody");
+ os.execute(lua .. source_dir.."/../../bin/prosody -D");
end
return true;
end
@@ -277,10 +288,36 @@ local function reload()
return true;
end
+local function get_path_custom_plugins(plugin_paths)
+ -- I'm considering that the custom plugins' path is the first one at prosody.paths.plugins
+ -- luacheck: ignore 512
+ for path in plugin_paths:gmatch("[^;]+") do
+ return path;
+ end
+end
+
+local function call_luarocks(mod, operation)
+ local dir = get_path_custom_plugins(prosody.paths.plugins);
+ if operation == "install" then
+ show_message("Installing %s at %s", mod, dir);
+ elseif operation == "remove" then
+ show_message("Removing %s from %s", mod, dir);
+ end
+ if operation == "list" then
+ os.execute("luarocks list --tree='"..dir.."'")
+ else
+ os.execute("luarocks --tree='"..dir.."' --server='http://localhost/' "..operation.." "..mod);
+ end
+ if operation == "install" then
+ show_module_configuration_help(mod);
+ end
+end
+
return {
show_message = show_message;
show_warning = show_message;
show_usage = show_usage;
+ show_module_configuration_help = show_module_configuration_help;
getchar = getchar;
getline = getline;
getpass = getpass;
@@ -296,4 +333,6 @@ return {
start = start;
stop = stop;
reload = reload;
+ get_path_custom_plugins = get_path_custom_plugins;
+ call_luarocks = call_luarocks;
};
diff --git a/util/pubsub.lua b/util/pubsub.lua
index 1674b9a7..cfac7a68 100644
--- a/util/pubsub.lua
+++ b/util/pubsub.lua
@@ -1,5 +1,6 @@
local events = require "util.events";
local cache = require "util.cache";
+local errors = require "util.error";
local service_mt = {};
@@ -280,7 +281,8 @@ function service:set_affiliation(node, actor, jid, affiliation) --> ok, err
node_obj.affiliations[jid] = affiliation;
if self.config.nodestore then
- local ok, err = save_node_to_store(self, node_obj);
+ -- TODO pass the error from storage to caller eg wrapped in an util.error
+ local ok, err = save_node_to_store(self, node_obj); -- luacheck: ignore 211/err
if not ok then
node_obj.affiliations[jid] = old_affiliation;
return ok, "internal-server-error";
@@ -344,7 +346,8 @@ function service:add_subscription(node, actor, jid, options) --> ok, err
end
if self.config.nodestore then
- local ok, err = save_node_to_store(self, node_obj);
+ -- TODO pass the error from storage to caller eg wrapped in an util.error
+ local ok, err = save_node_to_store(self, node_obj); -- luacheck: ignore 211/err
if not ok then
node_obj.subscribers[jid] = old_subscription;
self.subscriptions[normal_jid][jid][node] = old_subscription and true or nil;
@@ -396,7 +399,8 @@ function service:remove_subscription(node, actor, jid) --> ok, err
end
if self.config.nodestore then
- local ok, err = save_node_to_store(self, node_obj);
+ -- TODO pass the error from storage to caller eg wrapped in an util.error
+ local ok, err = save_node_to_store(self, node_obj); -- luacheck: ignore 211/err
if not ok then
node_obj.subscribers[jid] = old_subscription;
self.subscriptions[normal_jid][jid][node] = old_subscription and true or nil;
@@ -454,7 +458,8 @@ function service:create(node, actor, options) --> ok, err
};
if self.config.nodestore then
- local ok, err = save_node_to_store(self, self.nodes[node]);
+ -- TODO pass the error from storage to caller eg wrapped in an util.error
+ local ok, err = save_node_to_store(self, self.nodes[node]); -- luacheck: ignore 211/err
if not ok then
self.nodes[node] = nil;
return ok, "internal-server-error";
@@ -511,7 +516,7 @@ local function check_preconditions(node_config, required_config)
end
for config_field, value in pairs(required_config) do
if node_config[config_field] ~= value then
- return false;
+ return false, config_field;
end
end
return true;
@@ -547,8 +552,13 @@ function service:publish(node, actor, id, item, requested_config) --> ok, err
node_obj = self.nodes[node];
elseif requested_config and not requested_config._defaults_only then
-- Check that node has the requested config before we publish
- if not check_preconditions(node_obj.config, requested_config) then
- return false, "precondition-not-met";
+ local ok, field = check_preconditions(node_obj.config, requested_config);
+ if not ok then
+ local err = errors.new({
+ type = "cancel", condition = "conflict", text = "Field does not match: "..field;
+ });
+ err.pubsub_condition = "precondition-not-met";
+ return false, err;
end
end
if not self.config.itemcheck(item) then
@@ -768,7 +778,8 @@ function service:set_node_config(node, actor, new_config) --> ok, err
node_obj.config = new_config;
if self.config.nodestore then
- local ok, err = save_node_to_store(self, node_obj);
+ -- TODO pass the error from storage to caller eg wrapped in an util.error
+ local ok, err = save_node_to_store(self, node_obj); -- luacheck: ignore 211/err
if not ok then
node_obj.config = old_config;
return ok, "internal-server-error";
diff --git a/util/queue.lua b/util/queue.lua
index 728e905f..66ed098b 100644
--- a/util/queue.lua
+++ b/util/queue.lua
@@ -52,18 +52,20 @@ local function new(size, allow_wrapping)
return t[tail];
end;
items = function (self)
- --luacheck: ignore 431/t
- return function (t, pos)
- if pos >= t:count() then
+ return function (_, pos)
+ if pos >= items then
return nil;
end
local read_pos = tail + pos;
- if read_pos > t.size then
+ if read_pos > self.size then
read_pos = (read_pos%size);
end
- return pos+1, t._items[read_pos];
+ return pos+1, t[read_pos];
end, self, 0;
end;
+ consume = function (self)
+ return self.pop, self;
+ end;
};
end
diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua
index 043f328b..e2ce00f5 100644
--- a/util/sasl/scram.lua
+++ b/util/sasl/scram.lua
@@ -14,9 +14,7 @@
local s_match = string.match;
local type = type
local base64 = require "util.encodings".base64;
-local hmac_sha1 = require "util.hashes".hmac_sha1;
-local sha1 = require "util.hashes".sha1;
-local Hi = require "util.hashes".scram_Hi_sha1;
+local hashes = require "util.hashes";
local generate_uuid = require "util.uuid".generate;
local saslprep = require "util.encodings".stringprep.saslprep;
local nodeprep = require "util.encodings".stringprep.nodeprep;
@@ -99,20 +97,22 @@ local function hashprep(hashname)
return hashname:lower():gsub("-", "_");
end
-local function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
- if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
- return false, "inappropriate argument types"
- end
- if iteration_count < 4096 then
- log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
+local function get_scram_hasher(H, HMAC, Hi)
+ return function (password, salt, iteration_count)
+ if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
+ return false, "inappropriate argument types"
+ end
+ if iteration_count < 4096 then
+ log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
+ end
+ local salted_password = Hi(password, salt, iteration_count);
+ local stored_key = H(HMAC(salted_password, "Client Key"))
+ local server_key = HMAC(salted_password, "Server Key");
+ return true, stored_key, server_key
end
- local salted_password = Hi(password, salt, iteration_count);
- local stored_key = sha1(hmac_sha1(salted_password, "Client Key"))
- local server_key = hmac_sha1(salted_password, "Server Key");
- return true, stored_key, server_key
end
-local function scram_gen(hash_name, H_f, HMAC_f)
+local function scram_gen(hash_name, H_f, HMAC_f, get_auth_db)
local profile_name = "scram_" .. hashprep(hash_name);
local function scram_hash(self, message)
local support_channel_binding = false;
@@ -125,6 +125,7 @@ local function scram_gen(hash_name, H_f, HMAC_f)
local client_first_message = message;
-- TODO: fail if authzid is provided, since we don't support them yet
+ -- luacheck: ignore 211/authzid
local gs2_header, gs2_cbind_flag, gs2_cbind_name, authzid, client_first_message_bare, username, clientnonce
= s_match(client_first_message, "^(([pny])=?([^,]*),([^,]*),)(m?=?[^,]*,?n=([^,]*),r=([^,]*),?.*)$");
@@ -177,7 +178,7 @@ local function scram_gen(hash_name, H_f, HMAC_f)
iteration_count = default_i;
local succ;
- succ, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count);
+ succ, stored_key, server_key = get_auth_db(password, salt, iteration_count);
if not succ then
log("error", "Generating authentication database failed. Reason: %s", stored_key);
return "failure", "temporary-auth-failure";
@@ -190,7 +191,7 @@ local function scram_gen(hash_name, H_f, HMAC_f)
end
local nonce = clientnonce .. generate_uuid();
- local server_first_message = "r="..nonce..",s="..base64.encode(salt)..",i="..iteration_count;
+ local server_first_message = ("r=%s,s=%s,i=%d"):format(nonce, base64.encode(salt), iteration_count);
self.state = {
gs2_header = gs2_header;
gs2_cbind_name = gs2_cbind_name;
@@ -247,22 +248,28 @@ local function scram_gen(hash_name, H_f, HMAC_f)
return scram_hash;
end
+local auth_db_getters = {}
local function init(registerMechanism)
- local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
+ local function registerSCRAMMechanism(hash_name, hash, hmac_hash, pbkdf2)
+ local get_auth_db = get_scram_hasher(hash, hmac_hash, pbkdf2);
+ auth_db_getters[hash_name] = get_auth_db;
registerMechanism("SCRAM-"..hash_name,
{"plain", "scram_"..(hashprep(hash_name))},
- scram_gen(hash_name:lower(), hash, hmac_hash));
+ scram_gen(hash_name:lower(), hash, hmac_hash, get_auth_db));
-- register channel binding equivalent
registerMechanism("SCRAM-"..hash_name.."-PLUS",
{"plain", "scram_"..(hashprep(hash_name))},
- scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
+ scram_gen(hash_name:lower(), hash, hmac_hash, get_auth_db), {"tls-unique"});
end
- registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
+ registerSCRAMMechanism("SHA-1", hashes.sha1, hashes.hmac_sha1, hashes.pbkdf2_hmac_sha1);
+ registerSCRAMMechanism("SHA-256", hashes.sha256, hashes.hmac_sha256, hashes.pbkdf2_hmac_sha256);
end
return {
- getAuthenticationDatabaseSHA1 = getAuthenticationDatabaseSHA1;
+ get_hash = get_scram_hasher;
+ hashers = auth_db_getters;
+ getAuthenticationDatabaseSHA1 = get_scram_hasher(hashes.sha1, hashes.hmac_sha1, hashes.pbkdf2_hmac_sha1); -- COMPAT
init = init;
}
diff --git a/util/serialization.lua b/util/serialization.lua
index 5121a9f9..d70e92ba 100644
--- a/util/serialization.lua
+++ b/util/serialization.lua
@@ -16,22 +16,18 @@ local s_char = string.char;
local s_match = string.match;
local t_concat = table.concat;
+local to_hex = require "util.hex".to;
+
local pcall = pcall;
local envload = require"util.envload".envload;
local pos_inf, neg_inf = math.huge, -math.huge;
--- luacheck: ignore 143/math
local m_type = math.type or function (n)
return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float";
end;
-local char_to_hex = {};
-for i = 0,255 do
- char_to_hex[s_char(i)] = s_format("%02x", i);
-end
-
-local function to_hex(s)
- return (s_gsub(s, ".", char_to_hex));
+local function rawpairs(t)
+ return next, t, nil;
end
local function fatal_error(obj, why)
@@ -123,6 +119,7 @@ local function new(opt)
local freeze = opt.freeze;
local maxdepth = opt.maxdepth or 127;
local multirefs = opt.multiref;
+ local table_pairs = opt.table_iterator or rawpairs;
-- serialize one table, recursively
-- t - table being serialized
@@ -164,7 +161,9 @@ local function new(opt)
local indent = s_rep(indentwith, d);
local numkey = 1;
local ktyp, vtyp;
- for k,v in next,t do
+ local had_items = false;
+ for k,v in table_pairs(t) do
+ had_items = true;
o[l], l = itemstart, l + 1;
o[l], l = indent, l + 1;
ktyp, vtyp = type(k), type(v);
@@ -195,14 +194,10 @@ local function new(opt)
else
o[l], l = ser(v), l + 1;
end
- -- last item?
- if next(t, k) ~= nil then
- o[l], l = itemsep, l + 1;
- else
- o[l], l = itemlast, l + 1;
- end
+ o[l], l = itemsep, l + 1;
end
- if next(t) ~= nil then
+ if had_items then
+ o[l - 1] = itemlast;
o[l], l = s_rep(indentwith, d-1), l + 1;
end
o[l], l = tend, l +1;
diff --git a/util/session.lua b/util/session.lua
index b2a726ce..25b22faf 100644
--- a/util/session.lua
+++ b/util/session.lua
@@ -4,12 +4,13 @@ local logger = require "util.logger";
local function new_session(typ)
local session = {
type = typ .. "_unauthed";
+ base_type = typ;
};
return session;
end
local function set_id(session)
- local id = session.type .. tostring(session):match("%x+$"):lower();
+ local id = session.base_type .. tostring(session):match("%x+$"):lower();
session.id = id;
return session;
end
@@ -30,7 +31,7 @@ local function set_send(session)
local conn = session.conn;
if not conn then
function session.send(data)
- session.log("debug", "Discarding data sent to unconnected session: %s", tostring(data));
+ session.log("debug", "Discarding data sent to unconnected session: %s", data);
return false;
end
return session;
@@ -46,7 +47,7 @@ local function set_send(session)
if t then
local ret, err = w(conn, t);
if not ret then
- session.log("debug", "Error writing to connection: %s", tostring(err));
+ session.log("debug", "Error writing to connection: %s", err);
return false, err;
end
end
diff --git a/util/set.lua b/util/set.lua
index 02fabc6a..827a9158 100644
--- a/util/set.lua
+++ b/util/set.lua
@@ -8,6 +8,7 @@
local ipairs, pairs, setmetatable, next, tostring =
ipairs, pairs, setmetatable, next, tostring;
+local getmetatable = getmetatable;
local t_concat = table.concat;
local _ENV = nil;
@@ -146,6 +147,11 @@ function set_mt.__div(set, func)
return new_set;
end
function set_mt.__eq(set1, set2)
+ if getmetatable(set1) ~= set_mt or getmetatable(set2) ~= set_mt then
+ -- Lua 5.3+ calls this if both operands are tables, even if metatables differ
+ return false;
+ end
+
set1, set2 = set1._items, set2._items;
for item in pairs(set1) do
if not set2[item] then
diff --git a/util/sql.lua b/util/sql.lua
index 00c7b57f..9d1c86ca 100644
--- a/util/sql.lua
+++ b/util/sql.lua
@@ -201,31 +201,31 @@ function engine:_transaction(func, ...)
if not ok then return ok, err; end
end
--assert(not self.__transaction, "Recursive transactions not allowed");
- log("debug", "SQL transaction begin [%s]", tostring(func));
+ log("debug", "SQL transaction begin [%s]", func);
self.__transaction = true;
local success, a, b, c = xpcall(func, handleerr, ...);
self.__transaction = nil;
if success then
- log("debug", "SQL transaction success [%s]", tostring(func));
+ log("debug", "SQL transaction success [%s]", func);
local ok, err = self.conn:commit();
-- LuaDBI doesn't actually return an error message here, just a boolean
if not ok then return ok, err or "commit failed"; end
return success, a, b, c;
else
- log("debug", "SQL transaction failure [%s]: %s", tostring(func), a.err);
+ log("debug", "SQL transaction failure [%s]: %s", func, a.err);
if self.conn then self.conn:rollback(); end
return success, a.err;
end
end
function engine:transaction(...)
- local ok, ret = self:_transaction(...);
+ local ok, ret, b, c = self:_transaction(...);
if not ok then
local conn = self.conn;
if not conn or not conn:ping() then
log("debug", "Database connection was closed. Will reconnect and retry.");
self.conn = nil;
- log("debug", "Retrying SQL transaction [%s]", tostring((...)));
- ok, ret = self:_transaction(...);
+ log("debug", "Retrying SQL transaction [%s]", (...));
+ ok, ret, b, c = self:_transaction(...);
log("debug", "SQL transaction retry %s", ok and "succeeded" or "failed");
else
log("debug", "SQL connection is up, so not retrying");
@@ -234,7 +234,7 @@ function engine:transaction(...)
log("error", "Error in SQL transaction: %s", ret);
end
end
- return ok, ret;
+ return ok, ret, b, c;
end
function engine:_create_index(index)
local sql = "CREATE INDEX \""..index.name.."\" ON \""..index.table.."\" (";
@@ -335,6 +335,9 @@ function engine:set_encoding() -- to UTF-8
local ok, actual_charset = self:transaction(function ()
return self:select"SHOW SESSION VARIABLES LIKE 'character_set_client'";
end);
+ if not ok then
+ return false, "Failed to detect connection encoding";
+ end
local charset_ok = true;
for row in actual_charset do
if row[2] ~= charset then
diff --git a/util/stanza.lua b/util/stanza.lua
index a90d56b3..f5cd5668 100644
--- a/util/stanza.lua
+++ b/util/stanza.lua
@@ -98,7 +98,7 @@ function stanza_mt:query(xmlns)
end
function stanza_mt:body(text, attr)
- return self:tag("body", attr):text(text);
+ return self:text_tag("body", text, attr);
end
function stanza_mt:text_tag(name, text, attr, namespaces)
@@ -270,6 +270,34 @@ function stanza_mt:find(path)
until not self
end
+local function _clone(stanza, only_top)
+ local attr, tags = {}, {};
+ for k,v in pairs(stanza.attr) do attr[k] = v; end
+ local old_namespaces, namespaces = stanza.namespaces;
+ if old_namespaces then
+ namespaces = {};
+ for k,v in pairs(old_namespaces) do namespaces[k] = v; end
+ end
+ local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
+ if not only_top then
+ for i=1,#stanza do
+ local child = stanza[i];
+ if child.name then
+ child = _clone(child);
+ t_insert(tags, child);
+ end
+ t_insert(new, child);
+ end
+ end
+ return setmetatable(new, stanza_mt);
+end
+
+local function clone(stanza, only_top)
+ if not is_stanza(stanza) then
+ error("bad argument to clone: expected stanza, got "..type(stanza));
+ end
+ return _clone(stanza, only_top);
+end
local escape_table = { ["'"] = "&apos;", ["\""] = "&quot;", ["<"] = "&lt;", [">"] = "&gt;", ["&"] = "&amp;" };
local function xml_escape(str) return (s_gsub(str, "['&<>\"]", escape_table)); end
@@ -310,11 +338,8 @@ function stanza_mt.__tostring(t)
end
function stanza_mt.top_tag(t)
- local attr_string = "";
- if t.attr then
- for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(" %s='%s'", k, xml_escape(tostring(v))); end end
- end
- return s_format("<%s%s>", t.name, attr_string);
+ local top_tag_clone = clone(t, true);
+ return tostring(top_tag_clone):sub(1,-3)..">";
end
function stanza_mt.get_text(t)
@@ -388,50 +413,32 @@ local function deserialize(serialized)
end
end
-local function _clone(stanza)
- local attr, tags = {}, {};
- for k,v in pairs(stanza.attr) do attr[k] = v; end
- local old_namespaces, namespaces = stanza.namespaces;
- if old_namespaces then
- namespaces = {};
- for k,v in pairs(old_namespaces) do namespaces[k] = v; end
- end
- local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
- for i=1,#stanza do
- local child = stanza[i];
- if child.name then
- child = _clone(child);
- t_insert(tags, child);
- end
- t_insert(new, child);
- end
- return setmetatable(new, stanza_mt);
-end
-
-local function clone(stanza)
- if not is_stanza(stanza) then
- error("bad argument to clone: expected stanza, got "..type(stanza));
- end
- return _clone(stanza);
-end
-
local function message(attr, body)
if not body then
return new_stanza("message", attr);
else
- return new_stanza("message", attr):tag("body"):text(body):up();
+ return new_stanza("message", attr):text_tag("body", body);
end
end
local function iq(attr)
- if not (attr and attr.id) then
+ if not attr then
+ error("iq stanzas require id and type attributes");
+ end
+ if not attr.id then
error("iq stanzas require an id attribute");
end
+ if not attr.type then
+ error("iq stanzas require a type attribute");
+ end
return new_stanza("iq", attr);
end
local function reply(orig)
+ if not is_stanza(orig) then
+ error("bad argument to reply: expected stanza, got "..type(orig));
+ end
return new_stanza(orig.name,
- orig.attr and {
+ {
to = orig.attr.from,
from = orig.attr.to,
id = orig.attr.id,
@@ -440,12 +447,23 @@ local function reply(orig)
end
local xmpp_stanzas_attr = { xmlns = xmlns_stanzas };
-local function error_reply(orig, error_type, condition, error_message)
+local function error_reply(orig, error_type, condition, error_message, error_by)
+ if not is_stanza(orig) then
+ error("bad argument to error_reply: expected stanza, got "..type(orig));
+ elseif orig.attr.type == "error" then
+ error("bad argument to error_reply: got stanza of type error which must not be replied to");
+ end
local t = reply(orig);
t.attr.type = "error";
- t:tag("error", {type = error_type}) --COMPAT: Some day xmlns:stanzas goes here
+ if t.attr.from == error_by then
+ error_by = nil;
+ end
+ if type(error_type) == "table" then -- an util.error or similar object
+ error_type, condition, error_message = error_type.type, error_type.condition, error_type.text;
+ end
+ t:tag("error", {type = error_type, by = error_by}) --COMPAT: Some day xmlns:stanzas goes here
:tag(condition, xmpp_stanzas_attr):up();
- if error_message then t:tag("text", xmpp_stanzas_attr):text(error_message):up(); end
+ if error_message then t:text_tag("text", error_message, xmpp_stanzas_attr); end
return t; -- stanza ready for adding app-specific errors
end
diff --git a/util/startup.lua b/util/startup.lua
index 24ed6026..d45855f2 100644
--- a/util/startup.lua
+++ b/util/startup.lua
@@ -5,8 +5,10 @@ local startup = {};
local prosody = { events = require "util.events".new() };
local logger = require "util.logger";
local log = logger.init("startup");
+local parse_args = require "util.argparse".parse;
local config = require "core.configmanager";
+local config_warnings;
local dependencies = require "util.dependencies";
@@ -16,55 +18,10 @@ local short_params = { D = "daemonize", F = "no-daemonize" };
local value_params = { config = true };
function startup.parse_args()
- local parsed_opts = {};
- prosody.opts = parsed_opts;
-
- if #arg == 0 then
- return;
- end
- while true do
- local raw_param = arg[1];
- if not raw_param then
- break;
- end
-
- local prefix = raw_param:match("^%-%-?");
- if not prefix then
- break;
- elseif prefix == "--" and raw_param == "--" then
- table.remove(arg, 1);
- break;
- end
- local param = table.remove(arg, 1):sub(#prefix+1);
- if #param == 1 then
- param = short_params[param];
- end
-
- if not param then
- print("Unknown command-line option: "..tostring(param));
- print("Perhaps you meant to use prosodyctl instead?");
- os.exit(1);
- end
-
- local param_k, param_v;
- if value_params[param] then
- param_k, param_v = param, table.remove(arg, 1);
- if not param_v then
- print("Expected a value to follow command-line option: "..raw_param);
- os.exit(1);
- end
- else
- param_k, param_v = param:match("^([^=]+)=(.+)$");
- if not param_k then
- if param:match("^no%-") then
- param_k, param_v = param:sub(4), false;
- else
- param_k, param_v = param, true;
- end
- end
- end
- parsed_opts[param_k] = param_v;
- end
+ prosody.opts = parse_args(arg, {
+ short_params = short_params,
+ value_params = value_params,
+ });
end
function startup.read_config()
@@ -119,6 +76,8 @@ function startup.read_config()
print("**************************");
print("");
os.exit(1);
+ elseif err and #err > 0 then
+ config_warnings = err;
end
prosody.config_loaded = true;
end
@@ -151,8 +110,13 @@ function startup.init_logging()
end);
end
-function startup.log_dependency_warnings()
+function startup.log_startup_warnings()
dependencies.log_warnings();
+ if config_warnings then
+ for _, warning in ipairs(config_warnings) do
+ log("warn", "Configuration warning: %s", warning);
+ end
+ end
end
function startup.sanity_check()
@@ -274,8 +238,8 @@ end
function startup.setup_plugindir()
local custom_plugin_paths = config.get("*", "plugin_paths");
+ local path_sep = package.config:sub(3,3);
if custom_plugin_paths then
- local path_sep = package.config:sub(3,3);
-- path1;path2;path3;defaultpath...
-- luacheck: ignore 111
CFG_PLUGINDIR = table.concat(custom_plugin_paths, path_sep)..path_sep..(CFG_PLUGINDIR or "plugins");
@@ -283,6 +247,19 @@ function startup.setup_plugindir()
end
end
+function startup.setup_plugin_install_path()
+ local installer_plugin_path = config.get("*", "installer_plugin_path") or "custom_plugins";
+ local path_sep = package.config:sub(3,3);
+ -- TODO Figure out what this should be relative to, because CWD could be anywhere
+ installer_plugin_path = config.resolve_relative_path(require "lfs".currentdir(), installer_plugin_path);
+ -- TODO Can probably move directory creation to the install command
+ require "lfs".mkdir(installer_plugin_path);
+ require"util.paths".complement_lua_path(installer_plugin_path);
+ -- luacheck: ignore 111
+ CFG_PLUGINDIR = installer_plugin_path..path_sep..(CFG_PLUGINDIR or "plugins");
+ prosody.paths.plugins = CFG_PLUGINDIR;
+end
+
function startup.chdir()
if prosody.installed then
local lfs = require "lfs";
@@ -304,9 +281,9 @@ function startup.add_global_prosody_functions()
local ok, level, err = config.load(prosody.config_file);
if not ok then
if level == "parser" then
- log("error", "There was an error parsing the configuration file: %s", tostring(err));
+ log("error", "There was an error parsing the configuration file: %s", err);
elseif level == "file" then
- log("error", "Couldn't read the config file when trying to reload: %s", tostring(err));
+ log("error", "Couldn't read the config file when trying to reload: %s", err);
end
else
prosody.events.fire_event("config-reloaded", {
@@ -480,7 +457,7 @@ function startup.switch_user()
print("Warning: Couldn't switch to Prosody user/group '"..tostring(desired_user).."'/'"..tostring(desired_group).."': "..tostring(err));
else
-- Make sure the Prosody user can read the config
- local conf, err, errno = io.open(prosody.config_file);
+ local conf, err, errno = io.open(prosody.config_file); --luacheck: ignore 211/errno
if conf then
conf:close();
else
@@ -568,18 +545,20 @@ end
-- prosodyctl only
function startup.prosodyctl()
+ prosody.process_type = "prosodyctl";
startup.parse_args();
startup.init_global_state();
startup.read_config();
startup.force_console_logging();
startup.init_logging();
startup.setup_plugindir();
+ -- startup.setup_plugin_install_path();
startup.setup_datadir();
startup.chdir();
startup.read_version();
startup.switch_user();
startup.check_dependencies();
- startup.log_dependency_warnings();
+ startup.log_startup_warnings();
startup.check_unwriteable();
startup.load_libraries();
startup.init_http_client();
@@ -589,6 +568,7 @@ end
function startup.prosody()
-- These actions are in a strict order, as many depend on
-- previous steps to have already been performed
+ prosody.process_type = "prosody";
startup.parse_args();
startup.init_global_state();
startup.read_config();
@@ -600,12 +580,13 @@ function startup.prosody()
startup.init_logging();
startup.load_libraries();
startup.setup_plugindir();
+ -- startup.setup_plugin_install_path();
startup.setup_datadir();
startup.chdir();
startup.add_global_prosody_functions();
startup.read_version();
startup.log_greeting();
- startup.log_dependency_warnings();
+ startup.log_startup_warnings();
startup.load_secondary_libraries();
startup.init_http_client();
startup.init_data_store();
diff --git a/util/statistics.lua b/util/statistics.lua
index 39954652..0ec88e21 100644
--- a/util/statistics.lua
+++ b/util/statistics.lua
@@ -57,12 +57,14 @@ local function new_registry(config)
end;
end;
rate = function (name)
- local since, n = time(), 0;
+ local since, n, total = time(), 0, 0;
registry[name..":rate"] = function ()
+ total = total + n;
local t = time();
local stats = {
rate = n/(t-since);
count = n;
+ total = total;
};
since, n = t, 0;
return "rate", stats.rate, stats;
diff --git a/util/termcolours.lua b/util/termcolours.lua
index 829d84af..2c13d0ff 100644
--- a/util/termcolours.lua
+++ b/util/termcolours.lua
@@ -83,7 +83,7 @@ end
setmetatable(stylemap, { __index = function(_, style)
if type(style) == "string" and style:find("%x%x%x%x%x%x") == 1 then
local g = style:sub(7) == " background" and "48;5;" or "38;5;";
- return g .. color(hex2rgb(style));
+ return format("%s%d", g, color(hex2rgb(style)));
end
end } );
diff --git a/util/x509.lua b/util/x509.lua
index 15cc4d3c..342dafde 100644
--- a/util/x509.lua
+++ b/util/x509.lua
@@ -20,9 +20,12 @@
local nameprep = require "util.encodings".stringprep.nameprep;
local idna_to_ascii = require "util.encodings".idna.to_ascii;
+local idna_to_unicode = require "util.encodings".idna.to_unicode;
local base64 = require "util.encodings".base64;
local log = require "util.logger".init("x509");
+local mt = require "util.multitable";
local s_format = string.format;
+local ipairs = ipairs;
local _ENV = nil;
-- luacheck: std none
@@ -216,6 +219,60 @@ local function verify_identity(host, service, cert)
return false
end
+-- TODO Support other SANs
+local function get_identities(cert) --> map of names to sets of services
+ if cert.setencode then
+ cert:setencode("utf8");
+ end
+
+ local names = mt.new();
+
+ local ext = cert:extensions();
+ local sans = ext[oid_subjectaltname];
+ if sans then
+ if sans["dNSName"] then -- Valid for any service
+ for _, name in ipairs(sans["dNSName"]) do
+ name = idna_to_unicode(nameprep(name));
+ if name then
+ names:set(name, "*", true);
+ end
+ end
+ end
+ if sans[oid_xmppaddr] then
+ for _, name in ipairs(sans[oid_xmppaddr]) do
+ name = nameprep(name);
+ if name then
+ names:set(name, "xmpp-client", true);
+ names:set(name, "xmpp-server", true);
+ end
+ end
+ end
+ if sans[oid_dnssrv] then
+ for _, srvname in ipairs(sans[oid_dnssrv]) do
+ local srv, name = srvname:match("^_([^.]+)%.(.*)");
+ if srv then
+ name = nameprep(name);
+ if name then
+ names:set(name, srv, true);
+ end
+ end
+ end
+ end
+ end
+
+ local subject = cert:subject();
+ for i = 1, #subject do
+ local dn = subject[i];
+ if dn.oid == oid_commonname then
+ local name = nameprep(dn.value);
+ if name and idna_to_ascii(name) then
+ names:set(name, "*", true);
+ end
+ end
+ end
+ return names.data;
+end
+
local pat = "%-%-%-%-%-BEGIN ([A-Z ]+)%-%-%-%-%-\r?\n"..
"([0-9A-Za-z+/=\r\n]*)\r?\n%-%-%-%-%-END %1%-%-%-%-%-";
@@ -237,6 +294,7 @@ end
return {
verify_identity = verify_identity;
+ get_identities = get_identities;
pem2der = pem2der;
der2pem = der2pem;
};
diff --git a/util/xmppstream.lua b/util/xmppstream.lua
index 58cbd18e..6aa1def3 100644
--- a/util/xmppstream.lua
+++ b/util/xmppstream.lua
@@ -64,6 +64,8 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
local stream_default_ns = stream_callbacks.default_ns;
+ local stream_lang = "en";
+
local stack = {};
local chardata, stanza = {};
local stanza_size = 0;
@@ -101,6 +103,7 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
if session.notopen then
if tagname == stream_tag then
non_streamns_depth = 0;
+ stream_lang = attr["xml:lang"] or stream_lang;
if cb_streamopened then
if lxp_supports_bytecount then
cb_handleprogress(stanza_size);
@@ -178,6 +181,9 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
cb_handleprogress(stanza_size);
end
stanza_size = 0;
+ if stanza.attr["xml:lang"] == nil then
+ stanza.attr["xml:lang"] = stream_lang;
+ end
if tagname ~= stream_error_tag then
cb_handlestanza(session, stanza);
else