diff options
Diffstat (limited to 'util')
84 files changed, 2730 insertions, 1354 deletions
diff --git a/util/adminstream.lua b/util/adminstream.lua index 4075aa05..a29222cf 100644 --- a/util/adminstream.lua +++ b/util/adminstream.lua @@ -1,15 +1,15 @@ -local st = require "util.stanza"; -local new_xmpp_stream = require "util.xmppstream".new; -local sessionlib = require "util.session"; -local gettime = require "util.time".now; -local runner = require "util.async".runner; -local add_task = require "util.timer".add_task; -local events = require "util.events"; -local server = require "net.server"; +local st = require "prosody.util.stanza"; +local new_xmpp_stream = require "prosody.util.xmppstream".new; +local sessionlib = require "prosody.util.session"; +local gettime = require "prosody.util.time".now; +local runner = require "prosody.util.async".runner; +local add_task = require "prosody.util.timer".add_task; +local events = require "prosody.util.events"; +local server = require "prosody.net.server"; local stream_close_timeout = 5; -local log = require "util.logger".init("adminstream"); +local log = require "prosody.util.logger".init("adminstream"); local xmlns_xmpp_streams = "urn:ietf:params:xml:ns:xmpp-streams"; @@ -145,7 +145,7 @@ local function new_connection(socket_path, listeners) -- constructor was exported instead of a module table. Due to the lack of a -- proper release of LuaSocket, distros have settled on shipping either the -- last RC tag or some commit since then. - -- Here we accomodate both variants. + -- Here we accommodate both variants. unix = { stream = unix }; end if type(unix) ~= "table" then diff --git a/util/argparse.lua b/util/argparse.lua index 9ece050a..7a55cb0b 100644 --- a/util/argparse.lua +++ b/util/argparse.lua @@ -1,6 +1,7 @@ local function parse(arg, config) local short_params = config and config.short_params or {}; local value_params = config and config.value_params or {}; + local array_params = config and config.array_params or {}; local parsed_opts = {}; @@ -30,7 +31,7 @@ local function parse(arg, config) end local param_k, param_v; - if value_params[param] then + if value_params[param] or array_params[param] then param_k, param_v = param, table.remove(arg, 1); if not param_v then return nil, "missing-value", raw_param; @@ -44,8 +45,17 @@ local function parse(arg, config) param_k, param_v = param, true; end end + param_k = param_k:gsub("%-", "_"); + end + if array_params[param] then + if parsed_opts[param_k] then + table.insert(parsed_opts[param_k], param_v); + else + parsed_opts[param_k] = { param_v }; + end + else + parsed_opts[param_k] = param_v; end - parsed_opts[param_k] = param_v; end for i = 1, #arg do parsed_opts[i] = arg[i]; diff --git a/util/array.lua b/util/array.lua index 39a97e7f..074e9bc7 100644 --- a/util/array.lua +++ b/util/array.lua @@ -8,6 +8,7 @@ local t_insert, t_sort, t_remove, t_concat = table.insert, table.sort, table.remove, table.concat; +local t_move = require "prosody.util.table".move; local setmetatable = setmetatable; local getmetatable = getmetatable; @@ -23,7 +24,7 @@ local array_methods = {}; local array_mt = { __index = array_methods; __name = "array"; - __tostring = function (self) return "{"..self:concat(", ").."}"; end; + __tostring = function (self) return "["..self:concat(", ").."]"; end; }; function array_mt:__freeze() return self; end @@ -141,13 +142,11 @@ function array_base.slice(outa, ina, i, j) return outa; end - for idx = 1, 1+j-i do - outa[idx] = ina[i+(idx-1)]; - end + + t_move(ina, i, j, 1, outa); if ina == outa then - for idx = 2+j-i, #outa do - outa[idx] = nil; - end + -- Clear (nil) remainder of range + t_move(ina, #outa+1, #outa*2, 2+j-i, ina); end return outa; end @@ -213,10 +212,7 @@ function array_methods:shuffle() end function array_methods:append(ina) - local len, len2 = #self, #ina; - for i = 1, len2 do - self[len+i] = ina[i]; - end + t_move(ina, 1, #ina, #self+1, self); return self; end diff --git a/util/async.lua b/util/async.lua index 2830238f..9a3485e9 100644 --- a/util/async.lua +++ b/util/async.lua @@ -1,7 +1,8 @@ -local logger = require "util.logger"; +local logger = require "prosody.util.logger"; local log = logger.init("util.async"); -local new_id = require "util.id".short; -local xpcall = require "util.xpcall".xpcall; +local new_id = require "prosody.util.id".short; +local xpcall = require "prosody.util.xpcall".xpcall; +local time_now = require "prosody.util.time".now; local function checkthread() local thread, main = coroutine.running(); @@ -45,7 +46,9 @@ end local function runner_continue(thread) -- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure) if coroutine.status(thread) ~= "suspended" then -- This should suffice - log("error", "unexpected async state: thread not suspended"); + log("error", "unexpected async state: thread not suspended (%s, %s)", thread, coroutine.status(thread)); + -- Fetching the traceback is likely to *crash* if a C library is calling us while suspended + --log("error", "coroutine stack: %s", debug.traceback()); return false; end local ok, state, runner = coroutine.resume(thread); @@ -138,6 +141,8 @@ end local runner_mt = {}; runner_mt.__index = runner_mt; +local waiting_runners = {}; + local function runner_create_thread(func, self) local thread = coroutine.create(function (self) -- luacheck: ignore 432/self while true do @@ -195,6 +200,8 @@ function runner_mt:run(input) -- Loop through queue items, and attempt to run them for i = 1,n do local queued_input = q[i]; + self:log("Resuming thread with new item [%s]", thread); + self.current_item = queued_input; local ok, new_state = coroutine.resume(thread, queued_input); if not ok then -- There was an error running the coroutine, save the error, mark runner as ready to begin again @@ -221,8 +228,13 @@ function runner_mt:run(input) end -- Runner processed all items it can, so save current runner state self.state = state; + if state == "ready" and self.current_item then + self.current_item = nil; + end + if err or state ~= self.notified_state then - self:log("debug", "changed state from %s to %s", self.notified_state, err and ("error ("..state..")") or state); + self:log("debug", "changed state from %s to %s [%s %s]", self.notified_state, err and ("error (" .. state .. ")") or state, self.thread, + self.thread and coroutine.status(self.thread)); if err then state = "error" else @@ -234,6 +246,7 @@ function runner_mt:run(input) if n > 0 then return self:run(); end + waiting_runners[self] = state == "waiting" and time_now() or nil; return true, state, n; end @@ -293,4 +306,7 @@ return { set_nexttick = function(new_next_tick) next_tick = new_next_tick; end; set_schedule_function = function (new_schedule_function) schedule_task = new_schedule_function; end; + + waiting_runners = waiting_runners; + default_runner_func = default_func; }; diff --git a/util/bit53.lua b/util/bit53.lua index b5c473a3..42f17ce8 100644 --- a/util/bit53.lua +++ b/util/bit53.lua @@ -27,6 +27,9 @@ return { end return ret; end; + bnot = function (x) + return ~x; + end; rshift = function (a, n) return a >> n end; lshift = function (a, n) return a << n end; }; diff --git a/util/bitcompat.lua b/util/bitcompat.lua index 454181af..43fd0f6e 100644 --- a/util/bitcompat.lua +++ b/util/bitcompat.lua @@ -5,25 +5,11 @@ -- Lua 5.2 has it by default if _G.bit32 then return _G.bit32; -else - -- Lua 5.1 may have it as a standalone module that can be installed - local ok, bitop = pcall(require, "bit32") - if ok then - return bitop; - end end do -- Lua 5.3 and 5.4 would be able to use native infix operators - local ok, bitop = pcall(require, "util.bit53") - if ok then - return bitop; - end -end - -do - -- Lastly, try the LuaJIT bitop library - local ok, bitop = pcall(require, "bit") + local ok, bitop = pcall(require, "prosody.util.bit53") if ok then return bitop; end diff --git a/util/cache.lua b/util/cache.lua index cd1b4544..e1873cc8 100644 --- a/util/cache.lua +++ b/util/cache.lua @@ -54,12 +54,17 @@ function cache_methods:set(k, v) if self._count == self.size then local tail = self._tail; local on_evict, evicted_key, evicted_value = self._on_evict, tail.key, tail.value; - if on_evict ~= nil and (on_evict == false or on_evict(evicted_key, evicted_value) == false) then + + local do_evict = on_evict and on_evict(evicted_key, evicted_value, self); + + if do_evict == false then -- Cache is full, and we're not allowed to evict return false; + elseif self._count == self.size then + -- Cache wasn't grown + _remove(self, tail); + self._data[evicted_key] = nil; end - _remove(self, tail); - self._data[evicted_key] = nil; end m = { key = k, value = v, prev = nil, next = nil }; @@ -124,7 +129,7 @@ function cache_methods:resize(new_size) while self._count > new_size do local tail = self._tail; local evicted_key, evicted_value = tail.key, tail.value; - if on_evict ~= nil and (on_evict == false or on_evict(evicted_key, evicted_value) == false) then + if on_evict ~= nil and (on_evict == false or on_evict(evicted_key, evicted_value, self) == false) then -- Cache is full, and we're not allowed to evict return false; end diff --git a/util/caps.lua b/util/caps.lua index de492edb..08743393 100644 --- a/util/caps.lua +++ b/util/caps.lua @@ -6,8 +6,8 @@ -- COPYING file in the source package for more information. -- -local base64 = require "util.encodings".base64.encode; -local sha1 = require "util.hashes".sha1; +local base64 = require "prosody.util.encodings".base64.encode; +local sha1 = require "prosody.util.hashes".sha1; local t_insert, t_sort, t_concat = table.insert, table.sort, table.concat; local ipairs = ipairs; diff --git a/util/dataforms.lua b/util/dataforms.lua index 66733895..deb570cb 100644 --- a/util/dataforms.lua +++ b/util/dataforms.lua @@ -12,9 +12,9 @@ local type, next = type, next; local tonumber = tonumber; local tostring = tostring; local t_concat = table.concat; -local st = require "util.stanza"; -local jid_prep = require "util.jid".prep; -local datetime = require "util.datetime"; +local st = require "prosody.util.stanza"; +local jid_prep = require "prosody.util.jid".prep; +local datetime = require "prosody.util.datetime"; local _ENV = nil; -- luacheck: std none diff --git a/util/datamanager.lua b/util/datamanager.lua index c57f4a0e..8192d12a 100644 --- a/util/datamanager.lua +++ b/util/datamanager.lua @@ -7,35 +7,42 @@ -- +local string = string; local format = string.format; local setmetatable = setmetatable; local ipairs = ipairs; local char = string.char; local pcall = pcall; -local log = require "util.logger".init("datamanager"); +local log = require "prosody.util.logger".init("datamanager"); local io_open = io.open; local os_remove = os.remove; local os_rename = os.rename; local tonumber = tonumber; +local floor = math.floor; local next = next; local type = type; local t_insert = table.insert; local t_concat = table.concat; -local envloadfile = require"util.envload".envloadfile; -local serialize = require "util.serialization".serialize; +local envloadfile = require"prosody.util.envload".envloadfile; +local envload = require"prosody.util.envload".envload; +local serialize = require "prosody.util.serialization".serialize; local lfs = require "lfs"; -- Extract directory separator from package.config (an undocumented string that comes with lua) local path_separator = assert ( package.config:match ( "^([^\n]+)" ) , "package.config not in standard form" ) local prosody = prosody; +--luacheck: ignore 211/blocksize 211/remove_blocks +local blocksize = 0x1000; local raw_mkdir = lfs.mkdir; local atomic_append; +local remove_blocks; local ENOENT = 2; pcall(function() - local pposix = require "util.pposix"; + local pposix = require "prosody.util.pposix"; raw_mkdir = pposix.mkdir or raw_mkdir; -- Doesn't trample on umask atomic_append = pposix.atomic_append; + -- remove_blocks = pposix.remove_blocks; ENOENT = pposix.ENOENT or ENOENT; end); @@ -239,6 +246,14 @@ local function append(username, host, datastore, ext, data) end local pos = f:seek("end"); + --[[ TODO needs tests + if (blocksize-(pos%blocksize)) < (#data%blocksize) then + -- pad to blocksize with newlines so that the next item is both on a new + -- block and a new line + atomic_append(f, ("\n"):rep(blocksize-(pos%blocksize))); + pos = f:seek("end"); + end + --]] local ok, msg = atomic_append(f, data); @@ -255,6 +270,13 @@ local function append(username, host, datastore, ext, data) return true, pos; end +local index_fmt, index_item_size, index_magic; +if string.packsize then + index_fmt = "T"; -- offset to the end of the item, length can be derived from two index items + index_item_size = string.packsize(index_fmt); + index_magic = string.pack(index_fmt, 7767639 + 1); -- Magic string: T9 for "prosody", version number +end + local function list_append(username, host, datastore, data) if not data then return; end if callback(username, host, datastore) == false then return true; end @@ -267,6 +289,22 @@ local function list_append(username, host, datastore, data) datastore, msg, where, username or "nil", host or "nil"); return ok, msg; end + if string.packsize then + local offset = type(msg) == "number" and msg or 0; + local index_entry = string.pack(index_fmt, offset + #data); + if offset == 0 then + index_entry = index_magic .. index_entry; + end + local ok, off = append(username, host, datastore, "lidx", index_entry); + off = off or 0; + -- If this was the first item, then both the data and index offsets should + -- be zero, otherwise there's some kind of mismatch and we should drop the + -- index and recreate it from scratch + -- TODO Actually rebuild the index in this case? + if not ok or (off == 0 and offset ~= 0) or (off ~= 0 and offset == 0) then + os_remove(getpath(username, host, datastore, "lidx")); + end + end return true; end @@ -280,6 +318,7 @@ local function list_store(username, host, datastore, data) for i, item in ipairs(data) do d[i] = "item(" .. serialize(item) .. ");\n"; end + os_remove(getpath(username, host, datastore, "lidx")); local ok, msg = atomic_store(getpath(username, host, datastore, "list", true), t_concat(d)); if not ok then log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil"); @@ -294,6 +333,160 @@ local function list_store(username, host, datastore, data) return true; end +local function build_list_index(username, host, datastore, items) + log("debug", "Building index for (%s@%s/%s)", username, host, datastore); + local filename = getpath(username, host, datastore, "list"); + local fh, err, errno = io_open(filename); + if not fh then + return fh, err, errno; + end + local prev_pos = 0; -- position before reading + local last_item_start = nil; + + if items and items[1] then + local last_item = items[#items]; + last_item_start = fh:seek("set", last_item.start + last_item.length); + else + items = {}; + end + + for line in fh:lines() do + if line:sub(1, 4) == "item" then + if prev_pos ~= 0 and last_item_start then + t_insert(items, { start = last_item_start; length = prev_pos - last_item_start }); + end + last_item_start = prev_pos + end + -- seek position is at the start of the next line within each loop iteration + -- so we need to collect the "current" position at the end of the previous + prev_pos = fh:seek() + end + fh:close(); + if prev_pos ~= 0 then + t_insert(items, { start = last_item_start; length = prev_pos - last_item_start }); + end + return items; +end + +local function store_list_index(username, host, datastore, index) + local data = { index_magic }; + for i, v in ipairs(index) do + data[i + 1] = string.pack(index_fmt, v.start + v.length); + end + local filename = getpath(username, host, datastore, "lidx"); + return atomic_store(filename, t_concat(data)); +end + +local index_mt = { + __index = function(t, i) + if type(i) ~= "number" or i % 1 ~= 0 or i < 0 then + return + end + if i <= 0 then + return 0 + end + local fh = t.file; + local pos = (i - 1) * index_item_size; + if fh:seek("set", pos) ~= pos then + return nil + end + local data = fh:read(index_item_size * 2); + if not data or #data ~= index_item_size * 2 then + return nil + end + local start, next_pos = string.unpack(index_fmt .. index_fmt, data); + if pos == 0 then + start = 0 + end + local length = next_pos - start; + local v = { start = start; length = length }; + t[i] = v; + return v; + end; + __len = function(t) + -- Account for both the header and the fence post error + return floor(t.file:seek("end") / index_item_size) - 1; + end; +} + +local function get_list_index(username, host, datastore) + log("debug", "Loading index for (%s@%s/%s)", username, host, datastore); + local index_filename = getpath(username, host, datastore, "lidx"); + local ih = io_open(index_filename); + if ih then + local magic = ih:read(#index_magic); + if magic ~= index_magic then + log("debug", "Index %q has wrong version number (got %q, expected %q), rebuilding...", index_filename, magic, index_magic); + -- wrong version or something + ih:close(); + ih = nil; + end + end + + if ih then + local first_length = string.unpack(index_fmt, ih:read(index_item_size)); + return setmetatable({ file = ih; { start = 0; length = first_length } }, index_mt); + end + + local index, err = build_list_index(username, host, datastore); + if not index then + return index, err + end + + -- TODO How to handle failure to store the index? + local dontcare = store_list_index(username, host, datastore, index); -- luacheck: ignore 211/dontcare + return index; +end + +local function list_load_one(fh, start, length) + if fh:seek("set", start) ~= start then + return nil + end + local raw_data = fh:read(length) + if not raw_data or #raw_data ~= length then + return + end + local item; + local data, err, errno = envload(raw_data, "@list", { + item = function(i) + item = i; + end; + }); + if not data then + return data, err, errno + end + local success, ret = pcall(data); + if not success then + return success, ret; + end + return item; +end + +local function list_close(list) + if list.index and list.index.file then + list.index.file:close(); + end + return list.file:close(); +end + +local indexed_list_mt = { + __index = function(t, i) + if type(i) ~= "number" or i % 1 ~= 0 or i < 1 then + return + end + local ix = t.index[i]; + if not ix then + return + end + local item = list_load_one(t.file, ix.start, ix.length); + return item; + end; + __len = function(t) + return #t.index; + end; + __close = list_close; +} + local function list_load(username, host, datastore) local items = {}; local data, err, errno = envloadfile(getpath(username, host, datastore, "list"), {item = function(i) t_insert(items, i); end}); @@ -314,6 +507,123 @@ local function list_load(username, host, datastore) return items; end +local function list_open(username, host, datastore) + if not index_magic then + log("debug", "Falling back from lazy loading to loading full list for %s storage for user: %s@%s", datastore, username or "nil", host or "nil"); + return list_load(username, host, datastore); + end + local filename = getpath(username, host, datastore, "list"); + local file, err, errno = io_open(filename); + if not file then + if errno == ENOENT then + return nil; + end + return file, err, errno; + end + local index, err = get_list_index(username, host, datastore); + if not index then + file:close() + return index, err; + end + return setmetatable({ file = file; index = index; close = list_close }, indexed_list_mt); +end + +local function shift_index(index_filename, index, trim_to, offset) -- luacheck: ignore 212 + os_remove(index_filename); + return "deleted"; + -- TODO move and recalculate remaining items +end + +local function list_shift(username, host, datastore, trim_to) + if trim_to == 1 then + return true + end + if type(trim_to) ~= "number" or trim_to < 1 then + return nil, "invalid-argument"; + end + local list_filename = getpath(username, host, datastore, "list"); + local index_filename = getpath(username, host, datastore, "lidx"); + local index, err = get_list_index(username, host, datastore); + if not index then + return nil, err; + end + + local new_first = index[trim_to]; + if not new_first then + os_remove(index_filename); + return os_remove(list_filename); + end + + local offset = new_first.start; + if offset == 0 then + return true; + end + + --[[ + if remove_blocks then + local f, err = io_open(list_filename, "r+"); + if not f then + return f, err; + end + + local diff = 0; + local block_offset = 0; + if offset % 0x1000 ~= 0 then + -- Not an even block boundary, we will have to overwrite + diff = offset % 0x1000; + block_offset = offset - diff; + end + + if block_offset == 0 then + log("debug", "") + else + local ok, err = remove_blocks(f, 0, block_offset); + log("debug", "remove_blocks(%s, 0, %d)", f, block_offset); + if not ok then + log("warn", "Could not remove blocks from %q[%d, %d]: %s", list_filename, 0, block_offset, err); + else + if diff ~= 0 then + -- overwrite unaligned leftovers + if f:seek("set", 0) then + local wrote, err = f:write(string.rep("\n", diff)); + if not wrote then + log("error", "Could not blank out %q[%d, %d]: %s", list_filename, 0, diff, err); + end + end + end + local ok, err = f:close(); + shift_index(index_filename, index, trim_to, offset); -- Shift or delete the index + return ok, err; + end + end + end + --]] + + local r, err = io_open(list_filename, "r"); + if not r then + return nil, err; + end + local w, err = io_open(list_filename .. "~", "w"); + if not w then + return nil, err; + end + r:seek("set", offset); + for block in r:lines(0x1000) do + local ok, err = w:write(block); + if not ok then + return nil, err; + end + end + r:close(); + local ok, err = w:close(); + if not ok then + return nil, err; + end + shift_index(index_filename, index, trim_to, offset) + return os_rename(list_filename .. "~", list_filename); +end + + local type_map = { keyval = "dat"; list = "list"; @@ -392,6 +702,8 @@ local function purge(username, host) local ok, err = do_remove(getpath(username, host, store_name, "list")); if not ok then errs[#errs+1] = err; end + local ok, err = do_remove(getpath(username, host, store_name, "lidx")); + if not ok then errs[#errs+1] = err; end end end return #errs == 0, t_concat(errs, ", "); @@ -414,4 +726,8 @@ return { purge = purge; path_decode = decode; path_encode = encode; + + build_list_index = build_list_index; + list_open = list_open; + list_shift = list_shift; }; diff --git a/util/datamapper.lua b/util/datamapper.lua index 2378314c..03abc7ad 100644 --- a/util/datamapper.lua +++ b/util/datamapper.lua @@ -1,7 +1,11 @@ -- This file is generated from teal-src/util/datamapper.lua -local st = require("util.stanza"); -local pointer = require("util.jsonpointer"); +if not math.type then + require("prosody.util.mathcompat") +end + +local st = require("prosody.util.stanza"); +local pointer = require("prosody.util.jsonpointer"); local schema_t = {} diff --git a/util/datetime.lua b/util/datetime.lua index 2d27ece4..6df146f4 100644 --- a/util/datetime.lua +++ b/util/datetime.lua @@ -12,31 +12,41 @@ local os_date = os.date; local os_time = os.time; local os_difftime = os.difftime; +local floor = math.floor; local tonumber = tonumber; local _ENV = nil; -- luacheck: std none local function date(t) - return os_date("!%Y-%m-%d", t); + return os_date("!%Y-%m-%d", t and floor(t) or nil); end local function datetime(t) - return os_date("!%Y-%m-%dT%H:%M:%SZ", t); + if t == nil or t % 1 == 0 then + return os_date("!%Y-%m-%dT%H:%M:%SZ", t); + end + local m = t % 1; + local s = floor(t); + return os_date("!%Y-%m-%dT%H:%M:%S.%%06dZ", s):format(floor(m * 1000000)); end local function time(t) - return os_date("!%H:%M:%S", t); + if t == nil or t % 1 == 0 then + return os_date("!%H:%M:%S", t); + end + local m = t % 1; + local s = floor(t); + return os_date("!%H:%M:%S.%%06d", s):format(floor(m * 1000000)); end local function legacy(t) - return os_date("!%Y%m%dT%H:%M:%S", t); + return os_date("!%Y%m%dT%H:%M:%S", t and floor(t) or nil); end local function parse(s) if s then - local year, month, day, hour, min, sec, tzd; - year, month, day, hour, min, sec, tzd = s:match("^(%d%d%d%d)%-?(%d%d)%-?(%d%d)T(%d%d):(%d%d):(%d%d)%.?%d*([Z+%-]?.*)$"); + local year, month, day, hour, min, sec, tzd = s:match("^(%d%d%d%d)%-?(%d%d)%-?(%d%d)T(%d%d):(%d%d):(%d%d%.?%d*)([Z+%-]?.*)$"); if year then local now = os_time(); local time_offset = os_difftime(os_time(os_date("*t", now)), os_time(os_date("!*t", now))); -- to deal with local timezone @@ -49,8 +59,9 @@ local function parse(s) tzd_offset = h * 60 * 60 + m * 60; if sign == "-" then tzd_offset = -tzd_offset; end end - sec = (sec + time_offset) - tzd_offset; - return os_time({year=year, month=month, day=day, hour=hour, min=min, sec=sec, isdst=false}); + local prec = sec%1; + sec = floor(sec + time_offset) - tzd_offset; + return os_time({year=year, month=month, day=day, hour=hour, min=min, sec=sec, isdst=false})+prec; end end end diff --git a/util/dbuffer.lua b/util/dbuffer.lua index 3ad5fdfe..4c00eabc 100644 --- a/util/dbuffer.lua +++ b/util/dbuffer.lua @@ -1,4 +1,4 @@ -local queue = require "util.queue"; +local queue = require "prosody.util.queue"; local s_byte, s_sub = string.byte, string.sub; local dbuffer_methods = {}; @@ -91,18 +91,18 @@ function dbuffer_methods:read_until(char) end function dbuffer_methods:discard(requested_bytes) - if requested_bytes > self._length then - return nil; + if self._length == 0 then return true; end + if not requested_bytes or requested_bytes >= self._length then + self.front_consumed = 0; + self._length = 0; + for _ in self.items:consume() do end + return true; end local chunk, read_bytes = self:read_chunk(requested_bytes); - if chunk then - requested_bytes = requested_bytes - read_bytes; - if requested_bytes == 0 then -- Already read everything we need - return true; - end - else - return nil; + requested_bytes = requested_bytes - read_bytes; + if requested_bytes == 0 then -- Already read everything we need + return true; end while chunk do diff --git a/util/debug.lua b/util/debug.lua index 4c924d40..7a8312a9 100644 --- a/util/debug.lua +++ b/util/debug.lua @@ -12,7 +12,7 @@ local censored_names = { }; local optimal_line_length = 65; -local termcolours = require "util.termcolours"; +local termcolours = require "prosody.util.termcolours"; local getstring = termcolours.getstring; local styles; do diff --git a/util/dependencies.lua b/util/dependencies.lua index d7836404..30b53970 100644 --- a/util/dependencies.lua +++ b/util/dependencies.lua @@ -7,7 +7,6 @@ -- local function softreq(...) local ok, lib = pcall(require, ...); if ok then return lib; else return nil, lib; end end -local platform_table = require "util.human.io".table({ { width = 15, align = "right" }, { width = "100%" } }); -- Required to be able to find packages installed with luarocks if not softreq "luarocks.loader" then -- LuaRocks 2.x @@ -22,7 +21,7 @@ local function missingdep(name, sources, msg, err) -- luacheck: ignore err print("This package can be obtained in the following ways:"); print(""); for _, row in ipairs(sources) do - print(platform_table(row)); + print(string.format("%15s | %s", table.unpack(row))); end print(""); print(msg or (name.." is required for Prosody to run, so we will now exit.")); @@ -32,10 +31,10 @@ local function missingdep(name, sources, msg, err) -- luacheck: ignore err end local function check_dependencies() - if _VERSION < "Lua 5.1" then + if _VERSION < "Lua 5.2" then print "***********************************" print("Unsupported Lua version: ".._VERSION); - print("At least Lua 5.1 is required."); + print("At least Lua 5.2 is required."); print "***********************************" return false; end @@ -88,7 +87,7 @@ local function check_dependencies() }, nil, err); end - local bit, err = softreq"util.bitcompat"; + local bit, err = softreq"prosody.util.bitcompat"; if not bit then missingdep("lua-bitops", { @@ -106,16 +105,16 @@ local function check_dependencies() { "Source", "https://www.zash.se/luaunbound.html" }; }, "Old DNS resolver library will be used", err); else - package.preload["net.adns"] = function () - local ub = require "net.unbound"; + package.preload["prosody.net.adns"] = function () + local ub = require "prosody.net.unbound"; return ub; end end - local encodings, err = softreq "util.encodings" + local encodings, err = softreq "prosody.util.encodings" if not encodings then if err:match("module '[^']*' not found") then - missingdep("util.encodings", { + missingdep("prosody.util.encodings", { { "Windows", "Make sure you have encodings.dll from the Prosody distribution in util/" }; { "GNU/Linux", "Run './configure' and 'make' in the Prosody source directory to build util/encodings.so" }; }); @@ -130,10 +129,10 @@ local function check_dependencies() fatal = true; end - local hashes, err = softreq "util.hashes" + local hashes, err = softreq "prosody.util.hashes" if not hashes then if err:match("module '[^']*' not found") then - missingdep("util.hashes", { + missingdep("prosody.util.hashes", { { "Windows", "Make sure you have hashes.dll from the Prosody distribution in util/" }; { "GNU/Linux", "Run './configure' and 'make' in the Prosody source directory to build util/hashes.so" }; }); @@ -155,7 +154,7 @@ local function log_warnings() if _VERSION > "Lua 5.4" then prosody.log("warn", "Support for %s is experimental, please report any issues", _VERSION); elseif _VERSION < "Lua 5.2" then - prosody.log("warn", "%s has several issues and support is being phased out, consider upgrading", _VERSION); + prosody.log("warn", "%s support is deprecated, upgrade as soon as possible", _VERSION); end local ssl = softreq"ssl"; if ssl then diff --git a/util/dns.lua b/util/dns.lua index 3b58e03e..f113f97e 100644 --- a/util/dns.lua +++ b/util/dns.lua @@ -12,9 +12,9 @@ local s_byte = string.byte; local s_format = string.format; local s_sub = string.sub; -local iana_data = require "util.dnsregistry"; -local tohex = require "util.hex".encode; -local inet_ntop = require "util.net".ntop; +local iana_data = require "prosody.util.dnsregistry"; +local tohex = require "prosody.util.hex".encode; +local inet_ntop = require "prosody.util.net".ntop; -- Simplified versions of Waqas DNS parsers -- Only the per RR parsers are needed and only feed a single RR diff --git a/util/dnsregistry.lua b/util/dnsregistry.lua index 635b7e3a..b65debe0 100644 --- a/util/dnsregistry.lua +++ b/util/dnsregistry.lua @@ -1,5 +1,5 @@ -- Source: https://www.iana.org/assignments/dns-parameters/dns-parameters.xml --- Generated on 2022-02-02 +-- Generated on 2024-10-26 return { classes = { ["IN"] = 1; [1] = "IN"; @@ -61,7 +61,6 @@ return { ["NSEC3PARAM"] = 51; [51] = "NSEC3PARAM"; ["TLSA"] = 52; [52] = "TLSA"; ["SMIMEA"] = 53; [53] = "SMIMEA"; - ["Unassigned"] = 54; [54] = "Unassigned"; ["HIP"] = 55; [55] = "HIP"; ["NINFO"] = 56; [56] = "NINFO"; ["RKEY"] = 57; [57] = "RKEY"; @@ -80,6 +79,7 @@ return { ["LP"] = 107; [107] = "LP"; ["EUI48"] = 108; [108] = "EUI48"; ["EUI64"] = 109; [109] = "EUI64"; + ["NXNAME"] = 128; [128] = "NXNAME"; ["TKEY"] = 249; [249] = "TKEY"; ["TSIG"] = 250; [250] = "TSIG"; ["IXFR"] = 251; [251] = "IXFR"; @@ -92,6 +92,10 @@ return { ["AVC"] = 258; [258] = "AVC"; ["DOA"] = 259; [259] = "DOA"; ["AMTRELAY"] = 260; [260] = "AMTRELAY"; + ["RESINFO"] = 261; [261] = "RESINFO"; + ["WALLET"] = 262; [262] = "WALLET"; + ["CLA"] = 263; [263] = "CLA"; + ["IPN"] = 264; [264] = "IPN"; ["TA"] = 32768; [32768] = "TA"; ["DLV"] = 32769; [32769] = "DLV"; }; diff --git a/util/envload.lua b/util/envload.lua index 6182a1f9..cf45b702 100644 --- a/util/envload.lua +++ b/util/envload.lua @@ -6,38 +6,19 @@ -- -- luacheck: ignore 113/setfenv 113/loadstring -local load, loadstring, setfenv = load, loadstring, setfenv; +local load = load; local io_open = io.open; -local envload; -local envloadfile; -if setfenv then - function envload(code, source, env) - local f, err = loadstring(code, source); - if f and env then setfenv(f, env); end - return f, err; - end - - function envloadfile(file, env) - local fh, err, errno = io_open(file); - if not fh then return fh, err, errno; end - local f, err = load(function () return fh:read(2048); end, "@"..file); - fh:close(); - if f and env then setfenv(f, env); end - return f, err; - end -else - function envload(code, source, env) - return load(code, source, nil, env); - end +local function envload(code, source, env) + return load(code, source, nil, env); +end - function envloadfile(file, env) - local fh, err, errno = io_open(file); - if not fh then return fh, err, errno; end - local f, err = load(fh:lines(2048), "@"..file, nil, env); - fh:close(); - return f, err; - end +local function envloadfile(file, env) + local fh, err, errno = io_open(file); + if not fh then return fh, err, errno; end + local f, err = load(fh:lines(2048), "@" .. file, nil, env); + fh:close(); + return f, err; end return { envload = envload, envloadfile = envloadfile }; diff --git a/util/error.lua b/util/error.lua index b83f81e5..64c742ae 100644 --- a/util/error.lua +++ b/util/error.lua @@ -1,4 +1,4 @@ -local id = require "util.id"; +local id = require "prosody.util.id"; local util_debug; -- only imported on-demand @@ -19,7 +19,7 @@ local function configure(opt) if opt.auto_inject_traceback ~= nil then auto_inject_traceback = opt.auto_inject_traceback; if auto_inject_traceback then - util_debug = require "util.debug"; + util_debug = require "prosody.util.debug"; end end end diff --git a/util/format.lua b/util/format.lua index d709aada..f95f0575 100644 --- a/util/format.lua +++ b/util/format.lua @@ -6,14 +6,20 @@ -- Provides some protection from e.g. CAPEC-135, CWE-117, CWE-134, CWE-93 local tostring = tostring; -local unpack = table.unpack or unpack; -- luacheck: ignore 113/unpack -local pack = require "util.table".pack; -- TODO table.pack in 5.2+ -local valid_utf8 = require "util.encodings".utf8.valid; +local unpack = table.unpack; +local pack = table.pack; +local valid_utf8 = require "prosody.util.encodings".utf8.valid; local type = type; -local dump = require "util.serialization".new("debug"); -local num_type = math.type or function (n) - return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float"; -end +local dump = require"prosody.util.serialization".new({ + preset = "compact"; + fallback = function(v, why) + return "_[[" .. (why or tostring(v)) .. "]] "; + end; + freeze = true; + fatal = false; + maxdepth = 5; +}); +local num_type = math.type; -- In Lua 5.3+ these formats throw an error if given a float local expects_integer = { c = true, d = true, i = true, o = true, u = true, X = true, x = true, }; @@ -35,7 +41,6 @@ local control_symbols = { ["\030"] = "\226\144\158", ["\031"] = "\226\144\159", ["\127"] = "\226\144\161", }; local supports_p = pcall(string.format, "%p", ""); -- >= Lua 5.4 -local supports_a = pcall(string.format, "%a", 0.0); -- > Lua 5.1 local function format(formatstring, ...) local args = pack(...); @@ -93,8 +98,6 @@ local function format(formatstring, ...) elseif expects_positive[option] and arg < 0 then args[i] = tostring(arg); return "[%s]"; - elseif (option == "a" or option == "A") and not supports_a then - return "%x"; else return -- acceptable number end diff --git a/util/fsm.lua b/util/fsm.lua new file mode 100644 index 00000000..0afc2d16 --- /dev/null +++ b/util/fsm.lua @@ -0,0 +1,154 @@ +local events = require "prosody.util.events"; + +local fsm_methods = {}; +local fsm_mt = { __index = fsm_methods }; + +local function is_fsm(o) + local mt = getmetatable(o); + return mt == fsm_mt; +end + +local function notify_transition(fire_event, transition_event) + local ret; + ret = fire_event("transition", transition_event); + if ret ~= nil then return ret; end + if transition_event.from ~= transition_event.to then + ret = fire_event("leave/"..transition_event.from, transition_event); + if ret ~= nil then return ret; end + end + ret = fire_event("transition/"..transition_event.name, transition_event); + if ret ~= nil then return ret; end +end + +local function notify_transitioned(fire_event, transition_event) + if transition_event.to ~= transition_event.from then + fire_event("enter/"..transition_event.to, transition_event); + end + if transition_event.name then + fire_event("transitioned/"..transition_event.name, transition_event); + end + fire_event("transitioned", transition_event); +end + +local function do_transition(name) + return function (self, attr) + local new_state = self.fsm.states[self.state][name] or self.fsm.states["*"][name]; + if not new_state then + return error(("Invalid state transition: %s cannot %s"):format(self.state, name)); + end + + local transition_event = { + instance = self; + + name = name; + to = new_state; + to_attr = attr; + + from = self.state; + from_attr = self.state_attr; + }; + + local fire_event = self.fsm.events.fire_event; + local ret = notify_transition(fire_event, transition_event); + if ret ~= nil then return nil, ret; end + + self.state = new_state; + self.state_attr = attr; + + notify_transitioned(fire_event, transition_event); + return true; + end; +end + +local function new(desc) + local self = setmetatable({ + default_state = desc.default_state; + events = events.new(); + }, fsm_mt); + + -- states[state_name][transition_name] = new_state_name + local states = { ["*"] = {} }; + if desc.default_state then + states[desc.default_state] = {}; + end + self.states = states; + + local instance_methods = {}; + self._instance_mt = { __index = instance_methods }; + + for _, transition in ipairs(desc.transitions or {}) do + local from_states = transition.from; + if type(from_states) ~= "table" then + from_states = { from_states }; + end + for _, from in ipairs(from_states) do + if not states[from] then + states[from] = {}; + end + if not states[transition.to] then + states[transition.to] = {}; + end + if states[from][transition.name] then + return error(("Duplicate transition in FSM specification: %s from %s"):format(transition.name, from)); + end + states[from][transition.name] = transition.to; + end + + -- Add public method to trigger this transition + instance_methods[transition.name] = do_transition(transition.name); + end + + if desc.state_handlers then + for state_name, handler in pairs(desc.state_handlers) do + self.events.add_handler("enter/"..state_name, handler); + end + end + + if desc.transition_handlers then + for transition_name, handler in pairs(desc.transition_handlers) do + self.events.add_handler("transition/"..transition_name, handler); + end + end + + if desc.handlers then + self.events.add_handlers(desc.handlers); + end + + return self; +end + +function fsm_methods:init(state_name, state_attr) + local initial_state = assert(state_name or self.default_state, "no initial state specified"); + if not self.states[initial_state] then + return error("Invalid initial state: "..initial_state); + end + local instance = setmetatable({ + fsm = self; + state = initial_state; + state_attr = state_attr; + }, self._instance_mt); + + if initial_state ~= self.default_state then + local fire_event = self.events.fire_event; + notify_transitioned(fire_event, { + instance = instance; + + to = initial_state; + to_attr = state_attr; + + from = self.default_state; + }); + end + + return instance; +end + +function fsm_methods:is_instance(o) + local mt = getmetatable(o); + return mt == self._instance_mt; +end + +return { + new = new; + is_fsm = is_fsm; +}; diff --git a/util/gc.lua b/util/gc.lua index f46e4346..7fcf7546 100644 --- a/util/gc.lua +++ b/util/gc.lua @@ -1,4 +1,4 @@ -local set = require "util.set"; +local set = require "prosody.util.set"; local known_options = { incremental = set.new { "mode", "threshold", "speed", "step_size" }; diff --git a/util/hashring.lua b/util/hashring.lua index d4555669..a1debb53 100644 --- a/util/hashring.lua +++ b/util/hashring.lua @@ -1,3 +1,5 @@ +local it = require "prosody.util.iterators"; + local function generate_ring(nodes, num_replicas, hash) local new_ring = {}; for _, node_name in ipairs(nodes) do @@ -28,18 +30,22 @@ local function new(num_replicas, hash_function) return setmetatable({ nodes = {}, num_replicas = num_replicas, hash = hash_function }, hashring_mt); end; -function hashring_methods:add_node(name) +function hashring_methods:add_node(name, value) self.ring = nil; - self.nodes[name] = true; + self.nodes[name] = value == nil and true or value; table.insert(self.nodes, name); return true; end function hashring_methods:add_nodes(nodes) self.ring = nil; - for _, node_name in ipairs(nodes) do - if not self.nodes[node_name] then - self.nodes[node_name] = true; + local iter = pairs; + if nodes[1] then -- simple array? + iter = it.values; + end + for node_name, node_value in iter(nodes) do + if self.nodes[node_name] == nil then + self.nodes[node_name] = node_value == nil and true or node_value; table.insert(self.nodes, node_name); end end @@ -48,7 +54,7 @@ end function hashring_methods:remove_node(node_name) self.ring = nil; - if self.nodes[node_name] then + if self.nodes[node_name] ~= nil then for i, stored_node_name in ipairs(self.nodes) do if node_name == stored_node_name then self.nodes[node_name] = nil; @@ -69,18 +75,26 @@ end function hashring_methods:clone() local clone_hashring = new(self.num_replicas, self.hash); - clone_hashring:add_nodes(self.nodes); + for node_name, node_value in pairs(self.nodes) do + clone_hashring.nodes[node_name] = node_value; + end + clone_hashring.ring = nil; return clone_hashring; end function hashring_methods:get_node(key) + local node; local key_hash = self.hash(key); for _, replica_hash in ipairs(self.ring) do if key_hash < replica_hash then - return self.ring[replica_hash]; + node = self.ring[replica_hash]; + break; end end - return self.ring[self.ring[1]]; + if not node then + node = self.ring[self.ring[1]]; + end + return node, self.nodes[node]; end return { diff --git a/util/helpers.lua b/util/helpers.lua index 139b62ec..da680cae 100644 --- a/util/helpers.lua +++ b/util/helpers.lua @@ -6,11 +6,11 @@ -- COPYING file in the source package for more information. -- -local debug = require "util.debug"; +local debug = require "prosody.util.debug"; -- Helper functions for debugging -local log = require "util.logger".init("util.debug"); +local log = require "prosody.util.logger".init("util.debug"); local function log_events(events, name, logger) local f = events.fire_event; diff --git a/util/hmac.lua b/util/hmac.lua index 4cad17cc..c782da3c 100644 --- a/util/hmac.lua +++ b/util/hmac.lua @@ -8,11 +8,15 @@ -- COMPAT: Only for external pre-0.9 modules -local hashes = require "util.hashes" +local hashes = require "prosody.util.hashes" return { md5 = hashes.hmac_md5, sha1 = hashes.hmac_sha1, + sha224 = hashes.hmac_sha224, sha256 = hashes.hmac_sha256, + sha384 = hashes.hmac_sha384, sha512 = hashes.hmac_sha512, + blake2s256 = hashes.hmac_blake2s256, + blake2b512 = hashes.hmac_blake2b512, }; diff --git a/util/http.lua b/util/http.lua index 3852f91c..b21bf798 100644 --- a/util/http.lua +++ b/util/http.lua @@ -69,9 +69,42 @@ local function normalize_path(path, is_dir) return path; end +--- Parse the RFC 7239 Forwarded header into array of key-value pairs. +local function parse_forwarded(forwarded) + if type(forwarded) ~= "string" then + return nil; + end + + local fwd = {}; -- array + local cur = {}; -- map, to which we add the next key-value pair + for key, quoted, value, delim in forwarded:gmatch("(%w+)%s*=%s*(\"?)([^,;\"]+)%2%s*(.?)") do + -- FIXME quoted quotes like "foo\"bar" + -- unlikely when only dealing with IP addresses + if quoted == '"' then + value = value:gsub("\\(.)", "%1"); + end + + cur[key:lower()] = value; + if delim == "" or delim == "," then + t_insert(fwd, cur) + if delim == "" then + -- end of the string + break; + end + cur = {}; + elseif delim ~= ";" then + -- misparsed + return false; + end + end + + return fwd; +end + return { urlencode = urlencode, urldecode = urldecode; formencode = formencode, formdecode = formdecode; contains_token = contains_token; normalize_path = normalize_path; + parse_forwarded = parse_forwarded; }; diff --git a/util/human/io.lua b/util/human/io.lua index 7d7dea97..d6112b3b 100644 --- a/util/human/io.lua +++ b/util/human/io.lua @@ -1,5 +1,6 @@ -local array = require "util.array"; -local utf8 = rawget(_G, "utf8") or require"util.encodings".utf8; +local array = require "prosody.util.array"; +local pposix = require "prosody.util.pposix"; +local utf8 = rawget(_G, "utf8") or require"prosody.util.encodings".utf8; local len = utf8.len or function(s) local _, count = s:gsub("[%z\001-\127\194-\253][\128-\191]*", ""); return count; @@ -8,7 +9,7 @@ end; local function getchar(n) local stty_ret = os.execute("stty raw -echo 2>/dev/null"); local ok, char; - if stty_ret == true or stty_ret == 0 then + if stty_ret then ok, char = pcall(io.read, n or 1); os.execute("stty sane"); else @@ -30,15 +31,12 @@ local function getline() end local function getpass() - local stty_ret, _, status_code = os.execute("stty -echo 2>/dev/null"); - if status_code then -- COMPAT w/ Lua 5.1 - stty_ret = status_code; - end - if stty_ret ~= 0 then + local stty_ret = os.execute("stty -echo 2>/dev/null"); + if not stty_ret then io.write("\027[08m"); -- ANSI 'hidden' text attribute end local ok, pass = pcall(io.read, "*l"); - if stty_ret == 0 then + if stty_ret then os.execute("stty sane"); else io.write("\027[00m"); @@ -111,14 +109,30 @@ if utf8.len and utf8.offset then end end +local function term_width(default) + local env_cols = tonumber(os.getenv "COLUMNS"); + if env_cols then return env_cols; end + if not pposix.isatty(io.stdout) then + return default; + end + local stty = io.popen("stty -a"); + if not stty then return default; end + local result = stty:read("*a"); + if result then + result = result:match("%f[%w]columns[ =]*(%d+)"); + end + stty:close(); + return tonumber(result or default); +end + local function ellipsis(s, width) if len(s) <= width then return s; end - if width == 1 then return "…"; end + if width <= 1 then return "…"; end return utf8_cut(s, width - 1) .. "…"; end local function new_table(col_specs, max_width) - max_width = max_width or tonumber(os.getenv("COLUMNS")) or 80; + max_width = max_width or term_width(80); local separator = " | "; local widths = {}; @@ -127,21 +141,28 @@ local function new_table(col_specs, max_width) -- Calculate width of fixed-size columns for i = 1, #col_specs do local width = col_specs[i].width or "0"; - if not(type(width) == "string" and width:sub(-1) == "%") then + if not (type(width) == "string" and width:match("[p%%]$")) then local title = col_specs[i].title; width = math.max(tonumber(width), title and (#title+1) or 0); widths[i] = width; free_width = free_width - width; - if i > 1 then - free_width = free_width - #separator; - end end end - -- Calculate width of %-based columns + + -- Calculate width of proportional columns + local total_proportional_width = 0; for i = 1, #col_specs do if not widths[i] then - local pc_width = tonumber((col_specs[i].width:gsub("%%$", ""))); - widths[i] = math.floor(free_width*(pc_width/100)); + local width_spec = col_specs[i].width:match("([%d%.]+)[p%%]"); + total_proportional_width = total_proportional_width + tonumber(width_spec); + end + end + + for i = 1, #col_specs do + if not widths[i] then + local width_spec = col_specs[i].width:match("([%d%.]+)[p%%]"); + local rel_width = tonumber(width_spec); + widths[i] = math.floor(free_width*(rel_width/total_proportional_width)); end end @@ -155,7 +176,7 @@ local function new_table(col_specs, max_width) local width = widths[i]; local v = row[not titles and column.key or i]; if not titles and column.mapper then - v = column.mapper(v, row); + v = column.mapper(v, row, width, column); end if v == nil then v = column.default or ""; @@ -169,12 +190,36 @@ local function new_table(col_specs, max_width) v = padright(v, width); end elseif len(v) > width then - v = ellipsis(v, width); + v = (column.ellipsis or ellipsis)(v, width); end table.insert(output, v); end return table.concat(output, separator); - end; + end, max_width; +end + +local day = 86400; +local multipliers = { + d = day, w = day * 7, mon = 31 * day, y = 365.2425 * day; + s = 1, min = 60, h = 3600, ho = 3600 +}; + +local function parse_duration(duration_string) + local n, m = duration_string:lower():match("(%d+)%s*([smhdwy]?[io]?n?)"); + if not n or not multipliers[m] then return nil; end + return tonumber(n) * ( multipliers[m] or 1 ); +end + +local multipliers_lax = setmetatable({ + m = multipliers.mon; + mo = multipliers.mon; + mi = multipliers.min; +}, { __index = multipliers }); + +local function parse_duration_lax(duration_string) + local n, m = duration_string:lower():match("(%d+)%s*([smhdwy]?[io]?)"); + if not n then return nil; end + return tonumber(n) * ( multipliers_lax[m] or 1 ); end return { @@ -187,6 +232,9 @@ return { printf = printf; padleft = padleft; padright = padright; + term_width = term_width; ellipsis = ellipsis; table = new_table; + parse_duration = parse_duration; + parse_duration_lax = parse_duration_lax; }; diff --git a/util/human/units.lua b/util/human/units.lua index af233e98..329c8518 100644 --- a/util/human/units.lua +++ b/util/human/units.lua @@ -4,15 +4,7 @@ local math_floor = math.floor; local math_log = math.log; local math_max = math.max; local math_min = math.min; -local unpack = table.unpack or unpack; --luacheck: ignore 113 - -if math_log(10, 10) ~= 1 then - -- Lua 5.1 COMPAT - local log10 = math.log10; - function math_log(n, base) - return log10(n) / log10(base); - end -end +local unpack = table.unpack; local large = { "k", 1000, diff --git a/util/id.lua b/util/id.lua index ff4e919d..70b0cecc 100644 --- a/util/id.lua +++ b/util/id.lua @@ -8,8 +8,8 @@ -- local s_gsub = string.gsub; -local random_bytes = require "util.random".bytes; -local base64_encode = require "util.encodings".base64.encode; +local random_bytes = require "prosody.util.random".bytes; +local base64_encode = require "prosody.util.encodings".base64.encode; local b64url = { ["+"] = "-", ["/"] = "_", ["="] = "" }; local function b64url_random(len) diff --git a/util/import.lua b/util/import.lua index 1007bc0a..0892e9b1 100644 --- a/util/import.lua +++ b/util/import.lua @@ -8,7 +8,7 @@ -local unpack = table.unpack or unpack; --luacheck: ignore 113 +local unpack = table.unpack; local t_insert = table.insert; function _G.import(module, ...) local m = package.loaded[module] or require(module); diff --git a/util/ip.lua b/util/ip.lua index 4b450934..d820e72d 100644 --- a/util/ip.lua +++ b/util/ip.lua @@ -5,8 +5,8 @@ -- COPYING file in the source package for more information. -- -local net = require "util.net"; -local hex = require "util.hex"; +local net = require "prosody.util.net"; +local strbit = require "prosody.util.strbitop"; local ip_methods = {}; @@ -28,13 +28,6 @@ ip_mt.__eq = function (ipA, ipB) return ipA.packed == ipB.packed; end -local hex2bits = { - ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", - ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111", - ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011", - ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111", -}; - local function new_ip(ipStr, proto) local zone; if (not proto or proto == "IPv6") and ipStr:find('%', 1, true) then @@ -66,27 +59,18 @@ function ip_methods:normal() return net.ntop(self.packed); end -function ip_methods.bits(ip) - return hex.encode(ip.packed):upper():gsub(".", hex2bits); -end - -function ip_methods.bits_full(ip) +-- Returns the longest packed representation, i.e. IPv4 will be mapped +function ip_methods.packed_full(ip) if ip.proto == "IPv4" then ip = ip.toV4mapped; end - return ip.bits; + return ip.packed; end local match; local function commonPrefixLength(ipA, ipB) - ipA, ipB = ipA.bits_full, ipB.bits_full; - for i = 1, 128 do - if ipA:sub(i,i) ~= ipB:sub(i,i) then - return i-1; - end - end - return 128; + return strbit.common_prefix_bits(ipA.packed_full, ipB.packed_full); end -- Instantiate once @@ -238,7 +222,22 @@ function match(ipA, ipB, bits) bits = bits + (128 - 32); end end - return ipA.bits:sub(1, bits) == ipB.bits:sub(1, bits); + return strbit.common_prefix_bits(ipA.packed, ipB.packed) >= bits; +end + +local function is_ip(obj) + return getmetatable(obj) == ip_mt; +end + +local function truncate(ip, n_bits) + if n_bits % 8 ~= 0 then + return error("ip.truncate() only supports multiples of 8 bits"); + end + local n_octets = n_bits / 8; + if not is_ip(ip) then + ip = new_ip(ip); + end + return new_ip(net.ntop(ip.packed:sub(1, n_octets)..("\0"):rep(#ip.packed-n_octets))) end return { @@ -246,4 +245,6 @@ return { commonPrefixLength = commonPrefixLength, parse_cidr = parse_cidr, match = match, + is_ip = is_ip; + truncate = truncate; }; diff --git a/util/iterators.lua b/util/iterators.lua index c03c2fd6..eb4c54af 100644 --- a/util/iterators.lua +++ b/util/iterators.lua @@ -12,8 +12,8 @@ local it = {}; local t_insert = table.insert; local next = next; -local unpack = table.unpack or unpack; --luacheck: ignore 113 -local pack = table.pack or require "util.table".pack; +local unpack = table.unpack; +local pack = table.pack; local type = type; local table, setmetatable = table, setmetatable; @@ -240,7 +240,8 @@ function join_methods:prepend(f, s, var) end function it.join(f, s, var) - return setmetatable({ {f, s, var} }, join_mt); + local t = setmetatable({ {f, s, var} }, join_mt); + return t, { t, 1 }; end return it; diff --git a/util/jid.lua b/util/jid.lua index 694a6b1f..2c3436ca 100644 --- a/util/jid.lua +++ b/util/jid.lua @@ -10,9 +10,9 @@ local select = select; local match, sub = string.match, string.sub; -local nodeprep = require "util.encodings".stringprep.nodeprep; -local nameprep = require "util.encodings".stringprep.nameprep; -local resourceprep = require "util.encodings".stringprep.resourceprep; +local nodeprep = require "prosody.util.encodings".stringprep.nodeprep; +local nameprep = require "prosody.util.encodings".stringprep.nameprep; +local resourceprep = require "prosody.util.encodings".stringprep.resourceprep; local escapes = { [" "] = "\\20"; ['"'] = "\\22"; @@ -35,8 +35,7 @@ local function split(jid) if jid == nil then return; end local node, nodepos = match(jid, "^([^@/]+)@()"); local host, hostpos = match(jid, "^([^@/]+)()", nodepos); - if node ~= nil and host == nil then return nil, nil, nil; end - local resource = match(jid, "^/(.+)$", hostpos); + local resource = host and match(jid, "^/(.+)$", hostpos); if (host == nil) or ((resource == nil) and #jid >= hostpos) then return nil, nil, nil; end return node, host, resource; end @@ -91,9 +90,9 @@ local function compare(jid, acl) -- TODO compare to table of rules? local jid_node, jid_host, jid_resource = split(jid); local acl_node, acl_host, acl_resource = split(acl); - if ((acl_node ~= nil and acl_node == jid_node) or acl_node == nil) and - ((acl_host ~= nil and acl_host == jid_host) or acl_host == nil) and - ((acl_resource ~= nil and acl_resource == jid_resource) or acl_resource == nil) then + if (acl_node == nil or acl_node == jid_node) and + (acl_host == nil or acl_host == jid_host) and + (acl_resource == nil or acl_resource == jid_resource) then return true end return false @@ -111,6 +110,7 @@ local function resource(jid) return (select(3, split(jid))); end +-- TODO Forbid \20 at start and end of escaped output per XEP-0106 v1.1 local function escape(s) return s and (s:gsub("\\%x%x", backslash_escapes):gsub("[\"&'/:<>@ ]", escapes)); end local function unescape(s) return s and (s:gsub("\\%x%x", unescapes)); end diff --git a/util/json.lua b/util/json.lua index e6704b7e..369b9762 100644 --- a/util/json.lua +++ b/util/json.lua @@ -10,12 +10,12 @@ local type = type; local t_insert, t_concat, t_remove = table.insert, table.concat, table.remove; local s_char = string.char; local tostring, tonumber = tostring, tonumber; -local pairs, ipairs, spairs = pairs, ipairs, require "util.iterators".sorted_pairs; +local pairs, ipairs, spairs = pairs, ipairs, require "prosody.util.iterators".sorted_pairs; local next = next; local getmetatable, setmetatable = getmetatable, setmetatable; local print = print; -local has_array, array = pcall(require, "util.array"); +local has_array, array = pcall(require, "prosody.util.array"); local array_mt = has_array and getmetatable(array()) or {}; --module("json") diff --git a/util/jsonpointer.lua b/util/jsonpointer.lua index 9b871ae7..f1c354a4 100644 --- a/util/jsonpointer.lua +++ b/util/jsonpointer.lua @@ -1,6 +1,4 @@ -local m_type = math.type or function (n) - return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float"; -end; +local m_type = math.type; local function unescape_token(escaped_token) local unescaped = escaped_token:gsub("~1", "/"):gsub("~0", "~") diff --git a/util/jsonschema.lua b/util/jsonschema.lua index eafa8b7c..d7a2f9c9 100644 --- a/util/jsonschema.lua +++ b/util/jsonschema.lua @@ -1,16 +1,19 @@ -- This file is generated from teal-src/util/jsonschema.lua -local m_type = function(n) - return type(n) == "number" and n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float"; -end; -local json = require("util.json") -local null = json.null; +if not math.type then + require("prosody.util.mathcompat") +end -local pointer = require("util.jsonpointer") +local utf8_enc = rawget(_G, "utf8") or require("prosody.util.encodings").utf8; +local utf8_len = utf8_enc.len or function(s) + local _, count = s:gsub("[%z\001-\127\194-\253][\128-\191]*", ""); + return count +end; -local json_type_name = json.json_type_name +local json = require("prosody.util.json") +local null = json.null; -local schema_t = {} +local pointer = require("prosody.util.jsonpointer") local json_schema_object = { xml_t = {} } @@ -22,7 +25,7 @@ local function simple_validate(schema, data) elseif schema == "array" and type(data) == "table" then return type(data) == "table" and (next(data) == nil or type((next(data, nil))) == "number") elseif schema == "integer" then - return m_type(data) == schema + return math.type(data) == schema elseif schema == "null" then return data == null elseif type(schema) == "table" then @@ -37,33 +40,35 @@ local function simple_validate(schema, data) end end -local complex_validate +local function mkerr(sloc, iloc, err) + return { schemaLocation = sloc; instanceLocation = iloc; error = err } +end -local function validate(schema, data, root) +local function validate(schema, data, root, sloc, iloc, errs) if type(schema) == "boolean" then return schema - else - return complex_validate(schema, data, root) end -end - -function complex_validate(schema, data, root) if root == nil then root = schema + iloc = "" + sloc = "" + errs = {}; end if schema["$ref"] and schema["$ref"]:sub(1, 1) == "#" then local referenced = pointer.resolve(root, schema["$ref"]:sub(2)) if referenced ~= nil and referenced ~= root and referenced ~= schema then - if not validate(referenced, data, root) then - return false + if not validate(referenced, data, root, schema["$ref"], iloc, errs) then + table.insert(errs, mkerr(sloc .. "/$ref", iloc, "Subschema failed validation")) + return false, errs end end end if not simple_validate(schema.type, data) then - return false + table.insert(errs, mkerr(sloc .. "/type", iloc, "unexpected type")); + return false, errs end if schema.type == "object" then @@ -71,7 +76,8 @@ function complex_validate(schema, data, root) for k in pairs(data) do if not (type(k) == "string") then - return false + table.insert(errs, mkerr(sloc .. "/type", iloc, "'object' had non-string keys")); + return false, errs end end end @@ -81,8 +87,9 @@ function complex_validate(schema, data, root) if type(data) == "table" then for i in pairs(data) do - if not (m_type(i) == "integer") then - return false + if not (math.type(i) == "integer") then + table.insert(errs, mkerr(sloc .. "/type", iloc, "'array' had non-integer keys")); + return false, errs end end end @@ -98,146 +105,217 @@ function complex_validate(schema, data, root) end end if not match then - return false + table.insert(errs, mkerr(sloc .. "/enum", iloc, "not one of the enumerated values")); + return false, errs end end if type(data) == "string" then - if schema.maxLength and #data > schema.maxLength then - return false + if schema.maxLength and utf8_len(data) > schema.maxLength then + table.insert(errs, mkerr(sloc .. "/maxLength", iloc, "string too long")) + return false, errs + end + if schema.minLength and utf8_len(data) < schema.minLength then + table.insert(errs, mkerr(sloc .. "/maxLength", iloc, "string too short")) + return false, errs end - if schema.minLength and #data < schema.minLength then - return false + if schema.luaPattern and not data:match(schema.luaPattern) then + table.insert(errs, mkerr(sloc .. "/luaPattern", iloc, "string does not match pattern")) + return false, errs end end if type(data) == "number" then if schema.multipleOf and (data == 0 or data % schema.multipleOf ~= 0) then - return false + table.insert(errs, mkerr(sloc .. "/luaPattern", iloc, "not a multiple")) + return false, errs end if schema.maximum and not (data <= schema.maximum) then - return false + table.insert(errs, mkerr(sloc .. "/maximum", iloc, "number exceeds maximum")) + return false, errs end if schema.exclusiveMaximum and not (data < schema.exclusiveMaximum) then - return false + table.insert(errs, mkerr(sloc .. "/exclusiveMaximum", iloc, "number exceeds exclusive maximum")) + return false, errs end if schema.minimum and not (data >= schema.minimum) then - return false + table.insert(errs, mkerr(sloc .. "/minimum", iloc, "number below minimum")) + return false, errs end if schema.exclusiveMinimum and not (data > schema.exclusiveMinimum) then - return false + table.insert(errs, mkerr(sloc .. "/exclusiveMinimum", iloc, "number below exclusive minimum")) + return false, errs end end if schema.allOf then - for _, sub in ipairs(schema.allOf) do - if not validate(sub, data, root) then - return false + for i, sub in ipairs(schema.allOf) do + if not validate(sub, data, root, sloc .. "/allOf/" .. i, iloc, errs) then + table.insert(errs, mkerr(sloc .. "/allOf", iloc, "did not match all subschemas")) + return false, errs end end end if schema.oneOf then local valid = 0 - for _, sub in ipairs(schema.oneOf) do - if validate(sub, data, root) then + for i, sub in ipairs(schema.oneOf) do + if validate(sub, data, root, sloc .. "/oneOf" .. i, iloc, errs) then valid = valid + 1 end end if valid ~= 1 then - return false + table.insert(errs, mkerr(sloc .. "/oneOf", iloc, "did not match exactly one subschema")) + return false, errs end end if schema.anyOf then local match = false - for _, sub in ipairs(schema.anyOf) do - if validate(sub, data, root) then + for i, sub in ipairs(schema.anyOf) do + if validate(sub, data, root, sloc .. "/anyOf/" .. i, iloc, errs) then match = true break end end if not match then - return false + table.insert(errs, mkerr(sloc .. "/anyOf", iloc, "did not match any subschema")) + return false, errs end end if schema["not"] then - if validate(schema["not"], data, root) then - return false + if validate(schema["not"], data, root, sloc .. "/not", iloc, errs) then + table.insert(errs, mkerr(sloc .. "/not", iloc, "did match subschema")) + return false, errs end end if schema["if"] ~= nil then - if validate(schema["if"], data, root) then + if validate(schema["if"], data, root, sloc .. "/if", iloc, errs) then if schema["then"] then - return validate(schema["then"], data, root) + if not validate(schema["then"], data, root, sloc .. "/then", iloc, errs) then + table.insert(errs, mkerr(sloc .. "/then", iloc, "did not match subschema")) + return false, errs + end end else if schema["else"] then - return validate(schema["else"], data, root) + if not validate(schema["else"], data, root, sloc .. "/else", iloc, errs) then + table.insert(errs, mkerr(sloc .. "/else", iloc, "did not match subschema")) + return false, errs + end end end end if schema.const ~= nil and schema.const ~= data then - return false + table.insert(errs, mkerr(sloc .. "/const", iloc, "did not match constant value")) + return false, errs end if type(data) == "table" then - if schema.maxItems and #data > schema.maxItems then - return false + if schema.maxItems and #(data) > schema.maxItems then + table.insert(errs, mkerr(sloc .. "/maxItems", iloc, "too many items")) + return false, errs end - if schema.minItems and #data < schema.minItems then - return false + if schema.minItems and #(data) < schema.minItems then + table.insert(errs, mkerr(sloc .. "/minItems", iloc, "too few items")) + return false, errs end if schema.required then for _, k in ipairs(schema.required) do if data[k] == nil then - return false + table.insert(errs, mkerr(sloc .. "/required", iloc .. "/" .. tostring(k), "missing required property")) + return false, errs + end + end + end + + if schema.dependentRequired then + for k, reqs in pairs(schema.dependentRequired) do + if data[k] ~= nil then + for _, req in ipairs(reqs) do + if data[req] == nil then + table.insert(errs, mkerr(sloc .. "/dependentRequired", iloc, "missing dependent required property")) + return false, errs + end + end end end end if schema.propertyNames ~= nil then + for k in pairs(data) do - if not validate(schema.propertyNames, k, root) then - return false + if not validate(schema.propertyNames, k, root, sloc .. "/propertyNames", iloc .. "/" .. tostring(k), errs) then + table.insert(errs, mkerr(sloc .. "/propertyNames", iloc .. "/" .. tostring(k), "a property name did not match subschema")) + return false, errs end end end + local seen_properties = {} + if schema.properties then for k, sub in pairs(schema.properties) do - if data[k] ~= nil and not validate(sub, data[k], root) then - return false + if data[k] ~= nil and not validate(sub, data[k], root, sloc .. "/" .. tostring(k), iloc .. "/" .. tostring(k), errs) then + table.insert(errs, mkerr(sloc .. "/" .. tostring(k), iloc .. "/" .. tostring(k), "a property did not match subschema")) + return false, errs + end + seen_properties[k] = true + end + end + + if schema.luaPatternProperties then + + for pattern, sub in pairs(schema.luaPatternProperties) do + for k in pairs(data) do + if type(k) == "string" and k:match(pattern) then + if not validate(sub, data[k], root, sloc .. "/luaPatternProperties", iloc, errs) then + table.insert(errs, mkerr(sloc .. "/luaPatternProperties/" .. pattern, iloc .. "/" .. tostring(k), "a property did not match subschema")) + return false, errs + end + seen_properties[k] = true + end end end end if schema.additionalProperties ~= nil then for k, v in pairs(data) do - if schema.properties == nil or schema.properties[k] == nil then - if not validate(schema.additionalProperties, v, root) then - return false + if not seen_properties[k] then + if not validate(schema.additionalProperties, v, root, sloc .. "/additionalProperties", iloc .. "/" .. tostring(k), errs) then + table.insert(errs, mkerr(sloc .. "/additionalProperties", iloc .. "/" .. tostring(k), "additional property did not match subschema")) + return false, errs end end end end + if schema.dependentSchemas then + for k, sub in pairs(schema.dependentSchemas) do + if data[k] ~= nil and not validate(sub, data, root, sloc .. "/dependentSchemas/" .. k, iloc, errs) then + table.insert(errs, mkerr(sloc .. "/dependentSchemas", iloc .. "/" .. tostring(k), "did not match dependent subschema")) + return false, errs + end + end + end + if schema.uniqueItems then local values = {} for _, v in pairs(data) do if values[v] then - return false + table.insert(errs, mkerr(sloc .. "/uniqueItems", iloc, "had duplicate items")) + return false, errs end values[v] = true end @@ -248,32 +326,39 @@ function complex_validate(schema, data, root) for i, s in ipairs(schema.prefixItems) do if data[i] == nil then break - elseif validate(s, data[i], root) then + elseif validate(s, data[i], root, sloc .. "/prefixItems/" .. i, iloc .. "/" .. i, errs) then p = i else - return false + table.insert(errs, mkerr(sloc .. "/prefixItems/" .. i, iloc .. "/" .. tostring(i), "did not match subschema")) + return false, errs end end end if schema.items ~= nil then - for i = p + 1, #data do - if not validate(schema.items, data[i], root) then - return false + for i = p + 1, #(data) do + if not validate(schema.items, data[i], root, sloc, iloc .. "/" .. i, errs) then + table.insert(errs, mkerr(sloc .. "/prefixItems/" .. i, iloc .. "/" .. i, "did not match subschema")) + return false, errs end end end if schema.contains ~= nil then - local found = false - for i = 1, #data do - if validate(schema.contains, data[i], root) then - found = true - break + local found = 0 + for i = 1, #(data) do + if validate(schema.contains, data[i], root, sloc .. "/contains", iloc .. "/" .. i, errs) then + found = found + 1 + else + table.insert(errs, mkerr(sloc .. "/contains", iloc .. "/" .. i, "did not match subschema")) end end - if not found then - return false + if found < (schema.minContains or 1) then + table.insert(errs, mkerr(sloc .. "/minContains", iloc, "too few matches")) + return false, errs + elseif found > (schema.maxContains or math.huge) then + table.insert(errs, mkerr(sloc .. "/maxContains", iloc, "too many matches")) + return false, errs end end end diff --git a/util/jwt.lua b/util/jwt.lua index bf106dfa..997f0068 100644 --- a/util/jwt.lua +++ b/util/jwt.lua @@ -1,9 +1,10 @@ local s_gsub = string.gsub; -local json = require "util.json"; -local hashes = require "util.hashes"; -local base64_encode = require "util.encodings".base64.encode; -local base64_decode = require "util.encodings".base64.decode; -local secure_equals = require "util.hashes".equals; +local crypto = require "prosody.util.crypto"; +local json = require "prosody.util.json"; +local hashes = require "prosody.util.hashes"; +local base64_encode = require "prosody.util.encodings".base64.encode; +local base64_decode = require "prosody.util.encodings".base64.decode; +local secure_equals = require "prosody.util.hashes".equals; local b64url_rep = { ["+"] = "-", ["/"] = "_", ["="] = "", ["-"] = "+", ["_"] = "/" }; local function b64url(data) @@ -13,17 +14,8 @@ local function unb64url(data) return base64_decode(s_gsub(data, "[-_]", b64url_rep).."=="); end -local static_header = b64url('{"alg":"HS256","typ":"JWT"}') .. '.'; - -local function sign(key, payload) - local encoded_payload = json.encode(payload); - local signed = static_header .. b64url(encoded_payload); - local signature = hashes.hmac_sha256(key, signed); - return signed .. "." .. b64url(signature); -end - local jwt_pattern = "^(([A-Za-z0-9-_]+)%.([A-Za-z0-9-_]+))%.([A-Za-z0-9-_]+)$" -local function verify(key, blob) +local function decode_jwt(blob, expected_alg) local signed, bheader, bpayload, signature = string.match(blob, jwt_pattern); if not signed then return nil, "invalid-encoding"; @@ -31,21 +23,197 @@ local function verify(key, blob) local header = json.decode(unb64url(bheader)); if not header or type(header) ~= "table" then return nil, "invalid-header"; - elseif header.alg ~= "HS256" then + elseif header.alg ~= expected_alg then return nil, "unsupported-algorithm"; end - if not secure_equals(b64url(hashes.hmac_sha256(key, signed)), signature) then - return false, "signature-mismatch"; - end - local payload, err = json.decode(unb64url(bpayload)); + return signed, signature, bpayload; +end + +local function new_static_header(algorithm_name) + return b64url('{"alg":"'..algorithm_name..'","typ":"JWT"}') .. '.'; +end + +local function decode_raw_payload(raw_payload) + local payload, err = json.decode(unb64url(raw_payload)); if err ~= nil then return nil, "json-decode-error"; + elseif type(payload) ~= "table" then + return nil, "invalid-payload-type"; end return true, payload; end +-- HS*** family +local function new_hmac_algorithm(name) + local static_header = new_static_header(name); + + local hmac = hashes["hmac_sha"..name:sub(-3)]; + + local function sign(key, payload) + local encoded_payload = json.encode(payload); + local signed = static_header .. b64url(encoded_payload); + local signature = hmac(key, signed); + return signed .. "." .. b64url(signature); + end + + local function verify(key, blob) + local signed, signature, raw_payload = decode_jwt(blob, name); + if not signed then return nil, signature; end -- nil, err + + if not secure_equals(b64url(hmac(key, signed)), signature) then + return false, "signature-mismatch"; + end + + return decode_raw_payload(raw_payload); + end + + local function load_key(key) + assert(type(key) == "string", "key must be string (long, random, secure)"); + return key; + end + + return { sign = sign, verify = verify, load_key = load_key }; +end + +local function new_crypto_algorithm(name, key_type, c_sign, c_verify, sig_encode, sig_decode) + local static_header = new_static_header(name); + + return { + sign = function (private_key, payload) + local encoded_payload = json.encode(payload); + local signed = static_header .. b64url(encoded_payload); + + local signature = c_sign(private_key, signed); + if sig_encode then + signature = sig_encode(signature); + end + + return signed.."."..b64url(signature); + end; + + verify = function (public_key, blob) + local signed, signature, raw_payload = decode_jwt(blob, name); + if not signed then return nil, signature; end -- nil, err + + signature = unb64url(signature); + if sig_decode and signature then + signature = sig_decode(signature); + end + if not signature then + return false, "signature-mismatch"; + end + + local verify_ok = c_verify(public_key, signed, signature); + if not verify_ok then + return false, "signature-mismatch"; + end + + return decode_raw_payload(raw_payload); + end; + + load_public_key = function (public_key_pem) + local key = assert(crypto.import_public_pem(public_key_pem)); + assert(key:get_type() == key_type, "incorrect key type"); + return key; + end; + + load_private_key = function (private_key_pem) + local key = assert(crypto.import_private_pem(private_key_pem)); + assert(key:get_type() == key_type, "incorrect key type"); + return key; + end; + }; +end + +-- RS***, PS*** +local rsa_sign_algos = { RS = "rsassa_pkcs1", PS = "rsassa_pss" }; +local function new_rsa_algorithm(name) + local family, digest_bits = name:match("^(..)(...)$"); + local c_sign = crypto[rsa_sign_algos[family].."_sha"..digest_bits.."_sign"]; + local c_verify = crypto[rsa_sign_algos[family].."_sha"..digest_bits.."_verify"]; + return new_crypto_algorithm(name, "rsaEncryption", c_sign, c_verify); +end + +-- ES*** +local function new_ecdsa_algorithm(name, c_sign, c_verify, sig_bytes) + local function encode_ecdsa_sig(der_sig) + local r, s = crypto.parse_ecdsa_signature(der_sig, sig_bytes); + return r..s; + end + + local expected_sig_length = sig_bytes*2; + local function decode_ecdsa_sig(jwk_sig) + if #jwk_sig ~= expected_sig_length then + return nil; + end + return crypto.build_ecdsa_signature(jwk_sig:sub(1, sig_bytes), jwk_sig:sub(sig_bytes+1)); + end + return new_crypto_algorithm(name, "id-ecPublicKey", c_sign, c_verify, encode_ecdsa_sig, decode_ecdsa_sig); +end + +local algorithms = { + HS256 = new_hmac_algorithm("HS256"), HS384 = new_hmac_algorithm("HS384"), HS512 = new_hmac_algorithm("HS512"); + ES256 = new_ecdsa_algorithm("ES256", crypto.ecdsa_sha256_sign, crypto.ecdsa_sha256_verify, 32); + ES512 = new_ecdsa_algorithm("ES512", crypto.ecdsa_sha512_sign, crypto.ecdsa_sha512_verify, 66); + RS256 = new_rsa_algorithm("RS256"), RS384 = new_rsa_algorithm("RS384"), RS512 = new_rsa_algorithm("RS512"); + PS256 = new_rsa_algorithm("PS256"), PS384 = new_rsa_algorithm("PS384"), PS512 = new_rsa_algorithm("PS512"); +}; + +local function new_signer(algorithm, key_input, options) + local impl = assert(algorithms[algorithm], "Unknown JWT algorithm: "..algorithm); + local key = (impl.load_private_key or impl.load_key)(key_input); + local sign = impl.sign; + local default_ttl = (options and options.default_ttl) or 3600; + return function (payload) + local issued_at; + if not payload.iat then + issued_at = os.time(); + payload.iat = issued_at; + end + if not payload.exp then + payload.exp = (issued_at or os.time()) + default_ttl; + end + return sign(key, payload); + end +end + +local function new_verifier(algorithm, key_input, options) + local impl = assert(algorithms[algorithm], "Unknown JWT algorithm: "..algorithm); + local key = (impl.load_public_key or impl.load_key)(key_input); + local verify = impl.verify; + local check_expiry = not (options and options.accept_expired); + local claim_verifier = options and options.claim_verifier; + return function (token) + local ok, payload = verify(key, token); + if ok then + local expires_at = check_expiry and payload.exp; + if expires_at then + if type(expires_at) ~= "number" then + return nil, "invalid-expiry"; + elseif expires_at < os.time() then + return nil, "token-expired"; + end + end + if claim_verifier and not claim_verifier(payload) then + return nil, "incorrect-claims"; + end + end + return ok, payload; + end +end + +local function init(algorithm, private_key, public_key, options) + return new_signer(algorithm, private_key, options), new_verifier(algorithm, public_key or private_key, options); +end + return { - sign = sign; - verify = verify; + init = init; + new_signer = new_signer; + new_verifier = new_verifier; + -- Exported mainly for tests + _algorithms = algorithms; + -- Deprecated + sign = algorithms.HS256.sign; + verify = algorithms.HS256.verify; }; diff --git a/util/logger.lua b/util/logger.lua index 20a5cef2..ad6921a1 100644 --- a/util/logger.lua +++ b/util/logger.lua @@ -10,6 +10,7 @@ local pairs = pairs; local ipairs = ipairs; local require = require; +local t_remove = table.remove; local _ENV = nil; -- luacheck: std none @@ -71,13 +72,27 @@ local function add_level_sink(level, sink_function) end local function add_simple_sink(simple_sink_function, levels) - local format = require "util.format".format; + local format = require "prosody.util.format".format; local function sink_function(name, level, msg, ...) return simple_sink_function(name, level, format(msg, ...)); end for _, level in ipairs(levels or {"debug", "info", "warn", "error"}) do add_level_sink(level, sink_function); end + return sink_function; +end + +local function remove_sink(sink_function) + local removed; + for level, sinks in pairs(level_sinks) do + for i = #sinks, 1, -1 do + if sinks[i] == sink_function then + t_remove(sinks, i); + removed = true; + end + end + end + return removed; end return { @@ -87,4 +102,5 @@ return { add_level_sink = add_level_sink; add_simple_sink = add_simple_sink; new = make_logger; + remove_sink = remove_sink; }; diff --git a/util/mathcompat.lua b/util/mathcompat.lua new file mode 100644 index 00000000..e8acb261 --- /dev/null +++ b/util/mathcompat.lua @@ -0,0 +1,13 @@ +if not math.type then + + local function math_type(t) + if type(t) == "number" then + if t % 1 == 0 and t ~= t + 1 and t ~= t - 1 then + return "integer" + else + return "float" + end + end + end + _G.math.type = math_type +end diff --git a/util/multitable.lua b/util/multitable.lua index 4f2cd972..0c292b45 100644 --- a/util/multitable.lua +++ b/util/multitable.lua @@ -9,7 +9,7 @@ local select = select; local t_insert = table.insert; local pairs, next, type = pairs, next, type; -local unpack = table.unpack or unpack; --luacheck: ignore 113 +local unpack = table.unpack; local _ENV = nil; -- luacheck: std none diff --git a/util/openmetrics.lua b/util/openmetrics.lua index c18e63e9..cf9f5d24 100644 --- a/util/openmetrics.lua +++ b/util/openmetrics.lua @@ -1,7 +1,7 @@ --[[ This module implements a subset of the OpenMetrics Internet Draft version 00. -URL: https://tools.ietf.org/html/draft-richih-opsawg-openmetrics-00 +URL: https://datatracker.ietf.org/doc/html/draft-richih-opsawg-openmetrics-00 The following metric types are supported: @@ -19,14 +19,14 @@ defined in the I-D linked above. -- metric constructor interface: -- metric_ctor(..., family_name, labels, extra) -local time = require "util.time".now; +local time = require "prosody.util.time".now; local select = select; -local array = require "util.array"; -local log = require "util.logger".init("util.openmetrics"); -local new_multitable = require "util.multitable".new; -local iter_multitable = require "util.multitable".iter; +local array = require "prosody.util.array"; +local log = require "prosody.util.logger".init("util.openmetrics"); +local new_multitable = require "prosody.util.multitable".new; +local iter_multitable = require "prosody.util.multitable".iter; local t_concat, t_insert = table.concat, table.insert; -local t_pack, t_unpack = require "util.table".pack, table.unpack or unpack; --luacheck: ignore 113/unpack +local t_pack, t_unpack = table.pack, table.unpack; -- BEGIN of Utility: "metric proxy" -- This allows to wrap a MetricFamily in a proxy which only provides the @@ -35,6 +35,7 @@ local t_pack, t_unpack = require "util.table".pack, table.unpack or unpack; --lu -- `with_partial_label` by the moduleapi in order to pre-set the `host` label -- on metrics created in non-global modules. local metric_proxy_mt = {} +metric_proxy_mt.__name = "metric_proxy" metric_proxy_mt.__index = metric_proxy_mt local function new_metric_proxy(metric_family, with_labels_proxy_fun) @@ -128,6 +129,7 @@ end -- BEGIN of generic MetricFamily implementation local metric_family_mt = {} +metric_family_mt.__name = "metric_family" metric_family_mt.__index = metric_family_mt local function histogram_metric_ctor(orig_ctor, buckets) @@ -278,6 +280,7 @@ local function compose_name(name, unit) end local metric_registry_mt = {} +metric_registry_mt.__name = "metric_registry" metric_registry_mt.__index = metric_registry_mt local function new_metric_registry(backend) diff --git a/util/openssl.lua b/util/openssl.lua index 32b5aea7..422bc494 100644 --- a/util/openssl.lua +++ b/util/openssl.lua @@ -5,7 +5,7 @@ local s_format = string.format; local oid_xmppaddr = "1.3.6.1.5.5.7.8.5"; -- [XMPP-CORE] local oid_dnssrv = "1.3.6.1.5.5.7.8.7"; -- [SRV-ID] -local idna_to_ascii = require "util.encodings".idna.to_ascii; +local idna_to_ascii = require "prosody.util.encodings".idna.to_ascii; local _M = {}; local config = {}; @@ -166,8 +166,7 @@ do -- Lua to shell calls. setmetatable(_M, { __index = function(_, command) return function(opts) - local ret = os_execute(serialize(command, type(opts) == "table" and opts or {})); - return ret == true or ret == 0; + return os_execute(serialize(command, type(opts) == "table" and opts or {})); end; end; }); diff --git a/util/paseto.lua b/util/paseto.lua new file mode 100644 index 00000000..5a9acafb --- /dev/null +++ b/util/paseto.lua @@ -0,0 +1,218 @@ +local crypto = require "prosody.util.crypto"; +local json = require "prosody.util.json"; +local hashes = require "prosody.util.hashes"; +local base64_encode = require "prosody.util.encodings".base64.encode; +local base64_decode = require "prosody.util.encodings".base64.decode; +local secure_equals = require "prosody.util.hashes".equals; +local bit = require "prosody.util.bitcompat"; +local hex = require "prosody.util.hex"; +local rand = require "prosody.util.random"; +local s_pack = require "prosody.util.struct".pack; + +local s_gsub = string.gsub; + +local v4_public = {}; + +local b64url_rep = { ["+"] = "-", ["/"] = "_", ["="] = "", ["-"] = "+", ["_"] = "/" }; +local function b64url(data) + return (s_gsub(base64_encode(data), "[+/=]", b64url_rep)); +end + +local valid_tails = { + nil; -- Always invalid + "^.[AQgw]$"; -- b??????00 + "^..[AQgwEUk0IYo4Mcs8]$"; -- b????0000 +} + +local function unb64url(data) + local rem = #data%4; + if data:sub(-1,-1) == "=" or rem == 1 or (rem > 1 and not data:sub(-rem):match(valid_tails[rem])) then + return nil; + end + return base64_decode(s_gsub(data, "[-_]", b64url_rep).."=="); +end + +local function le64(n) + return s_pack("<I8", bit.band(n, 0x7F)); +end + +local function pae(parts) + if type(parts) ~= "table" then + error("bad argument #1 to 'pae' (table expected, got "..type(parts)..")"); + end + local o = { le64(#parts) }; + for _, part in ipairs(parts) do + table.insert(o, le64(#part)..part); + end + return table.concat(o); +end + +function v4_public.sign(m, sk, f, i) + if type(m) ~= "table" then + return nil, "PASETO payloads must be a table"; + end + m = json.encode(m); + local h = "v4.public."; + local m2 = pae({ h, m, f or "", i or "" }); + local sig = crypto.ed25519_sign(sk, m2); + if not f or f == "" then + return h..b64url(m..sig); + else + return h..b64url(m..sig).."."..b64url(f); + end +end + +function v4_public.verify(tok, pk, expected_f, i) + local h, sm, f = tok:match("^(v4%.public%.)([^%.]+)%.?(.*)$"); + if not h then + return nil, "invalid-token-format"; + end + f = f and unb64url(f) or nil; + if expected_f then + if not f or not secure_equals(expected_f, f) then + return nil, "invalid-footer"; + end + end + local raw_sm = unb64url(sm); + if not raw_sm or #raw_sm <= 64 then + return nil, "invalid-token-format"; + end + local s, m = raw_sm:sub(-64), raw_sm:sub(1, -65); + local m2 = pae({ h, m, f or "", i or "" }); + local ok = crypto.ed25519_verify(pk, m2, s); + if not ok then + return nil, "invalid-token"; + end + local payload, err = json.decode(m); + if err ~= nil or type(payload) ~= "table" then + return nil, "json-decode-error"; + end + return payload; +end + +v4_public.import_private_key = crypto.import_private_pem; +v4_public.import_public_key = crypto.import_public_pem; +function v4_public.new_keypair() + return crypto.generate_ed25519_keypair(); +end + +function v4_public.init(private_key_pem, public_key_pem, options) + local sign, verify = v4_public.sign, v4_public.verify; + local public_key = public_key_pem and v4_public.import_public_key(public_key_pem); + local private_key = private_key_pem and v4_public.import_private_key(private_key_pem); + local default_footer = options and options.default_footer; + local default_assertion = options and options.default_implicit_assertion; + return private_key and function (token, token_footer, token_assertion) + return sign(token, private_key, token_footer or default_footer, token_assertion or default_assertion); + end, public_key and function (token, expected_footer, token_assertion) + return verify(token, public_key, expected_footer or default_footer, token_assertion or default_assertion); + end; +end + +function v4_public.new_signer(private_key_pem, options) + return (v4_public.init(private_key_pem, nil, options)); +end + +function v4_public.new_verifier(public_key_pem, options) + return (select(2, v4_public.init(nil, public_key_pem, options))); +end + +local v3_local = { _key_mt = {} }; + +local function v3_local_derive_keys(k, n) + local tmp = hashes.hkdf_hmac_sha384(48, k, nil, "paseto-encryption-key"..n); + local Ek = tmp:sub(1, 32); + local n2 = tmp:sub(33); + local Ak = hashes.hkdf_hmac_sha384(48, k, nil, "paseto-auth-key-for-aead"..n); + return Ek, Ak, n2; +end + +function v3_local.encrypt(m, k, f, i) + assert(#k == 32) + if type(m) ~= "table" then + return nil, "PASETO payloads must be a table"; + end + m = json.encode(m); + local h = "v3.local."; + local n = rand.bytes(32); + local Ek, Ak, n2 = v3_local_derive_keys(k, n); + + local c = crypto.aes_256_ctr_encrypt(Ek, n2, m); + local m2 = pae({ h, n, c, f or "", i or "" }); + local t = hashes.hmac_sha384(Ak, m2); + + if not f or f == "" then + return h..b64url(n..c..t); + else + return h..b64url(n..c..t).."."..b64url(f); + end +end + +function v3_local.decrypt(tok, k, expected_f, i) + assert(#k == 32) + + local h, sm, f = tok:match("^(v3%.local%.)([^%.]+)%.?(.*)$"); + if not h then + return nil, "invalid-token-format"; + end + f = f and unb64url(f) or nil; + if expected_f then + if not f or not secure_equals(expected_f, f) then + return nil, "invalid-footer"; + end + end + local m = unb64url(sm); + if not m or #m <= 80 then + return nil, "invalid-token-format"; + end + local n, c, t = m:sub(1, 32), m:sub(33, -49), m:sub(-48); + local Ek, Ak, n2 = v3_local_derive_keys(k, n); + local preAuth = pae({ h, n, c, f or "", i or "" }); + local t2 = hashes.hmac_sha384(Ak, preAuth); + if not secure_equals(t, t2) then + return nil, "invalid-token"; + end + local m2 = crypto.aes_256_ctr_decrypt(Ek, n2, c); + if not m2 then + return nil, "invalid-token"; + end + + local payload, err = json.decode(m2); + if err ~= nil or type(payload) ~= "table" then + return nil, "json-decode-error"; + end + return payload; +end + +function v3_local.new_key() + return "secret-token:paseto.v3.local:"..hex.encode(rand.bytes(32)); +end + +function v3_local.init(key, options) + local encoded_key = key:match("^secret%-token:paseto%.v3%.local:(%x+)$"); + if not encoded_key or #encoded_key ~= 64 then + return error("invalid key for v3.local"); + end + local raw_key = hex.decode(encoded_key); + local default_footer = options and options.default_footer; + local default_assertion = options and options.default_implicit_assertion; + return function (token, token_footer, token_assertion) + return v3_local.encrypt(token, raw_key, token_footer or default_footer, token_assertion or default_assertion); + end, function (token, token_footer, token_assertion) + return v3_local.decrypt(token, raw_key, token_footer or default_footer, token_assertion or default_assertion); + end; +end + +function v3_local.new_signer(key, options) + return (v3_local.init(key, options)); +end + +function v3_local.new_verifier(key, options) + return (select(2, v3_local.init(key, options))); +end + +return { + pae = pae; + v3_local = v3_local; + v4_public = v4_public; +}; diff --git a/util/pluginloader.lua b/util/pluginloader.lua index f2ccb4cb..634bd6f8 100644 --- a/util/pluginloader.lua +++ b/util/pluginloader.lua @@ -17,7 +17,7 @@ for path in (CFG_PLUGINDIR or "./plugins/"):gsub("[/\\]", dir_sep):gmatch("[^".. end local io_open = io.open; -local envload = require "util.envload".envload; +local envload = require "prosody.util.envload".envload; local pluginloader_methods = {}; local pluginloader_mt = { __index = pluginloader_methods }; diff --git a/util/promise.lua b/util/promise.lua index c4e166ed..90780626 100644 --- a/util/promise.lua +++ b/util/promise.lua @@ -1,8 +1,8 @@ local promise_methods = {}; local promise_mt = { __name = "promise", __index = promise_methods }; -local xpcall = require "util.xpcall".xpcall; -local unpack = table.unpack or unpack; --luacheck: ignore 113 +local xpcall = require "prosody.util.xpcall".xpcall; +local unpack = table.unpack; function promise_mt:__tostring() return "promise (" .. (self._state or "invalid") .. ")"; @@ -57,10 +57,7 @@ local function promise_settle(promise, new_state, new_next, cbs, value) end local function new_resolve_functions(p) - local resolved = false; local function _resolve(v) - if resolved then return; end - resolved = true; if is_promise(v) then v:next(new_resolve_functions(p)); elseif promise_settle(p, "fulfilled", next_fulfilled, p._pending_on_fulfilled, v) then @@ -69,8 +66,6 @@ local function new_resolve_functions(p) end local function _reject(e) - if resolved then return; end - resolved = true; if promise_settle(p, "rejected", next_rejected, p._pending_on_rejected, e) then p.reason = e; end diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua index 4d49cd16..9cb4b4dd 100644 --- a/util/prosodyctl.lua +++ b/util/prosodyctl.lua @@ -7,17 +7,21 @@ -- -local config = require "core.configmanager"; -local encodings = require "util.encodings"; +local config = require "prosody.core.configmanager"; +local encodings = require "prosody.util.encodings"; local stringprep = encodings.stringprep; -local storagemanager = require "core.storagemanager"; -local usermanager = require "core.usermanager"; -local interpolation = require "util.interpolation"; -local signal = require "util.signal"; -local set = require "util.set"; +local storagemanager = require "prosody.core.storagemanager"; +local usermanager = require "prosody.core.usermanager"; +local interpolation = require "prosody.util.interpolation"; +local signal = require "prosody.util.signal"; +local set = require "prosody.util.set"; +local path = require"prosody.util.paths"; local lfs = require "lfs"; local type = type; +local have_socket_unix, socket_unix = pcall(require, "socket.unix"); +have_socket_unix = have_socket_unix and type(socket_unix) == "table"; -- was a function in older LuaSocket + local nodeprep, nameprep = stringprep.nodeprep, stringprep.nameprep; local io, os = io, os; @@ -42,7 +46,7 @@ local error_messages = setmetatable({ }, { __index = function (_,k) return "Error: "..(tostring(k):gsub("%-", " "):gsub("^.", string.upper)); end }); -- UI helpers -local show_message = require "util.human.io".printf; +local show_message = require "prosody.util.human.io".printf; local function show_usage(usage, desc) print("Usage: ".._G.arg[0].." "..usage); @@ -177,11 +181,31 @@ local function start(source_dir, lua) if ret then return false, "already-running"; end + local notify_socket; + if have_socket_unix then + local notify_path = path.join(prosody.paths.data, "notify.sock"); + os.remove(notify_path); + lua = string.format("NOTIFY_SOCKET=%q %s", notify_path, lua); + notify_socket = socket_unix.dgram(); + local ok = notify_socket:setsockname(notify_path); + if not ok then return false, "notify-failed"; end + end if not source_dir then os.execute(lua .. "./prosody -D"); else os.execute(lua .. source_dir.."/../../bin/prosody -D"); end + + if notify_socket then + for i = 1, 5 do + notify_socket:settimeout(i); + if notify_socket:receivefrom() == "READY=1" then + return true; + end + end + return false, "not-ready"; + end + return true; end @@ -224,8 +248,7 @@ local function call_luarocks(operation, mod, server) local ok, _, code = os.execute(render_cli("luarocks --lua-version={luav} {op} --tree={dir} {server&--server={server}} {mod?}", { dir = dir; op = operation; mod = mod; server = server; luav = _VERSION:match("5%.%d"); })); - if type(ok) == "number" then code = ok; end - return code; + return ok and code; end return { diff --git a/util/prosodyctl/cert.lua b/util/prosodyctl/cert.lua index abd2e1d6..70c09443 100644 --- a/util/prosodyctl/cert.lua +++ b/util/prosodyctl/cert.lua @@ -1,8 +1,8 @@ local lfs = require "lfs"; -local pctl = require "util.prosodyctl"; -local hi = require "util.human.io"; -local configmanager = require "core.configmanager"; +local pctl = require "prosody.util.prosodyctl"; +local hi = require "prosody.util.human.io"; +local configmanager = require "prosody.core.configmanager"; local openssl; @@ -24,7 +24,7 @@ local function use_existing(filename) end end -local have_pposix, pposix = pcall(require, "util.pposix"); +local have_pposix, pposix = pcall(require, "prosody.util.pposix"); local cert_basedir = prosody.paths.data == "." and "./certs" or prosody.paths.data; if have_pposix and pposix.getuid() == 0 then -- FIXME should be enough to check if this directory is writable @@ -179,7 +179,7 @@ local function copy(from, to, umask, owner, group) os.execute(("chown -c --reference=%s %s"):format(sh_esc(cert_basedir), sh_esc(to))); elseif owner and group then local ok = os.execute(("chown %s:%s %s"):format(sh_esc(owner), sh_esc(group), sh_esc(to))); - assert(ok == true or ok == 0, "Failed to change ownership of "..to); + assert(ok, "Failed to change ownership of "..to); end if old_umask then pposix.umask(old_umask); end return true; @@ -219,7 +219,7 @@ function cert_commands.import(arg) owner = configmanager.get("*", "prosody_user") or "prosody"; group = configmanager.get("*", "prosody_group") or owner; end - local cm = require "core.certmanager"; + local cm = require "prosody.core.certmanager"; local files_by_name = {} for _, dir in ipairs(arg) do cm.index_certs(dir, files_by_name); @@ -271,7 +271,7 @@ end local function cert(arg) if #arg >= 1 and arg[1] ~= "--help" then - openssl = require "util.openssl"; + openssl = require "prosody.util.openssl"; lfs = require "lfs"; local cert_dir_attrs = lfs.attributes(cert_basedir); if not cert_dir_attrs then @@ -303,7 +303,7 @@ local function cert(arg) end return cert_commands[subcmd](arg); elseif subcmd == "check" then - return require "util.prosodyctl.check".check({"certs"}); + return require "prosody.util.prosodyctl.check".check({"certs"}); end end pctl.show_usage("cert config|request|generate|key|import", "Helpers for generating X.509 certificates and keys.") diff --git a/util/prosodyctl/check.lua b/util/prosodyctl/check.lua index 43406f0c..ac8cc9c1 100644 --- a/util/prosodyctl/check.lua +++ b/util/prosodyctl/check.lua @@ -1,24 +1,24 @@ -local configmanager = require "core.configmanager"; -local moduleapi = require "core.moduleapi"; -local show_usage = require "util.prosodyctl".show_usage; -local show_warning = require "util.prosodyctl".show_warning; -local is_prosody_running = require "util.prosodyctl".isrunning; -local parse_args = require "util.argparse".parse; -local dependencies = require "util.dependencies"; +local configmanager = require "prosody.core.configmanager"; +local moduleapi = require "prosody.core.moduleapi"; +local show_usage = require "prosody.util.prosodyctl".show_usage; +local show_warning = require "prosody.util.prosodyctl".show_warning; +local is_prosody_running = require "prosody.util.prosodyctl".isrunning; +local parse_args = require "prosody.util.argparse".parse; +local dependencies = require "prosody.util.dependencies"; local socket = require "socket"; local socket_url = require "socket.url"; -local jid_split = require "util.jid".prepped_split; -local modulemanager = require "core.modulemanager"; -local async = require "util.async"; -local httputil = require "util.http"; +local jid_split = require "prosody.util.jid".prepped_split; +local modulemanager = require "prosody.core.modulemanager"; +local async = require "prosody.util.async"; +local httputil = require "prosody.util.http"; local function api(host) return setmetatable({ name = "prosodyctl.check"; host = host; log = prosody.log }, { __index = moduleapi }) end local function check_ojn(check_type, target_host) - local http = require "net.http"; -- .new({}); - local json = require "util.json"; + local http = require "prosody.net.http"; -- .new({}); + local json = require "prosody.util.json"; local response, err = async.wait_for(http.request( ("https://observe.jabber.network/api/v1/check/%s"):format(httputil.urlencode(check_type)), @@ -46,7 +46,7 @@ local function check_ojn(check_type, target_host) end local function check_probe(base_url, probe_module, target) - local http = require "net.http"; -- .new({}); + local http = require "prosody.net.http"; -- .new({}); local params = httputil.formencode({ module = probe_module; target = target }) local response, err = async.wait_for(http.request(base_url .. "?" .. params)); @@ -67,8 +67,8 @@ local function check_probe(base_url, probe_module, target) end local function check_turn_service(turn_service, ping_service) - local ip = require "util.ip"; - local stun = require "net.stun"; + local ip = require "prosody.util.ip"; + local stun = require "prosody.net.stun"; local result = { warnings = {} }; @@ -170,7 +170,7 @@ local function check_turn_service(turn_service, ping_service) result.error = "TURN server did not response to allocation request: "..err; return result; elseif alloc_response:is_err_resp() then - result.error = ("TURN allocation failed: %d (%s)"):format(alloc_response:get_error()); + result.error = ("TURN server failed to create allocation: %d (%s)"):format(alloc_response:get_error()); return result; elseif not alloc_response:is_success_resp() then result.error = ("Unexpected TURN response: %d (%s)"):format(alloc_response:get_type()); @@ -319,18 +319,15 @@ local function check(arg) print("Error: Unknown parameter: "..opts_info); return 1; end - local array = require "util.array"; - local set = require "util.set"; - local it = require "util.iterators"; + local array = require "prosody.util.array"; + local set = require "prosody.util.set"; + local it = require "prosody.util.iterators"; local ok = true; + local function contains_match(hayset, needle) for member in hayset do if member:find(needle) then return true end end end local function disabled_hosts(host, conf) return host ~= "*" and conf.enabled ~= false; end local function enabled_hosts() return it.filter(disabled_hosts, pairs(configmanager.getconfig())); end - if not (what == nil or what == "disabled" or what == "config" or what == "dns" or what == "certs" or what == "connectivity" or what == "turn") then - show_warning("Don't know how to check '%s'. Try one of 'config', 'dns', 'certs', 'disabled', 'turn' or 'connectivity'.", what); - show_warning("Note: The connectivity check will connect to a remote server."); - return 1; - end - if not what or what == "disabled" then + local checks = {}; + function checks.disabled() local disabled_hosts_set = set.new(); for host in it.filter("*", pairs(configmanager.getconfig())) do if api(host):get_option_boolean("enabled") == false then @@ -345,7 +342,7 @@ local function check(arg) print"" end end - if not what or what == "config" then + function checks.config() print("Checking config..."); if what == "config" then @@ -520,9 +517,9 @@ local function check(arg) end for k, v in pairs(modules) do if type(k) ~= "number" or type(v) ~= "string" then - print(" The " .. name .. " in the " .. host .. " section should not be a map of " .. type(k) .. " to " .. type(v) - .. " but a list of strings, e.g."); + print(" The " .. name .. " in the " .. host .. " section should be a list of strings, e.g."); print(" " .. name .. " = { \"name_of_module\", \"another_plugin\", }") + print(" It should not contain key = value pairs, try putting them outside the {} brackets."); ok = false break end @@ -750,7 +747,7 @@ local function check(arg) -- Check hostname validity do - local idna = require "util.encodings".idna; + local idna = require "prosody.util.encodings".idna; local invalid_hosts = {}; local alabel_hosts = {}; for host in it.filter("*", pairs(configmanager.getconfig())) do @@ -801,14 +798,14 @@ local function check(arg) print("Done.\n"); end - if not what or what == "dns" then - local dns = require "net.dns"; + function checks.dns() + local dns = require "prosody.net.dns"; pcall(function () - local unbound = require"net.unbound"; + local unbound = require"prosody.net.unbound"; dns = unbound.dns; end) - local idna = require "util.encodings".idna; - local ip = require "util.ip"; + local idna = require "prosody.util.encodings".idna; + local ip = require "prosody.util.ip"; local global = api("*"); local c2s_ports = global:get_option_set("c2s_ports", {5222}); local s2s_ports = global:get_option_set("s2s_ports", {5269}); @@ -871,7 +868,7 @@ local function check(arg) end end - local local_addresses = require"util.net".local_addresses() or {}; + local local_addresses = require"prosody.util.net".local_addresses() or {}; for addr in it.values(local_addresses) do if not ip.new_ip(addr).private then @@ -1038,9 +1035,6 @@ local function check(arg) end local known_http_modules = set.new { "bosh"; "http_files"; "http_file_share"; "http_openmetrics"; "websocket" }; - local function contains_match(hayset, needle) - for member in hayset do if member:find(needle) then return true end end - end if modules:contains("http") or not set.intersection(modules, known_http_modules):empty() or contains_match(modules, "^http_") or contains_match(modules, "_web$") then @@ -1176,11 +1170,14 @@ local function check(arg) ok = false; end end - if not what or what == "certs" then + function checks.certs() local cert_ok; print"Checking certificates..." - local x509_verify_identity = require"util.x509".verify_identity; - local create_context = require "core.certmanager".create_context; + local x509_verify_identity = require"prosody.util.x509".verify_identity; + local use_dane = configmanager.get("*", "use_dane"); + local pem2der = require"prosody.util.x509".pem2der; + local sha256 = require"prosody.util.hashes".sha256; + local create_context = require "prosody.core.certmanager".create_context; local ssl = dependencies.softreq"ssl"; -- local datetime_parse = require"util.datetime".parse_x509; local load_cert = ssl and ssl.loadcertificate; @@ -1193,13 +1190,14 @@ local function check(arg) cert_ok = false else for host in it.filter(skip_bare_jid_hosts, enabled_hosts()) do + local modules = modulemanager.get_modules_for_host(host); print("Checking certificate for "..host); -- First, let's find out what certificate this host uses. local host_ssl_config = configmanager.rawget(host, "ssl") or configmanager.rawget(host:match("%.(.*)"), "ssl"); local global_ssl_config = configmanager.rawget("*", "ssl"); - local ok, err, ssl_config = create_context(host, "server", host_ssl_config, global_ssl_config); - if not ok then + local ctx_ok, err, ssl_config = create_context(host, "server", host_ssl_config, global_ssl_config); + if not ctx_ok then print(" Error: "..err); cert_ok = false elseif not ssl_config.certificate then @@ -1234,17 +1232,39 @@ local function check(arg) elseif not cert:validat(os.time() + 86400*31) then print(" Certificate expires within one month.") end - if select(2, modulemanager.get_modules_for_host(host)) == nil - and not x509_verify_identity(host, "_xmpp-client", cert) then + if modules:contains("c2s") and not x509_verify_identity(host, "_xmpp-client", cert) then print(" Not valid for client connections to "..host..".") cert_ok = false end - if (not (api(host):get_option_boolean("anonymous_login", false) - or api(host):get_option_string("authentication", "internal_hashed") == "anonymous")) - and not x509_verify_identity(host, "_xmpp-server", cert) then + local anon = api(host):get_option_string("authentication", "internal_hashed") == "anonymous"; + local anon_s2s = api(host):get_option_boolean("allow_anonymous_s2s", false); + if modules:contains("s2s") and (anon_s2s or not anon) and not x509_verify_identity(host, "_xmpp-server", cert) then print(" Not valid for server-to-server connections to "..host..".") cert_ok = false end + + local known_http_modules = set.new { "bosh"; "http_files"; "http_file_share"; "http_openmetrics"; "websocket" }; + local http_loaded = modules:contains("http") + or not set.intersection(modules, known_http_modules):empty() + or contains_match(modules, "^http_") + or contains_match(modules, "_web$"); + + local http_host = api(host):get_option_string("http_host", host); + if api(host):get_option_string("http_external_url") then + -- Assumed behind a reverse proxy + http_loaded = false; + end + if http_loaded and not x509_verify_identity(http_host, nil, cert) then + print(" Not valid for HTTPS connections to "..host..".") + cert_ok = false + end + if use_dane then + if cert.pubkey then + print(" DANE: TLSA 3 1 1 "..sha256(pem2der(cert:pubkey()), true)) + elseif cert.pem then + print(" DANE: TLSA 3 0 1 "..sha256(pem2der(cert:pem()), true)) + end + end end end end @@ -1257,7 +1277,7 @@ local function check(arg) print("") end -- intentionally not doing this by default - if what == "connectivity" then + function checks.connectivity() local _, prosody_is_running = is_prosody_running(); if api("*"):get_option_string("pidfile") and not prosody_is_running then print("Prosody does not appear to be running, which is required for this test."); @@ -1349,7 +1369,7 @@ local function check(arg) print("Note: It does not ensure that the check actually reaches this specific prosody instance.") end - if not what or what == "turn" then + function checks.turn() local turn_enabled_hosts = {}; local turn_services = {}; @@ -1424,6 +1444,26 @@ local function check(arg) end end end + if what == nil or what == "all" then + local ret; + ret = checks.disabled(); + if ret ~= nil then return ret; end + ret = checks.config(); + if ret ~= nil then return ret; end + ret = checks.dns(); + if ret ~= nil then return ret; end + ret = checks.certs(); + if ret ~= nil then return ret; end + ret = checks.turn(); + if ret ~= nil then return ret; end + elseif checks[what] then + local ret = checks[what](); + if ret ~= nil then return ret; end + else + show_warning("Don't know how to check '%s'. Try one of 'config', 'dns', 'certs', 'disabled', 'turn' or 'connectivity'.", what); + show_warning("Note: The connectivity check will connect to a remote server."); + return 1; + end if not ok then print("Problems found, see above."); diff --git a/util/prosodyctl/shell.lua b/util/prosodyctl/shell.lua index 8cf7df69..05f81f15 100644 --- a/util/prosodyctl/shell.lua +++ b/util/prosodyctl/shell.lua @@ -1,13 +1,15 @@ -local config = require "core.configmanager"; -local server = require "net.server"; -local st = require "util.stanza"; -local path = require "util.paths"; -local parse_args = require "util.argparse".parse; -local unpack = table.unpack or _G.unpack; +local config = require "prosody.core.configmanager"; +local server = require "prosody.net.server"; +local st = require "prosody.util.stanza"; +local path = require "prosody.util.paths"; +local parse_args = require "prosody.util.argparse".parse; +local tc = require "prosody.util.termcolours"; +local isatty = require "prosody.util.pposix".isatty; +local term_width = require"prosody.util.human.io".term_width; local have_readline, readline = pcall(require, "readline"); -local adminstream = require "util.adminstream"; +local adminstream = require "prosody.util.adminstream"; if have_readline then readline.set_readline_name("prosody"); @@ -27,7 +29,7 @@ local function read_line(prompt_string) end local function send_line(client, line) - client.send(st.stanza("repl-input"):text(line)); + client.send(st.stanza("repl-input", { width = tostring(term_width()) }):text(line)); end local function repl(client) @@ -64,6 +66,7 @@ end local function start(arg) --luacheck: ignore 212/arg local client = adminstream.client(); local opts, err, where = parse_args(arg); + local ttyout = isatty(io.stdout); if not opts then if err == "param-not-found" then @@ -76,24 +79,36 @@ local function start(arg) --luacheck: ignore 212/arg if arg[1] then if arg[2] then - -- prosodyctl shell module reload foo bar.com --> module:reload("foo", "bar.com") - -- COMPAT Lua 5.1 doesn't have the separator argument to string.rep - arg[1] = string.format("%s:%s("..string.rep("%q, ", #arg-2):sub(1, -3)..")", unpack(arg)); + local fmt = { "%s"; ":%s("; ")" }; + for i = 3, #arg do + if arg[i]:sub(1, 1) == ":" then + table.insert(fmt, i, ")%s("); + elseif i > 3 and fmt[i - 1]:match("%%q$") then + table.insert(fmt, i, ", %q"); + else + table.insert(fmt, i, "%q"); + end + end + arg[1] = string.format(table.concat(fmt), table.unpack(arg)); end client.events.add_handler("connected", function() - client.send(st.stanza("repl-input"):text(arg[1])); + send_line(client, arg[1]); return true; end, 1); local errors = 0; -- TODO This is weird, but works for now. client.events.add_handler("received", function(stanza) if stanza.name == "repl-output" or stanza.name == "repl-result" then + local dest = io.stdout; if stanza.attr.type == "error" then errors = errors + 1; - io.stderr:write(stanza:get_text(), "\n"); + dest = io.stderr; + end + if stanza.attr.eol == "0" then + dest:write(stanza:get_text()); else - print(stanza:get_text()); + dest:write(stanza:get_text(), "\n"); end end if stanza.name == "repl-result" then @@ -118,7 +133,11 @@ local function start(arg) --luacheck: ignore 212/arg client.events.add_handler("received", function (stanza) if stanza.name == "repl-output" or stanza.name == "repl-result" then local result_prefix = stanza.attr.type == "error" and "!" or "|"; - print(result_prefix.." "..stanza:get_text()); + local out = result_prefix.." "..stanza:get_text(); + if ttyout and stanza.attr.type == "error" then + out = tc.getstring(tc.getstyle("red"), out); + end + print(out); end if stanza.name == "repl-result" then repl(client); diff --git a/util/pubsub.lua b/util/pubsub.lua index acb34db9..d6779736 100644 --- a/util/pubsub.lua +++ b/util/pubsub.lua @@ -1,6 +1,5 @@ -local events = require "util.events"; -local cache = require "util.cache"; -local errors = require "util.error"; +local events = require "prosody.util.events"; +local cache = require "prosody.util.cache"; local service_mt = {}; @@ -12,6 +11,7 @@ local default_config = { itemcheck = function () return true; end; get_affiliation = function () end; normalize_jid = function (jid) return jid; end; + metadata_subset = {}; capabilities = { outcast = { create = false; @@ -46,6 +46,7 @@ local default_config = { get_subscription = true; get_subscriptions = true; get_items = false; + get_metadata = true; subscribe_other = false; unsubscribe_other = false; @@ -68,6 +69,7 @@ local default_config = { get_subscription = true; get_subscriptions = true; get_items = true; + get_metadata = true; subscribe_other = false; unsubscribe_other = false; @@ -91,6 +93,7 @@ local default_config = { get_subscription = true; get_subscriptions = true; get_items = true; + get_metadata = true; subscribe_other = false; unsubscribe_other = false; @@ -116,6 +119,7 @@ local default_config = { get_subscription = true; get_subscriptions = true; get_items = true; + get_metadata = true; subscribe_other = true; @@ -263,7 +267,7 @@ function service:get_default_affiliation(node, actor) --> affiliation if self.config.access_models then local check = self.config.access_models[access_model]; if check then - local aff = check(actor); + local aff = check(actor, node_obj); if aff then return aff; end @@ -562,11 +566,7 @@ function service:publish(node, actor, id, item, requested_config) --> ok, err -- Check that node has the requested config before we publish local ok, field = check_preconditions(node_obj.config, requested_config); if not ok then - local err = errors.new({ - type = "cancel", condition = "conflict", text = "Field does not match: "..field; - }); - err.pubsub_condition = "precondition-not-met"; - return false, err; + return false, "precondition-not-met", { field = field }; end end if not self.config.itemcheck(item) then @@ -877,6 +877,20 @@ function service:get_node_config(node, actor) --> (true, config) or (false, err) return true, config_table; end +function service:get_node_metadata(node, actor) + if not self:may(node, actor, "get_metadata") then + return false, "forbidden"; + end + + local ok, config = self:get_node_config(node, true); + if not ok then return ok, config; end + local meta = {}; + for _, k in ipairs(self.config.metadata_subset) do + meta[k] = config[k]; + end + return true, meta; +end + return { new = new; }; diff --git a/util/queue.lua b/util/queue.lua index c94c62ae..8cff944a 100644 --- a/util/queue.lua +++ b/util/queue.lua @@ -9,7 +9,7 @@ -- Small ringbuffer library (i.e. an efficient FIFO queue with a size limit) -- (because unbounded dynamically-growing queues are a bad thing...) -local have_utable, utable = pcall(require, "util.table"); -- For pre-allocation of table +local have_utable, utable = pcall(require, "prosody.util.table"); -- For pre-allocation of table local function new(size, allow_wrapping) -- Head is next insert, tail is next read diff --git a/util/random.lua b/util/random.lua index 3305172f..a3adb605 100644 --- a/util/random.lua +++ b/util/random.lua @@ -6,7 +6,7 @@ -- COPYING file in the source package for more information. -- -local ok, crand = pcall(require, "util.crand"); +local ok, crand = pcall(require, "prosody.util.crand"); if ok and pcall(crand.bytes, 1) then return crand; end local urandom, urandom_err = io.open("/dev/urandom", "r"); diff --git a/util/rfc6724.lua b/util/rfc6724.lua deleted file mode 100644 index 81f78d55..00000000 --- a/util/rfc6724.lua +++ /dev/null @@ -1,141 +0,0 @@ --- Prosody IM --- Copyright (C) 2011-2013 Florian Zeitz --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - --- This is used to sort destination addresses by preference --- during S2S connections. --- We can't hand this off to getaddrinfo, since it blocks - -local ip_commonPrefixLength = require"util.ip".commonPrefixLength - -local function commonPrefixLength(ipA, ipB) - local len = ip_commonPrefixLength(ipA, ipB); - return len < 64 and len or 64; -end - -local function t_sort(t, comp) - for i = 1, (#t - 1) do - for j = (i + 1), #t do - local a, b = t[i], t[j]; - if not comp(a,b) then - t[i], t[j] = b, a; - end - end - end -end - -local function source(dest, candidates) - local function comp(ipA, ipB) - -- Rule 1: Prefer same address - if dest == ipA then - return true; - elseif dest == ipB then - return false; - end - - -- Rule 2: Prefer appropriate scope - if ipA.scope < ipB.scope then - if ipA.scope < dest.scope then - return false; - else - return true; - end - elseif ipA.scope > ipB.scope then - if ipB.scope < dest.scope then - return true; - else - return false; - end - end - - -- Rule 3: Avoid deprecated addresses - -- XXX: No way to determine this - -- Rule 4: Prefer home addresses - -- XXX: Mobility Address related, no way to determine this - -- Rule 5: Prefer outgoing interface - -- XXX: Interface to address relation. No way to determine this - -- Rule 6: Prefer matching label - if ipA.label == dest.label and ipB.label ~= dest.label then - return true; - elseif ipB.label == dest.label and ipA.label ~= dest.label then - return false; - end - - -- Rule 7: Prefer temporary addresses (over public ones) - -- XXX: No way to determine this - -- Rule 8: Use longest matching prefix - if commonPrefixLength(ipA, dest) > commonPrefixLength(ipB, dest) then - return true; - else - return false; - end - end - - t_sort(candidates, comp); - return candidates[1]; -end - -local function destination(candidates, sources) - local sourceAddrs = {}; - local function comp(ipA, ipB) - local ipAsource = sourceAddrs[ipA]; - local ipBsource = sourceAddrs[ipB]; - -- Rule 1: Avoid unusable destinations - -- XXX: No such information - -- Rule 2: Prefer matching scope - if ipA.scope == ipAsource.scope and ipB.scope ~= ipBsource.scope then - return true; - elseif ipA.scope ~= ipAsource.scope and ipB.scope == ipBsource.scope then - return false; - end - - -- Rule 3: Avoid deprecated addresses - -- XXX: No way to determine this - -- Rule 4: Prefer home addresses - -- XXX: Mobility Address related, no way to determine this - -- Rule 5: Prefer matching label - if ipAsource.label == ipA.label and ipBsource.label ~= ipB.label then - return true; - elseif ipBsource.label == ipB.label and ipAsource.label ~= ipA.label then - return false; - end - - -- Rule 6: Prefer higher precedence - if ipA.precedence > ipB.precedence then - return true; - elseif ipA.precedence < ipB.precedence then - return false; - end - - -- Rule 7: Prefer native transport - -- XXX: No way to determine this - -- Rule 8: Prefer smaller scope - if ipA.scope < ipB.scope then - return true; - elseif ipA.scope > ipB.scope then - return false; - end - - -- Rule 9: Use longest matching prefix - if commonPrefixLength(ipA, ipAsource) > commonPrefixLength(ipB, ipBsource) then - return true; - elseif commonPrefixLength(ipA, ipAsource) < commonPrefixLength(ipB, ipBsource) then - return false; - end - - -- Rule 10: Otherwise, leave order unchanged - return true; - end - for _, ip in ipairs(candidates) do - sourceAddrs[ip] = source(ip, sources); - end - - t_sort(candidates, comp); - return candidates; -end - -return {source = source, - destination = destination}; diff --git a/util/roles.lua b/util/roles.lua new file mode 100644 index 00000000..e7f22c12 --- /dev/null +++ b/util/roles.lua @@ -0,0 +1,123 @@ +local array = require "prosody.util.array"; +local it = require "prosody.util.iterators"; +local new_short_id = require "prosody.util.id".short; + +local role_methods = {}; +local role_mt = { + __index = role_methods; + __name = "role"; + __add = nil; +}; + +local function is_role(o) + local mt = getmetatable(o); + return mt == role_mt; +end + +local function _new_may(permissions, inherited_mays) + local n_inherited = inherited_mays and #inherited_mays; + return function (role, action, context) + -- Note: 'role' may be a descendent role, not only the one we're attached to + local policy = permissions[action]; + if policy ~= nil then + return policy; + end + if n_inherited then + for i = 1, n_inherited do + policy = inherited_mays[i](role, action, context); + if policy ~= nil then + return policy; + end + end + end + return nil; + end +end + +local permissions_key = {}; + +-- { +-- Required: +-- name = "My fancy role"; +-- +-- Optional: +-- inherits = { role_obj... } +-- default = true +-- priority = 100 +-- permissions = { +-- ["foo"] = true; -- allow +-- ["bar"] = false; -- deny +-- } +-- } +local function new(base_config, overrides) + local config = setmetatable(overrides or {}, { __index = base_config }); + local permissions = {}; + local inherited_mays; + if config.inherits then + inherited_mays = array.pluck(config.inherits, "may"); + end + local new_role = { + id = new_short_id(); + name = config.name; + description = config.description; + default = config.default; + priority = config.priority; + may = _new_may(permissions, inherited_mays); + inherits = config.inherits; + [permissions_key] = permissions; + }; + local desired_permissions = config.permissions or config[permissions_key]; + for k, v in pairs(desired_permissions or {}) do + permissions[k] = v; + end + return setmetatable(new_role, role_mt); +end + +function role_mt:__freeze() + local t = { + id = self.id; + name = self.name; + description = self.description; + default = self.default; + priority = self.priority; + inherits = self.inherits; + permissions = self[permissions_key]; + }; + return t; +end + +function role_methods:clone(overrides) + return new(self, overrides); +end + +function role_methods:set_permission(permission_name, policy, overwrite) + local permissions = self[permissions_key]; + if overwrite ~= true and permissions[permission_name] ~= nil and permissions[permission_name] ~= policy then + return false, "policy-already-exists"; + end + permissions[permission_name] = policy; + return true; +end + +function role_methods:policies() + local policy_iterator, s, v = it.join(pairs(self[permissions_key])); + if self.inherits then + for _, inherited_role in ipairs(self.inherits) do + policy_iterator:append(inherited_role:policies()); + end + end + return policy_iterator, s, v; +end + +function role_mt.__tostring(self) + return ("role<[%s] %s>"):format(self.id or "nil", self.name or "[no name]"); +end + +function role_mt.__pairs(self) + return it.filter(permissions_key, next, self); +end + +return { + is_role = is_role; + new = new; +}; diff --git a/util/rsm.lua b/util/rsm.lua index e6060af8..fc6a048d 100644 --- a/util/rsm.lua +++ b/util/rsm.lua @@ -9,7 +9,7 @@ -- XEP-0313: Message Archive Management for Prosody -- -local stanza = require"util.stanza".stanza; +local stanza = require"prosody.util.stanza".stanza; local tonumber = tonumber; local s_format = string.format; local type = type; diff --git a/util/sasl.lua b/util/sasl.lua index 528743d1..c3c22a1c 100644 --- a/util/sasl.lua +++ b/util/sasl.lua @@ -47,7 +47,7 @@ local registered_mechanisms = {}; local backend_mechanism = {}; local mechanism_channelbindings = {}; --- register a new SASL mechanisms +-- register a new SASL mechanism local function registerMechanism(name, backends, f, cb_backends) assert(type(name) == "string", "Parameter name MUST be a string."); assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table."); @@ -133,10 +133,11 @@ function method:process(message) end -- load the mechanisms -require "util.sasl.plain" .init(registerMechanism); -require "util.sasl.anonymous" .init(registerMechanism); -require "util.sasl.scram" .init(registerMechanism); -require "util.sasl.external" .init(registerMechanism); +require "prosody.util.sasl.plain" .init(registerMechanism); +require "prosody.util.sasl.anonymous" .init(registerMechanism); +require "prosody.util.sasl.oauthbearer" .init(registerMechanism); +require "prosody.util.sasl.scram" .init(registerMechanism); +require "prosody.util.sasl.external" .init(registerMechanism); return { registerMechanism = registerMechanism; diff --git a/util/sasl/anonymous.lua b/util/sasl/anonymous.lua index de98a5e2..be2c20d4 100644 --- a/util/sasl/anonymous.lua +++ b/util/sasl/anonymous.lua @@ -12,7 +12,7 @@ -- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -local generate_random_id = require "util.id".medium; +local generate_random_id = require "prosody.util.id".medium; local _ENV = nil; -- luacheck: std none @@ -33,8 +33,8 @@ local function anonymous(self, message) -- luacheck: ignore 212/message local username; repeat username = generate_random_id():lower(); - until self.profile.anonymous(self, username, self.realm); - self.username = username; + self.username = username; + until self.profile.anonymous(self, username, self.realm, message); return "success" end diff --git a/util/sasl/external.lua b/util/sasl/external.lua index ce50743e..c3a3beb8 100644 --- a/util/sasl/external.lua +++ b/util/sasl/external.lua @@ -1,4 +1,4 @@ -local saslprep = require "util.encodings".stringprep.saslprep; +local saslprep = require "prosody.util.encodings".stringprep.saslprep; local _ENV = nil; -- luacheck: std none diff --git a/util/sasl/oauthbearer.lua b/util/sasl/oauthbearer.lua new file mode 100644 index 00000000..0a2fe9dd --- /dev/null +++ b/util/sasl/oauthbearer.lua @@ -0,0 +1,62 @@ +local json = require "prosody.util.json"; +local _ENV = nil; + + +local function oauthbearer(self, message) + if not message then + return "failure", "malformed-request"; + end + + if message == "\001" then + return "failure", "not-authorized"; + end + + -- gs2-header kvsep *kvpair kvsep + local gs2_header, kvpairs = message:match("^(n,[^,]*,)\001(.+)\001$"); + if not gs2_header then + return "failure", "malformed-request"; + end + local gs2_authzid = gs2_header:match("^[^,]*,a=([^,]*),$"); + + -- key "=" value kvsep + local auth_header; + for k, v in kvpairs:gmatch("([a-zA-Z]+)=([\033-\126 \009\r\n]*)\001") do + if k == "auth" then + auth_header = v; + break; + end + end + + if not auth_header then + return "failure", "malformed-request"; + end + + local token = auth_header:match("^Bearer (.+)$"); + + local username, state, token_info = self.profile.oauthbearer(self, token, self.realm, gs2_authzid); + + if state == false then + return "failure", "account-disabled"; + elseif state == nil or not username then + -- For token-level errors, RFC 7628 demands use of a JSON-encoded + -- challenge response upon failure. We relay additional info from + -- the auth backend if available. + return "challenge", json.encode({ + status = token_info and token_info.status or "invalid_token"; + scope = token_info and token_info.scope or nil; + ["openid-configuration"] = token_info and token_info.oidc_discovery_url or nil; + }); + end + self.username = username; + self.token_info = token_info; + + return "success"; +end + +local function init(registerMechanism) + registerMechanism("OAUTHBEARER", {"oauthbearer"}, oauthbearer); +end + +return { + init = init; +} diff --git a/util/sasl/plain.lua b/util/sasl/plain.lua index 43a66c5b..da867fb1 100644 --- a/util/sasl/plain.lua +++ b/util/sasl/plain.lua @@ -12,9 +12,9 @@ -- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. local s_match = string.match; -local saslprep = require "util.encodings".stringprep.saslprep; -local nodeprep = require "util.encodings".stringprep.nodeprep; -local log = require "util.logger".init("sasl"); +local saslprep = require "prosody.util.encodings".stringprep.saslprep; +local nodeprep = require "prosody.util.encodings".stringprep.nodeprep; +local log = require "prosody.util.logger".init("sasl"); local _ENV = nil; -- luacheck: std none @@ -69,10 +69,10 @@ local function plain(self, message) local correct, state = false, false; if self.profile.plain then local correct_password; - correct_password, state = self.profile.plain(self, authentication, self.realm); + correct_password, state = self.profile.plain(self, authentication, self.realm, authorization); correct = (saslprep(correct_password) == password); elseif self.profile.plain_test then - correct, state = self.profile.plain_test(self, authentication, password, self.realm); + correct, state = self.profile.plain_test(self, authentication, password, self.realm, authorization); end if state == false then diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua index 37abf4a4..ad279999 100644 --- a/util/sasl/scram.lua +++ b/util/sasl/scram.lua @@ -13,13 +13,13 @@ local s_match = string.match; local type = type -local base64 = require "util.encodings".base64; -local hashes = require "util.hashes"; -local generate_uuid = require "util.uuid".generate; -local saslprep = require "util.encodings".stringprep.saslprep; -local nodeprep = require "util.encodings".stringprep.nodeprep; -local log = require "util.logger".init("sasl"); -local binaryXOR = require "util.strbitop".sxor; +local base64 = require "prosody.util.encodings".base64; +local hashes = require "prosody.util.hashes"; +local generate_uuid = require "prosody.util.uuid".generate; +local saslprep = require "prosody.util.encodings".stringprep.saslprep; +local nodeprep = require "prosody.util.encodings".stringprep.nodeprep; +local log = require "prosody.util.logger".init("sasl"); +local binaryXOR = require "prosody.util.strbitop".sxor; local _ENV = nil; -- luacheck: std none @@ -101,7 +101,6 @@ local function scram_gen(hash_name, H_f, HMAC_f, get_auth_db, expect_cb) local client_first_message = message; -- TODO: fail if authzid is provided, since we don't support them yet - -- luacheck: ignore 211/authzid local gs2_header, gs2_cbind_flag, gs2_cbind_name, authzid, client_first_message_bare, username, clientnonce = s_match(client_first_message, "^(([pny])=?([^,]*),([^,]*),)(m?=?[^,]*,?n=([^,]*),r=([^,]*),?.*)$"); @@ -112,8 +111,8 @@ local function scram_gen(hash_name, H_f, HMAC_f, get_auth_db, expect_cb) if support_channel_binding and gs2_cbind_flag == "y" then -- "y" -> client does support channel binding -- but thinks the server does not. - return "failure", "malformed-request"; - end + return "failure", "malformed-request"; + end if gs2_cbind_flag == "n" then -- "n" -> client doesn't support channel binding. @@ -144,7 +143,7 @@ local function scram_gen(hash_name, H_f, HMAC_f, get_auth_db, expect_cb) -- retrieve credentials local stored_key, server_key, salt, iteration_count; if self.profile.plain then - local password, status = self.profile.plain(self, username, self.realm) + local password, status = self.profile.plain(self, username, self.realm, authzid) if status == nil then return "failure", "not-authorized" elseif status == false then return "failure", "account-disabled" end @@ -165,7 +164,7 @@ local function scram_gen(hash_name, H_f, HMAC_f, get_auth_db, expect_cb) end elseif self.profile[profile_name] then local status; - stored_key, server_key, iteration_count, salt, status = self.profile[profile_name](self, username, self.realm); + stored_key, server_key, iteration_count, salt, status = self.profile[profile_name](self, username, self.realm, authzid); if status == nil then return "failure", "not-authorized" elseif status == false then return "failure", "account-disabled" end end @@ -240,7 +239,7 @@ local function init(registerMechanism) -- register channel binding equivalent registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, - scram_gen(hash_name:lower(), hash, hmac_hash, get_auth_db, true), {"tls-unique"}); + scram_gen(hash_name:lower(), hash, hmac_hash, get_auth_db, true), {"tls-unique", "tls-exporter"}); end registerSCRAMMechanism("SHA-1", hashes.sha1, hashes.hmac_sha1, hashes.pbkdf2_hmac_sha1); diff --git a/util/serialization.lua b/util/serialization.lua index d310a3e8..ee4751e2 100644 --- a/util/serialization.lua +++ b/util/serialization.lua @@ -16,15 +16,17 @@ local s_char = string.char; local s_match = string.match; local t_concat = table.concat; -local to_hex = require "util.hex".to; +local to_hex = require "prosody.util.hex".to; local pcall = pcall; -local envload = require"util.envload".envload; +local envload = require"prosody.util.envload".envload; + +if not math.type then + require "prosody.util.mathcompat" +end local pos_inf, neg_inf = math.huge, -math.huge; -local m_type = math.type or function (n) - return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float"; -end; +local m_type = math.type; local function rawpairs(t) return next, t, nil; @@ -94,6 +96,10 @@ local function new(opt) opt.itemlast = opt.itemlast or ""; opt.equals = opt.equals or "="; opt.unquoted = true; + elseif opt.preset == "pretty" then + opt.fatal = false; + opt.freeze = true; + opt.unquoted = true; end local fallback = opt.fallback or opt.fatal == false and nonfatal_fallback or fatal_error; diff --git a/util/session.lua b/util/session.lua index 25b22faf..2d3cbd54 100644 --- a/util/session.lua +++ b/util/session.lua @@ -1,10 +1,12 @@ -local initialize_filters = require "util.filters".initialize; -local logger = require "util.logger"; +local initialize_filters = require "prosody.util.filters".initialize; +local time = require "prosody.util.time"; +local logger = require "prosody.util.logger"; local function new_session(typ) local session = { type = typ .. "_unauthed"; base_type = typ; + since = time.now(); }; return session; end @@ -57,10 +59,16 @@ local function set_send(session) return session; end +local function set_role(session, role) + session.role = role; +end + return { new = new_session; + set_id = set_id; set_logger = set_logger; set_conn = set_conn; set_send = set_send; + set_role = set_role; } diff --git a/util/set.lua b/util/set.lua index 69dfef5d..6c860c3f 100644 --- a/util/set.lua +++ b/util/set.lua @@ -186,7 +186,7 @@ function set_mt.__tostring(set) for item in pairs(items) do s[#s+1] = tostring(item); end - return t_concat(s, ", "); + return "{"..t_concat(s, ", ").."}"; end return { diff --git a/util/smqueue.lua b/util/smqueue.lua index 6d8348d4..0514a979 100644 --- a/util/smqueue.lua +++ b/util/smqueue.lua @@ -1,4 +1,4 @@ -local queue = require("util.queue"); +local queue = require("prosody.util.queue"); local lib = { smqueue = {} } diff --git a/util/sql.lua b/util/sql.lua index 9d1c86ca..c897d734 100644 --- a/util/sql.lua +++ b/util/sql.lua @@ -4,9 +4,9 @@ local ipairs = ipairs; local tostring = tostring; local type = type; local assert, pcall, debug_traceback = assert, pcall, debug.traceback; -local xpcall = require "util.xpcall".xpcall; +local xpcall = require "prosody.util.xpcall".xpcall; local t_concat = table.concat; -local log = require "util.logger".init("sql"); +local log = require "prosody.util.logger".init("sql"); local DBI = require "DBI"; -- This loads all available drivers while globals are unlocked @@ -27,8 +27,6 @@ local function is_column(x) return getmetatable(x)==column_mt; end local function is_index(x) return getmetatable(x)==index_mt; end local function is_table(x) return getmetatable(x)==table_mt; end local function is_query(x) return getmetatable(x)==query_mt; end -local function Integer() return "Integer()" end -local function String() return "String()" end local function Column(definition) return setmetatable(definition, column_mt); @@ -99,6 +97,9 @@ end function engine:onconnect() -- luacheck: ignore 212/self -- Override from create_engine() end +function engine:ondisconnect() -- luacheck: ignore 212/self + -- Override from create_engine() +end function engine:prepquery(sql) if self.params.driver == "MySQL" then @@ -224,6 +225,7 @@ function engine:transaction(...) if not conn or not conn:ping() then log("debug", "Database connection was closed. Will reconnect and retry."); self.conn = nil; + self:ondisconnect(); log("debug", "Retrying SQL transaction [%s]", (...)); ok, ret, b, c = self:_transaction(...); log("debug", "SQL transaction retry %s", ok and "succeeded" or "failed"); @@ -365,8 +367,8 @@ local function db2uri(params) }; end -local function create_engine(_, params, onconnect) - return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt); +local function create_engine(_, params, onconnect, ondisconnect) + return setmetatable({ url = db2uri(params); params = params; onconnect = onconnect; ondisconnect = ondisconnect }, engine_mt); end return { @@ -374,8 +376,6 @@ return { is_index = is_index; is_table = is_table; is_query = is_query; - Integer = Integer; - String = String; Column = Column; Table = Table; Index = Index; diff --git a/util/sqlite3.lua b/util/sqlite3.lua new file mode 100644 index 00000000..470eb46d --- /dev/null +++ b/util/sqlite3.lua @@ -0,0 +1,370 @@ + +local setmetatable, getmetatable = setmetatable, getmetatable; +local ipairs, select = ipairs, select; +local tostring = tostring; +local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback; +local error = error +local type = type +local t_concat = table.concat; +local array = require "prosody.util.array"; +local log = require "prosody.util.logger".init("sql"); + +local lsqlite3 = require "lsqlite3"; +local build_url = require "socket.url".build; + +-- from sqlite3.h, no copyright claimed +local sqlite_errors = require"prosody.util.error".init("util.sqlite3", { + -- FIXME xmpp error conditions? + [1] = { code = 1; type = "modify"; condition = "ERROR"; text = "Generic error" }; + [2] = { code = 2; type = "cancel"; condition = "INTERNAL"; text = "Internal logic error in SQLite" }; + [3] = { code = 3; type = "auth"; condition = "PERM"; text = "Access permission denied" }; + [4] = { code = 4; type = "cancel"; condition = "ABORT"; text = "Callback routine requested an abort" }; + [5] = { code = 5; type = "wait"; condition = "BUSY"; text = "The database file is locked" }; + [6] = { code = 6; type = "wait"; condition = "LOCKED"; text = "A table in the database is locked" }; + [7] = { code = 7; type = "wait"; condition = "NOMEM"; text = "A malloc() failed" }; + [8] = { code = 8; type = "cancel"; condition = "READONLY"; text = "Attempt to write a readonly database" }; + [9] = { code = 9; type = "cancel"; condition = "INTERRUPT"; text = "Operation terminated by sqlite3_interrupt()" }; + [10] = { code = 10; type = "wait"; condition = "IOERR"; text = "Some kind of disk I/O error occurred" }; + [11] = { code = 11; type = "cancel"; condition = "CORRUPT"; text = "The database disk image is malformed" }; + [12] = { code = 12; type = "modify"; condition = "NOTFOUND"; text = "Unknown opcode in sqlite3_file_control()" }; + [13] = { code = 13; type = "wait"; condition = "FULL"; text = "Insertion failed because database is full" }; + [14] = { code = 14; type = "auth"; condition = "CANTOPEN"; text = "Unable to open the database file" }; + [15] = { code = 15; type = "cancel"; condition = "PROTOCOL"; text = "Database lock protocol error" }; + [16] = { code = 16; type = "continue"; condition = "EMPTY"; text = "Internal use only" }; + [17] = { code = 17; type = "modify"; condition = "SCHEMA"; text = "The database schema changed" }; + [18] = { code = 18; type = "modify"; condition = "TOOBIG"; text = "String or BLOB exceeds size limit" }; + [19] = { code = 19; type = "modify"; condition = "CONSTRAINT"; text = "Abort due to constraint violation" }; + [20] = { code = 20; type = "modify"; condition = "MISMATCH"; text = "Data type mismatch" }; + [21] = { code = 21; type = "modify"; condition = "MISUSE"; text = "Library used incorrectly" }; + [22] = { code = 22; type = "cancel"; condition = "NOLFS"; text = "Uses OS features not supported on host" }; + [23] = { code = 23; type = "auth"; condition = "AUTH"; text = "Authorization denied" }; + [24] = { code = 24; type = "modify"; condition = "FORMAT"; text = "Not used" }; + [25] = { code = 25; type = "modify"; condition = "RANGE"; text = "2nd parameter to sqlite3_bind out of range" }; + [26] = { code = 26; type = "cancel"; condition = "NOTADB"; text = "File opened that is not a database file" }; + [27] = { code = 27; type = "continue"; condition = "NOTICE"; text = "Notifications from sqlite3_log()" }; + [28] = { code = 28; type = "continue"; condition = "WARNING"; text = "Warnings from sqlite3_log()" }; + [100] = { code = 100; type = "continue"; condition = "ROW"; text = "sqlite3_step() has another row ready" }; + [101] = { code = 101; type = "continue"; condition = "DONE"; text = "sqlite3_step() has finished executing" }; +}); + +-- luacheck: ignore 411/assert +local assert = function(cond, errno, err) + return assert(sqlite_errors.coerce(cond, err or errno)); +end +local _ENV = nil; +-- luacheck: std none + +local column_mt = {}; +local table_mt = {}; +local query_mt = {}; +--local op_mt = {}; +local index_mt = {}; + +local function is_column(x) return getmetatable(x)==column_mt; end +local function is_index(x) return getmetatable(x)==index_mt; end +local function is_table(x) return getmetatable(x)==table_mt; end +local function is_query(x) return getmetatable(x)==query_mt; end + +local function Column(definition) + return setmetatable(definition, column_mt); +end +local function Table(definition) + local c = {} + for i,col in ipairs(definition) do + if is_column(col) then + c[i], c[col.name] = col, col; + elseif is_index(col) then + col.table = definition.name; + end + end + return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt); +end +local function Index(definition) + return setmetatable(definition, index_mt); +end + +function table_mt:__tostring() + local s = { 'name="'..self.__table__.name..'"' } + for _, col in ipairs(self.__table__) do + s[#s+1] = tostring(col); + end + return 'Table{ '..t_concat(s, ", ")..' }' +end +table_mt.__index = {}; +function table_mt.__index:create(engine) + return engine:_create_table(self); +end +function column_mt:__tostring() + return 'Column{ name="'..self.name..'", type="'..self.type..'" }' +end +function index_mt:__tostring() + local s = 'Index{ name="'..self.name..'"'; + for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end + return s..' }'; +-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }' +end + +local engine = {}; +function engine:connect() + if self.conn then return true; end + + local params = self.params; + assert(params.driver == "SQLite3", "Only sqlite3 is supported"); + local dbh, err = sqlite_errors.coerce(lsqlite3.open(params.database)); + if not dbh then return nil, err; end + self.conn = dbh; + self.prepared = {}; + local ok, err = self:set_encoding(); + if not ok then + return ok, err; + end + local ok, err = self:onconnect(); + if ok == false then + return ok, err; + end + return true; +end +function engine:onconnect() -- luacheck: ignore 212/self + -- Override from create_engine() +end +function engine:ondisconnect() -- luacheck: ignore 212/self + -- Override from create_engine() +end + +function engine:execute(sql, ...) + local success, err = self:connect(); + if not success then return success, err; end + + if select('#', ...) == 0 then + local ret = self.conn:exec(sql); + if ret ~= lsqlite3.OK then + local err = sqlite_errors.new(err); + err.text = self.conn:errmsg(); + return err; + end + return true; + end + + local stmt, err = self.conn:prepare(sql); + if not stmt then + err = sqlite_errors.new(err); + err.text = self.conn:errmsg(); + return stmt, err; + end + + local ret = stmt:bind_values(...); + if ret ~= lsqlite3.OK then + return nil, sqlite_errors.new(ret, { message = self.conn:errmsg() }); + end + return stmt; +end + +local function iterator(table) + local i = 0; + return function() + i = i + 1; + local item = table[i]; + if item ~= nil then + return item; + end + end +end + +local result_mt = { + __len = function(self) + return self.__rowcount; + end; + __index = { + affected = function(self) + return self.__affected; + end; + rowcount = function(self) + return self.__rowcount; + end; + }; + __call = function(self) + return iterator(self.__data); + end; +}; + +local function debugquery(where, sql, ...) + local i = 0; local a = {...} + sql = sql:gsub("\n?\t+", " "); + log("debug", "[%s] %s", where, (sql:gsub("%?", function () + i = i + 1; + local v = a[i]; + if type(v) == "string" then + v = ("'%s'"):format(v:gsub("'", "''")); + end + return tostring(v); + end))); +end + +function engine:execute_update(sql, ...) + local prepared = self.prepared; + local stmt = prepared[sql]; + if stmt and stmt:isopen() then + prepared[sql] = nil; -- Can't be used concurrently + else + stmt = assert(self.conn:prepare(sql)); + end + local ret = stmt:bind_values(...); + if ret ~= lsqlite3.OK then error(self.conn:errmsg()); end + local data = array(); + for row in stmt:rows() do + data:push(array(row)); + end + -- FIXME Error handling, BUSY, ERROR, MISUSE + if stmt:reset() == lsqlite3.OK then + prepared[sql] = stmt; + end + local affected = self.conn:changes(); + return setmetatable({ __affected = affected; __rowcount = #data; __data = data }, result_mt); +end + +function engine:execute_query(sql, ...) + return self:execute_update(sql, ...)() +end + +engine.insert = engine.execute_update; +engine.select = engine.execute_query; +engine.delete = engine.execute_update; +engine.update = engine.execute_update; +local function debugwrap(name, f) + return function (self, sql, ...) + debugquery(name, sql, ...) + return f(self, sql, ...) + end +end +function engine:debug(enable) + self._debug = enable; + if enable then + engine.insert = debugwrap("insert", engine.execute_update); + engine.select = debugwrap("select", engine.execute_query); + engine.delete = debugwrap("delete", engine.execute_update); + engine.update = debugwrap("update", engine.execute_update); + else + engine.insert = engine.execute_update; + engine.select = engine.execute_query; + engine.delete = engine.execute_update; + engine.update = engine.execute_update; + end +end +function engine:_(word) + local ret = self.conn:exec(word); + if ret ~= lsqlite3.OK then return nil, self.conn:errmsg(); end + return true; +end +function engine:_transaction(func, ...) + if not self.conn then + local a,b = self:connect(); + if not a then return a,b; end + end + --assert(not self.__transaction, "Recursive transactions not allowed"); + local ok, err = self:_"BEGIN"; + if not ok then return ok, err; end + self.__transaction = true; + local success, a, b, c = xpcall(func, debug_traceback, ...); + self.__transaction = nil; + if success then + log("debug", "SQL transaction success [%s]", tostring(func)); + local ok, err = self:_"COMMIT"; + if not ok then return ok, err; end -- commit failed + return success, a, b, c; + else + log("debug", "SQL transaction failure [%s]: %s", tostring(func), a); + if self.conn then self:_"ROLLBACK"; end + return success, a; + end +end +function engine:transaction(...) + local ok, ret = self:_transaction(...); + if not ok then + local conn = self.conn; + if not conn or not conn:isopen() then + self.conn = nil; + self:ondisconnect(); + ok, ret = self:_transaction(...); + end + end + return ok, ret; +end +function engine:_create_index(index) + local sql = "CREATE INDEX IF NOT EXISTS \""..index.name.."\" ON \""..index.table.."\" ("; + for i=1,#index do + sql = sql.."\""..index[i].."\""; + if i ~= #index then sql = sql..", "; end + end + sql = sql..");" + if index.unique then + sql = sql:gsub("^CREATE", "CREATE UNIQUE"); + end + if self._debug then + debugquery("create", sql); + end + return self:execute(sql); +end +function engine:_create_table(table) + local sql = "CREATE TABLE IF NOT EXISTS \""..table.name.."\" ("; + for i,col in ipairs(table.c) do + local col_type = col.type; + sql = sql.."\""..col.name.."\" "..col_type; + if col.nullable == false then sql = sql.." NOT NULL"; end + if col.primary_key == true then sql = sql.." PRIMARY KEY"; end + if col.auto_increment == true then + sql = sql.." AUTOINCREMENT"; + end + if i ~= #table.c then sql = sql..", "; end + end + sql = sql.. ");" + if self._debug then + debugquery("create", sql); + end + local success,err = self:execute(sql); + if not success then return success,err; end + for _, v in ipairs(table.__table__) do + if is_index(v) then + self:_create_index(v); + end + end + return success; +end + +function engine:set_encoding() -- to UTF-8 + return self:transaction(function() + for encoding in self:select "PRAGMA encoding;" do + if encoding[1] == "UTF-8" then + self.charset = "utf8"; + end + end + end); +end +local engine_mt = { __index = engine }; + +local function db2uri(params) + return build_url{ + scheme = params.driver, + user = params.username, + password = params.password, + host = params.host, + port = params.port, + path = params.database, + }; +end + +local function create_engine(_, params, onconnect, ondisconnect) + assert(params.driver == "SQLite3", "Only SQLite3 is supported without LuaDBI"); + return setmetatable({ url = db2uri(params); params = params; onconnect = onconnect; ondisconnect = ondisconnect }, engine_mt); +end + +return { + is_column = is_column; + is_index = is_index; + is_table = is_table; + is_query = is_query; + Column = Column; + Table = Table; + Index = Index; + create_engine = create_engine; + db2uri = db2uri; +}; diff --git a/util/sslconfig.lua b/util/sslconfig.lua index 6074a1fb..01a8adb5 100644 --- a/util/sslconfig.lua +++ b/util/sslconfig.lua @@ -3,9 +3,12 @@ local type = type; local pairs = pairs; local rawset = rawset; +local rawget = rawget; +local error = error; local t_concat = table.concat; local t_insert = table.insert; local setmetatable = setmetatable; +local resolve_path = require"prosody.util.paths".resolve_relative_path; local _ENV = nil; -- luacheck: std none @@ -34,7 +37,7 @@ function handlers.options(config, field, new) options[value] = true; end end - config[field] = options; + rawset(config, field, options) end handlers.verifyext = handlers.options; @@ -70,6 +73,30 @@ finalisers.curveslist = finalisers.ciphers; -- TLS 1.3 ciphers finalisers.ciphersuites = finalisers.ciphers; +-- Path expansion +function finalisers.key(path, config) + if type(path) == "string" then + return resolve_path(config._basedir, path); + else + return nil + end +end +finalisers.certificate = finalisers.key; +finalisers.cafile = finalisers.key; +finalisers.capath = finalisers.key; + +function finalisers.dhparam(value, config) + if type(value) == "string" then + if value:sub(1, 10) == "-----BEGIN" then + -- literal value + return value; + else + -- assume a filename + return resolve_path(config._basedir, value); + end + end +end + -- protocol = "x" should enable only that protocol -- protocol = "x+" should enable x and later versions @@ -89,37 +116,81 @@ end -- Merge options from 'new' config into 'config' local function apply(config, new) + rawset(config, "_cache", nil); if type(new) == "table" then for field, value in pairs(new) do - (handlers[field] or rawset)(config, field, value); + -- exclude keys which are internal to the config builder + if field:sub(1, 1) ~= "_" then + (handlers[field] or rawset)(config, field, value); + end end end + return config end -- Finalize the config into the form LuaSec expects local function final(config) local output = { }; for field, value in pairs(config) do - output[field] = (finalisers[field] or id)(value); + -- exclude keys which are internal to the config builder + if field:sub(1, 1) ~= "_" then + output[field] = (finalisers[field] or id)(value, config); + end end -- Need to handle protocols last because it adds to the options list protocol(output); return output; end +local function build(config) + local cached = rawget(config, "_cache"); + if cached then + return cached, nil + end + + local ctx, err = rawget(config, "_context_factory")(config:final(), config); + if ctx then + rawset(config, "_cache", ctx); + end + return ctx, err +end + local sslopts_mt = { __index = { apply = apply; final = final; + build = build; }; + __newindex = function() + error("SSL config objects cannot be modified directly. Use :apply()") + end; }; -local function new() - return setmetatable({options={}}, sslopts_mt); + +-- passing basedir through everything is required to avoid sslconfig depending +-- on prosody.paths.config +local function new(context_factory, basedir) + return setmetatable({ + _context_factory = context_factory, + _basedir = basedir, + options={}, + }, sslopts_mt); end +local function clone(config) + local result = new(); + for k, v in pairs(config) do + -- note that we *do* copy the internal keys on clone -- we have to carry + -- both the factory and the cache with us + rawset(result, k, v); + end + return result +end + +sslopts_mt.__index.clone = clone; + return { apply = apply; final = final; - new = new; + _new = new; }; diff --git a/util/stanza.lua b/util/stanza.lua index 86b88169..2f4e964a 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -21,12 +21,15 @@ local type = type; local s_gsub = string.gsub; local s_sub = string.sub; local s_find = string.find; +local t_move = table.move or require "prosody.util.table".move; +local t_create = require"prosody.util.table".create; -local valid_utf8 = require "util.encodings".utf8.valid; +local valid_utf8 = require "prosody.util.encodings".utf8.valid; -local do_pretty_printing, termcolours = pcall(require, "util.termcolours"); +local do_pretty_printing, termcolours = pcall(require, "prosody.util.termcolours"); local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas"; +local xmpp_stanzas_attr = { xmlns = xmlns_stanzas }; local _ENV = nil; -- luacheck: std none @@ -179,6 +182,14 @@ function stanza_mt:get_child_text(name, xmlns) return nil; end +function stanza_mt:get_child_attr(name, xmlns, attr) + local tag = self:get_child(name, xmlns); + if tag then + return tag.attr[attr]; + end + return nil; +end + function stanza_mt:child_with_name(name) for _, child in ipairs(self.tags) do if child.name == name then return child; end @@ -266,6 +277,13 @@ function stanza_mt:find(path) local xmlns, name, text; local char = s_sub(path, pos, pos); if char == "@" then + if s_sub(path, pos + 1, pos + 1) == "{" then + return self.attr[s_gsub(s_sub(path, pos+2), "}", "\1", 1)]; + end + local prefix, attr = s_match(path, "^([^:]+):(.*)", pos+1); + if prefix and self.namespaces and self.namespaces[prefix] then + return self.attr[self.namespaces[prefix] .. "\1" .. attr]; + end return self.attr[s_sub(path, pos + 1)]; elseif char == "{" then xmlns, pos = s_match(path, "^([^}]+)}()", pos + 1); @@ -283,25 +301,33 @@ function stanza_mt:find(path) end local function _clone(stanza, only_top) - local attr, tags = {}, {}; + local attr = {}; for k,v in pairs(stanza.attr) do attr[k] = v; end local old_namespaces, namespaces = stanza.namespaces; if old_namespaces then namespaces = {}; for k,v in pairs(old_namespaces) do namespaces[k] = v; end end - local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags }; + local tags, new; + if only_top then + tags = {}; + new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags }; + else + tags = t_create(#stanza.tags, 0); + new = t_create(#stanza, 4); + new.name = stanza.name; + new.attr = attr; + new.namespaces = namespaces; + new.tags = tags; + end + + setmetatable(new, stanza_mt); if not only_top then - for i=1,#stanza do - local child = stanza[i]; - if child.name then - child = _clone(child); - t_insert(tags, child); - end - t_insert(new, child); - end + t_move(stanza, 1, #stanza, 1, new); + t_move(stanza.tags, 1, #stanza.tags, 1, tags); + new:maptags(_clone); end - return setmetatable(new, stanza_mt); + return new; end local function clone(stanza, only_top) @@ -387,6 +413,33 @@ function stanza_mt.get_error(stanza) return error_type, condition or "undefined-condition", text, extra_tag; end +function stanza_mt.add_error(stanza, error_type, condition, error_message, error_by) + local extra; + if type(error_type) == "table" then -- an util.error or similar object + if type(error_type.extra) == "table" then + extra = error_type.extra; + end + if type(error_type.context) == "table" and type(error_type.context.by) == "string" then error_by = error_type.context.by; end + error_type, condition, error_message = error_type.type, error_type.condition, error_type.text; + end + if stanza.attr.from == error_by then + error_by = nil; + end + stanza:tag("error", {type = error_type, by = error_by}) --COMPAT: Some day xmlns:stanzas goes here + :tag(condition, xmpp_stanzas_attr); + if extra and condition == "gone" and type(extra.uri) == "string" then + stanza:text(extra.uri); + end + stanza:up(); + if error_message then stanza:text_tag("text", error_message, xmpp_stanzas_attr); end + if extra and is_stanza(extra.tag) then + stanza:add_child(extra.tag); + elseif extra and extra.namespace and extra.condition then + stanza:tag(extra.condition, { xmlns = extra.namespace }):up(); + end + return stanza:up(); +end + local function preserialize(stanza) local s = { name = stanza.name, attr = stanza.attr }; for _, child in ipairs(stanza) do @@ -461,7 +514,6 @@ local function reply(orig) }); end -local xmpp_stanzas_attr = { xmlns = xmlns_stanzas }; local function error_reply(orig, error_type, condition, error_message, error_by) if not is_stanza(orig) then error("bad argument to error_reply: expected stanza, got "..type(orig)); @@ -470,30 +522,9 @@ local function error_reply(orig, error_type, condition, error_message, error_by) end local t = reply(orig); t.attr.type = "error"; - local extra; - if type(error_type) == "table" then -- an util.error or similar object - if type(error_type.extra) == "table" then - extra = error_type.extra; - end - if type(error_type.context) == "table" and type(error_type.context.by) == "string" then error_by = error_type.context.by; end - error_type, condition, error_message = error_type.type, error_type.condition, error_type.text; - end - if t.attr.from == error_by then - error_by = nil; - end - t:tag("error", {type = error_type, by = error_by}) --COMPAT: Some day xmlns:stanzas goes here - :tag(condition, xmpp_stanzas_attr); - if extra and condition == "gone" and type(extra.uri) == "string" then - t:text(extra.uri); - end - t:up(); - if error_message then t:text_tag("text", error_message, xmpp_stanzas_attr); end - if extra and is_stanza(extra.tag) then - t:add_child(extra.tag); - elseif extra and extra.namespace and extra.condition then - t:tag(extra.condition, { xmlns = extra.namespace }):up(); - end - return t; -- stanza ready for adding app-specific errors + t:add_error(error_type, condition, error_message, error_by); + t.last_add = { t[1] }; -- ready to add application-specific errors + return t; end local function presence(attr) diff --git a/util/startup.lua b/util/startup.lua index 545b6ae7..caae895d 100644 --- a/util/startup.lua +++ b/util/startup.lua @@ -2,15 +2,15 @@ -- luacheck: ignore 113/CFG_CONFIGDIR 113/CFG_SOURCEDIR 113/CFG_DATADIR 113/CFG_PLUGINDIR local startup = {}; -local prosody = { events = require "util.events".new() }; -local logger = require "util.logger"; +local prosody = { events = require "prosody.util.events".new() }; +local logger = require "prosody.util.logger"; local log = logger.init("startup"); -local parse_args = require "util.argparse".parse; +local parse_args = require "prosody.util.argparse".parse; -local config = require "core.configmanager"; +local config = require "prosody.core.configmanager"; local config_warnings; -local dependencies = require "util.dependencies"; +local dependencies = require "prosody.util.dependencies"; local original_logging_config; @@ -132,14 +132,14 @@ end function startup.load_libraries() -- Load socket framework -- luacheck: ignore 111/server 111/socket - require "util.import" + require "prosody.util.import" socket = require "socket"; - server = require "net.server" + server = require "prosody.net.server" end function startup.init_logging() -- Initialize logging - local loggingmanager = require "core.loggingmanager" + local loggingmanager = require "prosody.core.loggingmanager" loggingmanager.reload_logging(); prosody.events.add_handler("config-reloaded", function () prosody.events.fire_event("reopen-log-files"); @@ -233,7 +233,7 @@ function startup.set_function_metatable() if info.isvararg then info[n_params+1] = "..."; end - return ("function<%s:%d>(%s)"):format(info.short_src:match("[^\\/]*$"), info.linedefined, table.concat(info, ", ")); + return ("function @%s:%d(%s)"):format(info.short_src:match("[^\\/]*$"), info.linedefined, table.concat(info, ", ")); end debug.setmetatable(function() end, mt); end @@ -277,6 +277,11 @@ function startup.init_global_state() startup.detect_platform(); startup.detect_installed(); _G.prosody = prosody; + + -- COMPAT Lua < 5.3 + if not math.type then + require "prosody.util.mathcompat" + end end function startup.setup_datadir() @@ -298,7 +303,7 @@ function startup.setup_plugin_install_path() local installer_plugin_path = config.get("*", "installer_plugin_path") or "custom_plugins"; local path_sep = package.config:sub(3,3); installer_plugin_path = config.resolve_relative_path(CFG_DATADIR or "data", installer_plugin_path); - require"util.paths".complement_lua_path(installer_plugin_path); + require"prosody.util.paths".complement_lua_path(installer_plugin_path); -- luacheck: ignore 111 CFG_PLUGINDIR = installer_plugin_path..path_sep..(CFG_PLUGINDIR or "plugins"); prosody.paths.installer = installer_plugin_path; @@ -360,28 +365,23 @@ end function startup.load_secondary_libraries() --- Load and initialise core modules - require "util.xmppstream" - require "core.stanza_router" - require "core.statsmanager" - require "core.hostmanager" - require "core.portmanager" - require "core.modulemanager" - require "core.usermanager" - require "core.rostermanager" - require "core.sessionmanager" - package.loaded['core.componentmanager'] = setmetatable({},{__index=function() - -- COMPAT which version? - log("warn", "componentmanager is deprecated: %s", debug.traceback():match("\n[^\n]*\n[ \t]*([^\n]*)")); - return function() end - end}); - - require "util.array" - require "util.datetime" - require "util.iterators" - require "util.timer" - require "util.helpers" - - pcall(require, "util.signal") -- Not on Windows + require "prosody.util.xmppstream" + require "prosody.core.stanza_router" + require "prosody.core.statsmanager".metric("gauge", "prosody_info", "", "Prosody version", { "version" }):with_labels(prosody.version):set(1); + require "prosody.core.hostmanager" + require "prosody.core.portmanager" + require "prosody.core.modulemanager" + require "prosody.core.usermanager" + require "prosody.core.rostermanager" + require "prosody.core.sessionmanager" + + require "prosody.util.array" + require "prosody.util.datetime" + require "prosody.util.iterators" + require "prosody.util.timer" + require "prosody.util.helpers" + + pcall(require, "prosody.util.signal") -- Not on Windows -- Commented to protect us from -- the second kind of people @@ -390,43 +390,83 @@ function startup.load_secondary_libraries() if remdebug then remdebug.engine.start() end ]] - require "util.stanza" - require "util.jid" + require "prosody.util.stanza" + require "prosody.util.jid" + + prosody.features = require "prosody.core.features".available; end function startup.init_http_client() - local http = require "net.http" + local http = require "prosody.net.http" local config_ssl = config.get("*", "ssl") or {} local https_client = config.get("*", "client_https_ssl") - http.default.options.sslctx = require "core.certmanager".create_context("client_https port 0", "client", + http.default.options.sslctx = require "prosody.core.certmanager".create_context("client_https port 0", "client", { capath = config_ssl.capath, cafile = config_ssl.cafile, verify = "peer", }, https_client); http.default.options.use_dane = config.get("*", "use_dane") end function startup.init_promise() - local promise = require "util.promise"; + local promise = require "prosody.util.promise"; - local timer = require "util.timer"; + local timer = require "prosody.util.timer"; promise.set_nexttick(function(f) return timer.add_task(0, f); end); end function startup.init_async() - local async = require "util.async"; + local async = require "prosody.util.async"; - local timer = require "util.timer"; + local timer = require "prosody.util.timer"; async.set_nexttick(function(f) return timer.add_task(0, f); end); async.set_schedule_function(timer.add_task); end function startup.init_data_store() - require "core.storagemanager"; + require "prosody.core.storagemanager"; end +local running_state = require "prosody.util.fsm".new({ + default_state = "uninitialized"; + transitions = { + { name = "begin_startup", from = "uninitialized", to = "starting" }; + { name = "finish_startup", from = "starting", to = "running" }; + { name = "begin_shutdown", from = { "running", "starting" }, to = "stopping" }; + { name = "finish_shutdown", from = "stopping", to = "stopped" }; + }; + handlers = { + transitioned = function (transition) + prosody.state = transition.to; + end; + }; + state_handlers = { + starting = function () + prosody.log("debug", "Firing server-starting event"); + prosody.events.fire_event("server-starting"); + prosody.start_time = os.time(); + end; + running = function () + prosody.log("debug", "Startup complete, firing server-started"); + prosody.events.fire_event("server-started"); + end; + }; +}):init(); + function startup.prepare_to_start() log("info", "Prosody is using the %s backend for connection handling", server.get_backend()); + -- Signal to modules that we are ready to start - prosody.events.fire_event("server-starting"); - prosody.start_time = os.time(); + prosody.started = require "prosody.util.promise".new(function (resolve) + if prosody.state == "running" then + resolve(); + else + prosody.events.add_handler("server-started", function () + resolve(); + end); + end + end):catch(function (err) + prosody.log("error", "Prosody startup error: %s", err); + end); + + running_state:begin_startup(); end function startup.init_global_protection() @@ -460,7 +500,7 @@ function startup.read_version() prosody.version = "hg:"..prosody.version; end else - local hg = require"util.mercurial"; + local hg = require"prosody.util.mercurial"; local hgid = hg.check_id(CFG_SOURCEDIR or "."); if hgid then prosody.version = "hg:" .. hgid; end end @@ -471,7 +511,7 @@ function startup.log_greeting() end function startup.notify_started() - prosody.events.fire_event("server-started"); + running_state:finish_startup(); end -- Override logging config (used by prosodyctl) @@ -491,21 +531,30 @@ function startup.force_console_logging() config.set("*", "log", { { levels = { min = log_level or "info" }, to = "console" } }); end +local function check_posix() + if prosody.platform ~= "posix" then return end + + local want_pposix_version = "0.4.1"; + local have_pposix, pposix = pcall(require, "prosody.util.pposix"); + + if pposix._VERSION ~= want_pposix_version then + print(string.format("Unknown version (%s) of binary pposix module, expected %s", + tostring(pposix._VERSION), want_pposix_version)); + os.exit(1); + end + if have_pposix and pposix then + return pposix; + end +end + function startup.switch_user() -- Switch away from root and into the prosody user -- -- NOTE: This function is only used by prosodyctl. -- The prosody process is built with the assumption that -- it is already started as the appropriate user. - local want_pposix_version = "0.4.0"; - local have_pposix, pposix = pcall(require, "util.pposix"); - - if have_pposix and pposix then - if pposix._VERSION ~= want_pposix_version then - print(string.format("Unknown version (%s) of binary pposix module, expected %s", - tostring(pposix._VERSION), want_pposix_version)); - os.exit(1); - end + local pposix = check_posix() + if pposix then prosody.current_uid = pposix.getuid(); local arg_root = prosody.opts.root; if prosody.current_uid == 0 and config.get("*", "run_as_root") ~= true and not arg_root then @@ -594,7 +643,7 @@ end function startup.init_gc() -- Apply garbage collector settings from the config file - local gc = require "util.gc"; + local gc = require "prosody.util.gc"; local gc_settings = config.get("*", "gc") or { mode = default_gc_params.mode }; local ok, err = gc.configure(gc_settings, default_gc_params); @@ -606,7 +655,7 @@ function startup.init_gc() end function startup.init_errors() - require "util.error".configure(config.get("*", "error_library") or {}); + require "prosody.util.error".configure(config.get("*", "error_library") or {}); end function startup.make_host(hostname) @@ -615,7 +664,7 @@ function startup.make_host(hostname) events = prosody.events, modules = {}, sessions = {}, - users = require "core.usermanager".new_null_provider(hostname) + users = require "prosody.core.usermanager".new_null_provider(hostname) }; end @@ -631,17 +680,183 @@ function startup.make_dummy_hosts() end end +function startup.posix_umask() + if prosody.platform ~= "posix" then return end + local pposix = require "prosody.util.pposix"; + local umask = config.get("*", "umask") or "027"; + pposix.umask(umask); +end + +function startup.check_user() + local pposix = check_posix(); + if not pposix then return end + -- Don't even think about it! + if pposix.getuid() == 0 and not config.get("*", "run_as_root") then + print("Danger, Will Robinson! Prosody doesn't need to be run as root, so don't do it!"); + print("For more information on running Prosody as root, see https://prosody.im/doc/root"); + os.exit(1); -- Refusing to run as root + end +end + +local function remove_pidfile() + local pidfile = prosody.pidfile; + if prosody.pidfile_handle then + prosody.pidfile_handle:close(); + os.remove(pidfile); + prosody.pidfile, prosody.pidfile_handle = nil, nil; + end +end + +function startup.write_pidfile() + local pposix = check_posix(); + if not pposix then return end + local lfs = require "lfs"; + local stat = lfs.attributes; + local pidfile = config.get("*", "pidfile") or nil; + if not pidfile then return end + pidfile = config.resolve_relative_path(prosody.paths.data, pidfile); + local mode = stat(pidfile) and "r+" or "w+"; + local pidfile_handle, err = io.open(pidfile, mode); + if not pidfile_handle then + log("error", "Couldn't write pidfile at %s; %s", pidfile, err); + os.exit(1); + else + prosody.pidfile = pidfile; + if not lfs.lock(pidfile_handle, "w") then -- Exclusive lock + local other_pid = pidfile_handle:read("*a"); + log("error", "Another Prosody instance seems to be running with PID %s, quitting", other_pid); + prosody.pidfile_handle = nil; + os.exit(1); + else + pidfile_handle:close(); + pidfile_handle, err = io.open(pidfile, "w+"); + if not pidfile_handle then + log("error", "Couldn't write pidfile at %s; %s", pidfile, err); + os.exit(1); + else + if lfs.lock(pidfile_handle, "w") then + pidfile_handle:write(tostring(pposix.getpid())); + pidfile_handle:flush(); + prosody.pidfile_handle = pidfile_handle; + end + end + end + end + prosody.events.add_handler("server-stopped", remove_pidfile); +end + +local function remove_log_sinks() + local lm = require "prosody.core.loggingmanager"; + lm.register_sink_type("console", nil); + lm.register_sink_type("stdout", nil); + lm.reload_logging(); +end + +function startup.posix_daemonize() + if not prosody.opts.daemonize then return end + local pposix = check_posix(); + log("info", "Prosody is about to detach from the console, disabling further console output"); + remove_log_sinks(); + local ok, ret = pposix.daemonize(); + if not ok then + log("error", "Failed to daemonize: %s", ret); + elseif ret and ret > 0 then + os.exit(0); + else + log("info", "Successfully daemonized to PID %d", pposix.getpid()); + end +end + +function startup.hook_posix_signals() + if prosody.platform ~= "posix" then return end + local have_signal, signal = pcall(require, "prosody.util.signal"); + if not have_signal then + log("warn", "Couldn't load signal library, won't respond to SIGTERM"); + return + end + signal.signal("SIGTERM", function() + log("warn", "Received SIGTERM"); + prosody.main_thread:run(function() + prosody.unlock_globals(); + prosody.shutdown("Received SIGTERM"); + prosody.lock_globals(); + end); + end); + + signal.signal("SIGHUP", function() + log("info", "Received SIGHUP"); + prosody.main_thread:run(function() prosody.reload_config(); end); + -- this also reloads logging + end); + + signal.signal("SIGINT", function() + log("info", "Received SIGINT"); + prosody.main_thread:run(function() + prosody.unlock_globals(); + prosody.shutdown("Received SIGINT"); + prosody.lock_globals(); + end); + end); + + signal.signal("SIGUSR1", function() + log("info", "Received SIGUSR1"); + prosody.events.fire_event("signal/SIGUSR1"); + end); + + signal.signal("SIGUSR2", function() + log("info", "Received SIGUSR2"); + prosody.events.fire_event("signal/SIGUSR2"); + end); +end + +function startup.systemd_notify() + local notify_socket_name = os.getenv("NOTIFY_SOCKET"); + if not notify_socket_name then return end + local have_unix, unix = pcall(require, "socket.unix"); + if not have_unix or type(unix) ~= "table" then + log("error", "LuaSocket without UNIX socket support, can't notify systemd.") + return os.exit(1); + end + log("debug", "Will notify on socket %q", notify_socket_name); + notify_socket_name = notify_socket_name:gsub("^@", "\0"); + local notify_socket = unix.dgram(); + local ok, err = notify_socket:setpeername(notify_socket_name); + if not ok then + log("error", "Could not connect to systemd notification socket %q: %q", notify_socket_name, err); + return os.exit(1); + end + local time = require "prosody.util.time"; + + prosody.notify_socket = notify_socket; + prosody.events.add_handler("server-started", function() + notify_socket:send("READY=1"); + end); + prosody.events.add_handler("reloading-config", function() + notify_socket:send(string.format("RELOADING=1\nMONOTONIC_USEC=%d", math.floor(time.monotonic() * 1000000))); + end); + prosody.events.add_handler("config-reloaded", function() + notify_socket:send("READY=1"); + end); + prosody.events.add_handler("server-stopping", function() + notify_socket:send("STOPPING=1"); + end); +end + function startup.cleanup() prosody.log("info", "Shutdown status: Cleaning up"); prosody.events.fire_event("server-cleanup"); end function startup.shutdown() + running_state:begin_shutdown(); + prosody.log("info", "Shutting down..."); startup.cleanup(); prosody.events.fire_event("server-stopped"); - prosody.log("info", "Shutdown complete"); + running_state:finish_shutdown(); + + prosody.log("info", "Shutdown complete"); prosody.log("debug", "Shutdown reason was: %s", prosody.shutdown_reason or "not specified"); prosody.log("debug", "Exiting with status code: %d", prosody.shutdown_code or 0); server.setquitting(true); @@ -682,6 +897,7 @@ function startup.prosody() startup.parse_args(); startup.init_global_state(); startup.read_config(); + startup.check_user(); startup.init_logging(); startup.init_gc(); startup.init_errors(); @@ -704,6 +920,10 @@ function startup.prosody() startup.init_http_client(); startup.init_data_store(); startup.init_global_protection(); + startup.posix_daemonize(); + startup.write_pidfile(); + startup.hook_posix_signals(); + startup.systemd_notify(); startup.prepare_to_start(); startup.notify_started(); end diff --git a/util/statistics.lua b/util/statistics.lua index cb6481c5..a0b5cea9 100644 --- a/util/statistics.lua +++ b/util/statistics.lua @@ -1,6 +1,6 @@ -local time = require "util.time".now; -local new_metric_registry = require "util.openmetrics".new_metric_registry; -local render_histogram_le = require "util.openmetrics".render_histogram_le; +local time = require "prosody.util.time".now; +local new_metric_registry = require "prosody.util.openmetrics".new_metric_registry; +local render_histogram_le = require "prosody.util.openmetrics".render_histogram_le; -- BEGIN of Metric implementations diff --git a/util/statsd.lua b/util/statsd.lua index 581f945a..6c035dfe 100644 --- a/util/statsd.lua +++ b/util/statsd.lua @@ -1,10 +1,10 @@ local socket = require "socket"; -local time = require "util.time".now; -local array = require "util.array"; +local time = require "prosody.util.time".now; +local array = require "prosody.util.array"; local t_concat = table.concat; -local new_metric_registry = require "util.openmetrics".new_metric_registry; -local render_histogram_le = require "util.openmetrics".render_histogram_le; +local new_metric_registry = require "prosody.util.openmetrics".new_metric_registry; +local render_histogram_le = require "prosody.util.openmetrics".render_histogram_le; -- BEGIN of Metric implementations diff --git a/util/template.lua b/util/template.lua index c11037c5..d14c4c28 100644 --- a/util/template.lua +++ b/util/template.lua @@ -1,13 +1,13 @@ -- luacheck: ignore 213/i -local stanza_mt = require "util.stanza".stanza_mt; +local stanza_mt = require "prosody.util.stanza".stanza_mt; local setmetatable = setmetatable; local pairs = pairs; local ipairs = ipairs; local error = error; -local envload = require "util.envload".envload; +local envload = require "prosody.util.envload".envload; local debug = debug; local t_remove = table.remove; -local parse_xml = require "util.xml".parse; +local parse_xml = require "prosody.util.xml".parse; local _ENV = nil; -- luacheck: std none diff --git a/util/termcolours.lua b/util/termcolours.lua index 2c13d0ff..7b6b24a7 100644 --- a/util/termcolours.lua +++ b/util/termcolours.lua @@ -21,7 +21,7 @@ local pairs = pairs; local windows; if os.getenv("WINDIR") then - windows = require "util.windows"; + windows = require "prosody.util.windows"; end local orig_color = windows and windows.get_consolecolor and windows.get_consolecolor(); diff --git a/util/throttle.lua b/util/throttle.lua index d2036e9e..c07bcf68 100644 --- a/util/throttle.lua +++ b/util/throttle.lua @@ -1,5 +1,5 @@ -local gettime = require "util.time".now +local gettime = require "prosody.util.time".now local setmetatable = setmetatable; local _ENV = nil; diff --git a/util/timer.lua b/util/timer.lua index 84da02cf..532bf112 100644 --- a/util/timer.lua +++ b/util/timer.lua @@ -6,14 +6,14 @@ -- COPYING file in the source package for more information. -- -local indexedbheap = require "util.indexedbheap"; -local log = require "util.logger".init("timer"); -local server = require "net.server"; -local get_time = require "util.time".now +local indexedbheap = require "prosody.util.indexedbheap"; +local log = require "prosody.util.logger".init("timer"); +local server = require "prosody.net.server"; +local get_time = require "prosody.util.time".now local type = type; local debug_traceback = debug.traceback; local tostring = tostring; -local xpcall = require "util.xpcall".xpcall; +local xpcall = require "prosody.util.xpcall".xpcall; local math_max = math.max; local pairs = pairs; diff --git a/util/uuid.lua b/util/uuid.lua index 54ea99b4..a70750bb 100644 --- a/util/uuid.lua +++ b/util/uuid.lua @@ -6,10 +6,12 @@ -- COPYING file in the source package for more information. -- -local random = require "util.random"; +local random = require "prosody.util.random"; local random_bytes = random.bytes; -local hex = require "util.hex".encode; +local time = require "prosody.util.time"; +local hex = require "prosody.util.hex".encode; local m_ceil = math.ceil; +local m_floor = math.floor; local function get_nibbles(n) return hex(random_bytes(m_ceil(n/2))):sub(1, n); @@ -24,7 +26,22 @@ local function generate() return get_nibbles(8).."-"..get_nibbles(4).."-4"..get_nibbles(3).."-"..(get_twobits())..get_nibbles(3).."-"..get_nibbles(12); end +local function generate_v7() + -- Sortable based on time and random + -- https://datatracker.ietf.org/doc/html/draft-peabody-dispatch-new-uuid-format-01#section-4.4 + local t = time.now(); + local unixts = m_floor(t); + local unixts_a = m_floor(unixts / 16); + local unixts_b = m_floor(unixts % 16); + local subsec = t % 1; + local subsec_a = m_floor(subsec * 0x1000); + local subsec_b = m_floor(subsec * 0x1000000) % 0x1000; + return ("%08x-%x%03x-7%03x-%4s-%12s"):format(unixts_a, unixts_b, subsec_a, subsec_b, get_twobits() .. get_nibbles(3), get_nibbles(12)); +end + return { + v4 = generate; + v7 = generate_v7; get_nibbles=get_nibbles; generate = generate ; -- COMPAT diff --git a/util/vcard.lua b/util/vcard.lua deleted file mode 100644 index e311f73f..00000000 --- a/util/vcard.lua +++ /dev/null @@ -1,574 +0,0 @@ --- Copyright (C) 2011-2014 Kim Alvefur --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - --- TODO --- Fix folding. - -local st = require "util.stanza"; -local t_insert, t_concat = table.insert, table.concat; -local type = type; -local pairs, ipairs = pairs, ipairs; - -local from_text, to_text, from_xep54, to_xep54; - -local line_sep = "\n"; - -local vCard_dtd; -- See end of file -local vCard4_dtd; - -local function vCard_esc(s) - return s:gsub("[,:;\\]", "\\%1"):gsub("\n","\\n"); -end - -local function vCard_unesc(s) - return s:gsub("\\?[\\nt:;,]", { - ["\\\\"] = "\\", - ["\\n"] = "\n", - ["\\r"] = "\r", - ["\\t"] = "\t", - ["\\:"] = ":", -- FIXME Shouldn't need to escape : in values, just params - ["\\;"] = ";", - ["\\,"] = ",", - [":"] = "\29", - [";"] = "\30", - [","] = "\31", - }); -end - -local function item_to_xep54(item) - local t = st.stanza(item.name, { xmlns = "vcard-temp" }); - - local prop_def = vCard_dtd[item.name]; - if prop_def == "text" then - t:text(item[1]); - elseif type(prop_def) == "table" then - if prop_def.types and item.TYPE then - if type(item.TYPE) == "table" then - for _,v in pairs(prop_def.types) do - for _,typ in pairs(item.TYPE) do - if typ:upper() == v then - t:tag(v):up(); - break; - end - end - end - else - t:tag(item.TYPE:upper()):up(); - end - end - - if prop_def.props then - for _,prop in pairs(prop_def.props) do - if item[prop] then - for _, v in ipairs(item[prop]) do - t:text_tag(prop, v); - end - end - end - end - - if prop_def.value then - t:text_tag(prop_def.value, item[1]); - elseif prop_def.values then - local prop_def_values = prop_def.values; - local repeat_last = prop_def_values.behaviour == "repeat-last" and prop_def_values[#prop_def_values]; - for i=1,#item do - t:text_tag(prop_def.values[i] or repeat_last, item[i]); - end - end - end - - return t; -end - -local function vcard_to_xep54(vCard) - local t = st.stanza("vCard", { xmlns = "vcard-temp" }); - for i=1,#vCard do - t:add_child(item_to_xep54(vCard[i])); - end - return t; -end - -function to_xep54(vCards) - if not vCards[1] or vCards[1].name then - return vcard_to_xep54(vCards) - else - local t = st.stanza("xCard", { xmlns = "vcard-temp" }); - for i=1,#vCards do - t:add_child(vcard_to_xep54(vCards[i])); - end - return t; - end -end - -function from_text(data) - data = data -- unfold and remove empty lines - :gsub("\r\n","\n") - :gsub("\n ", "") - :gsub("\n\n+","\n"); - local vCards = {}; - local current; - for line in data:gmatch("[^\n]+") do - line = vCard_unesc(line); - local name, params, value = line:match("^([-%a]+)(\30?[^\29]*)\29(.*)$"); - value = value:gsub("\29",":"); - if #params > 0 then - local _params = {}; - for k,isval,v in params:gmatch("\30([^=]+)(=?)([^\30]*)") do - k = k:upper(); - local _vt = {}; - for _p in v:gmatch("[^\31]+") do - _vt[#_vt+1]=_p - _vt[_p]=true; - end - if isval == "=" then - _params[k]=_vt; - else - _params[k]=true; - end - end - params = _params; - end - if name == "BEGIN" and value == "VCARD" then - current = {}; - vCards[#vCards+1] = current; - elseif name == "END" and value == "VCARD" then - current = nil; - elseif current and vCard_dtd[name] then - local dtd = vCard_dtd[name]; - local item = { name = name }; - t_insert(current, item); - local up = current; - current = item; - if dtd.types then - for _, t in ipairs(dtd.types) do - t = t:lower(); - if ( params.TYPE and params.TYPE[t] == true) - or params[t] == true then - current.TYPE=t; - end - end - end - if dtd.props then - for _, p in ipairs(dtd.props) do - if params[p] then - if params[p] == true then - current[p]=true; - else - for _, prop in ipairs(params[p]) do - current[p]=prop; - end - end - end - end - end - if dtd == "text" or dtd.value then - t_insert(current, value); - elseif dtd.values then - for p in ("\30"..value):gmatch("\30([^\30]*)") do - t_insert(current, p); - end - end - current = up; - end - end - return vCards; -end - -local function item_to_text(item) - local value = {}; - for i=1,#item do - value[i] = vCard_esc(item[i]); - end - value = t_concat(value, ";"); - - local params = ""; - for k,v in pairs(item) do - if type(k) == "string" and k ~= "name" then - params = params .. (";%s=%s"):format(k, type(v) == "table" and t_concat(v,",") or v); - end - end - - return ("%s%s:%s"):format(item.name, params, value) -end - -local function vcard_to_text(vcard) - local t={}; - t_insert(t, "BEGIN:VCARD") - for i=1,#vcard do - t_insert(t, item_to_text(vcard[i])); - end - t_insert(t, "END:VCARD") - return t_concat(t, line_sep); -end - -function to_text(vCards) - if vCards[1] and vCards[1].name then - return vcard_to_text(vCards) - else - local t = {}; - for i=1,#vCards do - t[i]=vcard_to_text(vCards[i]); - end - return t_concat(t, line_sep); - end -end - -local function from_xep54_item(item) - local prop_name = item.name; - local prop_def = vCard_dtd[prop_name]; - - local prop = { name = prop_name }; - - if prop_def == "text" then - prop[1] = item:get_text(); - elseif type(prop_def) == "table" then - if prop_def.value then --single item - prop[1] = item:get_child_text(prop_def.value) or ""; - elseif prop_def.values then --array - local value_names = prop_def.values; - if value_names.behaviour == "repeat-last" then - for i=1,#item.tags do - t_insert(prop, item.tags[i]:get_text() or ""); - end - else - for i=1,#value_names do - t_insert(prop, item:get_child_text(value_names[i]) or ""); - end - end - elseif prop_def.names then - local names = prop_def.names; - for i=1,#names do - if item:get_child(names[i]) then - prop[1] = names[i]; - break; - end - end - end - - if prop_def.props_verbatim then - for k,v in pairs(prop_def.props_verbatim) do - prop[k] = v; - end - end - - if prop_def.types then - local types = prop_def.types; - prop.TYPE = {}; - for i=1,#types do - if item:get_child(types[i]) then - t_insert(prop.TYPE, types[i]:lower()); - end - end - if #prop.TYPE == 0 then - prop.TYPE = nil; - end - end - - -- A key-value pair, within a key-value pair? - if prop_def.props then - local params = prop_def.props; - for i=1,#params do - local name = params[i] - local data = item:get_child_text(name); - if data then - prop[name] = prop[name] or {}; - t_insert(prop[name], data); - end - end - end - else - return nil - end - - return prop; -end - -local function from_xep54_vCard(vCard) - local tags = vCard.tags; - local t = {}; - for i=1,#tags do - t_insert(t, from_xep54_item(tags[i])); - end - return t -end - -function from_xep54(vCard) - if vCard.attr.xmlns ~= "vcard-temp" then - return nil, "wrong-xmlns"; - end - if vCard.name == "xCard" then -- A collection of vCards - local t = {}; - local vCards = vCard.tags; - for i=1,#vCards do - t[i] = from_xep54_vCard(vCards[i]); - end - return t - elseif vCard.name == "vCard" then -- A single vCard - return from_xep54_vCard(vCard) - end -end - -local vcard4 = { } - -function vcard4:text(node, params, value) -- luacheck: ignore 212/params - self:tag(node:lower()) - -- FIXME params - if type(value) == "string" then - self:text_tag("text", value); - elseif vcard4[node] then - vcard4[node](value); - end - self:up(); -end - -function vcard4.N(value) - for i, k in ipairs(vCard_dtd.N.values) do - value:text_tag(k, value[i]); - end -end - -local xmlns_vcard4 = "urn:ietf:params:xml:ns:vcard-4.0" - -local function item_to_vcard4(item) - local typ = item.name:lower(); - local t = st.stanza(typ, { xmlns = xmlns_vcard4 }); - - local prop_def = vCard4_dtd[typ]; - if prop_def == "text" then - t:text_tag("text", item[1]); - elseif prop_def == "uri" then - if item.ENCODING and item.ENCODING[1] == 'b' then - t:text_tag("uri", "data:;base64," .. item[1]); - else - t:text_tag("uri", item[1]); - end - elseif type(prop_def) == "table" then - if prop_def.values then - for i, v in ipairs(prop_def.values) do - t:text_tag(v:lower(), item[i]); - end - else - t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"}) - end - else - t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"}) - end - return t; -end - -local function vcard_to_vcard4xml(vCard) - local t = st.stanza("vcard", { xmlns = xmlns_vcard4 }); - for i=1,#vCard do - t:add_child(item_to_vcard4(vCard[i])); - end - return t; -end - -local function vcards_to_vcard4xml(vCards) - if not vCards[1] or vCards[1].name then - return vcard_to_vcard4xml(vCards) - else - local t = st.stanza("vcards", { xmlns = xmlns_vcard4 }); - for i=1,#vCards do - t:add_child(vcard_to_vcard4xml(vCards[i])); - end - return t; - end -end - --- This was adapted from http://xmpp.org/extensions/xep-0054.html#dtd -vCard_dtd = { - VERSION = "text", --MUST be 3.0, so parsing is redundant - FN = "text", - N = { - values = { - "FAMILY", - "GIVEN", - "MIDDLE", - "PREFIX", - "SUFFIX", - }, - }, - NICKNAME = "text", - PHOTO = { - props_verbatim = { ENCODING = { "b" } }, - props = { "TYPE" }, - value = "BINVAL", --{ "EXTVAL", }, - }, - BDAY = "text", - ADR = { - types = { - "HOME", - "WORK", - "POSTAL", - "PARCEL", - "DOM", - "INTL", - "PREF", - }, - values = { - "POBOX", - "EXTADD", - "STREET", - "LOCALITY", - "REGION", - "PCODE", - "CTRY", - } - }, - LABEL = { - types = { - "HOME", - "WORK", - "POSTAL", - "PARCEL", - "DOM", - "INTL", - "PREF", - }, - value = "LINE", - }, - TEL = { - types = { - "HOME", - "WORK", - "VOICE", - "FAX", - "PAGER", - "MSG", - "CELL", - "VIDEO", - "BBS", - "MODEM", - "ISDN", - "PCS", - "PREF", - }, - value = "NUMBER", - }, - EMAIL = { - types = { - "HOME", - "WORK", - "INTERNET", - "PREF", - "X400", - }, - value = "USERID", - }, - JABBERID = "text", - MAILER = "text", - TZ = "text", - GEO = { - values = { - "LAT", - "LON", - }, - }, - TITLE = "text", - ROLE = "text", - LOGO = "copy of PHOTO", - AGENT = "text", - ORG = { - values = { - behaviour = "repeat-last", - "ORGNAME", - "ORGUNIT", - } - }, - CATEGORIES = { - values = "KEYWORD", - }, - NOTE = "text", - PRODID = "text", - REV = "text", - SORTSTRING = "text", - SOUND = "copy of PHOTO", - UID = "text", - URL = "text", - CLASS = { - names = { -- The item.name is the value if it's one of these. - "PUBLIC", - "PRIVATE", - "CONFIDENTIAL", - }, - }, - KEY = { - props = { "TYPE" }, - value = "CRED", - }, - DESC = "text", -}; -vCard_dtd.LOGO = vCard_dtd.PHOTO; -vCard_dtd.SOUND = vCard_dtd.PHOTO; - -vCard4_dtd = { - source = "uri", - kind = "text", - xml = "text", - fn = "text", - n = { - values = { - "family", - "given", - "middle", - "prefix", - "suffix", - }, - }, - nickname = "text", - photo = "uri", - bday = "date-and-or-time", - anniversary = "date-and-or-time", - gender = "text", - adr = { - values = { - "pobox", - "ext", - "street", - "locality", - "region", - "code", - "country", - } - }, - tel = "text", - email = "text", - impp = "uri", - lang = "language-tag", - tz = "text", - geo = "uri", - title = "text", - role = "text", - logo = "uri", - org = "text", - member = "uri", - related = "uri", - categories = "text", - note = "text", - prodid = "text", - rev = "timestamp", - sound = "uri", - uid = "uri", - clientpidmap = "number, uuid", - url = "uri", - version = "text", - key = "uri", - fburl = "uri", - caladruri = "uri", - caluri = "uri", -}; - -return { - from_text = from_text; - to_text = to_text; - - from_xep54 = from_xep54; - to_xep54 = to_xep54; - - to_vcard4 = vcards_to_vcard4xml; -}; diff --git a/util/watchdog.lua b/util/watchdog.lua index 516e60e4..70df2530 100644 --- a/util/watchdog.lua +++ b/util/watchdog.lua @@ -1,6 +1,5 @@ -local timer = require "util.timer"; +local timer = require "prosody.util.timer"; local setmetatable = setmetatable; -local os_time = os.time; local _ENV = nil; -- luacheck: std none @@ -9,27 +8,35 @@ local watchdog_methods = {}; local watchdog_mt = { __index = watchdog_methods }; local function new(timeout, callback) - local watchdog = setmetatable({ timeout = timeout, last_reset = os_time(), callback = callback }, watchdog_mt); - timer.add_task(timeout+1, function (current_time) - local last_reset = watchdog.last_reset; - if not last_reset then - return; - end - local time_left = (last_reset + timeout) - current_time; - if time_left < 0 then - return watchdog:callback(); - end - return time_left + 1; - end); + local watchdog = setmetatable({ + timeout = timeout; + callback = callback; + timer_id = nil; + }, watchdog_mt); + + watchdog:reset(); -- Kick things off + return watchdog; end -function watchdog_methods:reset() - self.last_reset = os_time(); +function watchdog_methods:reset(new_timeout) + if new_timeout then + self.timeout = new_timeout; + end + if self.timer_id then + timer.reschedule(self.timer_id, self.timeout+1); + else + self.timer_id = timer.add_task(self.timeout+1, function () + return self:callback(); + end); + end end function watchdog_methods:cancel() - self.last_reset = nil; + if self.timer_id then + timer.stop(self.timer_id); + self.timer_id = nil; + end end return { diff --git a/util/x509.lua b/util/x509.lua index 76b50076..9ecb5b60 100644 --- a/util/x509.lua +++ b/util/x509.lua @@ -11,19 +11,19 @@ -- IDN libraries complicate that. --- [TLS-CERTS] - http://tools.ietf.org/html/rfc6125 --- [XMPP-CORE] - http://tools.ietf.org/html/rfc6120 --- [SRV-ID] - http://tools.ietf.org/html/rfc4985 --- [IDNA] - http://tools.ietf.org/html/rfc5890 --- [LDAP] - http://tools.ietf.org/html/rfc4519 --- [PKIX] - http://tools.ietf.org/html/rfc5280 - -local nameprep = require "util.encodings".stringprep.nameprep; -local idna_to_ascii = require "util.encodings".idna.to_ascii; -local idna_to_unicode = require "util.encodings".idna.to_unicode; -local base64 = require "util.encodings".base64; -local log = require "util.logger".init("x509"); -local mt = require "util.multitable"; +-- [TLS-CERTS] - https://www.rfc-editor.org/rfc/rfc6125.html +-- [XMPP-CORE] - https://www.rfc-editor.org/rfc/rfc6120.html +-- [SRV-ID] - https://www.rfc-editor.org/rfc/rfc4985.html +-- [IDNA] - https://www.rfc-editor.org/rfc/rfc5890.html +-- [LDAP] - https://www.rfc-editor.org/rfc/rfc4519.html +-- [PKIX] - https://www.rfc-editor.org/rfc/rfc5280.html + +local nameprep = require "prosody.util.encodings".stringprep.nameprep; +local idna_to_ascii = require "prosody.util.encodings".idna.to_ascii; +local idna_to_unicode = require "prosody.util.encodings".idna.to_unicode; +local base64 = require "prosody.util.encodings".base64; +local log = require "prosody.util.logger".init("x509"); +local mt = require "prosody.util.multitable"; local s_format = string.format; local ipairs = ipairs; diff --git a/util/xml.lua b/util/xml.lua index 2bf1ff4e..e23916be 100644 --- a/util/xml.lua +++ b/util/xml.lua @@ -1,5 +1,5 @@ -local st = require "util.stanza"; +local st = require "prosody.util.stanza"; local lxp = require "lxp"; local t_insert = table.insert; local t_remove = table.remove; diff --git a/util/xmppstream.lua b/util/xmppstream.lua index be113396..7f998312 100644 --- a/util/xmppstream.lua +++ b/util/xmppstream.lua @@ -7,7 +7,7 @@ -- local lxp = require "lxp"; -local st = require "util.stanza"; +local st = require "prosody.util.stanza"; local stanza_mt = st.stanza_mt; local error = error; diff --git a/util/xpcall.lua b/util/xpcall.lua index d2fc5011..852df03c 100644 --- a/util/xpcall.lua +++ b/util/xpcall.lua @@ -1,7 +1,7 @@ local xpcall = xpcall; if select(2, xpcall(function (x) return x end, function () end, "test")) ~= "test" then - xpcall = require"util.compat".xpcall; + xpcall = require"prosody.util.compat".xpcall; end return { diff --git a/util/xtemplate.lua b/util/xtemplate.lua index 88baf1f7..446d7d1f 100644 --- a/util/xtemplate.lua +++ b/util/xtemplate.lua @@ -3,13 +3,21 @@ local s_match = string.match; local s_sub = string.sub; local t_concat = table.concat; -local st = require("util.stanza"); +local st = require("prosody.util.stanza"); local function render(template, root, escape, filters) escape = escape or st.xml_escape; - return (s_gsub(template, "%b{}", function(block) + return (s_gsub(template, "(%s*)(%b{})(%s*)", function(pre_blank, block, post_blank) local inner = s_sub(block, 2, -2); + if inner:sub(1, 1) == "-" then + pre_blank = ""; + inner = inner:sub(2); + end + if inner:sub(-1, -1) == "-" then + post_blank = ""; + inner = inner:sub(1, -2); + end local path, pipe, pos = s_match(inner, "^([^|]+)(|?)()"); if not (type(path) == "string") then return end local value @@ -32,7 +40,7 @@ local function render(template, root, escape, filters) if args then args = s_sub(args, 2, -2); end if func == "each" and tmpl then - if not st.is_stanza(value) then return "" end + if not st.is_stanza(value) then return pre_blank .. post_blank end if not args then value, args = root, path; end local ns, name = s_match(args, "^(%b{})(.*)$"); if ns then @@ -62,11 +70,7 @@ local function render(template, root, escape, filters) end elseif filters and filters[func] then local f = filters[func]; - if args == nil then - value, is_escaped = f(value, tmpl); - else - value, is_escaped = f(args, value, tmpl); - end + value, is_escaped = f(value, args, tmpl); else error("No such filter function: " .. func); end @@ -75,12 +79,12 @@ local function render(template, root, escape, filters) if type(value) == "string" then if not is_escaped then value = escape(value); end - return value + return pre_blank .. value .. post_blank elseif st.is_stanza(value) then value = value:get_text(); - if value then return escape(value) end + if value then return pre_blank .. escape(value) .. post_blank end end - return "" + return pre_blank .. post_blank end)) end |