aboutsummaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/adminstream.lua2
-rw-r--r--util/array.lua16
-rw-r--r--util/bitcompat.lua14
-rw-r--r--util/datamapper.lua4
-rw-r--r--util/datetime.lua27
-rw-r--r--util/dbuffer.lua18
-rw-r--r--util/dependencies.lua6
-rw-r--r--util/dnsregistry.lua3
-rw-r--r--util/envload.lua39
-rw-r--r--util/format.lua11
-rw-r--r--util/hashring.lua32
-rw-r--r--util/hmac.lua4
-rw-r--r--util/human/io.lua11
-rw-r--r--util/human/units.lua10
-rw-r--r--util/import.lua2
-rw-r--r--util/iterators.lua7
-rw-r--r--util/jid.lua10
-rw-r--r--util/jsonpointer.lua4
-rw-r--r--util/jwt.lua202
-rw-r--r--util/logger.lua16
-rw-r--r--util/mathcompat.lua13
-rw-r--r--util/multitable.lua2
-rw-r--r--util/openmetrics.lua7
-rw-r--r--util/openssl.lua3
-rw-r--r--util/paseto.lua218
-rw-r--r--util/promise.lua7
-rw-r--r--util/prosodyctl.lua3
-rw-r--r--util/prosodyctl/cert.lua2
-rw-r--r--util/prosodyctl/check.lua2
-rw-r--r--util/prosodyctl/shell.lua22
-rw-r--r--util/roles.lua110
-rw-r--r--util/sasl.lua9
-rw-r--r--util/sasl/oauthbearer.lua87
-rw-r--r--util/sasl/scram.lua2
-rw-r--r--util/serialization.lua8
-rw-r--r--util/session.lua6
-rw-r--r--util/sql.lua8
-rw-r--r--util/sqlite3.lua413
-rw-r--r--util/sslconfig.lua73
-rw-r--r--util/stanza.lua96
-rw-r--r--util/startup.lua5
-rw-r--r--util/vcard.lua574
-rw-r--r--util/watchdog.lua39
-rw-r--r--util/x509.lua12
44 files changed, 1352 insertions, 807 deletions
diff --git a/util/adminstream.lua b/util/adminstream.lua
index 4075aa05..ba8ce0a0 100644
--- a/util/adminstream.lua
+++ b/util/adminstream.lua
@@ -145,7 +145,7 @@ local function new_connection(socket_path, listeners)
-- constructor was exported instead of a module table. Due to the lack of a
-- proper release of LuaSocket, distros have settled on shipping either the
-- last RC tag or some commit since then.
- -- Here we accomodate both variants.
+ -- Here we accommodate both variants.
unix = { stream = unix };
end
if type(unix) ~= "table" then
diff --git a/util/array.lua b/util/array.lua
index c33a5ef1..9d438940 100644
--- a/util/array.lua
+++ b/util/array.lua
@@ -8,6 +8,7 @@
local t_insert, t_sort, t_remove, t_concat
= table.insert, table.sort, table.remove, table.concat;
+local t_move = require "util.table".move;
local setmetatable = setmetatable;
local getmetatable = getmetatable;
@@ -137,13 +138,11 @@ function array_base.slice(outa, ina, i, j)
return outa;
end
- for idx = 1, 1+j-i do
- outa[idx] = ina[i+(idx-1)];
- end
+
+ t_move(ina, i, j, 1, outa);
if ina == outa then
- for idx = 2+j-i, #outa do
- outa[idx] = nil;
- end
+ -- Clear (nil) remainder of range
+ t_move(ina, #outa+1, #outa*2, 2+j-i, ina);
end
return outa;
end
@@ -209,10 +208,7 @@ function array_methods:shuffle()
end
function array_methods:append(ina)
- local len, len2 = #self, #ina;
- for i = 1, len2 do
- self[len+i] = ina[i];
- end
+ t_move(ina, 1, #ina, #self+1, self);
return self;
end
diff --git a/util/bitcompat.lua b/util/bitcompat.lua
index 454181af..8f227354 100644
--- a/util/bitcompat.lua
+++ b/util/bitcompat.lua
@@ -5,12 +5,6 @@
-- Lua 5.2 has it by default
if _G.bit32 then
return _G.bit32;
-else
- -- Lua 5.1 may have it as a standalone module that can be installed
- local ok, bitop = pcall(require, "bit32")
- if ok then
- return bitop;
- end
end
do
@@ -21,12 +15,4 @@ do
end
end
-do
- -- Lastly, try the LuaJIT bitop library
- local ok, bitop = pcall(require, "bit")
- if ok then
- return bitop;
- end
-end
-
error "No bit module found. See https://prosody.im/doc/depends#bitop";
diff --git a/util/datamapper.lua b/util/datamapper.lua
index 2378314c..e1484525 100644
--- a/util/datamapper.lua
+++ b/util/datamapper.lua
@@ -1,5 +1,9 @@
-- This file is generated from teal-src/util/datamapper.lua
+if not math.type then
+ require("util.mathcompat")
+end
+
local st = require("util.stanza");
local pointer = require("util.jsonpointer");
diff --git a/util/datetime.lua b/util/datetime.lua
index 2d27ece4..6df146f4 100644
--- a/util/datetime.lua
+++ b/util/datetime.lua
@@ -12,31 +12,41 @@
local os_date = os.date;
local os_time = os.time;
local os_difftime = os.difftime;
+local floor = math.floor;
local tonumber = tonumber;
local _ENV = nil;
-- luacheck: std none
local function date(t)
- return os_date("!%Y-%m-%d", t);
+ return os_date("!%Y-%m-%d", t and floor(t) or nil);
end
local function datetime(t)
- return os_date("!%Y-%m-%dT%H:%M:%SZ", t);
+ if t == nil or t % 1 == 0 then
+ return os_date("!%Y-%m-%dT%H:%M:%SZ", t);
+ end
+ local m = t % 1;
+ local s = floor(t);
+ return os_date("!%Y-%m-%dT%H:%M:%S.%%06dZ", s):format(floor(m * 1000000));
end
local function time(t)
- return os_date("!%H:%M:%S", t);
+ if t == nil or t % 1 == 0 then
+ return os_date("!%H:%M:%S", t);
+ end
+ local m = t % 1;
+ local s = floor(t);
+ return os_date("!%H:%M:%S.%%06d", s):format(floor(m * 1000000));
end
local function legacy(t)
- return os_date("!%Y%m%dT%H:%M:%S", t);
+ return os_date("!%Y%m%dT%H:%M:%S", t and floor(t) or nil);
end
local function parse(s)
if s then
- local year, month, day, hour, min, sec, tzd;
- year, month, day, hour, min, sec, tzd = s:match("^(%d%d%d%d)%-?(%d%d)%-?(%d%d)T(%d%d):(%d%d):(%d%d)%.?%d*([Z+%-]?.*)$");
+ local year, month, day, hour, min, sec, tzd = s:match("^(%d%d%d%d)%-?(%d%d)%-?(%d%d)T(%d%d):(%d%d):(%d%d%.?%d*)([Z+%-]?.*)$");
if year then
local now = os_time();
local time_offset = os_difftime(os_time(os_date("*t", now)), os_time(os_date("!*t", now))); -- to deal with local timezone
@@ -49,8 +59,9 @@ local function parse(s)
tzd_offset = h * 60 * 60 + m * 60;
if sign == "-" then tzd_offset = -tzd_offset; end
end
- sec = (sec + time_offset) - tzd_offset;
- return os_time({year=year, month=month, day=day, hour=hour, min=min, sec=sec, isdst=false});
+ local prec = sec%1;
+ sec = floor(sec + time_offset) - tzd_offset;
+ return os_time({year=year, month=month, day=day, hour=hour, min=min, sec=sec, isdst=false})+prec;
end
end
end
diff --git a/util/dbuffer.lua b/util/dbuffer.lua
index 3ad5fdfe..0a36288d 100644
--- a/util/dbuffer.lua
+++ b/util/dbuffer.lua
@@ -91,18 +91,18 @@ function dbuffer_methods:read_until(char)
end
function dbuffer_methods:discard(requested_bytes)
- if requested_bytes > self._length then
- return nil;
+ if self._length == 0 then return true; end
+ if not requested_bytes or requested_bytes >= self._length then
+ self.front_consumed = 0;
+ self._length = 0;
+ for _ in self.items:consume() do end
+ return true;
end
local chunk, read_bytes = self:read_chunk(requested_bytes);
- if chunk then
- requested_bytes = requested_bytes - read_bytes;
- if requested_bytes == 0 then -- Already read everything we need
- return true;
- end
- else
- return nil;
+ requested_bytes = requested_bytes - read_bytes;
+ if requested_bytes == 0 then -- Already read everything we need
+ return true;
end
while chunk do
diff --git a/util/dependencies.lua b/util/dependencies.lua
index d7836404..165468c5 100644
--- a/util/dependencies.lua
+++ b/util/dependencies.lua
@@ -32,10 +32,10 @@ local function missingdep(name, sources, msg, err) -- luacheck: ignore err
end
local function check_dependencies()
- if _VERSION < "Lua 5.1" then
+ if _VERSION < "Lua 5.2" then
print "***********************************"
print("Unsupported Lua version: ".._VERSION);
- print("At least Lua 5.1 is required.");
+ print("At least Lua 5.2 is required.");
print "***********************************"
return false;
end
@@ -155,7 +155,7 @@ local function log_warnings()
if _VERSION > "Lua 5.4" then
prosody.log("warn", "Support for %s is experimental, please report any issues", _VERSION);
elseif _VERSION < "Lua 5.2" then
- prosody.log("warn", "%s has several issues and support is being phased out, consider upgrading", _VERSION);
+ prosody.log("warn", "%s support is deprecated, upgrade as soon as possible", _VERSION);
end
local ssl = softreq"ssl";
if ssl then
diff --git a/util/dnsregistry.lua b/util/dnsregistry.lua
index 635b7e3a..c52abee9 100644
--- a/util/dnsregistry.lua
+++ b/util/dnsregistry.lua
@@ -1,5 +1,5 @@
-- Source: https://www.iana.org/assignments/dns-parameters/dns-parameters.xml
--- Generated on 2022-02-02
+-- Generated on 2023-01-20
return {
classes = {
["IN"] = 1; [1] = "IN";
@@ -61,7 +61,6 @@ return {
["NSEC3PARAM"] = 51; [51] = "NSEC3PARAM";
["TLSA"] = 52; [52] = "TLSA";
["SMIMEA"] = 53; [53] = "SMIMEA";
- ["Unassigned"] = 54; [54] = "Unassigned";
["HIP"] = 55; [55] = "HIP";
["NINFO"] = 56; [56] = "NINFO";
["RKEY"] = 57; [57] = "RKEY";
diff --git a/util/envload.lua b/util/envload.lua
index 6182a1f9..cf45b702 100644
--- a/util/envload.lua
+++ b/util/envload.lua
@@ -6,38 +6,19 @@
--
-- luacheck: ignore 113/setfenv 113/loadstring
-local load, loadstring, setfenv = load, loadstring, setfenv;
+local load = load;
local io_open = io.open;
-local envload;
-local envloadfile;
-if setfenv then
- function envload(code, source, env)
- local f, err = loadstring(code, source);
- if f and env then setfenv(f, env); end
- return f, err;
- end
-
- function envloadfile(file, env)
- local fh, err, errno = io_open(file);
- if not fh then return fh, err, errno; end
- local f, err = load(function () return fh:read(2048); end, "@"..file);
- fh:close();
- if f and env then setfenv(f, env); end
- return f, err;
- end
-else
- function envload(code, source, env)
- return load(code, source, nil, env);
- end
+local function envload(code, source, env)
+ return load(code, source, nil, env);
+end
- function envloadfile(file, env)
- local fh, err, errno = io_open(file);
- if not fh then return fh, err, errno; end
- local f, err = load(fh:lines(2048), "@"..file, nil, env);
- fh:close();
- return f, err;
- end
+local function envloadfile(file, env)
+ local fh, err, errno = io_open(file);
+ if not fh then return fh, err, errno; end
+ local f, err = load(fh:lines(2048), "@" .. file, nil, env);
+ fh:close();
+ return f, err;
end
return { envload = envload, envloadfile = envloadfile };
diff --git a/util/format.lua b/util/format.lua
index d709aada..0631f423 100644
--- a/util/format.lua
+++ b/util/format.lua
@@ -6,14 +6,12 @@
-- Provides some protection from e.g. CAPEC-135, CWE-117, CWE-134, CWE-93
local tostring = tostring;
-local unpack = table.unpack or unpack; -- luacheck: ignore 113/unpack
-local pack = require "util.table".pack; -- TODO table.pack in 5.2+
+local unpack = table.unpack;
+local pack = table.pack;
local valid_utf8 = require "util.encodings".utf8.valid;
local type = type;
local dump = require "util.serialization".new("debug");
-local num_type = math.type or function (n)
- return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float";
-end
+local num_type = math.type;
-- In Lua 5.3+ these formats throw an error if given a float
local expects_integer = { c = true, d = true, i = true, o = true, u = true, X = true, x = true, };
@@ -35,7 +33,6 @@ local control_symbols = {
["\030"] = "\226\144\158", ["\031"] = "\226\144\159", ["\127"] = "\226\144\161",
};
local supports_p = pcall(string.format, "%p", ""); -- >= Lua 5.4
-local supports_a = pcall(string.format, "%a", 0.0); -- > Lua 5.1
local function format(formatstring, ...)
local args = pack(...);
@@ -93,8 +90,6 @@ local function format(formatstring, ...)
elseif expects_positive[option] and arg < 0 then
args[i] = tostring(arg);
return "[%s]";
- elseif (option == "a" or option == "A") and not supports_a then
- return "%x";
else
return -- acceptable number
end
diff --git a/util/hashring.lua b/util/hashring.lua
index d4555669..5e71654b 100644
--- a/util/hashring.lua
+++ b/util/hashring.lua
@@ -1,3 +1,5 @@
+local it = require "util.iterators";
+
local function generate_ring(nodes, num_replicas, hash)
local new_ring = {};
for _, node_name in ipairs(nodes) do
@@ -28,18 +30,22 @@ local function new(num_replicas, hash_function)
return setmetatable({ nodes = {}, num_replicas = num_replicas, hash = hash_function }, hashring_mt);
end;
-function hashring_methods:add_node(name)
+function hashring_methods:add_node(name, value)
self.ring = nil;
- self.nodes[name] = true;
+ self.nodes[name] = value == nil and true or value;
table.insert(self.nodes, name);
return true;
end
function hashring_methods:add_nodes(nodes)
self.ring = nil;
- for _, node_name in ipairs(nodes) do
- if not self.nodes[node_name] then
- self.nodes[node_name] = true;
+ local iter = pairs;
+ if nodes[1] then -- simple array?
+ iter = it.values;
+ end
+ for node_name, node_value in iter(nodes) do
+ if self.nodes[node_name] == nil then
+ self.nodes[node_name] = node_value == nil and true or node_value;
table.insert(self.nodes, node_name);
end
end
@@ -48,7 +54,7 @@ end
function hashring_methods:remove_node(node_name)
self.ring = nil;
- if self.nodes[node_name] then
+ if self.nodes[node_name] ~= nil then
for i, stored_node_name in ipairs(self.nodes) do
if node_name == stored_node_name then
self.nodes[node_name] = nil;
@@ -69,18 +75,26 @@ end
function hashring_methods:clone()
local clone_hashring = new(self.num_replicas, self.hash);
- clone_hashring:add_nodes(self.nodes);
+ for node_name, node_value in pairs(self.nodes) do
+ clone_hashring.nodes[node_name] = node_value;
+ end
+ clone_hashring.ring = nil;
return clone_hashring;
end
function hashring_methods:get_node(key)
+ local node;
local key_hash = self.hash(key);
for _, replica_hash in ipairs(self.ring) do
if key_hash < replica_hash then
- return self.ring[replica_hash];
+ node = self.ring[replica_hash];
+ break;
end
end
- return self.ring[self.ring[1]];
+ if not node then
+ node = self.ring[self.ring[1]];
+ end
+ return node, self.nodes[node];
end
return {
diff --git a/util/hmac.lua b/util/hmac.lua
index 4cad17cc..ca030259 100644
--- a/util/hmac.lua
+++ b/util/hmac.lua
@@ -13,6 +13,10 @@ local hashes = require "util.hashes"
return {
md5 = hashes.hmac_md5,
sha1 = hashes.hmac_sha1,
+ sha224 = hashes.hmac_sha224,
sha256 = hashes.hmac_sha256,
+ sha384 = hashes.hmac_sha384,
sha512 = hashes.hmac_sha512,
+ blake2s256 = hashes.hmac_blake2s256,
+ blake2b512 = hashes.hmac_blake2b512,
};
diff --git a/util/human/io.lua b/util/human/io.lua
index 7d7dea97..b272af71 100644
--- a/util/human/io.lua
+++ b/util/human/io.lua
@@ -8,7 +8,7 @@ end;
local function getchar(n)
local stty_ret = os.execute("stty raw -echo 2>/dev/null");
local ok, char;
- if stty_ret == true or stty_ret == 0 then
+ if stty_ret then
ok, char = pcall(io.read, n or 1);
os.execute("stty sane");
else
@@ -30,15 +30,12 @@ local function getline()
end
local function getpass()
- local stty_ret, _, status_code = os.execute("stty -echo 2>/dev/null");
- if status_code then -- COMPAT w/ Lua 5.1
- stty_ret = status_code;
- end
- if stty_ret ~= 0 then
+ local stty_ret = os.execute("stty -echo 2>/dev/null");
+ if not stty_ret then
io.write("\027[08m"); -- ANSI 'hidden' text attribute
end
local ok, pass = pcall(io.read, "*l");
- if stty_ret == 0 then
+ if stty_ret then
os.execute("stty sane");
else
io.write("\027[00m");
diff --git a/util/human/units.lua b/util/human/units.lua
index af233e98..329c8518 100644
--- a/util/human/units.lua
+++ b/util/human/units.lua
@@ -4,15 +4,7 @@ local math_floor = math.floor;
local math_log = math.log;
local math_max = math.max;
local math_min = math.min;
-local unpack = table.unpack or unpack; --luacheck: ignore 113
-
-if math_log(10, 10) ~= 1 then
- -- Lua 5.1 COMPAT
- local log10 = math.log10;
- function math_log(n, base)
- return log10(n) / log10(base);
- end
-end
+local unpack = table.unpack;
local large = {
"k", 1000,
diff --git a/util/import.lua b/util/import.lua
index 1007bc0a..0892e9b1 100644
--- a/util/import.lua
+++ b/util/import.lua
@@ -8,7 +8,7 @@
-local unpack = table.unpack or unpack; --luacheck: ignore 113
+local unpack = table.unpack;
local t_insert = table.insert;
function _G.import(module, ...)
local m = package.loaded[module] or require(module);
diff --git a/util/iterators.lua b/util/iterators.lua
index c03c2fd6..eb4c54af 100644
--- a/util/iterators.lua
+++ b/util/iterators.lua
@@ -12,8 +12,8 @@ local it = {};
local t_insert = table.insert;
local next = next;
-local unpack = table.unpack or unpack; --luacheck: ignore 113
-local pack = table.pack or require "util.table".pack;
+local unpack = table.unpack;
+local pack = table.pack;
local type = type;
local table, setmetatable = table, setmetatable;
@@ -240,7 +240,8 @@ function join_methods:prepend(f, s, var)
end
function it.join(f, s, var)
- return setmetatable({ {f, s, var} }, join_mt);
+ local t = setmetatable({ {f, s, var} }, join_mt);
+ return t, { t, 1 };
end
return it;
diff --git a/util/jid.lua b/util/jid.lua
index 694a6b1f..55567ea2 100644
--- a/util/jid.lua
+++ b/util/jid.lua
@@ -35,8 +35,7 @@ local function split(jid)
if jid == nil then return; end
local node, nodepos = match(jid, "^([^@/]+)@()");
local host, hostpos = match(jid, "^([^@/]+)()", nodepos);
- if node ~= nil and host == nil then return nil, nil, nil; end
- local resource = match(jid, "^/(.+)$", hostpos);
+ local resource = host and match(jid, "^/(.+)$", hostpos);
if (host == nil) or ((resource == nil) and #jid >= hostpos) then return nil, nil, nil; end
return node, host, resource;
end
@@ -91,9 +90,9 @@ local function compare(jid, acl)
-- TODO compare to table of rules?
local jid_node, jid_host, jid_resource = split(jid);
local acl_node, acl_host, acl_resource = split(acl);
- if ((acl_node ~= nil and acl_node == jid_node) or acl_node == nil) and
- ((acl_host ~= nil and acl_host == jid_host) or acl_host == nil) and
- ((acl_resource ~= nil and acl_resource == jid_resource) or acl_resource == nil) then
+ if (acl_node == nil or acl_node == jid_node) and
+ (acl_host == nil or acl_host == jid_host) and
+ (acl_resource == nil or acl_resource == jid_resource) then
return true
end
return false
@@ -111,6 +110,7 @@ local function resource(jid)
return (select(3, split(jid)));
end
+-- TODO Forbid \20 at start and end of escaped output per XEP-0106 v1.1
local function escape(s) return s and (s:gsub("\\%x%x", backslash_escapes):gsub("[\"&'/:<>@ ]", escapes)); end
local function unescape(s) return s and (s:gsub("\\%x%x", unescapes)); end
diff --git a/util/jsonpointer.lua b/util/jsonpointer.lua
index 9b871ae7..f1c354a4 100644
--- a/util/jsonpointer.lua
+++ b/util/jsonpointer.lua
@@ -1,6 +1,4 @@
-local m_type = math.type or function (n)
- return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float";
-end;
+local m_type = math.type;
local function unescape_token(escaped_token)
local unescaped = escaped_token:gsub("~1", "/"):gsub("~0", "~")
diff --git a/util/jwt.lua b/util/jwt.lua
index bf106dfa..42a9f7f2 100644
--- a/util/jwt.lua
+++ b/util/jwt.lua
@@ -1,4 +1,5 @@
local s_gsub = string.gsub;
+local crypto = require "util.crypto";
local json = require "util.json";
local hashes = require "util.hashes";
local base64_encode = require "util.encodings".base64.encode;
@@ -13,17 +14,8 @@ local function unb64url(data)
return base64_decode(s_gsub(data, "[-_]", b64url_rep).."==");
end
-local static_header = b64url('{"alg":"HS256","typ":"JWT"}') .. '.';
-
-local function sign(key, payload)
- local encoded_payload = json.encode(payload);
- local signed = static_header .. b64url(encoded_payload);
- local signature = hashes.hmac_sha256(key, signed);
- return signed .. "." .. b64url(signature);
-end
-
local jwt_pattern = "^(([A-Za-z0-9-_]+)%.([A-Za-z0-9-_]+))%.([A-Za-z0-9-_]+)$"
-local function verify(key, blob)
+local function decode_jwt(blob, expected_alg)
local signed, bheader, bpayload, signature = string.match(blob, jwt_pattern);
if not signed then
return nil, "invalid-encoding";
@@ -31,21 +23,197 @@ local function verify(key, blob)
local header = json.decode(unb64url(bheader));
if not header or type(header) ~= "table" then
return nil, "invalid-header";
- elseif header.alg ~= "HS256" then
+ elseif header.alg ~= expected_alg then
return nil, "unsupported-algorithm";
end
- if not secure_equals(b64url(hashes.hmac_sha256(key, signed)), signature) then
- return false, "signature-mismatch";
- end
- local payload, err = json.decode(unb64url(bpayload));
+ return signed, signature, bpayload;
+end
+
+local function new_static_header(algorithm_name)
+ return b64url('{"alg":"'..algorithm_name..'","typ":"JWT"}') .. '.';
+end
+
+local function decode_raw_payload(raw_payload)
+ local payload, err = json.decode(unb64url(raw_payload));
if err ~= nil then
return nil, "json-decode-error";
+ elseif type(payload) ~= "table" then
+ return nil, "invalid-payload-type";
end
return true, payload;
end
+-- HS*** family
+local function new_hmac_algorithm(name)
+ local static_header = new_static_header(name);
+
+ local hmac = hashes["hmac_sha"..name:sub(-3)];
+
+ local function sign(key, payload)
+ local encoded_payload = json.encode(payload);
+ local signed = static_header .. b64url(encoded_payload);
+ local signature = hmac(key, signed);
+ return signed .. "." .. b64url(signature);
+ end
+
+ local function verify(key, blob)
+ local signed, signature, raw_payload = decode_jwt(blob, name);
+ if not signed then return nil, signature; end -- nil, err
+
+ if not secure_equals(b64url(hmac(key, signed)), signature) then
+ return false, "signature-mismatch";
+ end
+
+ return decode_raw_payload(raw_payload);
+ end
+
+ local function load_key(key)
+ assert(type(key) == "string", "key must be string (long, random, secure)");
+ return key;
+ end
+
+ return { sign = sign, verify = verify, load_key = load_key };
+end
+
+local function new_crypto_algorithm(name, key_type, c_sign, c_verify, sig_encode, sig_decode)
+ local static_header = new_static_header(name);
+
+ return {
+ sign = function (private_key, payload)
+ local encoded_payload = json.encode(payload);
+ local signed = static_header .. b64url(encoded_payload);
+
+ local signature = c_sign(private_key, signed);
+ if sig_encode then
+ signature = sig_encode(signature);
+ end
+
+ return signed.."."..b64url(signature);
+ end;
+
+ verify = function (public_key, blob)
+ local signed, signature, raw_payload = decode_jwt(blob, name);
+ if not signed then return nil, signature; end -- nil, err
+
+ signature = unb64url(signature);
+ if sig_decode and signature then
+ signature = sig_decode(signature);
+ end
+ if not signature then
+ return false, "signature-mismatch";
+ end
+
+ local verify_ok = c_verify(public_key, signed, signature);
+ if not verify_ok then
+ return false, "signature-mismatch";
+ end
+
+ return decode_raw_payload(raw_payload);
+ end;
+
+ load_public_key = function (public_key_pem)
+ local key = assert(crypto.import_public_pem(public_key_pem));
+ assert(key:get_type() == key_type, "incorrect key type");
+ return key;
+ end;
+
+ load_private_key = function (private_key_pem)
+ local key = assert(crypto.import_private_pem(private_key_pem));
+ assert(key:get_type() == key_type, "incorrect key type");
+ return key;
+ end;
+ };
+end
+
+-- RS***, PS***
+local rsa_sign_algos = { RS = "rsassa_pkcs1", PS = "rsassa_pss" };
+local function new_rsa_algorithm(name)
+ local family, digest_bits = name:match("^(..)(...)$");
+ local c_sign = crypto[rsa_sign_algos[family].."_sha"..digest_bits.."_sign"];
+ local c_verify = crypto[rsa_sign_algos[family].."_sha"..digest_bits.."_verify"];
+ return new_crypto_algorithm(name, "rsaEncryption", c_sign, c_verify);
+end
+
+-- ES***
+local function new_ecdsa_algorithm(name, c_sign, c_verify, sig_bytes)
+ local function encode_ecdsa_sig(der_sig)
+ local r, s = crypto.parse_ecdsa_signature(der_sig, sig_bytes);
+ return r..s;
+ end
+
+ local expected_sig_length = sig_bytes*2;
+ local function decode_ecdsa_sig(jwk_sig)
+ if #jwk_sig ~= expected_sig_length then
+ return nil;
+ end
+ return crypto.build_ecdsa_signature(jwk_sig:sub(1, sig_bytes), jwk_sig:sub(sig_bytes+1));
+ end
+ return new_crypto_algorithm(name, "id-ecPublicKey", c_sign, c_verify, encode_ecdsa_sig, decode_ecdsa_sig);
+end
+
+local algorithms = {
+ HS256 = new_hmac_algorithm("HS256"), HS384 = new_hmac_algorithm("HS384"), HS512 = new_hmac_algorithm("HS512");
+ ES256 = new_ecdsa_algorithm("ES256", crypto.ecdsa_sha256_sign, crypto.ecdsa_sha256_verify, 32);
+ ES512 = new_ecdsa_algorithm("ES512", crypto.ecdsa_sha512_sign, crypto.ecdsa_sha512_verify, 66);
+ RS256 = new_rsa_algorithm("RS256"), RS384 = new_rsa_algorithm("RS384"), RS512 = new_rsa_algorithm("RS512");
+ PS256 = new_rsa_algorithm("PS256"), PS384 = new_rsa_algorithm("PS384"), PS512 = new_rsa_algorithm("PS512");
+};
+
+local function new_signer(algorithm, key_input, options)
+ local impl = assert(algorithms[algorithm], "Unknown JWT algorithm: "..algorithm);
+ local key = (impl.load_private_key or impl.load_key)(key_input);
+ local sign = impl.sign;
+ local default_ttl = (options and options.default_ttl) or 3600;
+ return function (payload)
+ local issued_at;
+ if not payload.iat then
+ issued_at = os.time();
+ payload.iat = issued_at;
+ end
+ if not payload.exp then
+ payload.exp = (issued_at or os.time()) + default_ttl;
+ end
+ return sign(key, payload);
+ end
+end
+
+local function new_verifier(algorithm, key_input, options)
+ local impl = assert(algorithms[algorithm], "Unknown JWT algorithm: "..algorithm);
+ local key = (impl.load_public_key or impl.load_key)(key_input);
+ local verify = impl.verify;
+ local check_expiry = not (options and options.accept_expired);
+ local claim_verifier = options and options.claim_verifier;
+ return function (token)
+ local ok, payload = verify(key, token);
+ if ok then
+ local expires_at = check_expiry and payload.exp;
+ if expires_at then
+ if type(expires_at) ~= "number" then
+ return nil, "invalid-expiry";
+ elseif expires_at < os.time() then
+ return nil, "token-expired";
+ end
+ end
+ if claim_verifier and not claim_verifier(payload) then
+ return nil, "incorrect-claims";
+ end
+ end
+ return ok, payload;
+ end
+end
+
+local function init(algorithm, private_key, public_key, options)
+ return new_signer(algorithm, private_key, options), new_verifier(algorithm, public_key or private_key, options);
+end
+
return {
- sign = sign;
- verify = verify;
+ init = init;
+ new_signer = new_signer;
+ new_verifier = new_verifier;
+ -- Exported mainly for tests
+ _algorithms = algorithms;
+ -- Deprecated
+ sign = algorithms.HS256.sign;
+ verify = algorithms.HS256.verify;
};
diff --git a/util/logger.lua b/util/logger.lua
index 20a5cef2..148b98dc 100644
--- a/util/logger.lua
+++ b/util/logger.lua
@@ -10,6 +10,7 @@
local pairs = pairs;
local ipairs = ipairs;
local require = require;
+local t_remove = table.remove;
local _ENV = nil;
-- luacheck: std none
@@ -78,6 +79,20 @@ local function add_simple_sink(simple_sink_function, levels)
for _, level in ipairs(levels or {"debug", "info", "warn", "error"}) do
add_level_sink(level, sink_function);
end
+ return sink_function;
+end
+
+local function remove_sink(sink_function)
+ local removed;
+ for level, sinks in pairs(level_sinks) do
+ for i = #sinks, 1, -1 do
+ if sinks[i] == sink_function then
+ t_remove(sinks, i);
+ removed = true;
+ end
+ end
+ end
+ return removed;
end
return {
@@ -87,4 +102,5 @@ return {
add_level_sink = add_level_sink;
add_simple_sink = add_simple_sink;
new = make_logger;
+ remove_sink = remove_sink;
};
diff --git a/util/mathcompat.lua b/util/mathcompat.lua
new file mode 100644
index 00000000..e8acb261
--- /dev/null
+++ b/util/mathcompat.lua
@@ -0,0 +1,13 @@
+if not math.type then
+
+ local function math_type(t)
+ if type(t) == "number" then
+ if t % 1 == 0 and t ~= t + 1 and t ~= t - 1 then
+ return "integer"
+ else
+ return "float"
+ end
+ end
+ end
+ _G.math.type = math_type
+end
diff --git a/util/multitable.lua b/util/multitable.lua
index 4f2cd972..0c292b45 100644
--- a/util/multitable.lua
+++ b/util/multitable.lua
@@ -9,7 +9,7 @@
local select = select;
local t_insert = table.insert;
local pairs, next, type = pairs, next, type;
-local unpack = table.unpack or unpack; --luacheck: ignore 113
+local unpack = table.unpack;
local _ENV = nil;
-- luacheck: std none
diff --git a/util/openmetrics.lua b/util/openmetrics.lua
index c18e63e9..7bdbde9e 100644
--- a/util/openmetrics.lua
+++ b/util/openmetrics.lua
@@ -1,7 +1,7 @@
--[[
This module implements a subset of the OpenMetrics Internet Draft version 00.
-URL: https://tools.ietf.org/html/draft-richih-opsawg-openmetrics-00
+URL: https://datatracker.ietf.org/doc/html/draft-richih-opsawg-openmetrics-00
The following metric types are supported:
@@ -26,7 +26,7 @@ local log = require "util.logger".init("util.openmetrics");
local new_multitable = require "util.multitable".new;
local iter_multitable = require "util.multitable".iter;
local t_concat, t_insert = table.concat, table.insert;
-local t_pack, t_unpack = require "util.table".pack, table.unpack or unpack; --luacheck: ignore 113/unpack
+local t_pack, t_unpack = table.pack, table.unpack;
-- BEGIN of Utility: "metric proxy"
-- This allows to wrap a MetricFamily in a proxy which only provides the
@@ -35,6 +35,7 @@ local t_pack, t_unpack = require "util.table".pack, table.unpack or unpack; --lu
-- `with_partial_label` by the moduleapi in order to pre-set the `host` label
-- on metrics created in non-global modules.
local metric_proxy_mt = {}
+metric_proxy_mt.__name = "metric_proxy"
metric_proxy_mt.__index = metric_proxy_mt
local function new_metric_proxy(metric_family, with_labels_proxy_fun)
@@ -128,6 +129,7 @@ end
-- BEGIN of generic MetricFamily implementation
local metric_family_mt = {}
+metric_family_mt.__name = "metric_family"
metric_family_mt.__index = metric_family_mt
local function histogram_metric_ctor(orig_ctor, buckets)
@@ -278,6 +280,7 @@ local function compose_name(name, unit)
end
local metric_registry_mt = {}
+metric_registry_mt.__name = "metric_registry"
metric_registry_mt.__index = metric_registry_mt
local function new_metric_registry(backend)
diff --git a/util/openssl.lua b/util/openssl.lua
index 32b5aea7..3acb4f04 100644
--- a/util/openssl.lua
+++ b/util/openssl.lua
@@ -166,8 +166,7 @@ do -- Lua to shell calls.
setmetatable(_M, {
__index = function(_, command)
return function(opts)
- local ret = os_execute(serialize(command, type(opts) == "table" and opts or {}));
- return ret == true or ret == 0;
+ return os_execute(serialize(command, type(opts) == "table" and opts or {}));
end;
end;
});
diff --git a/util/paseto.lua b/util/paseto.lua
new file mode 100644
index 00000000..6cd29f68
--- /dev/null
+++ b/util/paseto.lua
@@ -0,0 +1,218 @@
+local crypto = require "util.crypto";
+local json = require "util.json";
+local hashes = require "util.hashes";
+local base64_encode = require "util.encodings".base64.encode;
+local base64_decode = require "util.encodings".base64.decode;
+local secure_equals = require "util.hashes".equals;
+local bit = require "util.bitcompat";
+local hex = require "util.hex";
+local rand = require "util.random";
+local s_pack = require "util.struct".pack;
+
+local s_gsub = string.gsub;
+
+local v4_public = {};
+
+local b64url_rep = { ["+"] = "-", ["/"] = "_", ["="] = "", ["-"] = "+", ["_"] = "/" };
+local function b64url(data)
+ return (s_gsub(base64_encode(data), "[+/=]", b64url_rep));
+end
+
+local valid_tails = {
+ nil; -- Always invalid
+ "^.[AQgw]$"; -- b??????00
+ "^..[AQgwEUk0IYo4Mcs8]$"; -- b????0000
+}
+
+local function unb64url(data)
+ local rem = #data%4;
+ if data:sub(-1,-1) == "=" or rem == 1 or (rem > 1 and not data:sub(-rem):match(valid_tails[rem])) then
+ return nil;
+ end
+ return base64_decode(s_gsub(data, "[-_]", b64url_rep).."==");
+end
+
+local function le64(n)
+ return s_pack("<I8", bit.band(n, 0x7F));
+end
+
+local function pae(parts)
+ if type(parts) ~= "table" then
+ error("bad argument #1 to 'pae' (table expected, got "..type(parts)..")");
+ end
+ local o = { le64(#parts) };
+ for _, part in ipairs(parts) do
+ table.insert(o, le64(#part)..part);
+ end
+ return table.concat(o);
+end
+
+function v4_public.sign(m, sk, f, i)
+ if type(m) ~= "table" then
+ return nil, "PASETO payloads must be a table";
+ end
+ m = json.encode(m);
+ local h = "v4.public.";
+ local m2 = pae({ h, m, f or "", i or "" });
+ local sig = crypto.ed25519_sign(sk, m2);
+ if not f or f == "" then
+ return h..b64url(m..sig);
+ else
+ return h..b64url(m..sig).."."..b64url(f);
+ end
+end
+
+function v4_public.verify(tok, pk, expected_f, i)
+ local h, sm, f = tok:match("^(v4%.public%.)([^%.]+)%.?(.*)$");
+ if not h then
+ return nil, "invalid-token-format";
+ end
+ f = f and unb64url(f) or nil;
+ if expected_f then
+ if not f or not secure_equals(expected_f, f) then
+ return nil, "invalid-footer";
+ end
+ end
+ local raw_sm = unb64url(sm);
+ if not raw_sm or #raw_sm <= 64 then
+ return nil, "invalid-token-format";
+ end
+ local s, m = raw_sm:sub(-64), raw_sm:sub(1, -65);
+ local m2 = pae({ h, m, f or "", i or "" });
+ local ok = crypto.ed25519_verify(pk, m2, s);
+ if not ok then
+ return nil, "invalid-token";
+ end
+ local payload, err = json.decode(m);
+ if err ~= nil or type(payload) ~= "table" then
+ return nil, "json-decode-error";
+ end
+ return payload;
+end
+
+v4_public.import_private_key = crypto.import_private_pem;
+v4_public.import_public_key = crypto.import_public_pem;
+function v4_public.new_keypair()
+ return crypto.generate_ed25519_keypair();
+end
+
+function v4_public.init(private_key_pem, public_key_pem, options)
+ local sign, verify = v4_public.sign, v4_public.verify;
+ local public_key = public_key_pem and v4_public.import_public_key(public_key_pem);
+ local private_key = private_key_pem and v4_public.import_private_key(private_key_pem);
+ local default_footer = options and options.default_footer;
+ local default_assertion = options and options.default_implicit_assertion;
+ return private_key and function (token, token_footer, token_assertion)
+ return sign(token, private_key, token_footer or default_footer, token_assertion or default_assertion);
+ end, public_key and function (token, expected_footer, token_assertion)
+ return verify(token, public_key, expected_footer or default_footer, token_assertion or default_assertion);
+ end;
+end
+
+function v4_public.new_signer(private_key_pem, options)
+ return (v4_public.init(private_key_pem, nil, options));
+end
+
+function v4_public.new_verifier(public_key_pem, options)
+ return (select(2, v4_public.init(nil, public_key_pem, options)));
+end
+
+local v3_local = { _key_mt = {} };
+
+local function v3_local_derive_keys(k, n)
+ local tmp = hashes.hkdf_hmac_sha384(48, k, nil, "paseto-encryption-key"..n);
+ local Ek = tmp:sub(1, 32);
+ local n2 = tmp:sub(33);
+ local Ak = hashes.hkdf_hmac_sha384(48, k, nil, "paseto-auth-key-for-aead"..n);
+ return Ek, Ak, n2;
+end
+
+function v3_local.encrypt(m, k, f, i)
+ assert(#k == 32)
+ if type(m) ~= "table" then
+ return nil, "PASETO payloads must be a table";
+ end
+ m = json.encode(m);
+ local h = "v3.local.";
+ local n = rand.bytes(32);
+ local Ek, Ak, n2 = v3_local_derive_keys(k, n);
+
+ local c = crypto.aes_256_ctr_encrypt(Ek, n2, m);
+ local m2 = pae({ h, n, c, f or "", i or "" });
+ local t = hashes.hmac_sha384(Ak, m2);
+
+ if not f or f == "" then
+ return h..b64url(n..c..t);
+ else
+ return h..b64url(n..c..t).."."..b64url(f);
+ end
+end
+
+function v3_local.decrypt(tok, k, expected_f, i)
+ assert(#k == 32)
+
+ local h, sm, f = tok:match("^(v3%.local%.)([^%.]+)%.?(.*)$");
+ if not h then
+ return nil, "invalid-token-format";
+ end
+ f = f and unb64url(f) or nil;
+ if expected_f then
+ if not f or not secure_equals(expected_f, f) then
+ return nil, "invalid-footer";
+ end
+ end
+ local m = unb64url(sm);
+ if not m or #m <= 80 then
+ return nil, "invalid-token-format";
+ end
+ local n, c, t = m:sub(1, 32), m:sub(33, -49), m:sub(-48);
+ local Ek, Ak, n2 = v3_local_derive_keys(k, n);
+ local preAuth = pae({ h, n, c, f or "", i or "" });
+ local t2 = hashes.hmac_sha384(Ak, preAuth);
+ if not secure_equals(t, t2) then
+ return nil, "invalid-token";
+ end
+ local m2 = crypto.aes_256_ctr_decrypt(Ek, n2, c);
+ if not m2 then
+ return nil, "invalid-token";
+ end
+
+ local payload, err = json.decode(m2);
+ if err ~= nil or type(payload) ~= "table" then
+ return nil, "json-decode-error";
+ end
+ return payload;
+end
+
+function v3_local.new_key()
+ return "secret-token:paseto.v3.local:"..hex.encode(rand.bytes(32));
+end
+
+function v3_local.init(key, options)
+ local encoded_key = key:match("^secret%-token:paseto%.v3%.local:(%x+)$");
+ if not encoded_key or #encoded_key ~= 64 then
+ return error("invalid key for v3.local");
+ end
+ local raw_key = hex.decode(encoded_key);
+ local default_footer = options and options.default_footer;
+ local default_assertion = options and options.default_implicit_assertion;
+ return function (token, token_footer, token_assertion)
+ return v3_local.encrypt(token, raw_key, token_footer or default_footer, token_assertion or default_assertion);
+ end, function (token, token_footer, token_assertion)
+ return v3_local.decrypt(token, raw_key, token_footer or default_footer, token_assertion or default_assertion);
+ end;
+end
+
+function v3_local.new_signer(key, options)
+ return (v3_local.init(key, options));
+end
+
+function v3_local.new_verifier(key, options)
+ return (select(2, v3_local.init(key, options)));
+end
+
+return {
+ pae = pae;
+ v3_local = v3_local;
+ v4_public = v4_public;
+};
diff --git a/util/promise.lua b/util/promise.lua
index c4e166ed..f56502d2 100644
--- a/util/promise.lua
+++ b/util/promise.lua
@@ -2,7 +2,7 @@ local promise_methods = {};
local promise_mt = { __name = "promise", __index = promise_methods };
local xpcall = require "util.xpcall".xpcall;
-local unpack = table.unpack or unpack; --luacheck: ignore 113
+local unpack = table.unpack;
function promise_mt:__tostring()
return "promise (" .. (self._state or "invalid") .. ")";
@@ -57,10 +57,7 @@ local function promise_settle(promise, new_state, new_next, cbs, value)
end
local function new_resolve_functions(p)
- local resolved = false;
local function _resolve(v)
- if resolved then return; end
- resolved = true;
if is_promise(v) then
v:next(new_resolve_functions(p));
elseif promise_settle(p, "fulfilled", next_fulfilled, p._pending_on_fulfilled, v) then
@@ -69,8 +66,6 @@ local function new_resolve_functions(p)
end
local function _reject(e)
- if resolved then return; end
- resolved = true;
if promise_settle(p, "rejected", next_rejected, p._pending_on_rejected, e) then
p.reason = e;
end
diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua
index 4d49cd16..b3163799 100644
--- a/util/prosodyctl.lua
+++ b/util/prosodyctl.lua
@@ -224,8 +224,7 @@ local function call_luarocks(operation, mod, server)
local ok, _, code = os.execute(render_cli("luarocks --lua-version={luav} {op} --tree={dir} {server&--server={server}} {mod?}", {
dir = dir; op = operation; mod = mod; server = server; luav = _VERSION:match("5%.%d");
}));
- if type(ok) == "number" then code = ok; end
- return code;
+ return ok and code;
end
return {
diff --git a/util/prosodyctl/cert.lua b/util/prosodyctl/cert.lua
index 02c81585..ebc14a4e 100644
--- a/util/prosodyctl/cert.lua
+++ b/util/prosodyctl/cert.lua
@@ -179,7 +179,7 @@ local function copy(from, to, umask, owner, group)
os.execute(("chown -c --reference=%s %s"):format(sh_esc(cert_basedir), sh_esc(to)));
elseif owner and group then
local ok = os.execute(("chown %s:%s %s"):format(sh_esc(owner), sh_esc(group), sh_esc(to)));
- assert(ok == true or ok == 0, "Failed to change ownership of "..to);
+ assert(ok, "Failed to change ownership of "..to);
end
if old_umask then pposix.umask(old_umask); end
return true;
diff --git a/util/prosodyctl/check.lua b/util/prosodyctl/check.lua
index 2fd3c784..793ef038 100644
--- a/util/prosodyctl/check.lua
+++ b/util/prosodyctl/check.lua
@@ -155,7 +155,7 @@ local function check_turn_service(turn_service, ping_service)
result.error = "TURN server did not response to allocation request: "..err;
return result;
elseif alloc_response:is_err_resp() then
- result.error = ("TURN allocation failed: %d (%s)"):format(alloc_response:get_error());
+ result.error = ("TURN server failed to create allocation: %d (%s)"):format(alloc_response:get_error());
return result;
elseif not alloc_response:is_success_resp() then
result.error = ("Unexpected TURN response: %d (%s)"):format(alloc_response:get_type());
diff --git a/util/prosodyctl/shell.lua b/util/prosodyctl/shell.lua
index 8cf7df69..61050e4d 100644
--- a/util/prosodyctl/shell.lua
+++ b/util/prosodyctl/shell.lua
@@ -4,6 +4,8 @@ local st = require "util.stanza";
local path = require "util.paths";
local parse_args = require "util.argparse".parse;
local unpack = table.unpack or _G.unpack;
+local tc = require "util.termcolours";
+local isatty = require "util.pposix".isatty;
local have_readline, readline = pcall(require, "readline");
@@ -27,7 +29,7 @@ local function read_line(prompt_string)
end
local function send_line(client, line)
- client.send(st.stanza("repl-input"):text(line));
+ client.send(st.stanza("repl-input", { width = os.getenv "COLUMNS" }):text(line));
end
local function repl(client)
@@ -64,6 +66,7 @@ end
local function start(arg) --luacheck: ignore 212/arg
local client = adminstream.client();
local opts, err, where = parse_args(arg);
+ local ttyout = isatty(io.stdout);
if not opts then
if err == "param-not-found" then
@@ -77,8 +80,7 @@ local function start(arg) --luacheck: ignore 212/arg
if arg[1] then
if arg[2] then
-- prosodyctl shell module reload foo bar.com --> module:reload("foo", "bar.com")
- -- COMPAT Lua 5.1 doesn't have the separator argument to string.rep
- arg[1] = string.format("%s:%s("..string.rep("%q, ", #arg-2):sub(1, -3)..")", unpack(arg));
+ arg[1] = string.format("%s:%s("..string.rep("%q", #arg-2,", ")..")", unpack(arg));
end
client.events.add_handler("connected", function()
@@ -89,11 +91,15 @@ local function start(arg) --luacheck: ignore 212/arg
local errors = 0; -- TODO This is weird, but works for now.
client.events.add_handler("received", function(stanza)
if stanza.name == "repl-output" or stanza.name == "repl-result" then
+ local dest = io.stdout;
if stanza.attr.type == "error" then
errors = errors + 1;
- io.stderr:write(stanza:get_text(), "\n");
+ dest = io.stderr;
+ end
+ if stanza.attr.eol == "0" then
+ dest:write(stanza:get_text());
else
- print(stanza:get_text());
+ dest:write(stanza:get_text(), "\n");
end
end
if stanza.name == "repl-result" then
@@ -118,7 +124,11 @@ local function start(arg) --luacheck: ignore 212/arg
client.events.add_handler("received", function (stanza)
if stanza.name == "repl-output" or stanza.name == "repl-result" then
local result_prefix = stanza.attr.type == "error" and "!" or "|";
- print(result_prefix.." "..stanza:get_text());
+ local out = result_prefix.." "..stanza:get_text();
+ if ttyout and stanza.attr.type == "error" then
+ out = tc.getstring(tc.getstyle("red"), out);
+ end
+ print(out);
end
if stanza.name == "repl-result" then
repl(client);
diff --git a/util/roles.lua b/util/roles.lua
new file mode 100644
index 00000000..2c3a5026
--- /dev/null
+++ b/util/roles.lua
@@ -0,0 +1,110 @@
+local array = require "util.array";
+local it = require "util.iterators";
+local new_short_id = require "util.id".short;
+
+local role_methods = {};
+local role_mt = {
+ __index = role_methods;
+ __name = "role";
+ __add = nil;
+};
+
+local function is_role(o)
+ local mt = getmetatable(o);
+ return mt == role_mt;
+end
+
+local function _new_may(permissions, inherited_mays)
+ local n_inherited = inherited_mays and #inherited_mays;
+ return function (role, action, context)
+ -- Note: 'role' may be a descendent role, not only the one we're attached to
+ local policy = permissions[action];
+ if policy ~= nil then
+ return policy;
+ end
+ if n_inherited then
+ for i = 1, n_inherited do
+ policy = inherited_mays[i](role, action, context);
+ if policy ~= nil then
+ return policy;
+ end
+ end
+ end
+ return nil;
+ end
+end
+
+local permissions_key = {};
+
+-- {
+-- Required:
+-- name = "My fancy role";
+--
+-- Optional:
+-- inherits = { role_obj... }
+-- default = true
+-- priority = 100
+-- permissions = {
+-- ["foo"] = true; -- allow
+-- ["bar"] = false; -- deny
+-- }
+-- }
+local function new(base_config, overrides)
+ local config = setmetatable(overrides or {}, { __index = base_config });
+ local permissions = {};
+ local inherited_mays;
+ if config.inherits then
+ inherited_mays = array.pluck(config.inherits, "may");
+ end
+ local new_role = {
+ id = new_short_id();
+ name = config.name;
+ description = config.description;
+ default = config.default;
+ priority = config.priority;
+ may = _new_may(permissions, inherited_mays);
+ inherits = config.inherits;
+ [permissions_key] = permissions;
+ };
+ local desired_permissions = config.permissions or config[permissions_key];
+ for k, v in pairs(desired_permissions or {}) do
+ permissions[k] = v;
+ end
+ return setmetatable(new_role, role_mt);
+end
+
+function role_methods:clone(overrides)
+ return new(self, overrides);
+end
+
+function role_methods:set_permission(permission_name, policy, overwrite)
+ local permissions = self[permissions_key];
+ if overwrite ~= true and permissions[permission_name] ~= nil and permissions[permission_name] ~= policy then
+ return false, "policy-already-exists";
+ end
+ permissions[permission_name] = policy;
+ return true;
+end
+
+function role_methods:policies()
+ local policy_iterator, s, v = it.join(pairs(self[permissions_key]));
+ if self.inherits then
+ for _, inherited_role in ipairs(self.inherits) do
+ policy_iterator:append(inherited_role:policies());
+ end
+ end
+ return policy_iterator, s, v;
+end
+
+function role_mt.__tostring(self)
+ return ("role<[%s] %s>"):format(self.id or "nil", self.name or "[no name]");
+end
+
+function role_mt.__pairs(self)
+ return it.filter(permissions_key, next, self);
+end
+
+return {
+ is_role = is_role;
+ new = new;
+};
diff --git a/util/sasl.lua b/util/sasl.lua
index 528743d1..1920cf21 100644
--- a/util/sasl.lua
+++ b/util/sasl.lua
@@ -133,10 +133,11 @@ function method:process(message)
end
-- load the mechanisms
-require "util.sasl.plain" .init(registerMechanism);
-require "util.sasl.anonymous" .init(registerMechanism);
-require "util.sasl.scram" .init(registerMechanism);
-require "util.sasl.external" .init(registerMechanism);
+require "util.sasl.plain" .init(registerMechanism);
+require "util.sasl.anonymous" .init(registerMechanism);
+require "util.sasl.oauthbearer" .init(registerMechanism);
+require "util.sasl.scram" .init(registerMechanism);
+require "util.sasl.external" .init(registerMechanism);
return {
registerMechanism = registerMechanism;
diff --git a/util/sasl/oauthbearer.lua b/util/sasl/oauthbearer.lua
new file mode 100644
index 00000000..54c63575
--- /dev/null
+++ b/util/sasl/oauthbearer.lua
@@ -0,0 +1,87 @@
+local saslprep = require "util.encodings".stringprep.saslprep;
+local nodeprep = require "util.encodings".stringprep.nodeprep;
+local jid = require "util.jid";
+local json = require "util.json";
+local log = require "util.logger".init("sasl");
+local _ENV = nil;
+
+
+local function oauthbearer(self, message)
+ if not message then
+ return "failure", "malformed-request";
+ end
+
+ if message == "\001" then
+ return "failure", "not-authorized";
+ end
+
+ local gs2_authzid, kvpairs = message:match("n,a=([^,]+),(.+)$");
+ if not gs2_authzid then
+ return "failure", "malformed-request";
+ end
+
+ local auth_header;
+ for k, v in kvpairs:gmatch("([a-zA-Z]+)=([\033-\126 \009\r\n]*)\001") do
+ if k == "auth" then
+ auth_header = v;
+ break;
+ end
+ end
+
+ if not auth_header then
+ return "failure", "malformed-request";
+ end
+
+ local username = jid.prepped_split(gs2_authzid);
+
+ if not username or username == "" then
+ return "failure", "malformed-request", "Expected authorization identity in the username@hostname format";
+ end
+
+ -- SASLprep username
+ username = saslprep(username);
+
+ if not username or username == "" then
+ log("debug", "Username violates SASLprep.");
+ return "failure", "malformed-request", "Invalid username.";
+ end
+
+ local _nodeprep = self.profile.nodeprep;
+ if _nodeprep ~= false then
+ username = (_nodeprep or nodeprep)(username);
+ if not username or username == "" then
+ return "failure", "malformed-request", "Invalid username or password."
+ end
+ end
+
+ self.username = username;
+
+ local token = auth_header:match("^Bearer (.+)$");
+
+ local correct, state, token_info = self.profile.oauthbearer(self, username, token, self.realm);
+
+ if state == false then
+ return "failure", "account-disabled";
+ elseif state == nil or not correct then
+ -- For token-level errors, RFC 7628 demands use of a JSON-encoded
+ -- challenge response upon failure. We relay additional info from
+ -- the auth backend if available.
+ return "challenge", json.encode({
+ status = token_info and token_info.status or "invalid_token";
+ scope = token_info and token_info.scope or nil;
+ ["openid-configuration"] = token_info and token_info.oidc_discovery_url or nil;
+ });
+ end
+
+ self.resource = token_info.resource;
+ self.role = token_info.role;
+ return "success";
+end
+
+local function init(registerMechanism)
+ registerMechanism("OAUTHBEARER", {"oauthbearer"}, oauthbearer);
+end
+
+return {
+ init = init;
+}
diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua
index 37abf4a4..4606d1fd 100644
--- a/util/sasl/scram.lua
+++ b/util/sasl/scram.lua
@@ -240,7 +240,7 @@ local function init(registerMechanism)
-- register channel binding equivalent
registerMechanism("SCRAM-"..hash_name.."-PLUS",
{"plain", "scram_"..(hashprep(hash_name))},
- scram_gen(hash_name:lower(), hash, hmac_hash, get_auth_db, true), {"tls-unique"});
+ scram_gen(hash_name:lower(), hash, hmac_hash, get_auth_db, true), {"tls-unique", "tls-exporter"});
end
registerSCRAMMechanism("SHA-1", hashes.sha1, hashes.hmac_sha1, hashes.pbkdf2_hmac_sha1);
diff --git a/util/serialization.lua b/util/serialization.lua
index d310a3e8..e2e104f1 100644
--- a/util/serialization.lua
+++ b/util/serialization.lua
@@ -21,10 +21,12 @@ local to_hex = require "util.hex".to;
local pcall = pcall;
local envload = require"util.envload".envload;
+if not math.type then
+ require "util.mathcompat"
+end
+
local pos_inf, neg_inf = math.huge, -math.huge;
-local m_type = math.type or function (n)
- return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float";
-end;
+local m_type = math.type;
local function rawpairs(t)
return next, t, nil;
diff --git a/util/session.lua b/util/session.lua
index 25b22faf..d908476a 100644
--- a/util/session.lua
+++ b/util/session.lua
@@ -57,10 +57,16 @@ local function set_send(session)
return session;
end
+local function set_role(session, role)
+ session.role = role;
+end
+
return {
new = new_session;
+
set_id = set_id;
set_logger = set_logger;
set_conn = set_conn;
set_send = set_send;
+ set_role = set_role;
}
diff --git a/util/sql.lua b/util/sql.lua
index 9d1c86ca..623d4ed4 100644
--- a/util/sql.lua
+++ b/util/sql.lua
@@ -99,6 +99,9 @@ end
function engine:onconnect() -- luacheck: ignore 212/self
-- Override from create_engine()
end
+function engine:ondisconnect() -- luacheck: ignore 212/self
+ -- Override from create_engine()
+end
function engine:prepquery(sql)
if self.params.driver == "MySQL" then
@@ -224,6 +227,7 @@ function engine:transaction(...)
if not conn or not conn:ping() then
log("debug", "Database connection was closed. Will reconnect and retry.");
self.conn = nil;
+ self:ondisconnect();
log("debug", "Retrying SQL transaction [%s]", (...));
ok, ret, b, c = self:_transaction(...);
log("debug", "SQL transaction retry %s", ok and "succeeded" or "failed");
@@ -365,8 +369,8 @@ local function db2uri(params)
};
end
-local function create_engine(_, params, onconnect)
- return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt);
+local function create_engine(_, params, onconnect, ondisconnect)
+ return setmetatable({ url = db2uri(params); params = params; onconnect = onconnect; ondisconnect = ondisconnect }, engine_mt);
end
return {
diff --git a/util/sqlite3.lua b/util/sqlite3.lua
new file mode 100644
index 00000000..dce45e13
--- /dev/null
+++ b/util/sqlite3.lua
@@ -0,0 +1,413 @@
+
+-- luacheck: ignore 113/unpack 211 212 411 213
+local setmetatable, getmetatable = setmetatable, getmetatable;
+local ipairs, unpack, select = ipairs, table.unpack or unpack, select;
+local tonumber, tostring = tonumber, tostring;
+local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback;
+local error = error
+local type = type
+local t_concat = table.concat;
+local t_insert = table.insert;
+local s_char = string.char;
+local log = require "util.logger".init("sql");
+
+local lsqlite3 = require "lsqlite3";
+local build_url = require "socket.url".build;
+local ROW, DONE = lsqlite3.ROW, lsqlite3.DONE;
+
+-- from sqlite3.h, no copyright claimed
+local sqlite_errors = require"util.error".init("util.sqlite3", {
+ -- FIXME xmpp error conditions?
+ [1] = { code = 1; type = "modify"; condition = "ERROR"; text = "Generic error" };
+ [2] = { code = 2; type = "cancel"; condition = "INTERNAL"; text = "Internal logic error in SQLite" };
+ [3] = { code = 3; type = "auth"; condition = "PERM"; text = "Access permission denied" };
+ [4] = { code = 4; type = "cancel"; condition = "ABORT"; text = "Callback routine requested an abort" };
+ [5] = { code = 5; type = "wait"; condition = "BUSY"; text = "The database file is locked" };
+ [6] = { code = 6; type = "wait"; condition = "LOCKED"; text = "A table in the database is locked" };
+ [7] = { code = 7; type = "wait"; condition = "NOMEM"; text = "A malloc() failed" };
+ [8] = { code = 8; type = "cancel"; condition = "READONLY"; text = "Attempt to write a readonly database" };
+ [9] = { code = 9; type = "cancel"; condition = "INTERRUPT"; text = "Operation terminated by sqlite3_interrupt()" };
+ [10] = { code = 10; type = "wait"; condition = "IOERR"; text = "Some kind of disk I/O error occurred" };
+ [11] = { code = 11; type = "cancel"; condition = "CORRUPT"; text = "The database disk image is malformed" };
+ [12] = { code = 12; type = "modify"; condition = "NOTFOUND"; text = "Unknown opcode in sqlite3_file_control()" };
+ [13] = { code = 13; type = "wait"; condition = "FULL"; text = "Insertion failed because database is full" };
+ [14] = { code = 14; type = "auth"; condition = "CANTOPEN"; text = "Unable to open the database file" };
+ [15] = { code = 15; type = "cancel"; condition = "PROTOCOL"; text = "Database lock protocol error" };
+ [16] = { code = 16; type = "continue"; condition = "EMPTY"; text = "Internal use only" };
+ [17] = { code = 17; type = "modify"; condition = "SCHEMA"; text = "The database schema changed" };
+ [18] = { code = 18; type = "modify"; condition = "TOOBIG"; text = "String or BLOB exceeds size limit" };
+ [19] = { code = 19; type = "modify"; condition = "CONSTRAINT"; text = "Abort due to constraint violation" };
+ [20] = { code = 20; type = "modify"; condition = "MISMATCH"; text = "Data type mismatch" };
+ [21] = { code = 21; type = "modify"; condition = "MISUSE"; text = "Library used incorrectly" };
+ [22] = { code = 22; type = "cancel"; condition = "NOLFS"; text = "Uses OS features not supported on host" };
+ [23] = { code = 23; type = "auth"; condition = "AUTH"; text = "Authorization denied" };
+ [24] = { code = 24; type = "modify"; condition = "FORMAT"; text = "Not used" };
+ [25] = { code = 25; type = "modify"; condition = "RANGE"; text = "2nd parameter to sqlite3_bind out of range" };
+ [26] = { code = 26; type = "cancel"; condition = "NOTADB"; text = "File opened that is not a database file" };
+ [27] = { code = 27; type = "continue"; condition = "NOTICE"; text = "Notifications from sqlite3_log()" };
+ [28] = { code = 28; type = "continue"; condition = "WARNING"; text = "Warnings from sqlite3_log()" };
+ [100] = { code = 100; type = "continue"; condition = "ROW"; text = "sqlite3_step() has another row ready" };
+ [101] = { code = 101; type = "continue"; condition = "DONE"; text = "sqlite3_step() has finished executing" };
+});
+
+local assert = function(cond, errno, err)
+ return assert(sqlite_errors.coerce(cond, err or errno));
+end
+local _ENV = nil;
+-- luacheck: std none
+
+local column_mt = {};
+local table_mt = {};
+local query_mt = {};
+--local op_mt = {};
+local index_mt = {};
+
+local function is_column(x) return getmetatable(x)==column_mt; end
+local function is_index(x) return getmetatable(x)==index_mt; end
+local function is_table(x) return getmetatable(x)==table_mt; end
+local function is_query(x) return getmetatable(x)==query_mt; end
+local function Integer(n) return "Integer()" end
+local function String(n) return "String()" end
+
+local function Column(definition)
+ return setmetatable(definition, column_mt);
+end
+local function Table(definition)
+ local c = {}
+ for i,col in ipairs(definition) do
+ if is_column(col) then
+ c[i], c[col.name] = col, col;
+ elseif is_index(col) then
+ col.table = definition.name;
+ end
+ end
+ return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
+end
+local function Index(definition)
+ return setmetatable(definition, index_mt);
+end
+
+function table_mt:__tostring()
+ local s = { 'name="'..self.__table__.name..'"' }
+ for i,col in ipairs(self.__table__) do
+ s[#s+1] = tostring(col);
+ end
+ return 'Table{ '..t_concat(s, ", ")..' }'
+end
+table_mt.__index = {};
+function table_mt.__index:create(engine)
+ return engine:_create_table(self);
+end
+function table_mt:__call(...)
+ -- TODO
+end
+function column_mt:__tostring()
+ return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
+end
+function index_mt:__tostring()
+ local s = 'Index{ name="'..self.name..'"';
+ for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end
+ return s..' }';
+-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
+end
+
+local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
+local function parse_url(url)
+ local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)");
+ assert(scheme, "Invalid URL format");
+ local username, password, host, port;
+ local authpart, hostpart = secondpart:match("([^@]+)@([^@+])");
+ if not authpart then hostpart = secondpart; end
+ if authpart then
+ username, password = authpart:match("([^:]*):(.*)");
+ username = username or authpart;
+ password = password and urldecode(password);
+ end
+ if hostpart then
+ host, port = hostpart:match("([^:]*):(.*)");
+ host = host or hostpart;
+ port = port and assert(tonumber(port), "Invalid URL format");
+ end
+ return {
+ scheme = scheme:lower();
+ username = username; password = password;
+ host = host; port = port;
+ database = #database > 0 and database or nil;
+ };
+end
+
+local engine = {};
+function engine:connect()
+ if self.conn then return true; end
+
+ local params = self.params;
+ assert(params.driver == "SQLite3", "Only sqlite3 is supported");
+ local dbh, err = sqlite_errors.coerce(lsqlite3.open(params.database));
+ if not dbh then return nil, err; end
+ self.conn = dbh;
+ self.prepared = {};
+ local ok, err = self:set_encoding();
+ if not ok then
+ return ok, err;
+ end
+ local ok, err = self:onconnect();
+ if ok == false then
+ return ok, err;
+ end
+ return true;
+end
+function engine:onconnect()
+ -- Override from create_engine()
+end
+function engine:ondisconnect() -- luacheck: ignore 212/self
+ -- Override from create_engine()
+end
+function engine:execute(sql, ...)
+ local success, err = self:connect();
+ if not success then return success, err; end
+ local prepared = self.prepared;
+
+ if select('#', ...) == 0 then
+ local ret = self.conn:exec(sql);
+ if ret ~= lsqlite3.OK then
+ local err = sqlite_errors.new(err);
+ err.text = self.conn:errmsg();
+ return err;
+ end
+ return true;
+ end
+
+ local stmt = prepared[sql];
+ if not stmt then
+ local err;
+ stmt, err = self.conn:prepare(sql);
+ if not stmt then
+ err = sqlite_errors.new(err);
+ err.text = self.conn:errmsg();
+ return stmt, err;
+ end
+ prepared[sql] = stmt;
+ end
+
+ local ret = stmt:bind_values(...);
+ if ret ~= lsqlite3.OK then return nil, sqlite_errors.new(ret, { message = self.conn:errmsg() }); end
+ return stmt;
+end
+
+local result_mt = {
+ __index = {
+ affected = function(self) return self.__affected; end;
+ rowcount = function(self) return self.__rowcount; end;
+ },
+};
+
+local function iterator(table)
+ local i=0;
+ return function()
+ i=i+1;
+ local item=table[i];
+ if item ~= nil then
+ return item;
+ end
+ end
+end
+
+local function debugquery(where, sql, ...)
+ local i = 0; local a = {...}
+ sql = sql:gsub("\n?\t+", " ");
+ log("debug", "[%s] %s", where, (sql:gsub("%?", function ()
+ i = i + 1;
+ local v = a[i];
+ if type(v) == "string" then
+ v = ("'%s'"):format(v:gsub("'", "''"));
+ end
+ return tostring(v);
+ end)));
+end
+
+function engine:execute_query(sql, ...)
+ local prepared = self.prepared;
+ local stmt = prepared[sql];
+ if stmt and stmt:isopen() then
+ prepared[sql] = nil; -- Can't be used concurrently
+ else
+ stmt = assert(self.conn:prepare(sql));
+ end
+ local ret = stmt:bind_values(...);
+ if ret ~= lsqlite3.OK then error(self.conn:errmsg()); end
+ local data, ret = {}
+ while stmt:step() == ROW do
+ t_insert(data, stmt:get_values());
+ end
+ -- FIXME Error handling, BUSY, ERROR, MISUSE
+ if stmt:reset() == lsqlite3.OK then
+ prepared[sql] = stmt;
+ end
+ return setmetatable({ __data = data }, { __index = result_mt.__index, __call = iterator(data) });
+end
+function engine:execute_update(sql, ...)
+ local prepared = self.prepared;
+ local stmt = prepared[sql];
+ if not stmt or not stmt:isopen() then
+ stmt = assert(self.conn:prepare(sql));
+ else
+ prepared[sql] = nil;
+ end
+ local ret = stmt:bind_values(...);
+ if ret ~= lsqlite3.OK then error(self.conn:errmsg()); end
+ local rowcount = 0;
+ repeat
+ ret = stmt:step();
+ if ret == lsqlite3.ROW then
+ rowcount = rowcount + 1;
+ end
+ until ret ~= lsqlite3.ROW;
+ local affected = self.conn:changes();
+ if stmt:reset() == lsqlite3.OK then
+ prepared[sql] = stmt;
+ end
+ return setmetatable({ __affected = affected, __rowcount = rowcount }, result_mt);
+end
+engine.insert = engine.execute_update;
+engine.select = engine.execute_query;
+engine.delete = engine.execute_update;
+engine.update = engine.execute_update;
+local function debugwrap(name, f)
+ return function (self, sql, ...)
+ debugquery(name, sql, ...)
+ return f(self, sql, ...)
+ end
+end
+function engine:debug(enable)
+ self._debug = enable;
+ if enable then
+ engine.insert = debugwrap("insert", engine.execute_update);
+ engine.select = debugwrap("select", engine.execute_query);
+ engine.delete = debugwrap("delete", engine.execute_update);
+ engine.update = debugwrap("update", engine.execute_update);
+ else
+ engine.insert = engine.execute_update;
+ engine.select = engine.execute_query;
+ engine.delete = engine.execute_update;
+ engine.update = engine.execute_update;
+ end
+end
+function engine:_(word)
+ local ret = self.conn:exec(word);
+ if ret ~= lsqlite3.OK then return nil, self.conn:errmsg(); end
+ return true;
+end
+function engine:_transaction(func, ...)
+ if not self.conn then
+ local a,b = self:connect();
+ if not a then return a,b; end
+ end
+ --assert(not self.__transaction, "Recursive transactions not allowed");
+ local ok, err = self:_"BEGIN";
+ if not ok then return ok, err; end
+ self.__transaction = true;
+ local success, a, b, c = xpcall(func, debug_traceback, ...);
+ self.__transaction = nil;
+ if success then
+ log("debug", "SQL transaction success [%s]", tostring(func));
+ local ok, err = self:_"COMMIT";
+ if not ok then return ok, err; end -- commit failed
+ return success, a, b, c;
+ else
+ log("debug", "SQL transaction failure [%s]: %s", tostring(func), a);
+ if self.conn then self:_"ROLLBACK"; end
+ return success, a;
+ end
+end
+function engine:transaction(...)
+ local ok, ret = self:_transaction(...);
+ if not ok then
+ local conn = self.conn;
+ if not conn or not conn:isopen() then
+ self.conn = nil;
+ self:ondisconnect();
+ ok, ret = self:_transaction(...);
+ end
+ end
+ return ok, ret;
+end
+function engine:_create_index(index)
+ local sql = "CREATE INDEX IF NOT EXISTS \""..index.name.."\" ON \""..index.table.."\" (";
+ for i=1,#index do
+ sql = sql.."\""..index[i].."\"";
+ if i ~= #index then sql = sql..", "; end
+ end
+ sql = sql..");"
+ if index.unique then
+ sql = sql:gsub("^CREATE", "CREATE UNIQUE");
+ end
+ if self._debug then
+ debugquery("create", sql);
+ end
+ return self:execute(sql);
+end
+function engine:_create_table(table)
+ local sql = "CREATE TABLE IF NOT EXISTS \""..table.name.."\" (";
+ for i,col in ipairs(table.c) do
+ local col_type = col.type;
+ sql = sql.."\""..col.name.."\" "..col_type;
+ if col.nullable == false then sql = sql.." NOT NULL"; end
+ if col.primary_key == true then sql = sql.." PRIMARY KEY"; end
+ if col.auto_increment == true then
+ sql = sql.." AUTOINCREMENT";
+ end
+ if i ~= #table.c then sql = sql..", "; end
+ end
+ sql = sql.. ");"
+ if self._debug then
+ debugquery("create", sql);
+ end
+ local success,err = self:execute(sql);
+ if not success then return success,err; end
+ for i,v in ipairs(table.__table__) do
+ if is_index(v) then
+ self:_create_index(v);
+ end
+ end
+ return success;
+end
+function engine:set_encoding() -- to UTF-8
+ return self:transaction(function()
+ for encoding in self:select"PRAGMA encoding;" do
+ if encoding[1] == "UTF-8" then
+ self.charset = "utf8";
+ end
+ end
+ end);
+end
+local engine_mt = { __index = engine };
+
+local function db2uri(params)
+ return build_url{
+ scheme = params.driver,
+ user = params.username,
+ password = params.password,
+ host = params.host,
+ port = params.port,
+ path = params.database,
+ };
+end
+
+local function create_engine(_, params, onconnect, ondisconnect)
+ assert(params.driver == "SQLite3", "Only SQLite3 is supported without LuaDBI");
+ return setmetatable({ url = db2uri(params); params = params; onconnect = onconnect; ondisconnect = ondisconnect }, engine_mt);
+end
+
+return {
+ is_column = is_column;
+ is_index = is_index;
+ is_table = is_table;
+ is_query = is_query;
+ Integer = Integer;
+ String = String;
+ Column = Column;
+ Table = Table;
+ Index = Index;
+ create_engine = create_engine;
+ db2uri = db2uri;
+};
diff --git a/util/sslconfig.lua b/util/sslconfig.lua
index 6074a1fb..0078365b 100644
--- a/util/sslconfig.lua
+++ b/util/sslconfig.lua
@@ -3,9 +3,12 @@
local type = type;
local pairs = pairs;
local rawset = rawset;
+local rawget = rawget;
+local error = error;
local t_concat = table.concat;
local t_insert = table.insert;
local setmetatable = setmetatable;
+local resolve_path = require"util.paths".resolve_relative_path;
local _ENV = nil;
-- luacheck: std none
@@ -34,7 +37,7 @@ function handlers.options(config, field, new)
options[value] = true;
end
end
- config[field] = options;
+ rawset(config, field, options)
end
handlers.verifyext = handlers.options;
@@ -70,6 +73,20 @@ finalisers.curveslist = finalisers.ciphers;
-- TLS 1.3 ciphers
finalisers.ciphersuites = finalisers.ciphers;
+-- Path expansion
+function finalisers.key(path, config)
+ if type(path) == "string" then
+ return resolve_path(config._basedir, path);
+ else
+ return nil
+ end
+end
+finalisers.certificate = finalisers.key;
+finalisers.cafile = finalisers.key;
+finalisers.capath = finalisers.key;
+-- XXX: copied from core/certmanager.lua, but this seems odd, because it would remove a dhparam function from the config
+finalisers.dhparam = finalisers.key;
+
-- protocol = "x" should enable only that protocol
-- protocol = "x+" should enable x and later versions
@@ -89,37 +106,81 @@ end
-- Merge options from 'new' config into 'config'
local function apply(config, new)
+ rawset(config, "_cache", nil);
if type(new) == "table" then
for field, value in pairs(new) do
- (handlers[field] or rawset)(config, field, value);
+ -- exclude keys which are internal to the config builder
+ if field:sub(1, 1) ~= "_" then
+ (handlers[field] or rawset)(config, field, value);
+ end
end
end
+ return config
end
-- Finalize the config into the form LuaSec expects
local function final(config)
local output = { };
for field, value in pairs(config) do
- output[field] = (finalisers[field] or id)(value);
+ -- exclude keys which are internal to the config builder
+ if field:sub(1, 1) ~= "_" then
+ output[field] = (finalisers[field] or id)(value, config);
+ end
end
-- Need to handle protocols last because it adds to the options list
protocol(output);
return output;
end
+local function build(config)
+ local cached = rawget(config, "_cache");
+ if cached then
+ return cached, nil
+ end
+
+ local ctx, err = rawget(config, "_context_factory")(config:final(), config);
+ if ctx then
+ rawset(config, "_cache", ctx);
+ end
+ return ctx, err
+end
+
local sslopts_mt = {
__index = {
apply = apply;
final = final;
+ build = build;
};
+ __newindex = function()
+ error("SSL config objects cannot be modified directly. Use :apply()")
+ end;
};
-local function new()
- return setmetatable({options={}}, sslopts_mt);
+
+-- passing basedir through everything is required to avoid sslconfig depending
+-- on prosody.paths.config
+local function new(context_factory, basedir)
+ return setmetatable({
+ _context_factory = context_factory,
+ _basedir = basedir,
+ options={},
+ }, sslopts_mt);
end
+local function clone(config)
+ local result = new();
+ for k, v in pairs(config) do
+ -- note that we *do* copy the internal keys on clone -- we have to carry
+ -- both the factory and the cache with us
+ rawset(result, k, v);
+ end
+ return result
+end
+
+sslopts_mt.__index.clone = clone;
+
return {
apply = apply;
final = final;
- new = new;
+ _new = new;
};
diff --git a/util/stanza.lua b/util/stanza.lua
index 86b88169..0f8827d5 100644
--- a/util/stanza.lua
+++ b/util/stanza.lua
@@ -21,12 +21,15 @@ local type = type;
local s_gsub = string.gsub;
local s_sub = string.sub;
local s_find = string.find;
+local t_move = table.move or require "util.table".move;
+local t_create = require"util.table".create;
local valid_utf8 = require "util.encodings".utf8.valid;
local do_pretty_printing, termcolours = pcall(require, "util.termcolours");
local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas";
+local xmpp_stanzas_attr = { xmlns = xmlns_stanzas };
local _ENV = nil;
-- luacheck: std none
@@ -179,6 +182,14 @@ function stanza_mt:get_child_text(name, xmlns)
return nil;
end
+function stanza_mt:get_child_attr(name, xmlns, attr)
+ local tag = self:get_child(name, xmlns);
+ if tag then
+ return tag.attr[attr];
+ end
+ return nil;
+end
+
function stanza_mt:child_with_name(name)
for _, child in ipairs(self.tags) do
if child.name == name then return child; end
@@ -283,25 +294,33 @@ function stanza_mt:find(path)
end
local function _clone(stanza, only_top)
- local attr, tags = {}, {};
+ local attr = {};
for k,v in pairs(stanza.attr) do attr[k] = v; end
local old_namespaces, namespaces = stanza.namespaces;
if old_namespaces then
namespaces = {};
for k,v in pairs(old_namespaces) do namespaces[k] = v; end
end
- local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
+ local tags, new;
+ if only_top then
+ tags = {};
+ new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
+ else
+ tags = t_create(#stanza.tags, 0);
+ new = t_create(#stanza, 4);
+ new.name = stanza.name;
+ new.attr = attr;
+ new.namespaces = namespaces;
+ new.tags = tags;
+ end
+
+ setmetatable(new, stanza_mt);
if not only_top then
- for i=1,#stanza do
- local child = stanza[i];
- if child.name then
- child = _clone(child);
- t_insert(tags, child);
- end
- t_insert(new, child);
- end
+ t_move(stanza, 1, #stanza, 1, new);
+ t_move(stanza.tags, 1, #stanza.tags, 1, tags);
+ new:maptags(_clone);
end
- return setmetatable(new, stanza_mt);
+ return new;
end
local function clone(stanza, only_top)
@@ -387,6 +406,33 @@ function stanza_mt.get_error(stanza)
return error_type, condition or "undefined-condition", text, extra_tag;
end
+function stanza_mt.add_error(stanza, error_type, condition, error_message, error_by)
+ local extra;
+ if type(error_type) == "table" then -- an util.error or similar object
+ if type(error_type.extra) == "table" then
+ extra = error_type.extra;
+ end
+ if type(error_type.context) == "table" and type(error_type.context.by) == "string" then error_by = error_type.context.by; end
+ error_type, condition, error_message = error_type.type, error_type.condition, error_type.text;
+ end
+ if stanza.attr.from == error_by then
+ error_by = nil;
+ end
+ stanza:tag("error", {type = error_type, by = error_by}) --COMPAT: Some day xmlns:stanzas goes here
+ :tag(condition, xmpp_stanzas_attr);
+ if extra and condition == "gone" and type(extra.uri) == "string" then
+ stanza:text(extra.uri);
+ end
+ stanza:up();
+ if error_message then stanza:text_tag("text", error_message, xmpp_stanzas_attr); end
+ if extra and is_stanza(extra.tag) then
+ stanza:add_child(extra.tag);
+ elseif extra and extra.namespace and extra.condition then
+ stanza:tag(extra.condition, { xmlns = extra.namespace }):up();
+ end
+ return stanza:up();
+end
+
local function preserialize(stanza)
local s = { name = stanza.name, attr = stanza.attr };
for _, child in ipairs(stanza) do
@@ -461,7 +507,6 @@ local function reply(orig)
});
end
-local xmpp_stanzas_attr = { xmlns = xmlns_stanzas };
local function error_reply(orig, error_type, condition, error_message, error_by)
if not is_stanza(orig) then
error("bad argument to error_reply: expected stanza, got "..type(orig));
@@ -470,30 +515,9 @@ local function error_reply(orig, error_type, condition, error_message, error_by)
end
local t = reply(orig);
t.attr.type = "error";
- local extra;
- if type(error_type) == "table" then -- an util.error or similar object
- if type(error_type.extra) == "table" then
- extra = error_type.extra;
- end
- if type(error_type.context) == "table" and type(error_type.context.by) == "string" then error_by = error_type.context.by; end
- error_type, condition, error_message = error_type.type, error_type.condition, error_type.text;
- end
- if t.attr.from == error_by then
- error_by = nil;
- end
- t:tag("error", {type = error_type, by = error_by}) --COMPAT: Some day xmlns:stanzas goes here
- :tag(condition, xmpp_stanzas_attr);
- if extra and condition == "gone" and type(extra.uri) == "string" then
- t:text(extra.uri);
- end
- t:up();
- if error_message then t:text_tag("text", error_message, xmpp_stanzas_attr); end
- if extra and is_stanza(extra.tag) then
- t:add_child(extra.tag);
- elseif extra and extra.namespace and extra.condition then
- t:tag(extra.condition, { xmlns = extra.namespace }):up();
- end
- return t; -- stanza ready for adding app-specific errors
+ t:add_error(error_type, condition, error_message, error_by);
+ t.last_add = { t[1] }; -- ready to add application-specific errors
+ return t;
end
local function presence(attr)
diff --git a/util/startup.lua b/util/startup.lua
index 545b6ae7..8be54884 100644
--- a/util/startup.lua
+++ b/util/startup.lua
@@ -277,6 +277,11 @@ function startup.init_global_state()
startup.detect_platform();
startup.detect_installed();
_G.prosody = prosody;
+
+ -- COMPAT Lua < 5.3
+ if not math.type then
+ require "util.mathcompat"
+ end
end
function startup.setup_datadir()
diff --git a/util/vcard.lua b/util/vcard.lua
deleted file mode 100644
index e311f73f..00000000
--- a/util/vcard.lua
+++ /dev/null
@@ -1,574 +0,0 @@
--- Copyright (C) 2011-2014 Kim Alvefur
---
--- This project is MIT/X11 licensed. Please see the
--- COPYING file in the source package for more information.
---
-
--- TODO
--- Fix folding.
-
-local st = require "util.stanza";
-local t_insert, t_concat = table.insert, table.concat;
-local type = type;
-local pairs, ipairs = pairs, ipairs;
-
-local from_text, to_text, from_xep54, to_xep54;
-
-local line_sep = "\n";
-
-local vCard_dtd; -- See end of file
-local vCard4_dtd;
-
-local function vCard_esc(s)
- return s:gsub("[,:;\\]", "\\%1"):gsub("\n","\\n");
-end
-
-local function vCard_unesc(s)
- return s:gsub("\\?[\\nt:;,]", {
- ["\\\\"] = "\\",
- ["\\n"] = "\n",
- ["\\r"] = "\r",
- ["\\t"] = "\t",
- ["\\:"] = ":", -- FIXME Shouldn't need to escape : in values, just params
- ["\\;"] = ";",
- ["\\,"] = ",",
- [":"] = "\29",
- [";"] = "\30",
- [","] = "\31",
- });
-end
-
-local function item_to_xep54(item)
- local t = st.stanza(item.name, { xmlns = "vcard-temp" });
-
- local prop_def = vCard_dtd[item.name];
- if prop_def == "text" then
- t:text(item[1]);
- elseif type(prop_def) == "table" then
- if prop_def.types and item.TYPE then
- if type(item.TYPE) == "table" then
- for _,v in pairs(prop_def.types) do
- for _,typ in pairs(item.TYPE) do
- if typ:upper() == v then
- t:tag(v):up();
- break;
- end
- end
- end
- else
- t:tag(item.TYPE:upper()):up();
- end
- end
-
- if prop_def.props then
- for _,prop in pairs(prop_def.props) do
- if item[prop] then
- for _, v in ipairs(item[prop]) do
- t:text_tag(prop, v);
- end
- end
- end
- end
-
- if prop_def.value then
- t:text_tag(prop_def.value, item[1]);
- elseif prop_def.values then
- local prop_def_values = prop_def.values;
- local repeat_last = prop_def_values.behaviour == "repeat-last" and prop_def_values[#prop_def_values];
- for i=1,#item do
- t:text_tag(prop_def.values[i] or repeat_last, item[i]);
- end
- end
- end
-
- return t;
-end
-
-local function vcard_to_xep54(vCard)
- local t = st.stanza("vCard", { xmlns = "vcard-temp" });
- for i=1,#vCard do
- t:add_child(item_to_xep54(vCard[i]));
- end
- return t;
-end
-
-function to_xep54(vCards)
- if not vCards[1] or vCards[1].name then
- return vcard_to_xep54(vCards)
- else
- local t = st.stanza("xCard", { xmlns = "vcard-temp" });
- for i=1,#vCards do
- t:add_child(vcard_to_xep54(vCards[i]));
- end
- return t;
- end
-end
-
-function from_text(data)
- data = data -- unfold and remove empty lines
- :gsub("\r\n","\n")
- :gsub("\n ", "")
- :gsub("\n\n+","\n");
- local vCards = {};
- local current;
- for line in data:gmatch("[^\n]+") do
- line = vCard_unesc(line);
- local name, params, value = line:match("^([-%a]+)(\30?[^\29]*)\29(.*)$");
- value = value:gsub("\29",":");
- if #params > 0 then
- local _params = {};
- for k,isval,v in params:gmatch("\30([^=]+)(=?)([^\30]*)") do
- k = k:upper();
- local _vt = {};
- for _p in v:gmatch("[^\31]+") do
- _vt[#_vt+1]=_p
- _vt[_p]=true;
- end
- if isval == "=" then
- _params[k]=_vt;
- else
- _params[k]=true;
- end
- end
- params = _params;
- end
- if name == "BEGIN" and value == "VCARD" then
- current = {};
- vCards[#vCards+1] = current;
- elseif name == "END" and value == "VCARD" then
- current = nil;
- elseif current and vCard_dtd[name] then
- local dtd = vCard_dtd[name];
- local item = { name = name };
- t_insert(current, item);
- local up = current;
- current = item;
- if dtd.types then
- for _, t in ipairs(dtd.types) do
- t = t:lower();
- if ( params.TYPE and params.TYPE[t] == true)
- or params[t] == true then
- current.TYPE=t;
- end
- end
- end
- if dtd.props then
- for _, p in ipairs(dtd.props) do
- if params[p] then
- if params[p] == true then
- current[p]=true;
- else
- for _, prop in ipairs(params[p]) do
- current[p]=prop;
- end
- end
- end
- end
- end
- if dtd == "text" or dtd.value then
- t_insert(current, value);
- elseif dtd.values then
- for p in ("\30"..value):gmatch("\30([^\30]*)") do
- t_insert(current, p);
- end
- end
- current = up;
- end
- end
- return vCards;
-end
-
-local function item_to_text(item)
- local value = {};
- for i=1,#item do
- value[i] = vCard_esc(item[i]);
- end
- value = t_concat(value, ";");
-
- local params = "";
- for k,v in pairs(item) do
- if type(k) == "string" and k ~= "name" then
- params = params .. (";%s=%s"):format(k, type(v) == "table" and t_concat(v,",") or v);
- end
- end
-
- return ("%s%s:%s"):format(item.name, params, value)
-end
-
-local function vcard_to_text(vcard)
- local t={};
- t_insert(t, "BEGIN:VCARD")
- for i=1,#vcard do
- t_insert(t, item_to_text(vcard[i]));
- end
- t_insert(t, "END:VCARD")
- return t_concat(t, line_sep);
-end
-
-function to_text(vCards)
- if vCards[1] and vCards[1].name then
- return vcard_to_text(vCards)
- else
- local t = {};
- for i=1,#vCards do
- t[i]=vcard_to_text(vCards[i]);
- end
- return t_concat(t, line_sep);
- end
-end
-
-local function from_xep54_item(item)
- local prop_name = item.name;
- local prop_def = vCard_dtd[prop_name];
-
- local prop = { name = prop_name };
-
- if prop_def == "text" then
- prop[1] = item:get_text();
- elseif type(prop_def) == "table" then
- if prop_def.value then --single item
- prop[1] = item:get_child_text(prop_def.value) or "";
- elseif prop_def.values then --array
- local value_names = prop_def.values;
- if value_names.behaviour == "repeat-last" then
- for i=1,#item.tags do
- t_insert(prop, item.tags[i]:get_text() or "");
- end
- else
- for i=1,#value_names do
- t_insert(prop, item:get_child_text(value_names[i]) or "");
- end
- end
- elseif prop_def.names then
- local names = prop_def.names;
- for i=1,#names do
- if item:get_child(names[i]) then
- prop[1] = names[i];
- break;
- end
- end
- end
-
- if prop_def.props_verbatim then
- for k,v in pairs(prop_def.props_verbatim) do
- prop[k] = v;
- end
- end
-
- if prop_def.types then
- local types = prop_def.types;
- prop.TYPE = {};
- for i=1,#types do
- if item:get_child(types[i]) then
- t_insert(prop.TYPE, types[i]:lower());
- end
- end
- if #prop.TYPE == 0 then
- prop.TYPE = nil;
- end
- end
-
- -- A key-value pair, within a key-value pair?
- if prop_def.props then
- local params = prop_def.props;
- for i=1,#params do
- local name = params[i]
- local data = item:get_child_text(name);
- if data then
- prop[name] = prop[name] or {};
- t_insert(prop[name], data);
- end
- end
- end
- else
- return nil
- end
-
- return prop;
-end
-
-local function from_xep54_vCard(vCard)
- local tags = vCard.tags;
- local t = {};
- for i=1,#tags do
- t_insert(t, from_xep54_item(tags[i]));
- end
- return t
-end
-
-function from_xep54(vCard)
- if vCard.attr.xmlns ~= "vcard-temp" then
- return nil, "wrong-xmlns";
- end
- if vCard.name == "xCard" then -- A collection of vCards
- local t = {};
- local vCards = vCard.tags;
- for i=1,#vCards do
- t[i] = from_xep54_vCard(vCards[i]);
- end
- return t
- elseif vCard.name == "vCard" then -- A single vCard
- return from_xep54_vCard(vCard)
- end
-end
-
-local vcard4 = { }
-
-function vcard4:text(node, params, value) -- luacheck: ignore 212/params
- self:tag(node:lower())
- -- FIXME params
- if type(value) == "string" then
- self:text_tag("text", value);
- elseif vcard4[node] then
- vcard4[node](value);
- end
- self:up();
-end
-
-function vcard4.N(value)
- for i, k in ipairs(vCard_dtd.N.values) do
- value:text_tag(k, value[i]);
- end
-end
-
-local xmlns_vcard4 = "urn:ietf:params:xml:ns:vcard-4.0"
-
-local function item_to_vcard4(item)
- local typ = item.name:lower();
- local t = st.stanza(typ, { xmlns = xmlns_vcard4 });
-
- local prop_def = vCard4_dtd[typ];
- if prop_def == "text" then
- t:text_tag("text", item[1]);
- elseif prop_def == "uri" then
- if item.ENCODING and item.ENCODING[1] == 'b' then
- t:text_tag("uri", "data:;base64," .. item[1]);
- else
- t:text_tag("uri", item[1]);
- end
- elseif type(prop_def) == "table" then
- if prop_def.values then
- for i, v in ipairs(prop_def.values) do
- t:text_tag(v:lower(), item[i]);
- end
- else
- t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"})
- end
- else
- t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"})
- end
- return t;
-end
-
-local function vcard_to_vcard4xml(vCard)
- local t = st.stanza("vcard", { xmlns = xmlns_vcard4 });
- for i=1,#vCard do
- t:add_child(item_to_vcard4(vCard[i]));
- end
- return t;
-end
-
-local function vcards_to_vcard4xml(vCards)
- if not vCards[1] or vCards[1].name then
- return vcard_to_vcard4xml(vCards)
- else
- local t = st.stanza("vcards", { xmlns = xmlns_vcard4 });
- for i=1,#vCards do
- t:add_child(vcard_to_vcard4xml(vCards[i]));
- end
- return t;
- end
-end
-
--- This was adapted from http://xmpp.org/extensions/xep-0054.html#dtd
-vCard_dtd = {
- VERSION = "text", --MUST be 3.0, so parsing is redundant
- FN = "text",
- N = {
- values = {
- "FAMILY",
- "GIVEN",
- "MIDDLE",
- "PREFIX",
- "SUFFIX",
- },
- },
- NICKNAME = "text",
- PHOTO = {
- props_verbatim = { ENCODING = { "b" } },
- props = { "TYPE" },
- value = "BINVAL", --{ "EXTVAL", },
- },
- BDAY = "text",
- ADR = {
- types = {
- "HOME",
- "WORK",
- "POSTAL",
- "PARCEL",
- "DOM",
- "INTL",
- "PREF",
- },
- values = {
- "POBOX",
- "EXTADD",
- "STREET",
- "LOCALITY",
- "REGION",
- "PCODE",
- "CTRY",
- }
- },
- LABEL = {
- types = {
- "HOME",
- "WORK",
- "POSTAL",
- "PARCEL",
- "DOM",
- "INTL",
- "PREF",
- },
- value = "LINE",
- },
- TEL = {
- types = {
- "HOME",
- "WORK",
- "VOICE",
- "FAX",
- "PAGER",
- "MSG",
- "CELL",
- "VIDEO",
- "BBS",
- "MODEM",
- "ISDN",
- "PCS",
- "PREF",
- },
- value = "NUMBER",
- },
- EMAIL = {
- types = {
- "HOME",
- "WORK",
- "INTERNET",
- "PREF",
- "X400",
- },
- value = "USERID",
- },
- JABBERID = "text",
- MAILER = "text",
- TZ = "text",
- GEO = {
- values = {
- "LAT",
- "LON",
- },
- },
- TITLE = "text",
- ROLE = "text",
- LOGO = "copy of PHOTO",
- AGENT = "text",
- ORG = {
- values = {
- behaviour = "repeat-last",
- "ORGNAME",
- "ORGUNIT",
- }
- },
- CATEGORIES = {
- values = "KEYWORD",
- },
- NOTE = "text",
- PRODID = "text",
- REV = "text",
- SORTSTRING = "text",
- SOUND = "copy of PHOTO",
- UID = "text",
- URL = "text",
- CLASS = {
- names = { -- The item.name is the value if it's one of these.
- "PUBLIC",
- "PRIVATE",
- "CONFIDENTIAL",
- },
- },
- KEY = {
- props = { "TYPE" },
- value = "CRED",
- },
- DESC = "text",
-};
-vCard_dtd.LOGO = vCard_dtd.PHOTO;
-vCard_dtd.SOUND = vCard_dtd.PHOTO;
-
-vCard4_dtd = {
- source = "uri",
- kind = "text",
- xml = "text",
- fn = "text",
- n = {
- values = {
- "family",
- "given",
- "middle",
- "prefix",
- "suffix",
- },
- },
- nickname = "text",
- photo = "uri",
- bday = "date-and-or-time",
- anniversary = "date-and-or-time",
- gender = "text",
- adr = {
- values = {
- "pobox",
- "ext",
- "street",
- "locality",
- "region",
- "code",
- "country",
- }
- },
- tel = "text",
- email = "text",
- impp = "uri",
- lang = "language-tag",
- tz = "text",
- geo = "uri",
- title = "text",
- role = "text",
- logo = "uri",
- org = "text",
- member = "uri",
- related = "uri",
- categories = "text",
- note = "text",
- prodid = "text",
- rev = "timestamp",
- sound = "uri",
- uid = "uri",
- clientpidmap = "number, uuid",
- url = "uri",
- version = "text",
- key = "uri",
- fburl = "uri",
- caladruri = "uri",
- caluri = "uri",
-};
-
-return {
- from_text = from_text;
- to_text = to_text;
-
- from_xep54 = from_xep54;
- to_xep54 = to_xep54;
-
- to_vcard4 = vcards_to_vcard4xml;
-};
diff --git a/util/watchdog.lua b/util/watchdog.lua
index 516e60e4..407028a5 100644
--- a/util/watchdog.lua
+++ b/util/watchdog.lua
@@ -1,6 +1,5 @@
local timer = require "util.timer";
local setmetatable = setmetatable;
-local os_time = os.time;
local _ENV = nil;
-- luacheck: std none
@@ -9,27 +8,35 @@ local watchdog_methods = {};
local watchdog_mt = { __index = watchdog_methods };
local function new(timeout, callback)
- local watchdog = setmetatable({ timeout = timeout, last_reset = os_time(), callback = callback }, watchdog_mt);
- timer.add_task(timeout+1, function (current_time)
- local last_reset = watchdog.last_reset;
- if not last_reset then
- return;
- end
- local time_left = (last_reset + timeout) - current_time;
- if time_left < 0 then
- return watchdog:callback();
- end
- return time_left + 1;
- end);
+ local watchdog = setmetatable({
+ timeout = timeout;
+ callback = callback;
+ timer_id = nil;
+ }, watchdog_mt);
+
+ watchdog:reset(); -- Kick things off
+
return watchdog;
end
-function watchdog_methods:reset()
- self.last_reset = os_time();
+function watchdog_methods:reset(new_timeout)
+ if new_timeout then
+ self.timeout = new_timeout;
+ end
+ if self.timer_id then
+ timer.reschedule(self.timer_id, self.timeout+1);
+ else
+ self.timer_id = timer.add_task(self.timeout+1, function ()
+ return self:callback();
+ end);
+ end
end
function watchdog_methods:cancel()
- self.last_reset = nil;
+ if self.timer_id then
+ timer.stop(self.timer_id);
+ self.timer_id = nil;
+ end
end
return {
diff --git a/util/x509.lua b/util/x509.lua
index 76b50076..51ca3c96 100644
--- a/util/x509.lua
+++ b/util/x509.lua
@@ -11,12 +11,12 @@
-- IDN libraries complicate that.
--- [TLS-CERTS] - http://tools.ietf.org/html/rfc6125
--- [XMPP-CORE] - http://tools.ietf.org/html/rfc6120
--- [SRV-ID] - http://tools.ietf.org/html/rfc4985
--- [IDNA] - http://tools.ietf.org/html/rfc5890
--- [LDAP] - http://tools.ietf.org/html/rfc4519
--- [PKIX] - http://tools.ietf.org/html/rfc5280
+-- [TLS-CERTS] - https://www.rfc-editor.org/rfc/rfc6125.html
+-- [XMPP-CORE] - https://www.rfc-editor.org/rfc/rfc6120.html
+-- [SRV-ID] - https://www.rfc-editor.org/rfc/rfc4985.html
+-- [IDNA] - https://www.rfc-editor.org/rfc/rfc5890.html
+-- [LDAP] - https://www.rfc-editor.org/rfc/rfc4519.html
+-- [PKIX] - https://www.rfc-editor.org/rfc/rfc5280.html
local nameprep = require "util.encodings".stringprep.nameprep;
local idna_to_ascii = require "util.encodings".idna.to_ascii;