diff options
Diffstat (limited to 'util')
-rw-r--r-- | util/datamanager.lua | 2 | ||||
-rw-r--r-- | util/error.lua | 52 | ||||
-rw-r--r-- | util/format.lua | 18 | ||||
-rw-r--r-- | util/hashring.lua | 88 | ||||
-rw-r--r-- | util/hmac.lua | 9 | ||||
-rw-r--r-- | util/http.lua | 22 | ||||
-rw-r--r-- | util/import.lua | 2 | ||||
-rw-r--r-- | util/iterators.lua | 6 | ||||
-rw-r--r-- | util/multitable.lua | 2 | ||||
-rw-r--r-- | util/promise.lua | 3 | ||||
-rw-r--r-- | util/prosodyctl.lua | 7 | ||||
-rw-r--r-- | util/queue.lua | 12 | ||||
-rw-r--r-- | util/serialization.lua | 27 | ||||
-rw-r--r-- | util/session.lua | 3 | ||||
-rw-r--r-- | util/stanza.lua | 70 | ||||
-rw-r--r-- | util/startup.lua | 14 | ||||
-rw-r--r-- | util/x509.lua | 28 |
17 files changed, 278 insertions, 87 deletions
diff --git a/util/datamanager.lua b/util/datamanager.lua index cf96887b..b52c77fa 100644 --- a/util/datamanager.lua +++ b/util/datamanager.lua @@ -24,7 +24,7 @@ local t_concat = table.concat; local envloadfile = require"util.envload".envloadfile; local serialize = require "util.serialization".serialize; local lfs = require "lfs"; --- Extract directory seperator from package.config (an undocumented string that comes with lua) +-- Extract directory separator from package.config (an undocumented string that comes with lua) local path_separator = assert ( package.config:match ( "^([^\n]+)" ) , "package.config not in standard form" ) local prosody = prosody; diff --git a/util/error.lua b/util/error.lua new file mode 100644 index 00000000..344dd274 --- /dev/null +++ b/util/error.lua @@ -0,0 +1,52 @@ +local error_mt = { __name = "error" }; + +function error_mt:__tostring() + return ("error<%s:%s:%s>"):format(self.type, self.condition, self.text); +end + +local function is_err(e) + return getmetatable(e) == error_mt; +end + +local function new(e, context, registry) + local template = (registry and registry[e]) or e or {}; + return setmetatable({ + type = template.type or "cancel"; + condition = template.condition or "undefined-condition"; + text = template.text; + + context = context or template.context or { _error_id = e }; + }, error_mt); +end + +local function coerce(ok, err, ...) + if ok or is_err(err) then + return ok, err, ...; + end + + local new_err = setmetatable({ + native = err; + + type = "cancel"; + condition = "undefined-condition"; + }, error_mt); + return ok, new_err, ...; +end + +local function from_stanza(stanza, context) + local error_type, condition, text = stanza:get_error(); + return setmetatable({ + type = error_type or "cancel"; + condition = condition or "undefined-condition"; + text = text; + + context = context or { stanza = stanza }; + }, error_mt); +end + +return { + new = new; + coerce = coerce; + is_err = is_err; + from_stanza = from_stanza; +} diff --git a/util/format.lua b/util/format.lua index c5e513fa..c31f599f 100644 --- a/util/format.lua +++ b/util/format.lua @@ -3,12 +3,14 @@ -- local tostring = tostring; -local select = select; local unpack = table.unpack or unpack; -- luacheck: ignore 113/unpack +local pack = require "util.table".pack; -- TODO table.pack in 5.2+ local type = type; +local dump = require "util.serialization".new("debug"); local function format(formatstring, ...) - local args, args_length = { ... }, select('#', ...); + local args = pack(...); + local args_length = args.n; -- format specifier spec: -- 1. Start: '%%' @@ -28,13 +30,15 @@ local function format(formatstring, ...) if spec ~= "%%" then i = i + 1; local arg = args[i]; - if arg == nil then -- special handling for nil - arg = "<nil>" - args[i] = "<nil>"; - end local option = spec:sub(-1); - if option == "q" or option == "s" then -- arg should be string + if arg == nil then + args[i] = "nil"; + spec = "<%s>"; + elseif option == "q" then + args[i] = dump(arg); + spec = "%s"; + elseif option == "s" then args[i] = tostring(arg); elseif type(arg) ~= "number" then -- arg isn't number as expected? args[i] = tostring(arg); diff --git a/util/hashring.lua b/util/hashring.lua new file mode 100644 index 00000000..322bc005 --- /dev/null +++ b/util/hashring.lua @@ -0,0 +1,88 @@ +local function generate_ring(nodes, num_replicas, hash) + local new_ring = {}; + for _, node_name in ipairs(nodes) do + for replica = 1, num_replicas do + local replica_hash = hash(node_name..":"..replica); + new_ring[replica_hash] = node_name; + table.insert(new_ring, replica_hash); + end + end + table.sort(new_ring); + return new_ring; +end + +local hashring_methods = {}; +local hashring_mt = { + __index = function (self, k) + -- Automatically build self.ring if it's missing + if k == "ring" then + local ring = generate_ring(self.nodes, self.num_replicas, self.hash); + rawset(self, "ring", ring); + return ring; + end + return rawget(hashring_methods, k); + end +}; + +local function new(num_replicas, hash_function) + return setmetatable({ nodes = {}, num_replicas = num_replicas, hash = hash_function }, hashring_mt); +end; + +function hashring_methods:add_node(name) + self.ring = nil; + self.nodes[name] = true; + table.insert(self.nodes, name); + return true; +end + +function hashring_methods:add_nodes(nodes) + self.ring = nil; + for _, node_name in ipairs(nodes) do + if not self.nodes[node_name] then + self.nodes[node_name] = true; + table.insert(self.nodes, node_name); + end + end + return true; +end + +function hashring_methods:remove_node(node_name) + self.ring = nil; + if self.nodes[node_name] then + for i, stored_node_name in ipairs(self.nodes) do + if node_name == stored_node_name then + self.nodes[node_name] = nil; + table.remove(self.nodes, i); + return true; + end + end + end + return false; +end + +function hashring_methods:remove_nodes(nodes) + self.ring = nil; + for _, node_name in ipairs(nodes) do + self:remove_node(node_name); + end +end + +function hashring_methods:clone() + local clone_hashring = new(self.num_replicas, self.hash); + clone_hashring:add_nodes(self.nodes); + return clone_hashring; +end + +function hashring_methods:get_node(key) + local key_hash = self.hash(key); + for _, replica_hash in ipairs(self.ring) do + if key_hash < replica_hash then + return self.ring[replica_hash]; + end + end + return self.ring[self.ring[1]]; +end + +return { + new = new; +} diff --git a/util/hmac.lua b/util/hmac.lua index 2c4cc6ef..4cad17cc 100644 --- a/util/hmac.lua +++ b/util/hmac.lua @@ -10,6 +10,9 @@ local hashes = require "util.hashes" -return { md5 = hashes.hmac_md5, - sha1 = hashes.hmac_sha1, - sha256 = hashes.hmac_sha256 }; +return { + md5 = hashes.hmac_md5, + sha1 = hashes.hmac_sha1, + sha256 = hashes.hmac_sha256, + sha512 = hashes.hmac_sha512, +}; diff --git a/util/http.lua b/util/http.lua index cfb89193..3852f91c 100644 --- a/util/http.lua +++ b/util/http.lua @@ -6,24 +6,26 @@ -- local format, char = string.format, string.char; -local pairs, ipairs, tonumber = pairs, ipairs, tonumber; +local pairs, ipairs = pairs, ipairs; local t_insert, t_concat = table.insert, table.concat; +local url_codes = {}; +for i = 0, 255 do + local c = char(i); + local u = format("%%%02x", i); + url_codes[c] = u; + url_codes[u] = c; + url_codes[u:upper()] = c; +end local function urlencode(s) - return s and (s:gsub("[^a-zA-Z0-9.~_-]", function (c) return format("%%%02x", c:byte()); end)); + return s and (s:gsub("[^a-zA-Z0-9.~_-]", url_codes)); end local function urldecode(s) - return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end)); + return s and (s:gsub("%%%x%x", url_codes)); end local function _formencodepart(s) - return s and (s:gsub("%W", function (c) - if c ~= " " then - return format("%%%02x", c:byte()); - else - return "+"; - end - end)); + return s and (urlencode(s):gsub("%%20", "+")); end local function formencode(form) diff --git a/util/import.lua b/util/import.lua index 8ecfe43c..1007bc0a 100644 --- a/util/import.lua +++ b/util/import.lua @@ -8,7 +8,7 @@ -local unpack = table.unpack or unpack; --luacheck: ignore 113 143 +local unpack = table.unpack or unpack; --luacheck: ignore 113 local t_insert = table.insert; function _G.import(module, ...) local m = package.loaded[module] or require(module); diff --git a/util/iterators.lua b/util/iterators.lua index 302cca36..c03c2fd6 100644 --- a/util/iterators.lua +++ b/util/iterators.lua @@ -11,9 +11,9 @@ local it = {}; local t_insert = table.insert; -local select, next = select, next; -local unpack = table.unpack or unpack; --luacheck: ignore 113 143 -local pack = table.pack or function (...) return { n = select("#", ...), ... }; end -- luacheck: ignore 143 +local next = next; +local unpack = table.unpack or unpack; --luacheck: ignore 113 +local pack = table.pack or require "util.table".pack; local type = type; local table, setmetatable = table, setmetatable; diff --git a/util/multitable.lua b/util/multitable.lua index 8d32ed8a..4f2cd972 100644 --- a/util/multitable.lua +++ b/util/multitable.lua @@ -9,7 +9,7 @@ local select = select; local t_insert = table.insert; local pairs, next, type = pairs, next, type; -local unpack = table.unpack or unpack; --luacheck: ignore 113 143 +local unpack = table.unpack or unpack; --luacheck: ignore 113 local _ENV = nil; -- luacheck: std none diff --git a/util/promise.lua b/util/promise.lua index 07c9c4dc..0b182b54 100644 --- a/util/promise.lua +++ b/util/promise.lua @@ -49,6 +49,9 @@ local function promise_settle(promise, new_state, new_next, cbs, value) for _, cb in ipairs(cbs) do cb(value); end + -- No need to keep references to callbacks + promise._pending_on_fulfilled = nil; + promise._pending_on_rejected = nil; return true; end diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua index 5f0c4d12..9b627bde 100644 --- a/util/prosodyctl.lua +++ b/util/prosodyctl.lua @@ -229,7 +229,8 @@ local function isrunning() return true, signal.kill(pid, 0) == 0; end -local function start(source_dir) +local function start(source_dir, lua) + lua = lua and lua .. " " or ""; local ok, ret = isrunning(); if not ok then return ok, ret; @@ -238,9 +239,9 @@ local function start(source_dir) return false, "already-running"; end if not source_dir then - os.execute("./prosody"); + os.execute(lua .. "./prosody"); else - os.execute(source_dir.."/../../bin/prosody"); + os.execute(lua .. source_dir.."/../../bin/prosody"); end return true; end diff --git a/util/queue.lua b/util/queue.lua index 728e905f..66ed098b 100644 --- a/util/queue.lua +++ b/util/queue.lua @@ -52,18 +52,20 @@ local function new(size, allow_wrapping) return t[tail]; end; items = function (self) - --luacheck: ignore 431/t - return function (t, pos) - if pos >= t:count() then + return function (_, pos) + if pos >= items then return nil; end local read_pos = tail + pos; - if read_pos > t.size then + if read_pos > self.size then read_pos = (read_pos%size); end - return pos+1, t._items[read_pos]; + return pos+1, t[read_pos]; end, self, 0; end; + consume = function (self) + return self.pop, self; + end; }; end diff --git a/util/serialization.lua b/util/serialization.lua index dd6a2a2b..60e341cf 100644 --- a/util/serialization.lua +++ b/util/serialization.lua @@ -16,22 +16,18 @@ local s_char = string.char; local s_match = string.match; local t_concat = table.concat; +local to_hex = require "util.hex".to; + local pcall = pcall; local envload = require"util.envload".envload; local pos_inf, neg_inf = math.huge, -math.huge; --- luacheck: ignore 143/math local m_type = math.type or function (n) return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float"; end; -local char_to_hex = {}; -for i = 0,255 do - char_to_hex[s_char(i)] = s_format("%02x", i); -end - -local function to_hex(s) - return (s_gsub(s, ".", char_to_hex)); +local function rawpairs(t) + return next, t, nil; end local function fatal_error(obj, why) @@ -123,6 +119,7 @@ local function new(opt) local freeze = opt.freeze; local maxdepth = opt.maxdepth or 127; local multirefs = opt.multiref; + local table_pairs = opt.table_iterator or rawpairs; -- serialize one table, recursively -- t - table being serialized @@ -164,7 +161,9 @@ local function new(opt) local indent = s_rep(indentwith, d); local numkey = 1; local ktyp, vtyp; - for k,v in next,t do + local had_items = false; + for k,v in table_pairs(t) do + had_items = true; o[l], l = itemstart, l + 1; o[l], l = indent, l + 1; ktyp, vtyp = type(k), type(v); @@ -195,14 +194,10 @@ local function new(opt) else o[l], l = ser(v), l + 1; end - -- last item? - if next(t, k) ~= nil then - o[l], l = itemsep, l + 1; - else - o[l], l = itemlast, l + 1; - end + o[l], l = itemsep, l + 1; end - if next(t) ~= nil then + if had_items then + o[l - 1] = itemlast; o[l], l = s_rep(indentwith, d-1), l + 1; end o[l], l = tend, l +1; diff --git a/util/session.lua b/util/session.lua index b2a726ce..b9c6bec7 100644 --- a/util/session.lua +++ b/util/session.lua @@ -4,12 +4,13 @@ local logger = require "util.logger"; local function new_session(typ) local session = { type = typ .. "_unauthed"; + base_type = typ; }; return session; end local function set_id(session) - local id = session.type .. tostring(session):match("%x+$"):lower(); + local id = session.base_type .. tostring(session):match("%x+$"):lower(); session.id = id; return session; end diff --git a/util/stanza.lua b/util/stanza.lua index a90d56b3..7fe5c7ae 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -270,6 +270,34 @@ function stanza_mt:find(path) until not self end +local function _clone(stanza, only_top) + local attr, tags = {}, {}; + for k,v in pairs(stanza.attr) do attr[k] = v; end + local old_namespaces, namespaces = stanza.namespaces; + if old_namespaces then + namespaces = {}; + for k,v in pairs(old_namespaces) do namespaces[k] = v; end + end + local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags }; + if not only_top then + for i=1,#stanza do + local child = stanza[i]; + if child.name then + child = _clone(child); + t_insert(tags, child); + end + t_insert(new, child); + end + end + return setmetatable(new, stanza_mt); +end + +local function clone(stanza, only_top) + if not is_stanza(stanza) then + error("bad argument to clone: expected stanza, got "..type(stanza)); + end + return _clone(stanza, only_top); +end local escape_table = { ["'"] = "'", ["\""] = """, ["<"] = "<", [">"] = ">", ["&"] = "&" }; local function xml_escape(str) return (s_gsub(str, "['&<>\"]", escape_table)); end @@ -310,11 +338,8 @@ function stanza_mt.__tostring(t) end function stanza_mt.top_tag(t) - local attr_string = ""; - if t.attr then - for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(" %s='%s'", k, xml_escape(tostring(v))); end end - end - return s_format("<%s%s>", t.name, attr_string); + local top_tag_clone = clone(t, true); + return tostring(top_tag_clone):sub(1,-3)..">"; end function stanza_mt.get_text(t) @@ -388,33 +413,6 @@ local function deserialize(serialized) end end -local function _clone(stanza) - local attr, tags = {}, {}; - for k,v in pairs(stanza.attr) do attr[k] = v; end - local old_namespaces, namespaces = stanza.namespaces; - if old_namespaces then - namespaces = {}; - for k,v in pairs(old_namespaces) do namespaces[k] = v; end - end - local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags }; - for i=1,#stanza do - local child = stanza[i]; - if child.name then - child = _clone(child); - t_insert(tags, child); - end - t_insert(new, child); - end - return setmetatable(new, stanza_mt); -end - -local function clone(stanza) - if not is_stanza(stanza) then - error("bad argument to clone: expected stanza, got "..type(stanza)); - end - return _clone(stanza); -end - local function message(attr, body) if not body then return new_stanza("message", attr); @@ -423,9 +421,15 @@ local function message(attr, body) end end local function iq(attr) - if not (attr and attr.id) then + if not attr then + error("iq stanzas require id and type attributes"); + end + if not attr.id then error("iq stanzas require an id attribute"); end + if not attr.type then + error("iq stanzas require a type attribute"); + end return new_stanza("iq", attr); end diff --git a/util/startup.lua b/util/startup.lua index c101c290..7a1a95aa 100644 --- a/util/startup.lua +++ b/util/startup.lua @@ -7,6 +7,7 @@ local logger = require "util.logger"; local log = logger.init("startup"); local config = require "core.configmanager"; +local config_warnings; local dependencies = require "util.dependencies"; @@ -64,6 +65,8 @@ function startup.read_config() print("**************************"); print(""); os.exit(1); + elseif err and #err > 0 then + config_warnings = err; end prosody.config_loaded = true; end @@ -96,8 +99,13 @@ function startup.init_logging() end); end -function startup.log_dependency_warnings() +function startup.log_startup_warnings() dependencies.log_warnings(); + if config_warnings then + for _, warning in ipairs(config_warnings) do + log("warn", "Configuration warning: %s", warning); + end + end end function startup.sanity_check() @@ -518,7 +526,7 @@ function startup.prosodyctl() startup.read_version(); startup.switch_user(); startup.check_dependencies(); - startup.log_dependency_warnings(); + startup.log_startup_warnings(); startup.check_unwriteable(); startup.load_libraries(); startup.init_http_client(); @@ -543,7 +551,7 @@ function startup.prosody() startup.add_global_prosody_functions(); startup.read_version(); startup.log_greeting(); - startup.log_dependency_warnings(); + startup.log_startup_warnings(); startup.load_secondary_libraries(); startup.init_http_client(); startup.init_data_store(); diff --git a/util/x509.lua b/util/x509.lua index 15cc4d3c..1cdf07dc 100644 --- a/util/x509.lua +++ b/util/x509.lua @@ -20,6 +20,7 @@ local nameprep = require "util.encodings".stringprep.nameprep; local idna_to_ascii = require "util.encodings".idna.to_ascii; +local idna_to_unicode = require "util.encodings".idna.to_unicode; local base64 = require "util.encodings".base64; local log = require "util.logger".init("x509"); local s_format = string.format; @@ -216,6 +217,32 @@ local function verify_identity(host, service, cert) return false end +-- TODO Support other SANs +local function get_identities(cert) --> set of names + if cert.setencode then + cert:setencode("utf8"); + end + + local names = {}; + + local ext = cert:extensions(); + local sans = ext[oid_subjectaltname]; + if sans and sans["dNSName"] then + for i = 1, #sans["dNSName"] do + names[ idna_to_unicode(sans["dNSName"][i]) ] = true; + end + end + + local subject = cert:subject(); + for i = 1, #subject do + local dn = subject[i]; + if dn.oid == oid_commonname and nameprep(dn.value) then + names[dn.value] = true; + end + end + return names; +end + local pat = "%-%-%-%-%-BEGIN ([A-Z ]+)%-%-%-%-%-\r?\n".. "([0-9A-Za-z+/=\r\n]*)\r?\n%-%-%-%-%-END %1%-%-%-%-%-"; @@ -237,6 +264,7 @@ end return { verify_identity = verify_identity; + get_identities = get_identities; pem2der = pem2der; der2pem = der2pem; }; |