aboutsummaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/datamanager.lua2
-rw-r--r--util/error.lua52
-rw-r--r--util/format.lua18
-rw-r--r--util/hashring.lua88
-rw-r--r--util/hmac.lua9
-rw-r--r--util/http.lua22
-rw-r--r--util/import.lua2
-rw-r--r--util/iterators.lua6
-rw-r--r--util/multitable.lua2
-rw-r--r--util/promise.lua3
-rw-r--r--util/prosodyctl.lua7
-rw-r--r--util/queue.lua12
-rw-r--r--util/serialization.lua27
-rw-r--r--util/session.lua3
-rw-r--r--util/stanza.lua70
-rw-r--r--util/startup.lua14
-rw-r--r--util/x509.lua28
17 files changed, 278 insertions, 87 deletions
diff --git a/util/datamanager.lua b/util/datamanager.lua
index cf96887b..b52c77fa 100644
--- a/util/datamanager.lua
+++ b/util/datamanager.lua
@@ -24,7 +24,7 @@ local t_concat = table.concat;
local envloadfile = require"util.envload".envloadfile;
local serialize = require "util.serialization".serialize;
local lfs = require "lfs";
--- Extract directory seperator from package.config (an undocumented string that comes with lua)
+-- Extract directory separator from package.config (an undocumented string that comes with lua)
local path_separator = assert ( package.config:match ( "^([^\n]+)" ) , "package.config not in standard form" )
local prosody = prosody;
diff --git a/util/error.lua b/util/error.lua
new file mode 100644
index 00000000..344dd274
--- /dev/null
+++ b/util/error.lua
@@ -0,0 +1,52 @@
+local error_mt = { __name = "error" };
+
+function error_mt:__tostring()
+ return ("error<%s:%s:%s>"):format(self.type, self.condition, self.text);
+end
+
+local function is_err(e)
+ return getmetatable(e) == error_mt;
+end
+
+local function new(e, context, registry)
+ local template = (registry and registry[e]) or e or {};
+ return setmetatable({
+ type = template.type or "cancel";
+ condition = template.condition or "undefined-condition";
+ text = template.text;
+
+ context = context or template.context or { _error_id = e };
+ }, error_mt);
+end
+
+local function coerce(ok, err, ...)
+ if ok or is_err(err) then
+ return ok, err, ...;
+ end
+
+ local new_err = setmetatable({
+ native = err;
+
+ type = "cancel";
+ condition = "undefined-condition";
+ }, error_mt);
+ return ok, new_err, ...;
+end
+
+local function from_stanza(stanza, context)
+ local error_type, condition, text = stanza:get_error();
+ return setmetatable({
+ type = error_type or "cancel";
+ condition = condition or "undefined-condition";
+ text = text;
+
+ context = context or { stanza = stanza };
+ }, error_mt);
+end
+
+return {
+ new = new;
+ coerce = coerce;
+ is_err = is_err;
+ from_stanza = from_stanza;
+}
diff --git a/util/format.lua b/util/format.lua
index c5e513fa..c31f599f 100644
--- a/util/format.lua
+++ b/util/format.lua
@@ -3,12 +3,14 @@
--
local tostring = tostring;
-local select = select;
local unpack = table.unpack or unpack; -- luacheck: ignore 113/unpack
+local pack = require "util.table".pack; -- TODO table.pack in 5.2+
local type = type;
+local dump = require "util.serialization".new("debug");
local function format(formatstring, ...)
- local args, args_length = { ... }, select('#', ...);
+ local args = pack(...);
+ local args_length = args.n;
-- format specifier spec:
-- 1. Start: '%%'
@@ -28,13 +30,15 @@ local function format(formatstring, ...)
if spec ~= "%%" then
i = i + 1;
local arg = args[i];
- if arg == nil then -- special handling for nil
- arg = "<nil>"
- args[i] = "<nil>";
- end
local option = spec:sub(-1);
- if option == "q" or option == "s" then -- arg should be string
+ if arg == nil then
+ args[i] = "nil";
+ spec = "<%s>";
+ elseif option == "q" then
+ args[i] = dump(arg);
+ spec = "%s";
+ elseif option == "s" then
args[i] = tostring(arg);
elseif type(arg) ~= "number" then -- arg isn't number as expected?
args[i] = tostring(arg);
diff --git a/util/hashring.lua b/util/hashring.lua
new file mode 100644
index 00000000..322bc005
--- /dev/null
+++ b/util/hashring.lua
@@ -0,0 +1,88 @@
+local function generate_ring(nodes, num_replicas, hash)
+ local new_ring = {};
+ for _, node_name in ipairs(nodes) do
+ for replica = 1, num_replicas do
+ local replica_hash = hash(node_name..":"..replica);
+ new_ring[replica_hash] = node_name;
+ table.insert(new_ring, replica_hash);
+ end
+ end
+ table.sort(new_ring);
+ return new_ring;
+end
+
+local hashring_methods = {};
+local hashring_mt = {
+ __index = function (self, k)
+ -- Automatically build self.ring if it's missing
+ if k == "ring" then
+ local ring = generate_ring(self.nodes, self.num_replicas, self.hash);
+ rawset(self, "ring", ring);
+ return ring;
+ end
+ return rawget(hashring_methods, k);
+ end
+};
+
+local function new(num_replicas, hash_function)
+ return setmetatable({ nodes = {}, num_replicas = num_replicas, hash = hash_function }, hashring_mt);
+end;
+
+function hashring_methods:add_node(name)
+ self.ring = nil;
+ self.nodes[name] = true;
+ table.insert(self.nodes, name);
+ return true;
+end
+
+function hashring_methods:add_nodes(nodes)
+ self.ring = nil;
+ for _, node_name in ipairs(nodes) do
+ if not self.nodes[node_name] then
+ self.nodes[node_name] = true;
+ table.insert(self.nodes, node_name);
+ end
+ end
+ return true;
+end
+
+function hashring_methods:remove_node(node_name)
+ self.ring = nil;
+ if self.nodes[node_name] then
+ for i, stored_node_name in ipairs(self.nodes) do
+ if node_name == stored_node_name then
+ self.nodes[node_name] = nil;
+ table.remove(self.nodes, i);
+ return true;
+ end
+ end
+ end
+ return false;
+end
+
+function hashring_methods:remove_nodes(nodes)
+ self.ring = nil;
+ for _, node_name in ipairs(nodes) do
+ self:remove_node(node_name);
+ end
+end
+
+function hashring_methods:clone()
+ local clone_hashring = new(self.num_replicas, self.hash);
+ clone_hashring:add_nodes(self.nodes);
+ return clone_hashring;
+end
+
+function hashring_methods:get_node(key)
+ local key_hash = self.hash(key);
+ for _, replica_hash in ipairs(self.ring) do
+ if key_hash < replica_hash then
+ return self.ring[replica_hash];
+ end
+ end
+ return self.ring[self.ring[1]];
+end
+
+return {
+ new = new;
+}
diff --git a/util/hmac.lua b/util/hmac.lua
index 2c4cc6ef..4cad17cc 100644
--- a/util/hmac.lua
+++ b/util/hmac.lua
@@ -10,6 +10,9 @@
local hashes = require "util.hashes"
-return { md5 = hashes.hmac_md5,
- sha1 = hashes.hmac_sha1,
- sha256 = hashes.hmac_sha256 };
+return {
+ md5 = hashes.hmac_md5,
+ sha1 = hashes.hmac_sha1,
+ sha256 = hashes.hmac_sha256,
+ sha512 = hashes.hmac_sha512,
+};
diff --git a/util/http.lua b/util/http.lua
index cfb89193..3852f91c 100644
--- a/util/http.lua
+++ b/util/http.lua
@@ -6,24 +6,26 @@
--
local format, char = string.format, string.char;
-local pairs, ipairs, tonumber = pairs, ipairs, tonumber;
+local pairs, ipairs = pairs, ipairs;
local t_insert, t_concat = table.insert, table.concat;
+local url_codes = {};
+for i = 0, 255 do
+ local c = char(i);
+ local u = format("%%%02x", i);
+ url_codes[c] = u;
+ url_codes[u] = c;
+ url_codes[u:upper()] = c;
+end
local function urlencode(s)
- return s and (s:gsub("[^a-zA-Z0-9.~_-]", function (c) return format("%%%02x", c:byte()); end));
+ return s and (s:gsub("[^a-zA-Z0-9.~_-]", url_codes));
end
local function urldecode(s)
- return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end));
+ return s and (s:gsub("%%%x%x", url_codes));
end
local function _formencodepart(s)
- return s and (s:gsub("%W", function (c)
- if c ~= " " then
- return format("%%%02x", c:byte());
- else
- return "+";
- end
- end));
+ return s and (urlencode(s):gsub("%%20", "+"));
end
local function formencode(form)
diff --git a/util/import.lua b/util/import.lua
index 8ecfe43c..1007bc0a 100644
--- a/util/import.lua
+++ b/util/import.lua
@@ -8,7 +8,7 @@
-local unpack = table.unpack or unpack; --luacheck: ignore 113 143
+local unpack = table.unpack or unpack; --luacheck: ignore 113
local t_insert = table.insert;
function _G.import(module, ...)
local m = package.loaded[module] or require(module);
diff --git a/util/iterators.lua b/util/iterators.lua
index 302cca36..c03c2fd6 100644
--- a/util/iterators.lua
+++ b/util/iterators.lua
@@ -11,9 +11,9 @@
local it = {};
local t_insert = table.insert;
-local select, next = select, next;
-local unpack = table.unpack or unpack; --luacheck: ignore 113 143
-local pack = table.pack or function (...) return { n = select("#", ...), ... }; end -- luacheck: ignore 143
+local next = next;
+local unpack = table.unpack or unpack; --luacheck: ignore 113
+local pack = table.pack or require "util.table".pack;
local type = type;
local table, setmetatable = table, setmetatable;
diff --git a/util/multitable.lua b/util/multitable.lua
index 8d32ed8a..4f2cd972 100644
--- a/util/multitable.lua
+++ b/util/multitable.lua
@@ -9,7 +9,7 @@
local select = select;
local t_insert = table.insert;
local pairs, next, type = pairs, next, type;
-local unpack = table.unpack or unpack; --luacheck: ignore 113 143
+local unpack = table.unpack or unpack; --luacheck: ignore 113
local _ENV = nil;
-- luacheck: std none
diff --git a/util/promise.lua b/util/promise.lua
index 07c9c4dc..0b182b54 100644
--- a/util/promise.lua
+++ b/util/promise.lua
@@ -49,6 +49,9 @@ local function promise_settle(promise, new_state, new_next, cbs, value)
for _, cb in ipairs(cbs) do
cb(value);
end
+ -- No need to keep references to callbacks
+ promise._pending_on_fulfilled = nil;
+ promise._pending_on_rejected = nil;
return true;
end
diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua
index 5f0c4d12..9b627bde 100644
--- a/util/prosodyctl.lua
+++ b/util/prosodyctl.lua
@@ -229,7 +229,8 @@ local function isrunning()
return true, signal.kill(pid, 0) == 0;
end
-local function start(source_dir)
+local function start(source_dir, lua)
+ lua = lua and lua .. " " or "";
local ok, ret = isrunning();
if not ok then
return ok, ret;
@@ -238,9 +239,9 @@ local function start(source_dir)
return false, "already-running";
end
if not source_dir then
- os.execute("./prosody");
+ os.execute(lua .. "./prosody");
else
- os.execute(source_dir.."/../../bin/prosody");
+ os.execute(lua .. source_dir.."/../../bin/prosody");
end
return true;
end
diff --git a/util/queue.lua b/util/queue.lua
index 728e905f..66ed098b 100644
--- a/util/queue.lua
+++ b/util/queue.lua
@@ -52,18 +52,20 @@ local function new(size, allow_wrapping)
return t[tail];
end;
items = function (self)
- --luacheck: ignore 431/t
- return function (t, pos)
- if pos >= t:count() then
+ return function (_, pos)
+ if pos >= items then
return nil;
end
local read_pos = tail + pos;
- if read_pos > t.size then
+ if read_pos > self.size then
read_pos = (read_pos%size);
end
- return pos+1, t._items[read_pos];
+ return pos+1, t[read_pos];
end, self, 0;
end;
+ consume = function (self)
+ return self.pop, self;
+ end;
};
end
diff --git a/util/serialization.lua b/util/serialization.lua
index dd6a2a2b..60e341cf 100644
--- a/util/serialization.lua
+++ b/util/serialization.lua
@@ -16,22 +16,18 @@ local s_char = string.char;
local s_match = string.match;
local t_concat = table.concat;
+local to_hex = require "util.hex".to;
+
local pcall = pcall;
local envload = require"util.envload".envload;
local pos_inf, neg_inf = math.huge, -math.huge;
--- luacheck: ignore 143/math
local m_type = math.type or function (n)
return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float";
end;
-local char_to_hex = {};
-for i = 0,255 do
- char_to_hex[s_char(i)] = s_format("%02x", i);
-end
-
-local function to_hex(s)
- return (s_gsub(s, ".", char_to_hex));
+local function rawpairs(t)
+ return next, t, nil;
end
local function fatal_error(obj, why)
@@ -123,6 +119,7 @@ local function new(opt)
local freeze = opt.freeze;
local maxdepth = opt.maxdepth or 127;
local multirefs = opt.multiref;
+ local table_pairs = opt.table_iterator or rawpairs;
-- serialize one table, recursively
-- t - table being serialized
@@ -164,7 +161,9 @@ local function new(opt)
local indent = s_rep(indentwith, d);
local numkey = 1;
local ktyp, vtyp;
- for k,v in next,t do
+ local had_items = false;
+ for k,v in table_pairs(t) do
+ had_items = true;
o[l], l = itemstart, l + 1;
o[l], l = indent, l + 1;
ktyp, vtyp = type(k), type(v);
@@ -195,14 +194,10 @@ local function new(opt)
else
o[l], l = ser(v), l + 1;
end
- -- last item?
- if next(t, k) ~= nil then
- o[l], l = itemsep, l + 1;
- else
- o[l], l = itemlast, l + 1;
- end
+ o[l], l = itemsep, l + 1;
end
- if next(t) ~= nil then
+ if had_items then
+ o[l - 1] = itemlast;
o[l], l = s_rep(indentwith, d-1), l + 1;
end
o[l], l = tend, l +1;
diff --git a/util/session.lua b/util/session.lua
index b2a726ce..b9c6bec7 100644
--- a/util/session.lua
+++ b/util/session.lua
@@ -4,12 +4,13 @@ local logger = require "util.logger";
local function new_session(typ)
local session = {
type = typ .. "_unauthed";
+ base_type = typ;
};
return session;
end
local function set_id(session)
- local id = session.type .. tostring(session):match("%x+$"):lower();
+ local id = session.base_type .. tostring(session):match("%x+$"):lower();
session.id = id;
return session;
end
diff --git a/util/stanza.lua b/util/stanza.lua
index a90d56b3..7fe5c7ae 100644
--- a/util/stanza.lua
+++ b/util/stanza.lua
@@ -270,6 +270,34 @@ function stanza_mt:find(path)
until not self
end
+local function _clone(stanza, only_top)
+ local attr, tags = {}, {};
+ for k,v in pairs(stanza.attr) do attr[k] = v; end
+ local old_namespaces, namespaces = stanza.namespaces;
+ if old_namespaces then
+ namespaces = {};
+ for k,v in pairs(old_namespaces) do namespaces[k] = v; end
+ end
+ local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
+ if not only_top then
+ for i=1,#stanza do
+ local child = stanza[i];
+ if child.name then
+ child = _clone(child);
+ t_insert(tags, child);
+ end
+ t_insert(new, child);
+ end
+ end
+ return setmetatable(new, stanza_mt);
+end
+
+local function clone(stanza, only_top)
+ if not is_stanza(stanza) then
+ error("bad argument to clone: expected stanza, got "..type(stanza));
+ end
+ return _clone(stanza, only_top);
+end
local escape_table = { ["'"] = "&apos;", ["\""] = "&quot;", ["<"] = "&lt;", [">"] = "&gt;", ["&"] = "&amp;" };
local function xml_escape(str) return (s_gsub(str, "['&<>\"]", escape_table)); end
@@ -310,11 +338,8 @@ function stanza_mt.__tostring(t)
end
function stanza_mt.top_tag(t)
- local attr_string = "";
- if t.attr then
- for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(" %s='%s'", k, xml_escape(tostring(v))); end end
- end
- return s_format("<%s%s>", t.name, attr_string);
+ local top_tag_clone = clone(t, true);
+ return tostring(top_tag_clone):sub(1,-3)..">";
end
function stanza_mt.get_text(t)
@@ -388,33 +413,6 @@ local function deserialize(serialized)
end
end
-local function _clone(stanza)
- local attr, tags = {}, {};
- for k,v in pairs(stanza.attr) do attr[k] = v; end
- local old_namespaces, namespaces = stanza.namespaces;
- if old_namespaces then
- namespaces = {};
- for k,v in pairs(old_namespaces) do namespaces[k] = v; end
- end
- local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
- for i=1,#stanza do
- local child = stanza[i];
- if child.name then
- child = _clone(child);
- t_insert(tags, child);
- end
- t_insert(new, child);
- end
- return setmetatable(new, stanza_mt);
-end
-
-local function clone(stanza)
- if not is_stanza(stanza) then
- error("bad argument to clone: expected stanza, got "..type(stanza));
- end
- return _clone(stanza);
-end
-
local function message(attr, body)
if not body then
return new_stanza("message", attr);
@@ -423,9 +421,15 @@ local function message(attr, body)
end
end
local function iq(attr)
- if not (attr and attr.id) then
+ if not attr then
+ error("iq stanzas require id and type attributes");
+ end
+ if not attr.id then
error("iq stanzas require an id attribute");
end
+ if not attr.type then
+ error("iq stanzas require a type attribute");
+ end
return new_stanza("iq", attr);
end
diff --git a/util/startup.lua b/util/startup.lua
index c101c290..7a1a95aa 100644
--- a/util/startup.lua
+++ b/util/startup.lua
@@ -7,6 +7,7 @@ local logger = require "util.logger";
local log = logger.init("startup");
local config = require "core.configmanager";
+local config_warnings;
local dependencies = require "util.dependencies";
@@ -64,6 +65,8 @@ function startup.read_config()
print("**************************");
print("");
os.exit(1);
+ elseif err and #err > 0 then
+ config_warnings = err;
end
prosody.config_loaded = true;
end
@@ -96,8 +99,13 @@ function startup.init_logging()
end);
end
-function startup.log_dependency_warnings()
+function startup.log_startup_warnings()
dependencies.log_warnings();
+ if config_warnings then
+ for _, warning in ipairs(config_warnings) do
+ log("warn", "Configuration warning: %s", warning);
+ end
+ end
end
function startup.sanity_check()
@@ -518,7 +526,7 @@ function startup.prosodyctl()
startup.read_version();
startup.switch_user();
startup.check_dependencies();
- startup.log_dependency_warnings();
+ startup.log_startup_warnings();
startup.check_unwriteable();
startup.load_libraries();
startup.init_http_client();
@@ -543,7 +551,7 @@ function startup.prosody()
startup.add_global_prosody_functions();
startup.read_version();
startup.log_greeting();
- startup.log_dependency_warnings();
+ startup.log_startup_warnings();
startup.load_secondary_libraries();
startup.init_http_client();
startup.init_data_store();
diff --git a/util/x509.lua b/util/x509.lua
index 15cc4d3c..1cdf07dc 100644
--- a/util/x509.lua
+++ b/util/x509.lua
@@ -20,6 +20,7 @@
local nameprep = require "util.encodings".stringprep.nameprep;
local idna_to_ascii = require "util.encodings".idna.to_ascii;
+local idna_to_unicode = require "util.encodings".idna.to_unicode;
local base64 = require "util.encodings".base64;
local log = require "util.logger".init("x509");
local s_format = string.format;
@@ -216,6 +217,32 @@ local function verify_identity(host, service, cert)
return false
end
+-- TODO Support other SANs
+local function get_identities(cert) --> set of names
+ if cert.setencode then
+ cert:setencode("utf8");
+ end
+
+ local names = {};
+
+ local ext = cert:extensions();
+ local sans = ext[oid_subjectaltname];
+ if sans and sans["dNSName"] then
+ for i = 1, #sans["dNSName"] do
+ names[ idna_to_unicode(sans["dNSName"][i]) ] = true;
+ end
+ end
+
+ local subject = cert:subject();
+ for i = 1, #subject do
+ local dn = subject[i];
+ if dn.oid == oid_commonname and nameprep(dn.value) then
+ names[dn.value] = true;
+ end
+ end
+ return names;
+end
+
local pat = "%-%-%-%-%-BEGIN ([A-Z ]+)%-%-%-%-%-\r?\n"..
"([0-9A-Za-z+/=\r\n]*)\r?\n%-%-%-%-%-END %1%-%-%-%-%-";
@@ -237,6 +264,7 @@ end
return {
verify_identity = verify_identity;
+ get_identities = get_identities;
pem2der = pem2der;
der2pem = der2pem;
};