aboutsummaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/adhoc.lua4
-rw-r--r--util/argparse.lua58
-rw-r--r--util/array.lua5
-rw-r--r--util/async.lua16
-rw-r--r--util/bit53.lua7
-rw-r--r--util/bitcompat.lua32
-rw-r--r--util/datamanager.lua5
-rw-r--r--util/dependencies.lua23
-rw-r--r--util/error.lua60
-rw-r--r--util/format.lua27
-rw-r--r--util/hashring.lua88
-rw-r--r--util/hmac.lua9
-rw-r--r--util/http.lua22
-rw-r--r--util/import.lua2
-rw-r--r--util/ip.lua8
-rw-r--r--util/iterators.lua6
-rw-r--r--util/jid.lua12
-rw-r--r--util/jwt.lua50
-rw-r--r--util/mercurial.lua2
-rw-r--r--util/multitable.lua2
-rw-r--r--util/paths.lua16
-rw-r--r--util/pluginloader.lua3
-rw-r--r--util/promise.lua3
-rw-r--r--util/prosodyctl.lua51
-rw-r--r--util/pubsub.lua27
-rw-r--r--util/queue.lua12
-rw-r--r--util/rsm.lua23
-rw-r--r--util/sasl.lua1
-rw-r--r--util/sasl/digest-md5.lua251
-rw-r--r--util/sasl/scram.lua53
-rw-r--r--util/serialization.lua27
-rw-r--r--util/session.lua7
-rw-r--r--util/set.lua6
-rw-r--r--util/sql.lua17
-rw-r--r--util/stanza.lua126
-rw-r--r--util/startup.lua93
-rw-r--r--util/statistics.lua4
-rw-r--r--util/termcolours.lua2
-rw-r--r--util/x509.lua58
-rw-r--r--util/xmppstream.lua6
40 files changed, 754 insertions, 470 deletions
diff --git a/util/adhoc.lua b/util/adhoc.lua
index d81b8242..19c94c34 100644
--- a/util/adhoc.lua
+++ b/util/adhoc.lua
@@ -2,7 +2,7 @@
local function new_simple_form(form, result_handler)
return function(self, data, state)
- if state then
+ if state or data.form then
if data.action == "cancel" then
return { status = "canceled" };
end
@@ -16,7 +16,7 @@ end
local function new_initial_data_form(form, initial_data, result_handler)
return function(self, data, state)
- if state then
+ if state or data.form then
if data.action == "cancel" then
return { status = "canceled" };
end
diff --git a/util/argparse.lua b/util/argparse.lua
new file mode 100644
index 00000000..928fc3eb
--- /dev/null
+++ b/util/argparse.lua
@@ -0,0 +1,58 @@
+local function parse(arg, config)
+ local short_params = config and config.short_params or {};
+ local value_params = config and config.value_params or {};
+
+ local parsed_opts = {};
+
+ if #arg == 0 then
+ return parsed_opts;
+ end
+ while true do
+ local raw_param = arg[1];
+ if not raw_param then
+ break;
+ end
+
+ local prefix = raw_param:match("^%-%-?");
+ if not prefix then
+ break;
+ elseif prefix == "--" and raw_param == "--" then
+ table.remove(arg, 1);
+ break;
+ end
+ local param = table.remove(arg, 1):sub(#prefix+1);
+ if #param == 1 and short_params then
+ param = short_params[param];
+ end
+
+ if not param then
+ print("Unknown command-line option: "..tostring(param));
+ print("Perhaps you meant to use prosodyctl instead?");
+ os.exit(1);
+ end
+
+ local param_k, param_v;
+ if value_params[param] then
+ param_k, param_v = param, table.remove(arg, 1);
+ if not param_v then
+ print("Expected a value to follow command-line option: "..raw_param);
+ os.exit(1);
+ end
+ else
+ param_k, param_v = param:match("^([^=]+)=(.+)$");
+ if not param_k then
+ if param:match("^no%-") then
+ param_k, param_v = param:sub(4), false;
+ else
+ param_k, param_v = param, true;
+ end
+ end
+ end
+ parsed_opts[param_k] = param_v;
+ end
+ return parsed_opts;
+end
+
+return {
+ parse = parse;
+}
diff --git a/util/array.lua b/util/array.lua
index 0b60a4fd..32d2d6a5 100644
--- a/util/array.lua
+++ b/util/array.lua
@@ -10,6 +10,7 @@ local t_insert, t_sort, t_remove, t_concat
= table.insert, table.sort, table.remove, table.concat;
local setmetatable = setmetatable;
+local getmetatable = getmetatable;
local math_random = math.random;
local math_floor = math.floor;
local pairs, ipairs = pairs, ipairs;
@@ -40,6 +41,10 @@ function array_mt.__add(a1, a2)
end
function array_mt.__eq(a, b)
+ if getmetatable(a) ~= array_mt or getmetatable(b) ~= array_mt then
+ -- Lua 5.3+ calls this if both operands are tables, even if metatables differ
+ return false;
+ end
if #a == #b then
for i = 1, #a do
if a[i] ~= b[i] then
diff --git a/util/async.lua b/util/async.lua
index 20397785..d338071f 100644
--- a/util/async.lua
+++ b/util/async.lua
@@ -246,9 +246,25 @@ local function ready()
return pcall(checkthread);
end
+local function wait(promise)
+ local async_wait, async_done = waiter();
+ local ret, err = nil, nil;
+ promise:next(
+ function (r) ret = r; end,
+ function (e) err = e; end)
+ :finally(async_done);
+ async_wait();
+ if ret then
+ return ret;
+ else
+ return nil, err;
+ end
+end
+
return {
ready = ready;
waiter = waiter;
guarder = guarder;
runner = runner;
+ wait = wait;
};
diff --git a/util/bit53.lua b/util/bit53.lua
new file mode 100644
index 00000000..4b5c2e9c
--- /dev/null
+++ b/util/bit53.lua
@@ -0,0 +1,7 @@
+-- Only the operators needed by net.websocket.frames are provided at this point
+return {
+ band = function (a, b) return a & b end;
+ bor = function (a, b) return a | b end;
+ bxor = function (a, b) return a ~ b end;
+};
+
diff --git a/util/bitcompat.lua b/util/bitcompat.lua
new file mode 100644
index 00000000..454181af
--- /dev/null
+++ b/util/bitcompat.lua
@@ -0,0 +1,32 @@
+-- Compatibility layer for bitwise operations
+
+-- First try the bit32 lib
+-- Lua 5.3 has it with compat enabled
+-- Lua 5.2 has it by default
+if _G.bit32 then
+ return _G.bit32;
+else
+ -- Lua 5.1 may have it as a standalone module that can be installed
+ local ok, bitop = pcall(require, "bit32")
+ if ok then
+ return bitop;
+ end
+end
+
+do
+ -- Lua 5.3 and 5.4 would be able to use native infix operators
+ local ok, bitop = pcall(require, "util.bit53")
+ if ok then
+ return bitop;
+ end
+end
+
+do
+ -- Lastly, try the LuaJIT bitop library
+ local ok, bitop = pcall(require, "bit")
+ if ok then
+ return bitop;
+ end
+end
+
+error "No bit module found. See https://prosody.im/doc/depends#bitop";
diff --git a/util/datamanager.lua b/util/datamanager.lua
index 0d7060b7..26dede08 100644
--- a/util/datamanager.lua
+++ b/util/datamanager.lua
@@ -24,7 +24,7 @@ local t_concat = table.concat;
local envloadfile = require"util.envload".envloadfile;
local serialize = require "util.serialization".serialize;
local lfs = require "lfs";
--- Extract directory seperator from package.config (an undocumented string that comes with lua)
+-- Extract directory separator from package.config (an undocumented string that comes with lua)
local path_separator = assert ( package.config:match ( "^([^\n]+)" ) , "package.config not in standard form" )
local prosody = prosody;
@@ -157,7 +157,8 @@ end
local function atomic_store(filename, data)
local scratch = filename.."~";
- local f, ok, msg, errno;
+ local f, ok, msg, errno; -- luacheck: ignore errno
+ -- TODO return util.error with code=errno?
f, msg, errno = io_open(scratch, "w");
if not f then
diff --git a/util/dependencies.lua b/util/dependencies.lua
index 7c7b938e..ede8c6ac 100644
--- a/util/dependencies.lua
+++ b/util/dependencies.lua
@@ -13,7 +13,8 @@ if not softreq "luarocks.loader" then -- LuaRocks 2.x
softreq "luarocks.require"; -- LuaRocks <1.x
end
-local function missingdep(name, sources, msg)
+local function missingdep(name, sources, msg, err) -- luacheck: ignore err
+ -- TODO print something about the underlying error, useful for debugging
print("");
print("**************************");
print("Prosody was unable to find "..tostring(name));
@@ -44,25 +45,25 @@ local function check_dependencies()
local fatal;
- local lxp = softreq "lxp"
+ local lxp, err = softreq "lxp"
if not lxp then
missingdep("luaexpat", {
["Debian/Ubuntu"] = "sudo apt-get install lua-expat";
["luarocks"] = "luarocks install luaexpat";
["Source"] = "http://matthewwild.co.uk/projects/luaexpat/";
- });
+ }, nil, err);
fatal = true;
end
- local socket = softreq "socket"
+ local socket, err = softreq "socket"
if not socket then
missingdep("luasocket", {
["Debian/Ubuntu"] = "sudo apt-get install lua-socket";
["luarocks"] = "luarocks install luasocket";
["Source"] = "http://www.tecgraf.puc-rio.br/~diego/professional/luasocket/";
- });
+ }, nil, err);
fatal = true;
elseif not socket.tcp4 then
-- COMPAT LuaSocket before being IP-version agnostic
@@ -76,28 +77,28 @@ local function check_dependencies()
["luarocks"] = "luarocks install luafilesystem";
["Debian/Ubuntu"] = "sudo apt-get install lua-filesystem";
["Source"] = "http://www.keplerproject.org/luafilesystem/";
- });
+ }, nil, err);
fatal = true;
end
- local ssl = softreq "ssl"
+ local ssl, err = softreq "ssl"
if not ssl then
missingdep("LuaSec", {
["Debian/Ubuntu"] = "sudo apt-get install lua-sec";
["luarocks"] = "luarocks install luasec";
["Source"] = "https://github.com/brunoos/luasec";
- }, "SSL/TLS support will not be available");
+ }, "SSL/TLS support will not be available", err);
end
- local bit = _G.bit32 or softreq"bit";
+ local bit, err = softreq"util.bitcompat";
if not bit then
missingdep("lua-bitops", {
["Debian/Ubuntu"] = "sudo apt-get install lua-bitop";
["luarocks"] = "luarocks install luabitop";
["Source"] = "http://bitop.luajit.org/";
- }, "WebSocket support will not be available");
+ }, "WebSocket support will not be available", err);
end
local encodings, err = softreq "util.encodings"
@@ -140,7 +141,7 @@ local function check_dependencies()
end
local function log_warnings()
- if _VERSION > "Lua 5.2" then
+ if _VERSION > "Lua 5.3" then
prosody.log("warn", "Support for %s is experimental, please report any issues", _VERSION);
end
local ssl = softreq"ssl";
diff --git a/util/error.lua b/util/error.lua
new file mode 100644
index 00000000..ca960dd9
--- /dev/null
+++ b/util/error.lua
@@ -0,0 +1,60 @@
+local error_mt = { __name = "error" };
+
+function error_mt:__tostring()
+ return ("error<%s:%s:%s>"):format(self.type, self.condition, self.text or "");
+end
+
+local function is_err(e)
+ return getmetatable(e) == error_mt;
+end
+
+-- Do we want any more well-known fields?
+-- Or could we just copy all fields from `e`?
+-- Sometimes you want variable details in the `text`, how to handle that?
+-- Translations?
+-- Should the `type` be restricted to the stanza error types or free-form?
+-- What to set `type` to for stream errors or SASL errors? Those don't have a 'type' attr.
+
+local function new(e, context, registry)
+ local template = (registry and registry[e]) or e or {};
+ return setmetatable({
+ type = template.type or "cancel";
+ condition = template.condition or "undefined-condition";
+ text = template.text;
+ code = template.code;
+
+ context = context or template.context or { _error_id = e };
+ }, error_mt);
+end
+
+local function coerce(ok, err, ...)
+ if ok or is_err(err) then
+ return ok, err, ...;
+ end
+
+ local new_err = setmetatable({
+ native = err;
+
+ type = "cancel";
+ condition = "undefined-condition";
+ }, error_mt);
+ return ok, new_err, ...;
+end
+
+local function from_stanza(stanza, context)
+ local error_type, condition, text = stanza:get_error();
+ return setmetatable({
+ type = error_type or "cancel";
+ condition = condition or "undefined-condition";
+ text = text;
+
+ context = context or { stanza = stanza };
+ }, error_mt);
+end
+
+return {
+ new = new;
+ coerce = coerce;
+ is_err = is_err;
+ from_stanza = from_stanza;
+}
diff --git a/util/format.lua b/util/format.lua
index c5e513fa..1ce670f3 100644
--- a/util/format.lua
+++ b/util/format.lua
@@ -3,12 +3,20 @@
--
local tostring = tostring;
-local select = select;
local unpack = table.unpack or unpack; -- luacheck: ignore 113/unpack
+local pack = require "util.table".pack; -- TODO table.pack in 5.2+
local type = type;
+local dump = require "util.serialization".new("debug");
+local num_type = math.type or function (n)
+ return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float";
+end
+
+-- In Lua 5.3+ these formats throw an error if given a float
+local expects_integer = { c = true, d = true, i = true, o = true, u = true, X = true, x = true, };
local function format(formatstring, ...)
- local args, args_length = { ... }, select('#', ...);
+ local args = pack(...);
+ local args_length = args.n;
-- format specifier spec:
-- 1. Start: '%%'
@@ -28,17 +36,22 @@ local function format(formatstring, ...)
if spec ~= "%%" then
i = i + 1;
local arg = args[i];
- if arg == nil then -- special handling for nil
- arg = "<nil>"
- args[i] = "<nil>";
- end
local option = spec:sub(-1);
- if option == "q" or option == "s" then -- arg should be string
+ if arg == nil then
+ args[i] = "nil";
+ spec = "<%s>";
+ elseif option == "q" then
+ args[i] = dump(arg);
+ spec = "%s";
+ elseif option == "s" then
args[i] = tostring(arg);
elseif type(arg) ~= "number" then -- arg isn't number as expected?
args[i] = tostring(arg);
spec = "[%s]";
+ elseif expects_integer[option] and num_type(arg) ~= "integer" then
+ args[i] = tostring(arg);
+ spec = "[%s]";
end
end
return spec;
diff --git a/util/hashring.lua b/util/hashring.lua
new file mode 100644
index 00000000..322bc005
--- /dev/null
+++ b/util/hashring.lua
@@ -0,0 +1,88 @@
+local function generate_ring(nodes, num_replicas, hash)
+ local new_ring = {};
+ for _, node_name in ipairs(nodes) do
+ for replica = 1, num_replicas do
+ local replica_hash = hash(node_name..":"..replica);
+ new_ring[replica_hash] = node_name;
+ table.insert(new_ring, replica_hash);
+ end
+ end
+ table.sort(new_ring);
+ return new_ring;
+end
+
+local hashring_methods = {};
+local hashring_mt = {
+ __index = function (self, k)
+ -- Automatically build self.ring if it's missing
+ if k == "ring" then
+ local ring = generate_ring(self.nodes, self.num_replicas, self.hash);
+ rawset(self, "ring", ring);
+ return ring;
+ end
+ return rawget(hashring_methods, k);
+ end
+};
+
+local function new(num_replicas, hash_function)
+ return setmetatable({ nodes = {}, num_replicas = num_replicas, hash = hash_function }, hashring_mt);
+end;
+
+function hashring_methods:add_node(name)
+ self.ring = nil;
+ self.nodes[name] = true;
+ table.insert(self.nodes, name);
+ return true;
+end
+
+function hashring_methods:add_nodes(nodes)
+ self.ring = nil;
+ for _, node_name in ipairs(nodes) do
+ if not self.nodes[node_name] then
+ self.nodes[node_name] = true;
+ table.insert(self.nodes, node_name);
+ end
+ end
+ return true;
+end
+
+function hashring_methods:remove_node(node_name)
+ self.ring = nil;
+ if self.nodes[node_name] then
+ for i, stored_node_name in ipairs(self.nodes) do
+ if node_name == stored_node_name then
+ self.nodes[node_name] = nil;
+ table.remove(self.nodes, i);
+ return true;
+ end
+ end
+ end
+ return false;
+end
+
+function hashring_methods:remove_nodes(nodes)
+ self.ring = nil;
+ for _, node_name in ipairs(nodes) do
+ self:remove_node(node_name);
+ end
+end
+
+function hashring_methods:clone()
+ local clone_hashring = new(self.num_replicas, self.hash);
+ clone_hashring:add_nodes(self.nodes);
+ return clone_hashring;
+end
+
+function hashring_methods:get_node(key)
+ local key_hash = self.hash(key);
+ for _, replica_hash in ipairs(self.ring) do
+ if key_hash < replica_hash then
+ return self.ring[replica_hash];
+ end
+ end
+ return self.ring[self.ring[1]];
+end
+
+return {
+ new = new;
+}
diff --git a/util/hmac.lua b/util/hmac.lua
index 2c4cc6ef..4cad17cc 100644
--- a/util/hmac.lua
+++ b/util/hmac.lua
@@ -10,6 +10,9 @@
local hashes = require "util.hashes"
-return { md5 = hashes.hmac_md5,
- sha1 = hashes.hmac_sha1,
- sha256 = hashes.hmac_sha256 };
+return {
+ md5 = hashes.hmac_md5,
+ sha1 = hashes.hmac_sha1,
+ sha256 = hashes.hmac_sha256,
+ sha512 = hashes.hmac_sha512,
+};
diff --git a/util/http.lua b/util/http.lua
index cfb89193..3852f91c 100644
--- a/util/http.lua
+++ b/util/http.lua
@@ -6,24 +6,26 @@
--
local format, char = string.format, string.char;
-local pairs, ipairs, tonumber = pairs, ipairs, tonumber;
+local pairs, ipairs = pairs, ipairs;
local t_insert, t_concat = table.insert, table.concat;
+local url_codes = {};
+for i = 0, 255 do
+ local c = char(i);
+ local u = format("%%%02x", i);
+ url_codes[c] = u;
+ url_codes[u] = c;
+ url_codes[u:upper()] = c;
+end
local function urlencode(s)
- return s and (s:gsub("[^a-zA-Z0-9.~_-]", function (c) return format("%%%02x", c:byte()); end));
+ return s and (s:gsub("[^a-zA-Z0-9.~_-]", url_codes));
end
local function urldecode(s)
- return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end));
+ return s and (s:gsub("%%%x%x", url_codes));
end
local function _formencodepart(s)
- return s and (s:gsub("%W", function (c)
- if c ~= " " then
- return format("%%%02x", c:byte());
- else
- return "+";
- end
- end));
+ return s and (urlencode(s):gsub("%%20", "+"));
end
local function formencode(form)
diff --git a/util/import.lua b/util/import.lua
index 8ecfe43c..1007bc0a 100644
--- a/util/import.lua
+++ b/util/import.lua
@@ -8,7 +8,7 @@
-local unpack = table.unpack or unpack; --luacheck: ignore 113 143
+local unpack = table.unpack or unpack; --luacheck: ignore 113
local t_insert = table.insert;
function _G.import(module, ...)
local m = package.loaded[module] or require(module);
diff --git a/util/ip.lua b/util/ip.lua
index d1808225..d039e475 100644
--- a/util/ip.lua
+++ b/util/ip.lua
@@ -19,8 +19,14 @@ local ip_mt = {
return ret;
end,
__tostring = function (ip) return ip.addr; end,
- __eq = function (ipA, ipB) return ipA.packed == ipB.packed; end
};
+ip_mt.__eq = function (ipA, ipB)
+ if getmetatable(ipA) ~= ip_mt or getmetatable(ipB) ~= ip_mt then
+ -- Lua 5.3+ calls this if both operands are tables, even if metatables differ
+ return false;
+ end
+ return ipA.packed == ipB.packed;
+end
local hex2bits = {
["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011",
diff --git a/util/iterators.lua b/util/iterators.lua
index 302cca36..c03c2fd6 100644
--- a/util/iterators.lua
+++ b/util/iterators.lua
@@ -11,9 +11,9 @@
local it = {};
local t_insert = table.insert;
-local select, next = select, next;
-local unpack = table.unpack or unpack; --luacheck: ignore 113 143
-local pack = table.pack or function (...) return { n = select("#", ...), ... }; end -- luacheck: ignore 143
+local next = next;
+local unpack = table.unpack or unpack; --luacheck: ignore 113
+local pack = table.pack or require "util.table".pack;
local type = type;
local table, setmetatable = table, setmetatable;
diff --git a/util/jid.lua b/util/jid.lua
index ec31f180..1ddf33d4 100644
--- a/util/jid.lua
+++ b/util/jid.lua
@@ -45,20 +45,20 @@ local function bare(jid)
return host;
end
-local function prepped_split(jid)
+local function prepped_split(jid, strict)
local node, host, resource = split(jid);
if host and host ~= "." then
if sub(host, -1, -1) == "." then -- Strip empty root label
host = sub(host, 1, -2);
end
- host = nameprep(host);
+ host = nameprep(host, strict);
if not host then return; end
if node then
- node = nodeprep(node);
+ node = nodeprep(node, strict);
if not node then return; end
end
if resource then
- resource = resourceprep(resource);
+ resource = resourceprep(resource, strict);
if not resource then return; end
end
return node, host, resource;
@@ -77,8 +77,8 @@ local function join(node, host, resource)
return host;
end
-local function prep(jid)
- local node, host, resource = prepped_split(jid);
+local function prep(jid, strict)
+ local node, host, resource = prepped_split(jid, strict);
return join(node, host, resource);
end
diff --git a/util/jwt.lua b/util/jwt.lua
new file mode 100644
index 00000000..2b172d38
--- /dev/null
+++ b/util/jwt.lua
@@ -0,0 +1,50 @@
+local s_gsub = string.gsub;
+local json = require "util.json";
+local hashes = require "util.hashes";
+local base64_encode = require "util.encodings".base64.encode;
+local base64_decode = require "util.encodings".base64.decode;
+
+local b64url_rep = { ["+"] = "-", ["/"] = "_", ["="] = "", ["-"] = "+", ["_"] = "/" };
+local function b64url(data)
+ return (s_gsub(base64_encode(data), "[+/=]", b64url_rep));
+end
+local function unb64url(data)
+ return base64_decode(s_gsub(data, "[-_]", b64url_rep).."==");
+end
+
+local static_header = b64url('{"alg":"HS256","typ":"JWT"}') .. '.';
+
+local function sign(key, payload)
+ local encoded_payload = json.encode(payload);
+ local signed = static_header .. b64url(encoded_payload);
+ local signature = hashes.hmac_sha256(key, signed);
+ return signed .. "." .. b64url(signature);
+end
+
+local jwt_pattern = "^(([A-Za-z0-9-_]+)%.([A-Za-z0-9-_]+))%.([A-Za-z0-9-_]+)$"
+local function verify(key, blob)
+ local signed, bheader, bpayload, signature = string.match(blob, jwt_pattern);
+ if not signed then
+ return nil, "invalid-encoding";
+ end
+ local header = json.decode(unb64url(bheader));
+ if not header or type(header) ~= "table" then
+ return nil, "invalid-header";
+ elseif header.alg ~= "HS256" then
+ return nil, "unsupported-algorithm";
+ end
+ if b64url(hashes.hmac_sha256(key, signed)) ~= signature then
+ return false, "signature-mismatch";
+ end
+ local payload, err = json.decode(unb64url(bpayload));
+ if err ~= nil then
+ return nil, "json-decode-error";
+ end
+ return true, payload;
+end
+
+return {
+ sign = sign;
+ verify = verify;
+};
+
diff --git a/util/mercurial.lua b/util/mercurial.lua
index 3f75c4c1..0f2b1d04 100644
--- a/util/mercurial.lua
+++ b/util/mercurial.lua
@@ -19,7 +19,7 @@ function hg.check_id(path)
hg_changelog:close();
end
else
- local hg_archival,e = io.open(path.."/.hg_archival.txt");
+ local hg_archival,e = io.open(path.."/.hg_archival.txt"); -- luacheck: ignore 211/e
if hg_archival then
local repo = hg_archival:read("*l");
local node = hg_archival:read("*l");
diff --git a/util/multitable.lua b/util/multitable.lua
index 8d32ed8a..4f2cd972 100644
--- a/util/multitable.lua
+++ b/util/multitable.lua
@@ -9,7 +9,7 @@
local select = select;
local t_insert = table.insert;
local pairs, next, type = pairs, next, type;
-local unpack = table.unpack or unpack; --luacheck: ignore 113 143
+local unpack = table.unpack or unpack; --luacheck: ignore 113
local _ENV = nil;
-- luacheck: std none
diff --git a/util/paths.lua b/util/paths.lua
index 89f4cad9..036f315a 100644
--- a/util/paths.lua
+++ b/util/paths.lua
@@ -41,4 +41,20 @@ function path_util.join(...)
return t_concat({...}, path_sep);
end
+function path_util.complement_lua_path(installer_plugin_path)
+ -- Checking for duplicates
+ -- The commands using luarocks need the path to the directory that has the /share and /lib folders.
+ local lua_version = _VERSION:match(" (.+)$");
+ local lua_path_sep = package.config:sub(3,3);
+ local dir_sep = package.config:sub(1,1);
+ local sub_path = dir_sep.."lua"..dir_sep..lua_version..dir_sep;
+ if not string.find(package.path, installer_plugin_path, 1, true) then
+ package.path = package.path..lua_path_sep..installer_plugin_path..dir_sep.."share"..sub_path.."?.lua";
+ package.path = package.path..lua_path_sep..installer_plugin_path..dir_sep.."share"..sub_path.."?"..dir_sep.."init.lua";
+ end
+ if not string.find(package.path, installer_plugin_path, 1, true) then
+ package.cpath = package.cpath..lua_path_sep..installer_plugin_path..dir_sep.."lib"..sub_path.."?.so";
+ end
+end
+
return path_util;
diff --git a/util/pluginloader.lua b/util/pluginloader.lua
index 9ab8f245..af0428c4 100644
--- a/util/pluginloader.lua
+++ b/util/pluginloader.lua
@@ -36,12 +36,13 @@ end
local function load_resource(plugin, resource)
resource = resource or "mod_"..plugin..".lua";
-
+ local lua_version = _VERSION:match(" (.+)$");
local names = {
"mod_"..plugin..dir_sep..plugin..dir_sep..resource; -- mod_hello/hello/mod_hello.lua
"mod_"..plugin..dir_sep..resource; -- mod_hello/mod_hello.lua
plugin..dir_sep..resource; -- hello/mod_hello.lua
resource; -- mod_hello.lua
+ "share"..dir_sep.."lua"..dir_sep..lua_version..dir_sep.."mod_"..plugin..dir_sep..resource;
};
return load_file(names);
diff --git a/util/promise.lua b/util/promise.lua
index 07c9c4dc..0b182b54 100644
--- a/util/promise.lua
+++ b/util/promise.lua
@@ -49,6 +49,9 @@ local function promise_settle(promise, new_state, new_next, cbs, value)
for _, cb in ipairs(cbs) do
cb(value);
end
+ -- No need to keep references to callbacks
+ promise._pending_on_fulfilled = nil;
+ promise._pending_on_rejected = nil;
return true;
end
diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua
index 6c84ab6e..ea697ffc 100644
--- a/util/prosodyctl.lua
+++ b/util/prosodyctl.lua
@@ -39,6 +39,16 @@ local function show_usage(usage, desc)
end
end
+local function show_module_configuration_help(mod_name)
+ print("Done.")
+ print("If you installed a prosody plugin, don't forget to add its name under the 'modules_enabled' section inside your configuration file.")
+ print("Depending on the module, there might be further configuration steps required.")
+ print("")
+ print("More info about: ")
+ print(" modules_enabled: https://prosody.im/doc/modules_enabled")
+ print(" "..mod_name..": https://modules.prosody.im/"..mod_name..".html")
+end
+
local function getchar(n)
local stty_ret = os.execute("stty raw -echo 2>/dev/null");
local ok, char;
@@ -124,7 +134,7 @@ end
-- Server control
local function adduser(params)
- local user, host, password = nodeprep(params.user), nameprep(params.host), params.password;
+ local user, host, password = nodeprep(params.user, true), nameprep(params.host), params.password;
if not user then
return false, "invalid-username";
elseif not host then
@@ -200,7 +210,7 @@ local function getpid()
return false, "pidfile-read-failed", err;
end
- local locked, err = lfs.lock(file, "w");
+ local locked, err = lfs.lock(file, "w"); -- luacheck: ignore 211/err
if locked then
file:close();
return false, "pidfile-not-locked";
@@ -217,7 +227,7 @@ local function getpid()
end
local function isrunning()
- local ok, pid, err = getpid();
+ local ok, pid, err = getpid(); -- luacheck: ignore 211/err
if not ok then
if pid == "pidfile-read-failed" or pid == "pidfile-not-locked" then
-- Report as not running, since we can't open the pidfile
@@ -229,7 +239,8 @@ local function isrunning()
return true, signal.kill(pid, 0) == 0;
end
-local function start(source_dir)
+local function start(source_dir, lua)
+ lua = lua and lua .. " " or "";
local ok, ret = isrunning();
if not ok then
return ok, ret;
@@ -238,9 +249,9 @@ local function start(source_dir)
return false, "already-running";
end
if not source_dir then
- os.execute("./prosody -D");
+ os.execute(lua .. "./prosody -D");
else
- os.execute(source_dir.."/../../bin/prosody -D");
+ os.execute(lua .. source_dir.."/../../bin/prosody -D");
end
return true;
end
@@ -277,10 +288,36 @@ local function reload()
return true;
end
+local function get_path_custom_plugins(plugin_paths)
+ -- I'm considering that the custom plugins' path is the first one at prosody.paths.plugins
+ -- luacheck: ignore 512
+ for path in plugin_paths:gmatch("[^;]+") do
+ return path;
+ end
+end
+
+local function call_luarocks(mod, operation)
+ local dir = get_path_custom_plugins(prosody.paths.plugins);
+ if operation == "install" then
+ show_message("Installing %s at %s", mod, dir);
+ elseif operation == "remove" then
+ show_message("Removing %s from %s", mod, dir);
+ end
+ if operation == "list" then
+ os.execute("luarocks list --tree='"..dir.."'")
+ else
+ os.execute("luarocks --tree='"..dir.."' --server='http://localhost/' "..operation.." "..mod);
+ end
+ if operation == "install" then
+ show_module_configuration_help(mod);
+ end
+end
+
return {
show_message = show_message;
show_warning = show_message;
show_usage = show_usage;
+ show_module_configuration_help = show_module_configuration_help;
getchar = getchar;
getline = getline;
getpass = getpass;
@@ -296,4 +333,6 @@ return {
start = start;
stop = stop;
reload = reload;
+ get_path_custom_plugins = get_path_custom_plugins;
+ call_luarocks = call_luarocks;
};
diff --git a/util/pubsub.lua b/util/pubsub.lua
index 1674b9a7..cfac7a68 100644
--- a/util/pubsub.lua
+++ b/util/pubsub.lua
@@ -1,5 +1,6 @@
local events = require "util.events";
local cache = require "util.cache";
+local errors = require "util.error";
local service_mt = {};
@@ -280,7 +281,8 @@ function service:set_affiliation(node, actor, jid, affiliation) --> ok, err
node_obj.affiliations[jid] = affiliation;
if self.config.nodestore then
- local ok, err = save_node_to_store(self, node_obj);
+ -- TODO pass the error from storage to caller eg wrapped in an util.error
+ local ok, err = save_node_to_store(self, node_obj); -- luacheck: ignore 211/err
if not ok then
node_obj.affiliations[jid] = old_affiliation;
return ok, "internal-server-error";
@@ -344,7 +346,8 @@ function service:add_subscription(node, actor, jid, options) --> ok, err
end
if self.config.nodestore then
- local ok, err = save_node_to_store(self, node_obj);
+ -- TODO pass the error from storage to caller eg wrapped in an util.error
+ local ok, err = save_node_to_store(self, node_obj); -- luacheck: ignore 211/err
if not ok then
node_obj.subscribers[jid] = old_subscription;
self.subscriptions[normal_jid][jid][node] = old_subscription and true or nil;
@@ -396,7 +399,8 @@ function service:remove_subscription(node, actor, jid) --> ok, err
end
if self.config.nodestore then
- local ok, err = save_node_to_store(self, node_obj);
+ -- TODO pass the error from storage to caller eg wrapped in an util.error
+ local ok, err = save_node_to_store(self, node_obj); -- luacheck: ignore 211/err
if not ok then
node_obj.subscribers[jid] = old_subscription;
self.subscriptions[normal_jid][jid][node] = old_subscription and true or nil;
@@ -454,7 +458,8 @@ function service:create(node, actor, options) --> ok, err
};
if self.config.nodestore then
- local ok, err = save_node_to_store(self, self.nodes[node]);
+ -- TODO pass the error from storage to caller eg wrapped in an util.error
+ local ok, err = save_node_to_store(self, self.nodes[node]); -- luacheck: ignore 211/err
if not ok then
self.nodes[node] = nil;
return ok, "internal-server-error";
@@ -511,7 +516,7 @@ local function check_preconditions(node_config, required_config)
end
for config_field, value in pairs(required_config) do
if node_config[config_field] ~= value then
- return false;
+ return false, config_field;
end
end
return true;
@@ -547,8 +552,13 @@ function service:publish(node, actor, id, item, requested_config) --> ok, err
node_obj = self.nodes[node];
elseif requested_config and not requested_config._defaults_only then
-- Check that node has the requested config before we publish
- if not check_preconditions(node_obj.config, requested_config) then
- return false, "precondition-not-met";
+ local ok, field = check_preconditions(node_obj.config, requested_config);
+ if not ok then
+ local err = errors.new({
+ type = "cancel", condition = "conflict", text = "Field does not match: "..field;
+ });
+ err.pubsub_condition = "precondition-not-met";
+ return false, err;
end
end
if not self.config.itemcheck(item) then
@@ -768,7 +778,8 @@ function service:set_node_config(node, actor, new_config) --> ok, err
node_obj.config = new_config;
if self.config.nodestore then
- local ok, err = save_node_to_store(self, node_obj);
+ -- TODO pass the error from storage to caller eg wrapped in an util.error
+ local ok, err = save_node_to_store(self, node_obj); -- luacheck: ignore 211/err
if not ok then
node_obj.config = old_config;
return ok, "internal-server-error";
diff --git a/util/queue.lua b/util/queue.lua
index 728e905f..66ed098b 100644
--- a/util/queue.lua
+++ b/util/queue.lua
@@ -52,18 +52,20 @@ local function new(size, allow_wrapping)
return t[tail];
end;
items = function (self)
- --luacheck: ignore 431/t
- return function (t, pos)
- if pos >= t:count() then
+ return function (_, pos)
+ if pos >= items then
return nil;
end
local read_pos = tail + pos;
- if read_pos > t.size then
+ if read_pos > self.size then
read_pos = (read_pos%size);
end
- return pos+1, t._items[read_pos];
+ return pos+1, t[read_pos];
end, self, 0;
end;
+ consume = function (self)
+ return self.pop, self;
+ end;
};
end
diff --git a/util/rsm.lua b/util/rsm.lua
index 40a78fb5..ad725d76 100644
--- a/util/rsm.lua
+++ b/util/rsm.lua
@@ -10,10 +10,15 @@
--
local stanza = require"util.stanza".stanza;
-local tostring, tonumber = tostring, tonumber;
+local tonumber = tonumber;
+local s_format = string.format;
local type = type;
local pairs = pairs;
+local function inttostr(n)
+ return s_format("%d", n);
+end
+
local xmlns_rsm = 'http://jabber.org/protocol/rsm';
local element_parsers = {};
@@ -45,22 +50,28 @@ end
local element_generators = setmetatable({
first = function(st, data)
if type(data) == "table" then
- st:tag("first", { index = data.index }):text(data[1]):up();
+ st:tag("first", { index = inttostr(data.index) }):text(data[1]):up();
else
- st:tag("first"):text(tostring(data)):up();
+ st:tag("first"):text(data):up();
end
end;
before = function(st, data)
if data == true then
st:tag("before"):up();
else
- st:tag("before"):text(tostring(data)):up();
+ st:tag("before"):text(data):up();
end
- end
+ end;
+ max = function (st, data)
+ st:tag("max"):text(inttostr(data)):up();
+ end;
+ count = function (st, data)
+ st:tag("count"):text(inttostr(data)):up();
+ end;
}, {
__index = function(_, name)
return function(st, data)
- st:tag(name):text(tostring(data)):up();
+ st:tag(name):text(data):up();
end
end;
});
diff --git a/util/sasl.lua b/util/sasl.lua
index 50851405..fc2abdf3 100644
--- a/util/sasl.lua
+++ b/util/sasl.lua
@@ -134,7 +134,6 @@ end
-- load the mechanisms
require "util.sasl.plain" .init(registerMechanism);
-require "util.sasl.digest-md5".init(registerMechanism);
require "util.sasl.anonymous" .init(registerMechanism);
require "util.sasl.scram" .init(registerMechanism);
require "util.sasl.external" .init(registerMechanism);
diff --git a/util/sasl/digest-md5.lua b/util/sasl/digest-md5.lua
deleted file mode 100644
index 7542a037..00000000
--- a/util/sasl/digest-md5.lua
+++ /dev/null
@@ -1,251 +0,0 @@
--- sasl.lua v0.4
--- Copyright (C) 2008-2010 Tobias Markmann
---
--- All rights reserved.
---
--- Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
---
--- * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
--- * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
--- * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
---
--- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-local tostring = tostring;
-local type = type;
-
-local s_gmatch = string.gmatch;
-local s_match = string.match;
-local t_concat = table.concat;
-local t_insert = table.insert;
-local to_byte, to_char = string.byte, string.char;
-
-local md5 = require "util.hashes".md5;
-local log = require "util.logger".init("sasl");
-local generate_uuid = require "util.uuid".generate;
-local nodeprep = require "util.encodings".stringprep.nodeprep;
-
-local _ENV = nil;
--- luacheck: std none
-
---=========================
---SASL DIGEST-MD5 according to RFC 2831
-
---[[
-Supported Authentication Backends
-
-digest_md5:
- function(username, domain, realm, encoding) -- domain and realm are usually the same; for some broken
- -- implementations it's not
- return digesthash, state;
- end
-
-digest_md5_test:
- function(username, domain, realm, encoding, digesthash)
- return true or false, state;
- end
-]]
-
-local function digest(self, message)
- --TODO complete support for authzid
-
- local function serialize(message)
- local data = ""
-
- -- testing all possible values
- if message["realm"] then data = data..[[realm="]]..message.realm..[[",]] end
- if message["nonce"] then data = data..[[nonce="]]..message.nonce..[[",]] end
- if message["qop"] then data = data..[[qop="]]..message.qop..[[",]] end
- if message["charset"] then data = data..[[charset=]]..message.charset.."," end
- if message["algorithm"] then data = data..[[algorithm=]]..message.algorithm.."," end
- if message["rspauth"] then data = data..[[rspauth=]]..message.rspauth.."," end
- data = data:gsub(",$", "")
- return data
- end
-
- local function utf8tolatin1ifpossible(passwd)
- local i = 1;
- while i <= #passwd do
- local passwd_i = to_byte(passwd:sub(i, i));
- if passwd_i > 0x7F then
- if passwd_i < 0xC0 or passwd_i > 0xC3 then
- return passwd;
- end
- i = i + 1;
- passwd_i = to_byte(passwd:sub(i, i));
- if passwd_i < 0x80 or passwd_i > 0xBF then
- return passwd;
- end
- end
- i = i + 1;
- end
-
- local p = {};
- local j = 0;
- i = 1;
- while (i <= #passwd) do
- local passwd_i = to_byte(passwd:sub(i, i));
- if passwd_i > 0x7F then
- i = i + 1;
- local passwd_i_1 = to_byte(passwd:sub(i, i));
- t_insert(p, to_char(passwd_i%4*64 + passwd_i_1%64)); -- I'm so clever
- else
- t_insert(p, to_char(passwd_i));
- end
- i = i + 1;
- end
- return t_concat(p);
- end
- local function latin1toutf8(str)
- local p = {};
- for ch in s_gmatch(str, ".") do
- ch = to_byte(ch);
- if (ch < 0x80) then
- t_insert(p, to_char(ch));
- elseif (ch < 0xC0) then
- t_insert(p, to_char(0xC2, ch));
- else
- t_insert(p, to_char(0xC3, ch - 64));
- end
- end
- return t_concat(p);
- end
- local function parse(data)
- local message = {}
- -- COMPAT: %z in the pattern to work around jwchat bug (sends "charset=utf-8\0")
- for k, v in s_gmatch(data, [[([%w%-]+)="?([^",%z]*)"?,?]]) do -- FIXME The hacky regex makes me shudder
- message[k] = v;
- end
- return message;
- end
-
- if not self.nonce then
- self.nonce = generate_uuid();
- self.step = 0;
- self.nonce_count = {};
- end
-
- self.step = self.step + 1;
- if (self.step == 1) then
- local challenge = serialize({ nonce = self.nonce,
- qop = "auth",
- charset = "utf-8",
- algorithm = "md5-sess",
- realm = self.realm});
- return "challenge", challenge;
- elseif (self.step == 2) then
- local response = parse(message);
- -- check for replay attack
- if response["nc"] then
- if self.nonce_count[response["nc"]] then return "failure", "not-authorized" end
- end
-
- -- check for username, it's REQUIRED by RFC 2831
- local username = response["username"];
- local _nodeprep = self.profile.nodeprep;
- if username and _nodeprep ~= false then
- username = (_nodeprep or nodeprep)(username); -- FIXME charset
- end
- if not username or username == "" then
- return "failure", "malformed-request";
- end
- self.username = username;
-
- -- check for nonce, ...
- if not response["nonce"] then
- return "failure", "malformed-request";
- else
- -- check if it's the right nonce
- if response["nonce"] ~= tostring(self.nonce) then return "failure", "malformed-request" end
- end
-
- if not response["cnonce"] then return "failure", "malformed-request", "Missing entry for cnonce in SASL message." end
- if not response["qop"] then response["qop"] = "auth" end
-
- if response["realm"] == nil or response["realm"] == "" then
- response["realm"] = "";
- elseif response["realm"] ~= self.realm then
- return "failure", "not-authorized", "Incorrect realm value";
- end
-
- local decoder;
- if response["charset"] == nil then
- decoder = utf8tolatin1ifpossible;
- elseif response["charset"] ~= "utf-8" then
- return "failure", "incorrect-encoding", "The client's response uses "..response["charset"].." for encoding with isn't supported by sasl.lua. Supported encodings are latin or utf-8.";
- end
-
- local domain = "";
- local protocol = "";
- if response["digest-uri"] then
- protocol, domain = response["digest-uri"]:match("(%w+)/(.*)$");
- if protocol == nil or domain == nil then return "failure", "malformed-request" end
- else
- return "failure", "malformed-request", "Missing entry for digest-uri in SASL message."
- end
-
- --TODO maybe realm support
- local Y, state;
- if self.profile.plain then
- local password, state = self.profile.plain(self, response["username"], self.realm)
- if state == nil then return "failure", "not-authorized"
- elseif state == false then return "failure", "account-disabled" end
- Y = md5(response["username"]..":"..response["realm"]..":"..password);
- elseif self.profile["digest-md5"] then
- Y, state = self.profile["digest-md5"](self, response["username"], self.realm, response["realm"], response["charset"])
- if state == nil then return "failure", "not-authorized"
- elseif state == false then return "failure", "account-disabled" end
- elseif self.profile["digest-md5-test"] then
- -- TODO
- end
- --local password_encoding, Y = self.credentials_handler("DIGEST-MD5", response["username"], self.realm, response["realm"], decoder);
- --if Y == nil then return "failure", "not-authorized"
- --elseif Y == false then return "failure", "account-disabled" end
- local A1 = "";
- if response.authzid then
- if response.authzid == self.username or response.authzid == self.username.."@"..self.realm then
- -- COMPAT
- log("warn", "Client is violating RFC 3920 (section 6.1, point 7).");
- A1 = Y..":"..response["nonce"]..":"..response["cnonce"]..":"..response.authzid;
- else
- return "failure", "invalid-authzid";
- end
- else
- A1 = Y..":"..response["nonce"]..":"..response["cnonce"];
- end
- local A2 = "AUTHENTICATE:"..protocol.."/"..domain;
-
- local HA1 = md5(A1, true);
- local HA2 = md5(A2, true);
-
- local KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2;
- local response_value = md5(KD, true);
-
- if response_value == response["response"] then
- -- calculate rspauth
- A2 = ":"..protocol.."/"..domain;
-
- HA1 = md5(A1, true);
- HA2 = md5(A2, true);
-
- KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
- local rspauth = md5(KD, true);
- self.authenticated = true;
- --TODO: considering sending the rspauth in a success node for saving one roundtrip; allowed according to http://tools.ietf.org/html/draft-saintandre-rfc3920bis-09#section-7.3.6
- return "challenge", serialize({rspauth = rspauth});
- else
- return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated."
- end
- elseif self.step == 3 then
- if self.authenticated ~= nil then return "success"
- else return "failure", "malformed-request" end
- end
-end
-
-local function init(registerMechanism)
- registerMechanism("DIGEST-MD5", {"plain"}, digest);
-end
-
-return {
- init = init;
-}
diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua
index 043f328b..b3370d4b 100644
--- a/util/sasl/scram.lua
+++ b/util/sasl/scram.lua
@@ -14,9 +14,7 @@
local s_match = string.match;
local type = type
local base64 = require "util.encodings".base64;
-local hmac_sha1 = require "util.hashes".hmac_sha1;
-local sha1 = require "util.hashes".sha1;
-local Hi = require "util.hashes".scram_Hi_sha1;
+local hashes = require "util.hashes";
local generate_uuid = require "util.uuid".generate;
local saslprep = require "util.encodings".stringprep.saslprep;
local nodeprep = require "util.encodings".stringprep.nodeprep;
@@ -99,20 +97,22 @@ local function hashprep(hashname)
return hashname:lower():gsub("-", "_");
end
-local function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
- if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
- return false, "inappropriate argument types"
- end
- if iteration_count < 4096 then
- log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
+local function get_scram_hasher(H, HMAC, Hi)
+ return function (password, salt, iteration_count)
+ if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
+ return false, "inappropriate argument types"
+ end
+ if iteration_count < 4096 then
+ log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
+ end
+ local salted_password = Hi(password, salt, iteration_count);
+ local stored_key = H(HMAC(salted_password, "Client Key"))
+ local server_key = HMAC(salted_password, "Server Key");
+ return true, stored_key, server_key
end
- local salted_password = Hi(password, salt, iteration_count);
- local stored_key = sha1(hmac_sha1(salted_password, "Client Key"))
- local server_key = hmac_sha1(salted_password, "Server Key");
- return true, stored_key, server_key
end
-local function scram_gen(hash_name, H_f, HMAC_f)
+local function scram_gen(hash_name, H_f, HMAC_f, get_auth_db, expect_cb)
local profile_name = "scram_" .. hashprep(hash_name);
local function scram_hash(self, message)
local support_channel_binding = false;
@@ -125,6 +125,7 @@ local function scram_gen(hash_name, H_f, HMAC_f)
local client_first_message = message;
-- TODO: fail if authzid is provided, since we don't support them yet
+ -- luacheck: ignore 211/authzid
local gs2_header, gs2_cbind_flag, gs2_cbind_name, authzid, client_first_message_bare, username, clientnonce
= s_match(client_first_message, "^(([pny])=?([^,]*),([^,]*),)(m?=?[^,]*,?n=([^,]*),r=([^,]*),?.*)$");
@@ -140,6 +141,10 @@ local function scram_gen(hash_name, H_f, HMAC_f)
if gs2_cbind_flag == "n" then
-- "n" -> client doesn't support channel binding.
+ if expect_cb then
+ log("debug", "Client unexpectedly doesn't support channel binding");
+ -- XXX Is it sensible to abort if the client starts -PLUS but doesn't use channel binding?
+ end
support_channel_binding = false;
end
@@ -177,7 +182,7 @@ local function scram_gen(hash_name, H_f, HMAC_f)
iteration_count = default_i;
local succ;
- succ, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count);
+ succ, stored_key, server_key = get_auth_db(password, salt, iteration_count);
if not succ then
log("error", "Generating authentication database failed. Reason: %s", stored_key);
return "failure", "temporary-auth-failure";
@@ -190,7 +195,7 @@ local function scram_gen(hash_name, H_f, HMAC_f)
end
local nonce = clientnonce .. generate_uuid();
- local server_first_message = "r="..nonce..",s="..base64.encode(salt)..",i="..iteration_count;
+ local server_first_message = ("r=%s,s=%s,i=%d"):format(nonce, base64.encode(salt), iteration_count);
self.state = {
gs2_header = gs2_header;
gs2_cbind_name = gs2_cbind_name;
@@ -247,22 +252,28 @@ local function scram_gen(hash_name, H_f, HMAC_f)
return scram_hash;
end
+local auth_db_getters = {}
local function init(registerMechanism)
- local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
+ local function registerSCRAMMechanism(hash_name, hash, hmac_hash, pbkdf2)
+ local get_auth_db = get_scram_hasher(hash, hmac_hash, pbkdf2);
+ auth_db_getters[hash_name] = get_auth_db;
registerMechanism("SCRAM-"..hash_name,
{"plain", "scram_"..(hashprep(hash_name))},
- scram_gen(hash_name:lower(), hash, hmac_hash));
+ scram_gen(hash_name:lower(), hash, hmac_hash, get_auth_db));
-- register channel binding equivalent
registerMechanism("SCRAM-"..hash_name.."-PLUS",
{"plain", "scram_"..(hashprep(hash_name))},
- scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
+ scram_gen(hash_name:lower(), hash, hmac_hash, get_auth_db, true), {"tls-unique"});
end
- registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
+ registerSCRAMMechanism("SHA-1", hashes.sha1, hashes.hmac_sha1, hashes.pbkdf2_hmac_sha1);
+ registerSCRAMMechanism("SHA-256", hashes.sha256, hashes.hmac_sha256, hashes.pbkdf2_hmac_sha256);
end
return {
- getAuthenticationDatabaseSHA1 = getAuthenticationDatabaseSHA1;
+ get_hash = get_scram_hasher;
+ hashers = auth_db_getters;
+ getAuthenticationDatabaseSHA1 = get_scram_hasher(hashes.sha1, hashes.hmac_sha1, hashes.pbkdf2_hmac_sha1); -- COMPAT
init = init;
}
diff --git a/util/serialization.lua b/util/serialization.lua
index 5121a9f9..d70e92ba 100644
--- a/util/serialization.lua
+++ b/util/serialization.lua
@@ -16,22 +16,18 @@ local s_char = string.char;
local s_match = string.match;
local t_concat = table.concat;
+local to_hex = require "util.hex".to;
+
local pcall = pcall;
local envload = require"util.envload".envload;
local pos_inf, neg_inf = math.huge, -math.huge;
--- luacheck: ignore 143/math
local m_type = math.type or function (n)
return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float";
end;
-local char_to_hex = {};
-for i = 0,255 do
- char_to_hex[s_char(i)] = s_format("%02x", i);
-end
-
-local function to_hex(s)
- return (s_gsub(s, ".", char_to_hex));
+local function rawpairs(t)
+ return next, t, nil;
end
local function fatal_error(obj, why)
@@ -123,6 +119,7 @@ local function new(opt)
local freeze = opt.freeze;
local maxdepth = opt.maxdepth or 127;
local multirefs = opt.multiref;
+ local table_pairs = opt.table_iterator or rawpairs;
-- serialize one table, recursively
-- t - table being serialized
@@ -164,7 +161,9 @@ local function new(opt)
local indent = s_rep(indentwith, d);
local numkey = 1;
local ktyp, vtyp;
- for k,v in next,t do
+ local had_items = false;
+ for k,v in table_pairs(t) do
+ had_items = true;
o[l], l = itemstart, l + 1;
o[l], l = indent, l + 1;
ktyp, vtyp = type(k), type(v);
@@ -195,14 +194,10 @@ local function new(opt)
else
o[l], l = ser(v), l + 1;
end
- -- last item?
- if next(t, k) ~= nil then
- o[l], l = itemsep, l + 1;
- else
- o[l], l = itemlast, l + 1;
- end
+ o[l], l = itemsep, l + 1;
end
- if next(t) ~= nil then
+ if had_items then
+ o[l - 1] = itemlast;
o[l], l = s_rep(indentwith, d-1), l + 1;
end
o[l], l = tend, l +1;
diff --git a/util/session.lua b/util/session.lua
index b2a726ce..25b22faf 100644
--- a/util/session.lua
+++ b/util/session.lua
@@ -4,12 +4,13 @@ local logger = require "util.logger";
local function new_session(typ)
local session = {
type = typ .. "_unauthed";
+ base_type = typ;
};
return session;
end
local function set_id(session)
- local id = session.type .. tostring(session):match("%x+$"):lower();
+ local id = session.base_type .. tostring(session):match("%x+$"):lower();
session.id = id;
return session;
end
@@ -30,7 +31,7 @@ local function set_send(session)
local conn = session.conn;
if not conn then
function session.send(data)
- session.log("debug", "Discarding data sent to unconnected session: %s", tostring(data));
+ session.log("debug", "Discarding data sent to unconnected session: %s", data);
return false;
end
return session;
@@ -46,7 +47,7 @@ local function set_send(session)
if t then
local ret, err = w(conn, t);
if not ret then
- session.log("debug", "Error writing to connection: %s", tostring(err));
+ session.log("debug", "Error writing to connection: %s", err);
return false, err;
end
end
diff --git a/util/set.lua b/util/set.lua
index 02fabc6a..827a9158 100644
--- a/util/set.lua
+++ b/util/set.lua
@@ -8,6 +8,7 @@
local ipairs, pairs, setmetatable, next, tostring =
ipairs, pairs, setmetatable, next, tostring;
+local getmetatable = getmetatable;
local t_concat = table.concat;
local _ENV = nil;
@@ -146,6 +147,11 @@ function set_mt.__div(set, func)
return new_set;
end
function set_mt.__eq(set1, set2)
+ if getmetatable(set1) ~= set_mt or getmetatable(set2) ~= set_mt then
+ -- Lua 5.3+ calls this if both operands are tables, even if metatables differ
+ return false;
+ end
+
set1, set2 = set1._items, set2._items;
for item in pairs(set1) do
if not set2[item] then
diff --git a/util/sql.lua b/util/sql.lua
index 00c7b57f..9d1c86ca 100644
--- a/util/sql.lua
+++ b/util/sql.lua
@@ -201,31 +201,31 @@ function engine:_transaction(func, ...)
if not ok then return ok, err; end
end
--assert(not self.__transaction, "Recursive transactions not allowed");
- log("debug", "SQL transaction begin [%s]", tostring(func));
+ log("debug", "SQL transaction begin [%s]", func);
self.__transaction = true;
local success, a, b, c = xpcall(func, handleerr, ...);
self.__transaction = nil;
if success then
- log("debug", "SQL transaction success [%s]", tostring(func));
+ log("debug", "SQL transaction success [%s]", func);
local ok, err = self.conn:commit();
-- LuaDBI doesn't actually return an error message here, just a boolean
if not ok then return ok, err or "commit failed"; end
return success, a, b, c;
else
- log("debug", "SQL transaction failure [%s]: %s", tostring(func), a.err);
+ log("debug", "SQL transaction failure [%s]: %s", func, a.err);
if self.conn then self.conn:rollback(); end
return success, a.err;
end
end
function engine:transaction(...)
- local ok, ret = self:_transaction(...);
+ local ok, ret, b, c = self:_transaction(...);
if not ok then
local conn = self.conn;
if not conn or not conn:ping() then
log("debug", "Database connection was closed. Will reconnect and retry.");
self.conn = nil;
- log("debug", "Retrying SQL transaction [%s]", tostring((...)));
- ok, ret = self:_transaction(...);
+ log("debug", "Retrying SQL transaction [%s]", (...));
+ ok, ret, b, c = self:_transaction(...);
log("debug", "SQL transaction retry %s", ok and "succeeded" or "failed");
else
log("debug", "SQL connection is up, so not retrying");
@@ -234,7 +234,7 @@ function engine:transaction(...)
log("error", "Error in SQL transaction: %s", ret);
end
end
- return ok, ret;
+ return ok, ret, b, c;
end
function engine:_create_index(index)
local sql = "CREATE INDEX \""..index.name.."\" ON \""..index.table.."\" (";
@@ -335,6 +335,9 @@ function engine:set_encoding() -- to UTF-8
local ok, actual_charset = self:transaction(function ()
return self:select"SHOW SESSION VARIABLES LIKE 'character_set_client'";
end);
+ if not ok then
+ return false, "Failed to detect connection encoding";
+ end
local charset_ok = true;
for row in actual_charset do
if row[2] ~= charset then
diff --git a/util/stanza.lua b/util/stanza.lua
index a90d56b3..a8a417ab 100644
--- a/util/stanza.lua
+++ b/util/stanza.lua
@@ -98,7 +98,7 @@ function stanza_mt:query(xmlns)
end
function stanza_mt:body(text, attr)
- return self:tag("body", attr):text(text);
+ return self:text_tag("body", text, attr);
end
function stanza_mt:text_tag(name, text, attr, namespaces)
@@ -270,6 +270,34 @@ function stanza_mt:find(path)
until not self
end
+local function _clone(stanza, only_top)
+ local attr, tags = {}, {};
+ for k,v in pairs(stanza.attr) do attr[k] = v; end
+ local old_namespaces, namespaces = stanza.namespaces;
+ if old_namespaces then
+ namespaces = {};
+ for k,v in pairs(old_namespaces) do namespaces[k] = v; end
+ end
+ local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
+ if not only_top then
+ for i=1,#stanza do
+ local child = stanza[i];
+ if child.name then
+ child = _clone(child);
+ t_insert(tags, child);
+ end
+ t_insert(new, child);
+ end
+ end
+ return setmetatable(new, stanza_mt);
+end
+
+local function clone(stanza, only_top)
+ if not is_stanza(stanza) then
+ error("bad argument to clone: expected stanza, got "..type(stanza));
+ end
+ return _clone(stanza, only_top);
+end
local escape_table = { ["'"] = "&apos;", ["\""] = "&quot;", ["<"] = "&lt;", [">"] = "&gt;", ["&"] = "&amp;" };
local function xml_escape(str) return (s_gsub(str, "['&<>\"]", escape_table)); end
@@ -310,11 +338,8 @@ function stanza_mt.__tostring(t)
end
function stanza_mt.top_tag(t)
- local attr_string = "";
- if t.attr then
- for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(" %s='%s'", k, xml_escape(tostring(v))); end end
- end
- return s_format("<%s%s>", t.name, attr_string);
+ local top_tag_clone = clone(t, true);
+ return tostring(top_tag_clone):sub(1,-3)..">";
end
function stanza_mt.get_text(t)
@@ -388,50 +413,32 @@ local function deserialize(serialized)
end
end
-local function _clone(stanza)
- local attr, tags = {}, {};
- for k,v in pairs(stanza.attr) do attr[k] = v; end
- local old_namespaces, namespaces = stanza.namespaces;
- if old_namespaces then
- namespaces = {};
- for k,v in pairs(old_namespaces) do namespaces[k] = v; end
- end
- local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
- for i=1,#stanza do
- local child = stanza[i];
- if child.name then
- child = _clone(child);
- t_insert(tags, child);
- end
- t_insert(new, child);
- end
- return setmetatable(new, stanza_mt);
-end
-
-local function clone(stanza)
- if not is_stanza(stanza) then
- error("bad argument to clone: expected stanza, got "..type(stanza));
- end
- return _clone(stanza);
-end
-
local function message(attr, body)
if not body then
return new_stanza("message", attr);
else
- return new_stanza("message", attr):tag("body"):text(body):up();
+ return new_stanza("message", attr):text_tag("body", body);
end
end
local function iq(attr)
- if not (attr and attr.id) then
+ if not attr then
+ error("iq stanzas require id and type attributes");
+ end
+ if not attr.id then
error("iq stanzas require an id attribute");
end
+ if not attr.type then
+ error("iq stanzas require a type attribute");
+ end
return new_stanza("iq", attr);
end
local function reply(orig)
+ if not is_stanza(orig) then
+ error("bad argument to reply: expected stanza, got "..type(orig));
+ end
return new_stanza(orig.name,
- orig.attr and {
+ {
to = orig.attr.from,
from = orig.attr.to,
id = orig.attr.id,
@@ -440,12 +447,23 @@ local function reply(orig)
end
local xmpp_stanzas_attr = { xmlns = xmlns_stanzas };
-local function error_reply(orig, error_type, condition, error_message)
+local function error_reply(orig, error_type, condition, error_message, error_by)
+ if not is_stanza(orig) then
+ error("bad argument to error_reply: expected stanza, got "..type(orig));
+ elseif orig.attr.type == "error" then
+ error("bad argument to error_reply: got stanza of type error which must not be replied to");
+ end
local t = reply(orig);
t.attr.type = "error";
- t:tag("error", {type = error_type}) --COMPAT: Some day xmlns:stanzas goes here
+ if t.attr.from == error_by then
+ error_by = nil;
+ end
+ if type(error_type) == "table" then -- an util.error or similar object
+ error_type, condition, error_message = error_type.type, error_type.condition, error_type.text;
+ end
+ t:tag("error", {type = error_type, by = error_by}) --COMPAT: Some day xmlns:stanzas goes here
:tag(condition, xmpp_stanzas_attr):up();
- if error_message then t:tag("text", xmpp_stanzas_attr):text(error_message):up(); end
+ if error_message then t:text_tag("text", error_message, xmpp_stanzas_attr); end
return t; -- stanza ready for adding app-specific errors
end
@@ -493,6 +511,36 @@ else
stanza_mt.pretty_top_tag = stanza_mt.top_tag;
end
+function stanza_mt.indent(t, level, indent)
+ if #t == 0 or (#t == 1 and type(t[1]) == "string") then
+ -- Empty nodes wouldn't have any indentation
+ -- Text-only nodes are preserved as to not alter the text content
+ -- Optimization: Skip clone of these since we don't alter them
+ return t;
+ end
+
+ indent = indent or "\t";
+ level = level or 1;
+ local tag = clone(t, true);
+
+ for child in t:children() do
+ if type(child) == "string" then
+ -- Already indented text would look weird but let's ignore that for now.
+ if child:find("%S") then
+ tag:text("\n" .. indent:rep(level));
+ tag:text(child);
+ end
+ elseif is_stanza(child) then
+ tag:text("\n" .. indent:rep(level));
+ tag:add_direct_child(child:indent(level+1, indent));
+ end
+ end
+ -- before the closing tag
+ tag:text("\n" .. indent:rep((level-1)));
+
+ return tag;
+end
+
return {
stanza_mt = stanza_mt;
stanza = new_stanza;
diff --git a/util/startup.lua b/util/startup.lua
index 24ed6026..d45855f2 100644
--- a/util/startup.lua
+++ b/util/startup.lua
@@ -5,8 +5,10 @@ local startup = {};
local prosody = { events = require "util.events".new() };
local logger = require "util.logger";
local log = logger.init("startup");
+local parse_args = require "util.argparse".parse;
local config = require "core.configmanager";
+local config_warnings;
local dependencies = require "util.dependencies";
@@ -16,55 +18,10 @@ local short_params = { D = "daemonize", F = "no-daemonize" };
local value_params = { config = true };
function startup.parse_args()
- local parsed_opts = {};
- prosody.opts = parsed_opts;
-
- if #arg == 0 then
- return;
- end
- while true do
- local raw_param = arg[1];
- if not raw_param then
- break;
- end
-
- local prefix = raw_param:match("^%-%-?");
- if not prefix then
- break;
- elseif prefix == "--" and raw_param == "--" then
- table.remove(arg, 1);
- break;
- end
- local param = table.remove(arg, 1):sub(#prefix+1);
- if #param == 1 then
- param = short_params[param];
- end
-
- if not param then
- print("Unknown command-line option: "..tostring(param));
- print("Perhaps you meant to use prosodyctl instead?");
- os.exit(1);
- end
-
- local param_k, param_v;
- if value_params[param] then
- param_k, param_v = param, table.remove(arg, 1);
- if not param_v then
- print("Expected a value to follow command-line option: "..raw_param);
- os.exit(1);
- end
- else
- param_k, param_v = param:match("^([^=]+)=(.+)$");
- if not param_k then
- if param:match("^no%-") then
- param_k, param_v = param:sub(4), false;
- else
- param_k, param_v = param, true;
- end
- end
- end
- parsed_opts[param_k] = param_v;
- end
+ prosody.opts = parse_args(arg, {
+ short_params = short_params,
+ value_params = value_params,
+ });
end
function startup.read_config()
@@ -119,6 +76,8 @@ function startup.read_config()
print("**************************");
print("");
os.exit(1);
+ elseif err and #err > 0 then
+ config_warnings = err;
end
prosody.config_loaded = true;
end
@@ -151,8 +110,13 @@ function startup.init_logging()
end);
end
-function startup.log_dependency_warnings()
+function startup.log_startup_warnings()
dependencies.log_warnings();
+ if config_warnings then
+ for _, warning in ipairs(config_warnings) do
+ log("warn", "Configuration warning: %s", warning);
+ end
+ end
end
function startup.sanity_check()
@@ -274,8 +238,8 @@ end
function startup.setup_plugindir()
local custom_plugin_paths = config.get("*", "plugin_paths");
+ local path_sep = package.config:sub(3,3);
if custom_plugin_paths then
- local path_sep = package.config:sub(3,3);
-- path1;path2;path3;defaultpath...
-- luacheck: ignore 111
CFG_PLUGINDIR = table.concat(custom_plugin_paths, path_sep)..path_sep..(CFG_PLUGINDIR or "plugins");
@@ -283,6 +247,19 @@ function startup.setup_plugindir()
end
end
+function startup.setup_plugin_install_path()
+ local installer_plugin_path = config.get("*", "installer_plugin_path") or "custom_plugins";
+ local path_sep = package.config:sub(3,3);
+ -- TODO Figure out what this should be relative to, because CWD could be anywhere
+ installer_plugin_path = config.resolve_relative_path(require "lfs".currentdir(), installer_plugin_path);
+ -- TODO Can probably move directory creation to the install command
+ require "lfs".mkdir(installer_plugin_path);
+ require"util.paths".complement_lua_path(installer_plugin_path);
+ -- luacheck: ignore 111
+ CFG_PLUGINDIR = installer_plugin_path..path_sep..(CFG_PLUGINDIR or "plugins");
+ prosody.paths.plugins = CFG_PLUGINDIR;
+end
+
function startup.chdir()
if prosody.installed then
local lfs = require "lfs";
@@ -304,9 +281,9 @@ function startup.add_global_prosody_functions()
local ok, level, err = config.load(prosody.config_file);
if not ok then
if level == "parser" then
- log("error", "There was an error parsing the configuration file: %s", tostring(err));
+ log("error", "There was an error parsing the configuration file: %s", err);
elseif level == "file" then
- log("error", "Couldn't read the config file when trying to reload: %s", tostring(err));
+ log("error", "Couldn't read the config file when trying to reload: %s", err);
end
else
prosody.events.fire_event("config-reloaded", {
@@ -480,7 +457,7 @@ function startup.switch_user()
print("Warning: Couldn't switch to Prosody user/group '"..tostring(desired_user).."'/'"..tostring(desired_group).."': "..tostring(err));
else
-- Make sure the Prosody user can read the config
- local conf, err, errno = io.open(prosody.config_file);
+ local conf, err, errno = io.open(prosody.config_file); --luacheck: ignore 211/errno
if conf then
conf:close();
else
@@ -568,18 +545,20 @@ end
-- prosodyctl only
function startup.prosodyctl()
+ prosody.process_type = "prosodyctl";
startup.parse_args();
startup.init_global_state();
startup.read_config();
startup.force_console_logging();
startup.init_logging();
startup.setup_plugindir();
+ -- startup.setup_plugin_install_path();
startup.setup_datadir();
startup.chdir();
startup.read_version();
startup.switch_user();
startup.check_dependencies();
- startup.log_dependency_warnings();
+ startup.log_startup_warnings();
startup.check_unwriteable();
startup.load_libraries();
startup.init_http_client();
@@ -589,6 +568,7 @@ end
function startup.prosody()
-- These actions are in a strict order, as many depend on
-- previous steps to have already been performed
+ prosody.process_type = "prosody";
startup.parse_args();
startup.init_global_state();
startup.read_config();
@@ -600,12 +580,13 @@ function startup.prosody()
startup.init_logging();
startup.load_libraries();
startup.setup_plugindir();
+ -- startup.setup_plugin_install_path();
startup.setup_datadir();
startup.chdir();
startup.add_global_prosody_functions();
startup.read_version();
startup.log_greeting();
- startup.log_dependency_warnings();
+ startup.log_startup_warnings();
startup.load_secondary_libraries();
startup.init_http_client();
startup.init_data_store();
diff --git a/util/statistics.lua b/util/statistics.lua
index 39954652..0ec88e21 100644
--- a/util/statistics.lua
+++ b/util/statistics.lua
@@ -57,12 +57,14 @@ local function new_registry(config)
end;
end;
rate = function (name)
- local since, n = time(), 0;
+ local since, n, total = time(), 0, 0;
registry[name..":rate"] = function ()
+ total = total + n;
local t = time();
local stats = {
rate = n/(t-since);
count = n;
+ total = total;
};
since, n = t, 0;
return "rate", stats.rate, stats;
diff --git a/util/termcolours.lua b/util/termcolours.lua
index 829d84af..2c13d0ff 100644
--- a/util/termcolours.lua
+++ b/util/termcolours.lua
@@ -83,7 +83,7 @@ end
setmetatable(stylemap, { __index = function(_, style)
if type(style) == "string" and style:find("%x%x%x%x%x%x") == 1 then
local g = style:sub(7) == " background" and "48;5;" or "38;5;";
- return g .. color(hex2rgb(style));
+ return format("%s%d", g, color(hex2rgb(style)));
end
end } );
diff --git a/util/x509.lua b/util/x509.lua
index 15cc4d3c..342dafde 100644
--- a/util/x509.lua
+++ b/util/x509.lua
@@ -20,9 +20,12 @@
local nameprep = require "util.encodings".stringprep.nameprep;
local idna_to_ascii = require "util.encodings".idna.to_ascii;
+local idna_to_unicode = require "util.encodings".idna.to_unicode;
local base64 = require "util.encodings".base64;
local log = require "util.logger".init("x509");
+local mt = require "util.multitable";
local s_format = string.format;
+local ipairs = ipairs;
local _ENV = nil;
-- luacheck: std none
@@ -216,6 +219,60 @@ local function verify_identity(host, service, cert)
return false
end
+-- TODO Support other SANs
+local function get_identities(cert) --> map of names to sets of services
+ if cert.setencode then
+ cert:setencode("utf8");
+ end
+
+ local names = mt.new();
+
+ local ext = cert:extensions();
+ local sans = ext[oid_subjectaltname];
+ if sans then
+ if sans["dNSName"] then -- Valid for any service
+ for _, name in ipairs(sans["dNSName"]) do
+ name = idna_to_unicode(nameprep(name));
+ if name then
+ names:set(name, "*", true);
+ end
+ end
+ end
+ if sans[oid_xmppaddr] then
+ for _, name in ipairs(sans[oid_xmppaddr]) do
+ name = nameprep(name);
+ if name then
+ names:set(name, "xmpp-client", true);
+ names:set(name, "xmpp-server", true);
+ end
+ end
+ end
+ if sans[oid_dnssrv] then
+ for _, srvname in ipairs(sans[oid_dnssrv]) do
+ local srv, name = srvname:match("^_([^.]+)%.(.*)");
+ if srv then
+ name = nameprep(name);
+ if name then
+ names:set(name, srv, true);
+ end
+ end
+ end
+ end
+ end
+
+ local subject = cert:subject();
+ for i = 1, #subject do
+ local dn = subject[i];
+ if dn.oid == oid_commonname then
+ local name = nameprep(dn.value);
+ if name and idna_to_ascii(name) then
+ names:set(name, "*", true);
+ end
+ end
+ end
+ return names.data;
+end
+
local pat = "%-%-%-%-%-BEGIN ([A-Z ]+)%-%-%-%-%-\r?\n"..
"([0-9A-Za-z+/=\r\n]*)\r?\n%-%-%-%-%-END %1%-%-%-%-%-";
@@ -237,6 +294,7 @@ end
return {
verify_identity = verify_identity;
+ get_identities = get_identities;
pem2der = pem2der;
der2pem = der2pem;
};
diff --git a/util/xmppstream.lua b/util/xmppstream.lua
index 58cbd18e..6aa1def3 100644
--- a/util/xmppstream.lua
+++ b/util/xmppstream.lua
@@ -64,6 +64,8 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
local stream_default_ns = stream_callbacks.default_ns;
+ local stream_lang = "en";
+
local stack = {};
local chardata, stanza = {};
local stanza_size = 0;
@@ -101,6 +103,7 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
if session.notopen then
if tagname == stream_tag then
non_streamns_depth = 0;
+ stream_lang = attr["xml:lang"] or stream_lang;
if cb_streamopened then
if lxp_supports_bytecount then
cb_handleprogress(stanza_size);
@@ -178,6 +181,9 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
cb_handleprogress(stanza_size);
end
stanza_size = 0;
+ if stanza.attr["xml:lang"] == nil then
+ stanza.attr["xml:lang"] = stream_lang;
+ end
if tagname ~= stream_error_tag then
cb_handlestanza(session, stanza);
else