diff options
Diffstat (limited to 'util')
72 files changed, 6066 insertions, 1175 deletions
diff --git a/util/adhoc.lua b/util/adhoc.lua index d81b8242..a0ad52bb 100644 --- a/util/adhoc.lua +++ b/util/adhoc.lua @@ -2,7 +2,7 @@ local function new_simple_form(form, result_handler) return function(self, data, state) - if state then + if state or data.form then if data.action == "cancel" then return { status = "canceled" }; end @@ -16,15 +16,21 @@ end local function new_initial_data_form(form, initial_data, result_handler) return function(self, data, state) - if state then + if state or data.form then if data.action == "cancel" then return { status = "canceled" }; end local fields, err = form:data(data.form); return result_handler(fields, err, data); else + local values, err = initial_data(data); + if type(err) == "table" then + return {status = "error"; error = err} + elseif type(err) == "string" then + return {status = "error"; error = {type = "cancel"; condition = "internal-server-error", err}} + end return { status = "executing", actions = {"next", "complete", default = "complete"}, - form = { layout = form, values = initial_data(data) } }, "executing"; + form = { layout = form, values = values } }, "executing"; end end end diff --git a/util/adminstream.lua b/util/adminstream.lua new file mode 100644 index 00000000..4075aa05 --- /dev/null +++ b/util/adminstream.lua @@ -0,0 +1,346 @@ +local st = require "util.stanza"; +local new_xmpp_stream = require "util.xmppstream".new; +local sessionlib = require "util.session"; +local gettime = require "util.time".now; +local runner = require "util.async".runner; +local add_task = require "util.timer".add_task; +local events = require "util.events"; +local server = require "net.server"; + +local stream_close_timeout = 5; + +local log = require "util.logger".init("adminstream"); + +local xmlns_xmpp_streams = "urn:ietf:params:xml:ns:xmpp-streams"; + +local stream_callbacks = { default_ns = "xmpp:prosody.im/admin" }; + +function stream_callbacks.streamopened(session, attr) + -- run _streamopened in async context + session.thread:run({ stream = "opened", attr = attr }); +end + +function stream_callbacks._streamopened(session, attr) --luacheck: ignore 212/attr + if session.type ~= "client" then + session:open_stream(); + end + session.notopen = nil; +end + +function stream_callbacks.streamclosed(session, attr) + -- run _streamclosed in async context + session.thread:run({ stream = "closed", attr = attr }); +end + +function stream_callbacks._streamclosed(session) + session.log("debug", "Received </stream:stream>"); + session:close(false); +end + +function stream_callbacks.error(session, error, data) + if error == "no-stream" then + session.log("debug", "Invalid opening stream header (%s)", (data:gsub("^([^\1]+)\1", "{%1}"))); + session:close("invalid-namespace"); + elseif error == "parse-error" then + session.log("debug", "Client XML parse error: %s", data); + session:close("not-well-formed"); + elseif error == "stream-error" then + local condition, text = "undefined-condition"; + for child in data:childtags(nil, xmlns_xmpp_streams) do + if child.name ~= "text" then + condition = child.name; + else + text = child:get_text(); + end + if condition ~= "undefined-condition" and text then + break; + end + end + text = condition .. (text and (" ("..text..")") or ""); + session.log("info", "Session closed by remote with error: %s", text); + session:close(nil, text); + end +end + +function stream_callbacks.handlestanza(session, stanza) + session.thread:run(stanza); +end + +local runner_callbacks = {}; + +function runner_callbacks:error(err) + self.data.log("error", "Traceback[c2s]: %s", err); +end + +local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; + +local function destroy_session(session, reason) + if session.destroyed then return; end + session.destroyed = true; + session.log("debug", "Destroying session: %s", reason or "unknown reason"); +end + +local function session_close(session, reason) + local log = session.log or log; + if session.conn then + if session.notopen then + session:open_stream(); + end + if reason then -- nil == no err, initiated by us, false == initiated by client + local stream_error = st.stanza("stream:error"); + if type(reason) == "string" then -- assume stream error + stream_error:tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' }); + elseif type(reason) == "table" then + if reason.condition then + stream_error:tag(reason.condition, stream_xmlns_attr):up(); + if reason.text then + stream_error:tag("text", stream_xmlns_attr):text(reason.text):up(); + end + if reason.extra then + stream_error:add_child(reason.extra); + end + elseif reason.name then -- a stanza + stream_error = reason; + end + end + stream_error = tostring(stream_error); + log("debug", "Disconnecting client, <stream:error> is: %s", stream_error); + session.send(stream_error); + end + + session.send("</stream:stream>"); + function session.send() return false; end + + local reason_text = (reason and (reason.name or reason.text or reason.condition)) or reason; + session.log("debug", "c2s stream for %s closed: %s", session.full_jid or session.ip or "<unknown>", reason_text or "session closed"); + + -- Authenticated incoming stream may still be sending us stanzas, so wait for </stream:stream> from remote + local conn = session.conn; + if reason_text == nil and not session.notopen and session.type == "c2s" then + -- Grace time to process data from authenticated cleanly-closed stream + add_task(stream_close_timeout, function () + if not session.destroyed then + session.log("warn", "Failed to receive a stream close response, closing connection anyway..."); + destroy_session(session); + conn:close(); + end + end); + else + destroy_session(session, reason_text); + conn:close(); + end + else + local reason_text = (reason and (reason.name or reason.text or reason.condition)) or reason; + destroy_session(session, reason_text); + end +end + +--- Public methods + +local function new_connection(socket_path, listeners) + local have_unix, unix = pcall(require, "socket.unix"); + if have_unix and type(unix) == "function" then + -- COMPAT #1717 + -- Before the introduction of datagram support, only the stream socket + -- constructor was exported instead of a module table. Due to the lack of a + -- proper release of LuaSocket, distros have settled on shipping either the + -- last RC tag or some commit since then. + -- Here we accomodate both variants. + unix = { stream = unix }; + end + if type(unix) ~= "table" then + have_unix = false; + end + local conn, sock; + + return { + connect = function () + if not have_unix then + return nil, "no unix socket support"; + end + if sock or conn then + return nil, "already connected"; + end + sock = unix.stream(); + sock:settimeout(0); + local ok, err = sock:connect(socket_path); + if not ok then + return nil, err; + end + conn = server.wrapclient(sock, nil, nil, listeners, "*a"); + return true; + end; + disconnect = function () + if conn then + conn:close(); + conn = nil; + end + if sock then + sock:close(); + sock = nil; + end + return true; + end; + }; +end + +local function new_server(sessions, stanza_handler) + local listeners = {}; + + function listeners.onconnect(conn) + log("debug", "New connection"); + local session = sessionlib.new("admin"); + sessionlib.set_id(session); + sessionlib.set_logger(session); + sessionlib.set_conn(session, conn); + + session.conntime = gettime(); + session.type = "admin"; + + local stream = new_xmpp_stream(session, stream_callbacks); + session.stream = stream; + session.notopen = true; + + session.thread = runner(function (stanza) + if st.is_stanza(stanza) then + stanza_handler(session, stanza); + elseif stanza.stream == "opened" then + stream_callbacks._streamopened(session, stanza.attr); + elseif stanza.stream == "closed" then + stream_callbacks._streamclosed(session, stanza.attr); + end + end, runner_callbacks, session); + + function session.data(data) + -- Parse the data, which will store stanzas in session.pending_stanzas + if data then + local ok, err = stream:feed(data); + if not ok then + session.log("debug", "Received invalid XML (%s) %d bytes: %q", err, #data, data:sub(1, 300)); + session:close("not-well-formed"); + end + end + end + + session.close = session_close; + + session.send = function (t) + session.log("debug", "Sending[%s]: %s", session.type, t.top_tag and t:top_tag() or t:match("^[^>]*>?")); + return session.rawsend(tostring(t)); + end + + function session.rawsend(t) + local ret, err = conn:write(t); + if not ret then + session.log("debug", "Error writing to connection: %s", err); + return false, err; + end + return true; + end + + sessions[conn] = session; + end + + function listeners.onincoming(conn, data) + local session = sessions[conn]; + if session then + session.data(data); + end + end + + function listeners.ondisconnect(conn, err) + local session = sessions[conn]; + if session then + session.log("info", "Admin client disconnected: %s", err or "connection closed"); + session.conn = nil; + sessions[conn] = nil; + end + end + + function listeners.onreadtimeout(conn) + return conn:send(" "); + end + + return { + listeners = listeners; + }; +end + +local function new_client() + local client = { + type = "client"; + events = events.new(); + log = log; + }; + + local listeners = {}; + + function listeners.onconnect(conn) + log("debug", "Connected"); + client.conn = conn; + + local stream = new_xmpp_stream(client, stream_callbacks); + client.stream = stream; + client.notopen = true; + + client.thread = runner(function (stanza) + if st.is_stanza(stanza) then + if not client.events.fire_event("received", stanza) and not stanza.attr.xmlns then + client.events.fire_event("received/"..stanza.name, stanza); + end + elseif stanza.stream == "opened" then + stream_callbacks._streamopened(client, stanza.attr); + client.events.fire_event("connected"); + elseif stanza.stream == "closed" then + client.events.fire_event("disconnected"); + stream_callbacks._streamclosed(client, stanza.attr); + end + end, runner_callbacks, client); + + client.close = session_close; + + function client.send(t) + client.log("debug", "Sending: %s", t.top_tag and t:top_tag() or t:match("^[^>]*>?")); + return client.rawsend(tostring(t)); + end + + function client.rawsend(t) + local ret, err = conn:write(t); + if not ret then + client.log("debug", "Error writing to connection: %s", err); + return false, err; + end + return true; + end + client.log("debug", "Opening stream..."); + client:open_stream(); + end + + function listeners.onincoming(conn, data) --luacheck: ignore 212/conn + local ok, err = client.stream:feed(data); + if not ok then + client.log("debug", "Received invalid XML (%s) %d bytes: %q", err, #data, data:sub(1, 300)); + client:close("not-well-formed"); + end + end + + function listeners.ondisconnect(conn, err) --luacheck: ignore 212/conn + client.log("info", "Admin client disconnected: %s", err or "connection closed"); + client.conn = nil; + client.events.fire_event("disconnected"); + end + + function listeners.onreadtimeout(conn) + conn:send(" "); + end + + client.listeners = listeners; + + return client; +end + +return { + connection = new_connection; + server = new_server; + client = new_client; +}; diff --git a/util/argparse.lua b/util/argparse.lua new file mode 100644 index 00000000..9ece050a --- /dev/null +++ b/util/argparse.lua @@ -0,0 +1,58 @@ +local function parse(arg, config) + local short_params = config and config.short_params or {}; + local value_params = config and config.value_params or {}; + + local parsed_opts = {}; + + if #arg == 0 then + return parsed_opts; + end + while true do + local raw_param = arg[1]; + if not raw_param then + break; + end + + local prefix = raw_param:match("^%-%-?"); + if not prefix then + break; + elseif prefix == "--" and raw_param == "--" then + table.remove(arg, 1); + break; + end + local param = table.remove(arg, 1):sub(#prefix+1); + if #param == 1 and short_params then + param = short_params[param]; + end + + if not param then + return nil, "param-not-found", raw_param; + end + + local param_k, param_v; + if value_params[param] then + param_k, param_v = param, table.remove(arg, 1); + if not param_v then + return nil, "missing-value", raw_param; + end + else + param_k, param_v = param:match("^([^=]+)=(.+)$"); + if not param_k then + if param:match("^no%-") then + param_k, param_v = param:sub(4), false; + else + param_k, param_v = param, true; + end + end + end + parsed_opts[param_k] = param_v; + end + for i = 1, #arg do + parsed_opts[i] = arg[i]; + end + return parsed_opts; +end + +return { + parse = parse; +} diff --git a/util/array.lua b/util/array.lua index 0b60a4fd..c33a5ef1 100644 --- a/util/array.lua +++ b/util/array.lua @@ -10,6 +10,7 @@ local t_insert, t_sort, t_remove, t_concat = table.insert, table.sort, table.remove, table.concat; local setmetatable = setmetatable; +local getmetatable = getmetatable; local math_random = math.random; local math_floor = math.floor; local pairs, ipairs = pairs, ipairs; @@ -40,6 +41,10 @@ function array_mt.__add(a1, a2) end function array_mt.__eq(a, b) + if getmetatable(a) ~= array_mt or getmetatable(b) ~= array_mt then + -- Lua 5.3+ calls this if both operands are tables, even if metatables differ + return false; + end if #a == #b then for i = 1, #a do if a[i] ~= b[i] then @@ -109,6 +114,40 @@ function array_base.filter(outa, ina, func) return outa; end +function array_base.slice(outa, ina, i, j) + if j == nil then + j = -1; + end + if j < 0 then + j = #ina + (j+1); + end + if i < 0 then + i = #ina + (i+1); + end + if i < 1 then + i = 1; + end + if j > #ina then + j = #ina; + end + if i > j then + for idx = 1, #outa do + outa[idx] = nil; + end + return outa; + end + + for idx = 1, 1+j-i do + outa[idx] = ina[i+(idx-1)]; + end + if ina == outa then + for idx = 2+j-i, #outa do + outa[idx] = nil; + end + end + return outa; +end + function array_base.sort(outa, ina, ...) if ina ~= outa then outa:append(ina); @@ -129,9 +168,13 @@ function array_base.unique(outa, ina) end); end -function array_base.pluck(outa, ina, key) +function array_base.pluck(outa, ina, key, default) for i = 1, #ina do - outa[i] = ina[i][key]; + local v = ina[i][key]; + if v == nil then + v = default; + end + outa[i] = v; end return outa; end diff --git a/util/async.lua b/util/async.lua index 20397785..2830238f 100644 --- a/util/async.lua +++ b/util/async.lua @@ -11,6 +11,12 @@ local function checkthread() return thread; end +-- Configurable functions +local schedule_task = nil; -- schedule_task(seconds, callback) +local next_tick = function (f) + f(); +end + local function runner_from_thread(thread) local level = 0; -- Find the 'level' of the top-most function (0 == current level, 1 == caller, ...) @@ -53,19 +59,21 @@ local function runner_continue(thread) return false; end call_watcher(runner, "error", debug.traceback(thread, err)); - runner.state, runner.thread = "ready", nil; + runner.state = "ready"; return runner:run(); elseif state == "ready" then -- If state is 'ready', it is our responsibility to update runner.state from 'waiting'. -- We also have to :run(), because the queue might have further items that will not be -- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer). - runner.state = "ready"; - runner:run(); + next_tick(function () + runner.state = "ready"; + runner:run(); + end); end return true; end -local function waiter(num) +local function waiter(num, allow_many) local thread = checkthread(); num = num or 1; local waiting; @@ -77,7 +85,7 @@ local function waiter(num) num = num - 1; if num == 0 and waiting then runner_continue(thread); - elseif num < 0 then + elseif not allow_many and num < 0 then error("done() called too many times"); end end; @@ -118,6 +126,15 @@ local function guarder() end; end +local function sleep(seconds) + if not schedule_task then + error("async.sleep() is not available - configure schedule function"); + end + local wait, done = waiter(); + schedule_task(seconds, done); + wait(); +end + local runner_mt = {}; runner_mt.__index = runner_mt; @@ -159,6 +176,10 @@ function runner_mt:run(input) local q, thread = self.queue, self.thread; if not thread or coroutine.status(thread) == "dead" then + --luacheck: ignore 143/coroutine + if thread and coroutine.close then + coroutine.close(thread); + end self:log("debug", "creating new coroutine"); -- Create a new coroutine for this runner thread = runner_create_thread(self.func, self); @@ -246,9 +267,30 @@ local function ready() return pcall(checkthread); end +local function wait_for(promise) + local async_wait, async_done = waiter(); + local ret, err = nil, nil; + promise:next( + function (r) ret = r; end, + function (e) err = e; end) + :finally(async_done); + async_wait(); + if ret then + return ret; + else + return nil, err; + end +end + return { ready = ready; waiter = waiter; guarder = guarder; runner = runner; + wait = wait_for; -- COMPAT w/trunk pre-0.12 + wait_for = wait_for; + sleep = sleep; + + set_nexttick = function(new_next_tick) next_tick = new_next_tick; end; + set_schedule_function = function (new_schedule_function) schedule_task = new_schedule_function; end; }; diff --git a/util/bit53.lua b/util/bit53.lua new file mode 100644 index 00000000..b5c473a3 --- /dev/null +++ b/util/bit53.lua @@ -0,0 +1,33 @@ +-- Only the operators needed by net.websocket.frames are provided at this point +return { + band = function (a, b, ...) + local ret = a & b; + if ... then + for i = 1, select("#", ...) do + ret = ret & (select(i, ...)); + end + end + return ret; + end; + bor = function (a, b, ...) + local ret = a | b; + if ... then + for i = 1, select("#", ...) do + ret = ret | (select(i, ...)); + end + end + return ret; + end; + bxor = function (a, b, ...) + local ret = a ~ b; + if ... then + for i = 1, select("#", ...) do + ret = ret ~ (select(i, ...)); + end + end + return ret; + end; + rshift = function (a, n) return a >> n end; + lshift = function (a, n) return a << n end; +}; + diff --git a/util/bitcompat.lua b/util/bitcompat.lua new file mode 100644 index 00000000..454181af --- /dev/null +++ b/util/bitcompat.lua @@ -0,0 +1,32 @@ +-- Compatibility layer for bitwise operations + +-- First try the bit32 lib +-- Lua 5.3 has it with compat enabled +-- Lua 5.2 has it by default +if _G.bit32 then + return _G.bit32; +else + -- Lua 5.1 may have it as a standalone module that can be installed + local ok, bitop = pcall(require, "bit32") + if ok then + return bitop; + end +end + +do + -- Lua 5.3 and 5.4 would be able to use native infix operators + local ok, bitop = pcall(require, "util.bit53") + if ok then + return bitop; + end +end + +do + -- Lastly, try the LuaJIT bitop library + local ok, bitop = pcall(require, "bit") + if ok then + return bitop; + end +end + +error "No bit module found. See https://prosody.im/doc/depends#bitop"; diff --git a/util/cache.lua b/util/cache.lua index a5fd5e6d..cd1b4544 100644 --- a/util/cache.lua +++ b/util/cache.lua @@ -28,7 +28,7 @@ local function _insert(list, m) end local cache_methods = {}; -local cache_mt = { __index = cache_methods }; +local cache_mt = { __name = "cache", __index = cache_methods }; function cache_methods:set(k, v) local m = self._data[k]; diff --git a/util/dataforms.lua b/util/dataforms.lua index 052d6a55..66733895 100644 --- a/util/dataforms.lua +++ b/util/dataforms.lua @@ -10,9 +10,11 @@ local setmetatable = setmetatable; local ipairs = ipairs; local type, next = type, next; local tonumber = tonumber; +local tostring = tostring; local t_concat = table.concat; local st = require "util.stanza"; local jid_prep = require "util.jid".prep; +local datetime = require "util.datetime"; local _ENV = nil; -- luacheck: std none @@ -54,6 +56,12 @@ function form_t.form(layout, data, formtype) if formtype == "form" and field.datatype then form:tag("validate", { xmlns = xmlns_validate, datatype = field.datatype }); + if field.range_min or field.range_max then + form:tag("range", { + min = field.range_min and tostring(field.range_min), + max = field.range_max and tostring(field.range_max), + }):up(); + end -- <basic/> assumed form:up(); end @@ -95,8 +103,15 @@ function form_t.form(layout, data, formtype) if value ~= nil then if type(value) == "number" then - -- TODO validate that this is ok somehow, eg check field.datatype - value = ("%g"):format(value); + if field.datatype == "xs:dateTime" then + value = datetime.datetime(value); + elseif field_type == "boolean" then + value = value ~= 0; + elseif field.datatype == "xs:double" or field.datatype == "xs:decimal" then + value = ("%f"):format(value); + else + value = ("%d"):format(value); + end end -- Add value, depending on type if field_type == "hidden" then @@ -136,7 +151,7 @@ function form_t.form(layout, data, formtype) local media = field.media; if media then - form:tag("media", { xmlns = "urn:xmpp:media-element", height = media.height, width = media.width }); + form:tag("media", { xmlns = "urn:xmpp:media-element", height = ("%d"):format(media.height), width = ("%d"):format(media.width) }); for _, val in ipairs(media) do form:tag("uri", { type = val.type }):text(val.uri):up() end @@ -290,13 +305,34 @@ field_readers["hidden"] = end data_validators["xs:integer"] = - function (data) + function (data, field) local n = tonumber(data); if not n then return false, "not a number"; elseif n % 1 ~= 0 then return false, "not an integer"; end + if field.range_max and n > field.range_max then + return false, "out of bounds"; + elseif field.range_min and n < field.range_min then + return false, "out of bounds"; + end + return true, n; + end + +data_validators["pubsub:integer-or-max"] = + function (data, field) + if data == "max" then + return true, data; + else + return data_validators["xs:integer"](data, field); + end + end + +data_validators["xs:dateTime"] = + function(data, field) -- luacheck: ignore 212/field + local n = datetime.parse(data); + if not n then return false, "invalid timestamp"; end return true, n; end diff --git a/util/datamanager.lua b/util/datamanager.lua index 0d7060b7..c57f4a0e 100644 --- a/util/datamanager.lua +++ b/util/datamanager.lua @@ -24,7 +24,7 @@ local t_concat = table.concat; local envloadfile = require"util.envload".envloadfile; local serialize = require "util.serialization".serialize; local lfs = require "lfs"; --- Extract directory seperator from package.config (an undocumented string that comes with lua) +-- Extract directory separator from package.config (an undocumented string that comes with lua) local path_separator = assert ( package.config:match ( "^([^\n]+)" ) , "package.config not in standard form" ) local prosody = prosody; @@ -157,7 +157,8 @@ end local function atomic_store(filename, data) local scratch = filename.."~"; - local f, ok, msg, errno; + local f, ok, msg, errno; -- luacheck: ignore errno + -- TODO return util.error with code=errno? f, msg, errno = io_open(scratch, "w"); if not f then @@ -221,7 +222,7 @@ local function store(username, host, datastore, data) os_remove(getpath(username, host, datastore)); end -- we write data even when we are deleting because lua doesn't have a - -- platform independent way of checking for non-exisitng files + -- platform independent way of checking for nonexisting files until ok; return true; end @@ -289,7 +290,7 @@ local function list_store(username, host, datastore, data) os_remove(getpath(username, host, datastore, "list")); end -- we write data even when we are deleting because lua doesn't have a - -- platform independent way of checking for non-exisitng files + -- platform independent way of checking for nonexisting files return true; end @@ -319,7 +320,7 @@ local type_map = { } local function users(host, store, typ) -- luacheck: ignore 431/store - typ = type_map[typ or "keyval"]; + typ = "."..(type_map[typ or "keyval"] or typ); local store_dir = format("%s/%s/%s", data_path, encode(host), store_encode(store)); local mode, err = lfs.attributes(store_dir, "mode"); @@ -329,9 +330,8 @@ local function users(host, store, typ) -- luacheck: ignore 431/store local next, state = lfs.dir(store_dir); -- luacheck: ignore 431/next 431/state return function(state) -- luacheck: ignore 431/state for node in next, state do - local file, ext = node:match("^(.*)%.([dalist]+)$"); - if file and ext == typ then - return decode(file); + if node:sub(-#typ, -1) == typ then + return decode(node:sub(1, -#typ-1)); end end end, state; @@ -343,7 +343,7 @@ local function stores(username, host, typ) local mode, err = lfs.attributes(store_dir, "mode"); if not mode then - return function() log("debug", err or (store_dir .. " does not exist")) end + return function() log("debug", "Could not iterate over stores in %s: %s", store_dir, err); end end local next, state = lfs.dir(store_dir); -- luacheck: ignore 431/next 431/state return function(state) -- luacheck: ignore 431/state diff --git a/util/datamapper.lua b/util/datamapper.lua new file mode 100644 index 00000000..2378314c --- /dev/null +++ b/util/datamapper.lua @@ -0,0 +1,349 @@ +-- This file is generated from teal-src/util/datamapper.lua + +local st = require("util.stanza"); +local pointer = require("util.jsonpointer"); + +local schema_t = {} + +local function toboolean(s) + if s == "true" or s == "1" then + return true + elseif s == "false" or s == "0" then + return false + elseif s then + return true + end +end + +local function totype(t, s) + if not s then + return nil + end + if t == "string" then + return s + elseif t == "boolean" then + return toboolean(s) + elseif t == "number" or t == "integer" then + return tonumber(s) + end +end + +local value_goes = {} + +local function resolve_schema(schema, root) + if type(schema) == "table" then + if schema["$ref"] and schema["$ref"]:sub(1, 1) == "#" then + return pointer.resolve(root, schema["$ref"]:sub(2)) + end + end + return schema +end + +local function guess_schema_type(schema) + local schema_types = schema.type + if type(schema_types) == "string" then + return schema_types + elseif schema_types ~= nil then + error("schema has unsupported 'type' property") + elseif schema.properties then + return "object" + elseif schema.items then + return "array" + end + return "string" +end + +local function unpack_propschema(propschema, propname, current_ns) + + local proptype = "string" + local value_where = propname and "in_text_tag" or "in_text" + local name = propname + local namespace + local prefix + local single_attribute + local enums + + if type(propschema) == "table" then + proptype = guess_schema_type(propschema); + elseif type(propschema) == "string" then + error("schema as string is not supported: " .. propschema .. " {" .. current_ns .. "}" .. propname) + end + + if proptype == "object" or proptype == "array" then + value_where = "in_children" + end + + if type(propschema) == "table" then + local xml = propschema.xml + if xml then + if xml.name then + name = xml.name + end + if xml.namespace and xml.namespace ~= current_ns then + namespace = xml.namespace + end + if xml.prefix then + prefix = xml.prefix + end + if proptype == "array" and xml.wrapped then + value_where = "in_wrapper" + elseif xml.attribute then + value_where = "in_attribute" + elseif xml.text then + value_where = "in_text" + elseif xml.x_name_is_value then + value_where = "in_tag_name" + elseif xml.x_single_attribute then + single_attribute = xml.x_single_attribute + value_where = "in_single_attribute" + end + end + if propschema["const"] then + enums = {propschema["const"]} + elseif propschema["enum"] then + enums = propschema["enum"] + end + end + + return proptype, value_where, name, namespace, prefix, single_attribute, enums +end + +local parse_object +local parse_array + +local function extract_value(s, value_where, proptype, name, namespace, prefix, single_attribute, enums) + if value_where == "in_tag_name" then + local c + if proptype == "boolean" then + c = s:get_child(name, namespace); + elseif enums and proptype == "string" then + + for i = 1, #enums do + c = s:get_child(enums[i], namespace); + if c then + break + end + end + else + c = s:get_child(nil, namespace); + end + if c then + return c.name + end + elseif value_where == "in_attribute" then + local attr = name + if prefix then + attr = prefix .. ":" .. name + elseif namespace and namespace ~= s.attr.xmlns then + attr = namespace .. "\1" .. name + end + return s.attr[attr] + + elseif value_where == "in_text" then + return s:get_text() + + elseif value_where == "in_single_attribute" then + local c = s:get_child(name, namespace) + return c and c.attr[single_attribute] + elseif value_where == "in_text_tag" then + return s:get_child_text(name, namespace) + end +end + +function parse_object(schema, s, root) + local out = {} + schema = resolve_schema(schema, root) + if type(schema) == "table" and schema.properties then + for prop, propschema in pairs(schema.properties) do + propschema = resolve_schema(propschema, root) + + local proptype, value_where, name, namespace, prefix, single_attribute, enums = unpack_propschema(propschema, prop, s.attr.xmlns) + + if value_where == "in_children" and type(propschema) == "table" then + if proptype == "object" then + local c = s:get_child(name, namespace) + if c then + out[prop] = parse_object(propschema, c, root); + end + elseif proptype == "array" then + local a = parse_array(propschema, s, root); + if a and a[1] ~= nil then + out[prop] = a; + end + else + error("unreachable") + end + elseif value_where == "in_wrapper" and type(propschema) == "table" and proptype == "array" then + local wrapper = s:get_child(name, namespace); + if wrapper then + out[prop] = parse_array(propschema, wrapper, root); + end + else + local value = extract_value(s, value_where, proptype, name, namespace, prefix, single_attribute, enums) + + out[prop] = totype(proptype, value) + end + end + end + + return out +end + +function parse_array(schema, s, root) + local itemschema = resolve_schema(schema.items, root); + local proptype, value_where, child_name, namespace, prefix, single_attribute, enums = unpack_propschema(itemschema, nil, s.attr.xmlns) + local attr_name + if value_where == "in_single_attribute" then + value_where = "in_attribute"; + attr_name = single_attribute; + end + local out = {} + + if proptype == "object" then + if type(itemschema) == "table" then + for c in s:childtags(child_name, namespace) do + table.insert(out, parse_object(itemschema, c, root)); + end + else + error("array items must be schema object") + end + elseif proptype == "array" then + if type(itemschema) == "table" then + for c in s:childtags(child_name, namespace) do + table.insert(out, parse_array(itemschema, c, root)); + end + end + else + for c in s:childtags(child_name, namespace) do + local value = extract_value(c, value_where, proptype, attr_name or child_name, namespace, prefix, single_attribute, enums) + + table.insert(out, totype(proptype, value)); + end + end + return out +end + +local function parse(schema, s) + local s_type = guess_schema_type(schema) + if s_type == "object" then + return parse_object(schema, s, schema) + elseif s_type == "array" then + return parse_array(schema, s, schema) + else + error("top-level scalars unsupported") + end +end + +local function toxmlstring(proptype, v) + if proptype == "string" and type(v) == "string" then + return v + elseif proptype == "number" and type(v) == "number" then + return string.format("%g", v) + elseif proptype == "integer" and type(v) == "number" then + return string.format("%d", v) + elseif proptype == "boolean" then + return v and "1" or "0" + end +end + +local unparse + +local function unparse_property(out, v, proptype, propschema, value_where, name, namespace, current_ns, prefix, + single_attribute, root) + + if value_where == "in_attribute" then + local attr = name + if prefix then + attr = prefix .. ":" .. name + elseif namespace and namespace ~= current_ns then + attr = namespace .. "\1" .. name + end + + out.attr[attr] = toxmlstring(proptype, v) + elseif value_where == "in_text" then + out:text(toxmlstring(proptype, v)) + elseif value_where == "in_single_attribute" then + assert(single_attribute) + local propattr = {} + + if namespace and namespace ~= current_ns then + propattr.xmlns = namespace + end + + propattr[single_attribute] = toxmlstring(proptype, v) + out:tag(name, propattr):up(); + + else + local propattr + if namespace ~= current_ns then + propattr = {xmlns = namespace} + end + if value_where == "in_tag_name" then + if proptype == "string" and type(v) == "string" then + out:tag(v, propattr):up(); + elseif proptype == "boolean" and v == true then + out:tag(name, propattr):up(); + end + elseif proptype == "object" and type(propschema) == "table" and type(v) == "table" then + local c = unparse(propschema, v, name, namespace, nil, root); + if c then + out:add_direct_child(c); + end + elseif proptype == "array" and type(propschema) == "table" and type(v) == "table" then + if value_where == "in_wrapper" then + local c = unparse(propschema, v, name, namespace, nil, root); + if c then + out:add_direct_child(c); + end + else + unparse(propschema, v, name, namespace, out, root); + end + else + out:text_tag(name, toxmlstring(proptype, v), propattr) + end + end +end + +function unparse(schema, t, current_name, current_ns, ctx, root) + + if root == nil then + root = schema + end + + if schema.xml then + if schema.xml.name then + current_name = schema.xml.name + end + if schema.xml.namespace then + current_ns = schema.xml.namespace + end + + end + + local out = ctx or st.stanza(current_name, {xmlns = current_ns}) + + local s_type = guess_schema_type(schema) + if s_type == "object" then + + for prop, propschema in pairs(schema.properties) do + propschema = resolve_schema(propschema, root) + local v = t[prop] + + if v ~= nil then + local proptype, value_where, name, namespace, prefix, single_attribute = unpack_propschema(propschema, prop, current_ns) + unparse_property(out, v, proptype, propschema, value_where, name, namespace, current_ns, prefix, single_attribute, root) + end + end + return out + + elseif s_type == "array" then + local itemschema = resolve_schema(schema.items, root) + local proptype, value_where, name, namespace, prefix, single_attribute = unpack_propschema(itemschema, current_name, current_ns) + for _, item in ipairs(t) do + unparse_property(out, item, proptype, itemschema, value_where, name, namespace, current_ns, prefix, single_attribute, root) + end + return out + end +end + +return {parse = parse; unparse = unparse} diff --git a/util/dbuffer.lua b/util/dbuffer.lua index 640c1449..3ad5fdfe 100644 --- a/util/dbuffer.lua +++ b/util/dbuffer.lua @@ -2,7 +2,7 @@ local queue = require "util.queue"; local s_byte, s_sub = string.byte, string.sub; local dbuffer_methods = {}; -local dynamic_buffer_mt = { __index = dbuffer_methods }; +local dynamic_buffer_mt = { __name = "dbuffer", __index = dbuffer_methods }; function dbuffer_methods:write(data) if self.max_size and #data + self._length > self.max_size then @@ -76,6 +76,20 @@ function dbuffer_methods:read(requested_bytes) return table.concat(chunks); end +-- Read to, and including, the specified character sequence (return nil if not found) +function dbuffer_methods:read_until(char) + local buffer_pos = 0; + for i, chunk in self.items:items() do + local start = 1 + ((i == 1) and self.front_consumed or 0); + local char_pos = chunk:find(char, start, true); + if char_pos then + return self:read(1 + buffer_pos + char_pos - start); + end + buffer_pos = buffer_pos + #chunk - (start - 1); + end + return nil; +end + function dbuffer_methods:discard(requested_bytes) if requested_bytes > self._length then return nil; diff --git a/util/dependencies.lua b/util/dependencies.lua index 24975567..d7836404 100644 --- a/util/dependencies.lua +++ b/util/dependencies.lua @@ -7,24 +7,22 @@ -- local function softreq(...) local ok, lib = pcall(require, ...); if ok then return lib; else return nil, lib; end end +local platform_table = require "util.human.io".table({ { width = 15, align = "right" }, { width = "100%" } }); -- Required to be able to find packages installed with luarocks if not softreq "luarocks.loader" then -- LuaRocks 2.x softreq "luarocks.require"; -- LuaRocks <1.x end -local function missingdep(name, sources, msg) +local function missingdep(name, sources, msg, err) -- luacheck: ignore err + -- TODO print something about the underlying error, useful for debugging print(""); print("**************************"); print("Prosody was unable to find "..tostring(name)); print("This package can be obtained in the following ways:"); print(""); - local longest_platform = 0; - for platform in pairs(sources) do - longest_platform = math.max(longest_platform, #platform); - end - for platform, source in pairs(sources) do - print("", platform..":"..(" "):rep(4+longest_platform-#platform)..source); + for _, row in ipairs(sources) do + print(platform_table(row)); end print(""); print(msg or (name.." is required for Prosody to run, so we will now exit.")); @@ -44,25 +42,25 @@ local function check_dependencies() local fatal; - local lxp = softreq "lxp" + local lxp, err = softreq "lxp" if not lxp then missingdep("luaexpat", { - ["Debian/Ubuntu"] = "sudo apt-get install lua-expat"; - ["luarocks"] = "luarocks install luaexpat"; - ["Source"] = "http://matthewwild.co.uk/projects/luaexpat/"; - }); + { "Debian/Ubuntu", "sudo apt install lua-expat" }; + { "luarocks", "luarocks install luaexpat" }; + { "Source", "http://matthewwild.co.uk/projects/luaexpat/" }; + }, nil, err); fatal = true; end - local socket = softreq "socket" + local socket, err = softreq "socket" if not socket then missingdep("luasocket", { - ["Debian/Ubuntu"] = "sudo apt-get install lua-socket"; - ["luarocks"] = "luarocks install luasocket"; - ["Source"] = "http://www.tecgraf.puc-rio.br/~diego/professional/luasocket/"; - }); + { "Debian/Ubuntu", "sudo apt install lua-socket" }; + { "luarocks", "luarocks install luasocket" }; + { "Source", "http://www.tecgraf.puc-rio.br/~diego/professional/luasocket/" }; + }, nil, err); fatal = true; elseif not socket.tcp4 then -- COMPAT LuaSocket before being IP-version agnostic @@ -73,39 +71,53 @@ local function check_dependencies() local lfs, err = softreq "lfs" if not lfs then missingdep("luafilesystem", { - ["luarocks"] = "luarocks install luafilesystem"; - ["Debian/Ubuntu"] = "sudo apt-get install lua-filesystem"; - ["Source"] = "http://www.keplerproject.org/luafilesystem/"; - }); + { "luarocks", "luarocks install luafilesystem" }; + { "Debian/Ubuntu", "sudo apt install lua-filesystem" }; + { "Source", "http://www.keplerproject.org/luafilesystem/" }; + }, nil, err); fatal = true; end - local ssl = softreq "ssl" + local ssl, err = softreq "ssl" if not ssl then missingdep("LuaSec", { - ["Debian/Ubuntu"] = "sudo apt-get install lua-sec"; - ["luarocks"] = "luarocks install luasec"; - ["Source"] = "https://github.com/brunoos/luasec"; - }, "SSL/TLS support will not be available"); + { "Debian/Ubuntu", "sudo apt install lua-sec" }; + { "luarocks", "luarocks install luasec" }; + { "Source", "https://github.com/brunoos/luasec" }; + }, nil, err); end - local bit = softreq"bit" or softreq"bit32"; + local bit, err = softreq"util.bitcompat"; if not bit then missingdep("lua-bitops", { - ["Debian/Ubuntu"] = "sudo apt-get install lua-bitop"; - ["luarocks"] = "luarocks install luabitop"; - ["Source"] = "http://bitop.luajit.org/"; - }, "WebSocket support will not be available"); + { "Debian/Ubuntu", "sudo apt install lua-bitop" }; + { "luarocks", "luarocks install luabitop" }; + { "Source", "http://bitop.luajit.org/" }; + }, "WebSocket support will not be available", err); + end + + local unbound, err = softreq"lunbound"; -- luacheck: ignore 211/err + if not unbound then + missingdep("lua-unbound", { + { "Debian/Ubuntu", "sudo apt install lua-unbound" }; + { "luarocks", "luarocks install luaunbound" }; + { "Source", "https://www.zash.se/luaunbound.html" }; + }, "Old DNS resolver library will be used", err); + else + package.preload["net.adns"] = function () + local ub = require "net.unbound"; + return ub; + end end local encodings, err = softreq "util.encodings" if not encodings then if err:match("module '[^']*' not found") then missingdep("util.encodings", { - ["Windows"] = "Make sure you have encodings.dll from the Prosody distribution in util/"; - ["GNU/Linux"] = "Run './configure' and 'make' in the Prosody source directory to build util/encodings.so"; + { "Windows", "Make sure you have encodings.dll from the Prosody distribution in util/" }; + { "GNU/Linux", "Run './configure' and 'make' in the Prosody source directory to build util/encodings.so" }; }); else print "***********************************" @@ -122,8 +134,8 @@ local function check_dependencies() if not hashes then if err:match("module '[^']*' not found") then missingdep("util.hashes", { - ["Windows"] = "Make sure you have hashes.dll from the Prosody distribution in util/"; - ["GNU/Linux"] = "Run './configure' and 'make' in the Prosody source directory to build util/hashes.so"; + { "Windows", "Make sure you have hashes.dll from the Prosody distribution in util/" }; + { "GNU/Linux", "Run './configure' and 'make' in the Prosody source directory to build util/hashes.so" }; }); else print "***********************************" @@ -140,8 +152,10 @@ local function check_dependencies() end local function log_warnings() - if _VERSION > "Lua 5.2" then + if _VERSION > "Lua 5.4" then prosody.log("warn", "Support for %s is experimental, please report any issues", _VERSION); + elseif _VERSION < "Lua 5.2" then + prosody.log("warn", "%s has several issues and support is being phased out, consider upgrading", _VERSION); end local ssl = softreq"ssl"; if ssl then diff --git a/util/dns.lua b/util/dns.lua new file mode 100644 index 00000000..3b58e03e --- /dev/null +++ b/util/dns.lua @@ -0,0 +1,242 @@ +-- libunbound based net.adns replacement for Prosody IM +-- Copyright (C) 2012-2015 Kim Alvefur +-- Copyright (C) 2012 Waqas Hussain +-- +-- This file is MIT licensed. + +local setmetatable = setmetatable; +local table = table; +local t_concat = table.concat; +local t_insert = table.insert; +local s_byte = string.byte; +local s_format = string.format; +local s_sub = string.sub; + +local iana_data = require "util.dnsregistry"; +local tohex = require "util.hex".encode; +local inet_ntop = require "util.net".ntop; + +-- Simplified versions of Waqas DNS parsers +-- Only the per RR parsers are needed and only feed a single RR + +local parsers = {}; + +-- No support for pointers, but libunbound appears to take care of that. +local function readDnsName(packet, pos) + if s_byte(packet, pos) == 0 then return ".", pos+1; end + local pack_len, r, len = #packet, {}; + pos = pos or 1; + repeat + len = s_byte(packet, pos) or 0; + t_insert(r, s_sub(packet, pos + 1, pos + len)); + pos = pos + len + 1; + until len == 0 or pos >= pack_len; + return t_concat(r, "."), pos; +end + +-- These are just simple names. +parsers.CNAME = readDnsName; +parsers.NS = readDnsName +parsers.PTR = readDnsName; + +local soa_mt = { + __tostring = function(rr) + return s_format("%s %s %d %d %d %d %d", rr.mname, rr.rname, rr.serial, rr.refresh, rr.retry, rr.expire, rr.minimum); + end; +}; +function parsers.SOA(packet) + local mname, rname, offset; + + mname, offset = readDnsName(packet, 1); + rname, offset = readDnsName(packet, offset); + + -- Extract all the bytes of these fields in one call + local + s1, s2, s3, s4, -- serial + r1, r2, r3, r4, -- refresh + t1, t2, t3, t4, -- retry + e1, e2, e3, e4, -- expire + m1, m2, m3, m4 -- minimum + = s_byte(packet, offset, offset + 19); + + return setmetatable({ + mname = mname; + rname = rname; + serial = s1*0x1000000 + s2*0x10000 + s3*0x100 + s4; + refresh = r1*0x1000000 + r2*0x10000 + r3*0x100 + r4; + retry = t1*0x1000000 + t2*0x10000 + t3*0x100 + t4; + expire = e1*0x1000000 + e2*0x10000 + e3*0x100 + e4; + minimum = m1*0x1000000 + m2*0x10000 + m3*0x100 + m4; + }, soa_mt); +end + +parsers.A = inet_ntop; +parsers.AAAA = inet_ntop; + +local mx_mt = { + __tostring = function(rr) + return s_format("%d %s", rr.pref, rr.mx) + end +}; +function parsers.MX(packet) + local name = readDnsName(packet, 3); + local b1,b2 = s_byte(packet, 1, 2); + return setmetatable({ + pref = b1*256+b2; + mx = name; + }, mx_mt); +end + +local srv_mt = { + __tostring = function(rr) + return s_format("%d %d %d %s", rr.priority, rr.weight, rr.port, rr.target); + end +}; +function parsers.SRV(packet) + local name = readDnsName(packet, 7); + local b1, b2, b3, b4, b5, b6 = s_byte(packet, 1, 6); + return setmetatable({ + priority = b1*256+b2; + weight = b3*256+b4; + port = b5*256+b6; + target = name; + }, srv_mt); +end + +local txt_mt = { __tostring = t_concat }; +function parsers.TXT(packet) + local pack_len = #packet; + local r, pos, len = {}, 1; + repeat + len = s_byte(packet, pos) or 0; + t_insert(r, s_sub(packet, pos + 1, pos + len)); + pos = pos + len + 1; + until pos >= pack_len; + return setmetatable(r, txt_mt); +end + +parsers.SPF = parsers.TXT; + +-- Acronyms from RFC 7218 +local tlsa_usages = { + [0] = "PKIX-CA"; + [1] = "PKIX-EE"; + [2] = "DANE-TA"; + [3] = "DANE-EE"; + [255] = "PrivCert"; +}; +local tlsa_selectors = { + [0] = "Cert", + [1] = "SPKI", + [255] = "PrivSel", +}; +local tlsa_match_types = { + [0] = "Full", + [1] = "SHA2-256", + [2] = "SHA2-512", + [255] = "PrivMatch", +}; +local tlsa_mt = { + __tostring = function(rr) + return s_format("%s %s %s %s", + tlsa_usages[rr.use] or rr.use, + tlsa_selectors[rr.select] or rr.select, + tlsa_match_types[rr.match] or rr.match, + tohex(rr.data)); + end; + __index = { + getUsage = function(rr) return tlsa_usages[rr.use] end; + getSelector = function(rr) return tlsa_selectors[rr.select] end; + getMatchType = function(rr) return tlsa_match_types[rr.match] end; + } +}; +function parsers.TLSA(packet) + local use, select, match = s_byte(packet, 1,3); + return setmetatable({ + use = use; + select = select; + match = match; + data = s_sub(packet, 4); + }, tlsa_mt); +end + +local svcb_params = {"alpn"; "no-default-alpn"; "port"; "ipv4hint"; "ech"; "ipv6hint"}; +setmetatable(svcb_params, {__index = function(_, n) return "key" .. tostring(n); end}); + +local svcb_mt = { + __tostring = function (rr) + local kv = {}; + for i = 1, #rr.fields do + t_insert(kv, s_format("%s=%q", svcb_params[rr.fields[i].key], tostring(rr.fields[i].value))); + -- FIXME the =value part may be omitted when the value is "empty" + end + return s_format("%d %s %s", rr.prio, rr.name, t_concat(kv, " ")); + end; +}; +local svbc_ip_mt = {__tostring = function(ip) return t_concat(ip, ", "); end} + +function parsers.SVCB(packet) + local prio_h, prio_l = packet:byte(1,2); + local prio = prio_h*256+prio_l; + local name, pos = readDnsName(packet, 3); + local fields = {}; + while #packet > pos do + local key_h, key_l = packet:byte(pos+0,pos+1); + local len_h, len_l = packet:byte(pos+2,pos+3); + local key = key_h*256+key_l; + local len = len_h*256+len_l; + local value = packet:sub(pos+4,pos+4-1+len) + if key == 1 then + value = setmetatable(parsers.TXT(value), svbc_ip_mt); + elseif key == 3 then + local port_h, port_l = value:byte(1,2); + local port = port_h*256+port_l; + value = port; + elseif key == 4 then + local ip = {}; + for i = 1, #value, 4 do + t_insert(ip, parsers.A(value:sub(i, i+3))); + end + value = setmetatable(ip, svbc_ip_mt); + elseif key == 6 then + local ip = {}; + for i = 1, #value, 16 do + t_insert(ip, parsers.AAAA(value:sub(i, i+15))); + end + value = setmetatable(ip, svbc_ip_mt); + end + t_insert(fields, { key = key, value = value, len = len }); + pos = pos+len+4; + end + return setmetatable({ + prio = prio, name = name, fields = fields, + }, svcb_mt); +end + +parsers.HTTPS = parsers.SVCB; + +local params = { + TLSA = { + use = tlsa_usages; + select = tlsa_selectors; + match = tlsa_match_types; + }; +}; + +local fallback_mt = { + __tostring = function(rr) + return s_format([[\# %d %s]], #rr.raw, tohex(rr.raw)); + end; +}; +local function fallback_parser(packet) + return setmetatable({ raw = packet },fallback_mt); +end +setmetatable(parsers, { __index = function() return fallback_parser end }); + +return { + parsers = parsers; + classes = iana_data.classes; + types = iana_data.types; + errors = iana_data.errors; + params = params; +}; diff --git a/util/dnsregistry.lua b/util/dnsregistry.lua new file mode 100644 index 00000000..635b7e3a --- /dev/null +++ b/util/dnsregistry.lua @@ -0,0 +1,122 @@ +-- Source: https://www.iana.org/assignments/dns-parameters/dns-parameters.xml +-- Generated on 2022-02-02 +return { + classes = { + ["IN"] = 1; [1] = "IN"; + ["CH"] = 3; [3] = "CH"; + ["HS"] = 4; [4] = "HS"; + ["ANY"] = 255; [255] = "ANY"; + }; + types = { + ["A"] = 1; [1] = "A"; + ["NS"] = 2; [2] = "NS"; + ["MD"] = 3; [3] = "MD"; + ["MF"] = 4; [4] = "MF"; + ["CNAME"] = 5; [5] = "CNAME"; + ["SOA"] = 6; [6] = "SOA"; + ["MB"] = 7; [7] = "MB"; + ["MG"] = 8; [8] = "MG"; + ["MR"] = 9; [9] = "MR"; + ["NULL"] = 10; [10] = "NULL"; + ["WKS"] = 11; [11] = "WKS"; + ["PTR"] = 12; [12] = "PTR"; + ["HINFO"] = 13; [13] = "HINFO"; + ["MINFO"] = 14; [14] = "MINFO"; + ["MX"] = 15; [15] = "MX"; + ["TXT"] = 16; [16] = "TXT"; + ["RP"] = 17; [17] = "RP"; + ["AFSDB"] = 18; [18] = "AFSDB"; + ["X25"] = 19; [19] = "X25"; + ["ISDN"] = 20; [20] = "ISDN"; + ["RT"] = 21; [21] = "RT"; + ["NSAP"] = 22; [22] = "NSAP"; + ["NSAP-PTR"] = 23; [23] = "NSAP-PTR"; + ["SIG"] = 24; [24] = "SIG"; + ["KEY"] = 25; [25] = "KEY"; + ["PX"] = 26; [26] = "PX"; + ["GPOS"] = 27; [27] = "GPOS"; + ["AAAA"] = 28; [28] = "AAAA"; + ["LOC"] = 29; [29] = "LOC"; + ["NXT"] = 30; [30] = "NXT"; + ["EID"] = 31; [31] = "EID"; + ["NIMLOC"] = 32; [32] = "NIMLOC"; + ["SRV"] = 33; [33] = "SRV"; + ["ATMA"] = 34; [34] = "ATMA"; + ["NAPTR"] = 35; [35] = "NAPTR"; + ["KX"] = 36; [36] = "KX"; + ["CERT"] = 37; [37] = "CERT"; + ["A6"] = 38; [38] = "A6"; + ["DNAME"] = 39; [39] = "DNAME"; + ["SINK"] = 40; [40] = "SINK"; + ["OPT"] = 41; [41] = "OPT"; + ["APL"] = 42; [42] = "APL"; + ["DS"] = 43; [43] = "DS"; + ["SSHFP"] = 44; [44] = "SSHFP"; + ["IPSECKEY"] = 45; [45] = "IPSECKEY"; + ["RRSIG"] = 46; [46] = "RRSIG"; + ["NSEC"] = 47; [47] = "NSEC"; + ["DNSKEY"] = 48; [48] = "DNSKEY"; + ["DHCID"] = 49; [49] = "DHCID"; + ["NSEC3"] = 50; [50] = "NSEC3"; + ["NSEC3PARAM"] = 51; [51] = "NSEC3PARAM"; + ["TLSA"] = 52; [52] = "TLSA"; + ["SMIMEA"] = 53; [53] = "SMIMEA"; + ["Unassigned"] = 54; [54] = "Unassigned"; + ["HIP"] = 55; [55] = "HIP"; + ["NINFO"] = 56; [56] = "NINFO"; + ["RKEY"] = 57; [57] = "RKEY"; + ["TALINK"] = 58; [58] = "TALINK"; + ["CDS"] = 59; [59] = "CDS"; + ["CDNSKEY"] = 60; [60] = "CDNSKEY"; + ["OPENPGPKEY"] = 61; [61] = "OPENPGPKEY"; + ["CSYNC"] = 62; [62] = "CSYNC"; + ["ZONEMD"] = 63; [63] = "ZONEMD"; + ["SVCB"] = 64; [64] = "SVCB"; + ["HTTPS"] = 65; [65] = "HTTPS"; + ["SPF"] = 99; [99] = "SPF"; + ["NID"] = 104; [104] = "NID"; + ["L32"] = 105; [105] = "L32"; + ["L64"] = 106; [106] = "L64"; + ["LP"] = 107; [107] = "LP"; + ["EUI48"] = 108; [108] = "EUI48"; + ["EUI64"] = 109; [109] = "EUI64"; + ["TKEY"] = 249; [249] = "TKEY"; + ["TSIG"] = 250; [250] = "TSIG"; + ["IXFR"] = 251; [251] = "IXFR"; + ["AXFR"] = 252; [252] = "AXFR"; + ["MAILB"] = 253; [253] = "MAILB"; + ["MAILA"] = 254; [254] = "MAILA"; + ["*"] = 255; [255] = "*"; + ["URI"] = 256; [256] = "URI"; + ["CAA"] = 257; [257] = "CAA"; + ["AVC"] = 258; [258] = "AVC"; + ["DOA"] = 259; [259] = "DOA"; + ["AMTRELAY"] = 260; [260] = "AMTRELAY"; + ["TA"] = 32768; [32768] = "TA"; + ["DLV"] = 32769; [32769] = "DLV"; + }; + errors = { + [0] = "NoError"; ["NoError"] = "No Error"; + [1] = "FormErr"; ["FormErr"] = "Format Error"; + [2] = "ServFail"; ["ServFail"] = "Server Failure"; + [3] = "NXDomain"; ["NXDomain"] = "Non-Existent Domain"; + [4] = "NotImp"; ["NotImp"] = "Not Implemented"; + [5] = "Refused"; ["Refused"] = "Query Refused"; + [6] = "YXDomain"; ["YXDomain"] = "Name Exists when it should not"; + [7] = "YXRRSet"; ["YXRRSet"] = "RR Set Exists when it should not"; + [8] = "NXRRSet"; ["NXRRSet"] = "RR Set that should exist does not"; + [9] = "NotAuth"; ["NotAuth"] = "Server Not Authoritative for zone"; + -- [9] = "NotAuth"; ["NotAuth"] = "Not Authorized"; + [10] = "NotZone"; ["NotZone"] = "Name not contained in zone"; + [11] = "DSOTYPENI"; ["DSOTYPENI"] = "DSO-TYPE Not Implemented"; + [16] = "BADVERS"; ["BADVERS"] = "Bad OPT Version"; + -- [16] = "BADSIG"; ["BADSIG"] = "TSIG Signature Failure"; + [17] = "BADKEY"; ["BADKEY"] = "Key not recognized"; + [18] = "BADTIME"; ["BADTIME"] = "Signature out of time window"; + [19] = "BADMODE"; ["BADMODE"] = "Bad TKEY Mode"; + [20] = "BADNAME"; ["BADNAME"] = "Duplicate key name"; + [21] = "BADALG"; ["BADALG"] = "Algorithm not supported"; + [22] = "BADTRUNC"; ["BADTRUNC"] = "Bad Truncation"; + [23] = "BADCOOKIE"; ["BADCOOKIE"] = "Bad/missing Server Cookie"; + }; +}; diff --git a/util/error.lua b/util/error.lua new file mode 100644 index 00000000..326c01f8 --- /dev/null +++ b/util/error.lua @@ -0,0 +1,170 @@ +local id = require "util.id"; + +local util_debug; -- only imported on-demand + +-- Library configuration (see configure()) +local auto_inject_traceback = false; + +local error_mt = { __name = "error" }; + +function error_mt:__tostring() + return ("error<%s:%s:%s>"):format(self.type, self.condition, self.text or ""); +end + +local function is_error(e) + return getmetatable(e) == error_mt; +end + +local function configure(opt) + if opt.auto_inject_traceback ~= nil then + auto_inject_traceback = opt.auto_inject_traceback; + if auto_inject_traceback then + util_debug = require "util.debug"; + end + end +end + +-- Do we want any more well-known fields? +-- Or could we just copy all fields from `e`? +-- Sometimes you want variable details in the `text`, how to handle that? +-- Translations? +-- Should the `type` be restricted to the stanza error types or free-form? +-- What to set `type` to for stream errors or SASL errors? Those don't have a 'type' attr. + +local function new(e, context, registry, source) + if is_error(e) then return e; end + local template = registry and registry[e]; + if not template then + if type(e) == "table" then + template = { + code = e.code; + type = e.type; + condition = e.condition; + text = e.text; + extra = e.extra; + }; + else + template = {}; + end + end + context = context or {}; + + if auto_inject_traceback then + context.traceback = util_debug.get_traceback_table(nil, 2); + end + + local error_instance = setmetatable({ + instance_id = id.short(); + + type = template.type or "cancel"; + condition = template.condition or "undefined-condition"; + text = template.text; + code = template.code; + extra = template.extra; + + context = context; + source = source; + }, error_mt); + + return error_instance; +end + +-- compact --> normal form +local function expand_registry(namespace, registry) + local mapped = {} + for err,template in pairs(registry) do + local e = { + type = template[1]; + condition = template[2]; + text = template[3]; + }; + if namespace and template[4] then + e.extra = { namespace = namespace, condition = template[4] }; + end + mapped[err] = e; + end + return mapped; +end + +local function init(source, namespace, registry) + if type(namespace) == "table" then + -- registry can be given as second argument if namespace is not used + registry, namespace = namespace, nil; + end + local _, protoerr = next(registry, nil); + if protoerr and type(next(protoerr)) == "number" then + registry = expand_registry(namespace, registry); + end + + local function wrap(e, context) + if is_error(e) then + return e; + end + local err = new(registry[e] or { + type = "cancel", condition = "undefined-condition" + }, context, registry, source); + err.context.wrapped_error = e; + return err; + end + + return { + source = source; + registry = registry; + new = function (e, context) + return new(e, context, registry, source); + end; + coerce = function (ok, err, ...) + if ok then + return ok, err, ...; + end + return nil, wrap(err); + end; + wrap = wrap; + is_error = is_error; + }; +end + +local function coerce(ok, err, ...) + if ok or is_error(err) then + return ok, err, ...; + end + + local new_err = new({ + type = "cancel", condition = "undefined-condition" + }, { wrapped_error = err }); + + return ok, new_err, ...; +end + +local function from_stanza(stanza, context, source) + local error_type, condition, text, extra_tag = stanza:get_error(); + local error_tag = stanza:get_child("error"); + context = context or {}; + context.stanza = stanza; + context.by = error_tag.attr.by or stanza.attr.from; + + local uri; + if condition == "gone" or condition == "redirect" then + uri = error_tag:get_child_text(condition, "urn:ietf:params:xml:ns:xmpp-stanzas"); + end + + return new({ + type = error_type or "cancel"; + condition = condition or "undefined-condition"; + text = text; + extra = (extra_tag or uri) and { + uri = uri; + tag = extra_tag; + } or nil; + }, context, nil, source); +end + +return { + new = new; + init = init; + coerce = coerce; + is_error = is_error; + is_err = is_error; -- COMPAT w/ older 0.12 trunk + from_stanza = from_stanza; + configure = configure; +} diff --git a/util/events.lua b/util/events.lua index 0bf0ddcb..5205a457 100644 --- a/util/events.lua +++ b/util/events.lua @@ -26,6 +26,8 @@ local function new() local wrappers = {}; -- Event map: event_map[handler_function] = priority_number local event_map = {}; + -- Debug hook, if any + local active_debug_hook = nil; -- Called on-demand to build handlers entries local function _rebuild_index(self, event) local _handlers = event_map[event]; @@ -74,11 +76,16 @@ local function new() end; local function _fire_event(event_name, event_data) local h = handlers[event_name]; - if h then + if h and not active_debug_hook then for i=1,#h do local ret = h[i](event_data); if ret ~= nil then return ret; end end + elseif h and active_debug_hook then + for i=1,#h do + local ret = active_debug_hook(h[i], event_name, event_data); + if ret ~= nil then return ret; end + end end end; local function fire_event(event_name, event_data) @@ -140,6 +147,13 @@ local function new() end end end + + local function set_debug_hook(new_hook) + local old_hook = active_debug_hook; + active_debug_hook = new_hook; + return old_hook; + end + return { add_handler = add_handler; remove_handler = remove_handler; @@ -150,8 +164,12 @@ local function new() add_handler = add_wrapper; remove_handler = remove_wrapper; }; + add_wrapper = add_wrapper; remove_wrapper = remove_wrapper; + + set_debug_hook = set_debug_hook; + fire_event = fire_event; _handlers = handlers; _event_map = event_map; diff --git a/util/format.lua b/util/format.lua index c5e513fa..d709aada 100644 --- a/util/format.lua +++ b/util/format.lua @@ -1,14 +1,45 @@ -- --- A string.format wrapper that gracefully handles invalid arguments +-- A string.format wrapper that gracefully handles invalid arguments since +-- certain format string and argument combinations may cause errors or other +-- issues like log spoofing -- +-- Provides some protection from e.g. CAPEC-135, CWE-117, CWE-134, CWE-93 local tostring = tostring; -local select = select; local unpack = table.unpack or unpack; -- luacheck: ignore 113/unpack +local pack = require "util.table".pack; -- TODO table.pack in 5.2+ +local valid_utf8 = require "util.encodings".utf8.valid; local type = type; +local dump = require "util.serialization".new("debug"); +local num_type = math.type or function (n) + return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float"; +end + +-- In Lua 5.3+ these formats throw an error if given a float +local expects_integer = { c = true, d = true, i = true, o = true, u = true, X = true, x = true, }; +-- In Lua 5.2 these throw an error given a negative number +local expects_positive = { o = true; u = true; x = true; X = true }; +-- Printable Unicode replacements for control characters +local control_symbols = { + -- 0x00 .. 0x1F --> U+2400 .. U+241F, 0x7F --> U+2421 + ["\000"] = "\226\144\128", ["\001"] = "\226\144\129", ["\002"] = "\226\144\130", + ["\003"] = "\226\144\131", ["\004"] = "\226\144\132", ["\005"] = "\226\144\133", + ["\006"] = "\226\144\134", ["\007"] = "\226\144\135", ["\008"] = "\226\144\136", + ["\009"] = "\226\144\137", ["\010"] = "\226\144\138", ["\011"] = "\226\144\139", + ["\012"] = "\226\144\140", ["\013"] = "\226\144\141", ["\014"] = "\226\144\142", + ["\015"] = "\226\144\143", ["\016"] = "\226\144\144", ["\017"] = "\226\144\145", + ["\018"] = "\226\144\146", ["\019"] = "\226\144\147", ["\020"] = "\226\144\148", + ["\021"] = "\226\144\149", ["\022"] = "\226\144\150", ["\023"] = "\226\144\151", + ["\024"] = "\226\144\152", ["\025"] = "\226\144\153", ["\026"] = "\226\144\154", + ["\027"] = "\226\144\155", ["\028"] = "\226\144\156", ["\029"] = "\226\144\157", + ["\030"] = "\226\144\158", ["\031"] = "\226\144\159", ["\127"] = "\226\144\161", +}; +local supports_p = pcall(string.format, "%p", ""); -- >= Lua 5.4 +local supports_a = pcall(string.format, "%a", 0.0); -- > Lua 5.1 local function format(formatstring, ...) - local args, args_length = { ... }, select('#', ...); + local args = pack(...); + local args_length = args.n; -- format specifier spec: -- 1. Start: '%%' @@ -20,28 +51,83 @@ local function format(formatstring, ...) -- The options c, d, E, e, f, g, G, i, o, u, X, and x all expect a number as argument, whereas q and s expect a string. -- This function does not accept string values containing embedded zeros, except as arguments to the q option. -- a and A are only in Lua 5.2+ + -- Lua 5.4 adds a p format that produces a pointer -- process each format specifier local i = 0; - formatstring = formatstring:gsub("%%[^cdiouxXaAeEfgGqs%%]*[cdiouxXaAeEfgGqs%%]", function(spec) - if spec ~= "%%" then - i = i + 1; - local arg = args[i]; - if arg == nil then -- special handling for nil - arg = "<nil>" - args[i] = "<nil>"; - end + formatstring = formatstring:gsub("%%[^cdiouxXaAeEfgGpqs%%]*[cdiouxXaAeEfgGpqs%%]", function(spec) + if spec == "%%" then return end + i = i + 1; + local arg = args[i]; - local option = spec:sub(-1); - if option == "q" or option == "s" then -- arg should be string + if arg == nil then + args[i] = "nil"; + return "(%s)"; + end + + local option = spec:sub(-1); + local t = type(arg); + + if option == "s" and t == "string" and not arg:find("[%z\1-\31\128-\255]") then + -- No UTF-8 or control characters, assumed to be the common case. + return + elseif t == "number" then + if option == "g" or (option == "d" and num_type(arg) == "integer") then return end + elseif option == "s" and t ~= "string" then + arg = tostring(arg); + t = "string"; + end + + if option ~= "s" and option ~= "q" and option ~= "p" then + -- all other options expect numbers + if t ~= "number" then + -- arg isn't number as expected? + arg = tostring(arg); + option = "s"; + spec = "[%s]"; + t = "string"; + elseif expects_integer[option] and num_type(arg) ~= "integer" then args[i] = tostring(arg); - elseif type(arg) ~= "number" then -- arg isn't number as expected? + return "[%s]"; + elseif expects_positive[option] and arg < 0 then args[i] = tostring(arg); - spec = "[%s]"; + return "[%s]"; + elseif (option == "a" or option == "A") and not supports_a then + return "%x"; + else + return -- acceptable number + end + end + + + if option == "p" and not supports_p then + arg = tostring(arg); + option = "s"; + spec = "[%s]"; + t = "string"; + end + + if t == "string" and option ~= "p" then + if not valid_utf8(arg) then + option = "q"; + elseif option ~= "q" then -- gets fully escaped in the next block + -- Prevent funny things with ASCII control characters and ANSI escape codes (CWE-117) + -- Also ensure embedded newlines can't look like another log line (CWE-93) + args[i] = arg:gsub("[%z\1-\8\11-\31\127]", control_symbols):gsub("\n\t?", "\n\t"); + return spec; end end - return spec; + + if option == "q" then + args[i] = dump(arg); + return "%s"; + end + + if option == "p" and (t == "boolean" or t == "number") then + args[i] = tostring(arg); + return "[%s]"; + end end); -- process extra args @@ -49,9 +135,9 @@ local function format(formatstring, ...) i = i + 1; local arg = args[i]; if arg == nil then - args[i] = "<nil>"; + args[i] = "(nil)"; else - args[i] = tostring(arg); + args[i] = tostring(arg):gsub("[%z\1-\8\11-\31\127]", control_symbols):gsub("\n\t?", "\n\t"); end formatstring = formatstring .. " [%s]" end diff --git a/util/gc.lua b/util/gc.lua index b400af6b..f46e4346 100644 --- a/util/gc.lua +++ b/util/gc.lua @@ -5,7 +5,7 @@ local known_options = { generational = set.new { "mode", "minor_threshold", "major_threshold" }; }; -if _VERSION ~= "5.4" then +if _VERSION ~= "Lua 5.4" then known_options.generational = nil; known_options.incremental:remove("step_size"); end diff --git a/util/hashring.lua b/util/hashring.lua new file mode 100644 index 00000000..d4555669 --- /dev/null +++ b/util/hashring.lua @@ -0,0 +1,88 @@ +local function generate_ring(nodes, num_replicas, hash) + local new_ring = {}; + for _, node_name in ipairs(nodes) do + for replica = 1, num_replicas do + local replica_hash = hash(node_name..":"..replica); + new_ring[replica_hash] = node_name; + table.insert(new_ring, replica_hash); + end + end + table.sort(new_ring); + return new_ring; +end + +local hashring_methods = {}; +local hashring_mt = { + __index = function (self, k) + -- Automatically build self.ring if it's missing + if k == "ring" then + local ring = generate_ring(self.nodes, self.num_replicas, self.hash); + rawset(self, "ring", ring); + return ring; + end + return rawget(hashring_methods, k); + end +}; + +local function new(num_replicas, hash_function) + return setmetatable({ nodes = {}, num_replicas = num_replicas, hash = hash_function }, hashring_mt); +end; + +function hashring_methods:add_node(name) + self.ring = nil; + self.nodes[name] = true; + table.insert(self.nodes, name); + return true; +end + +function hashring_methods:add_nodes(nodes) + self.ring = nil; + for _, node_name in ipairs(nodes) do + if not self.nodes[node_name] then + self.nodes[node_name] = true; + table.insert(self.nodes, node_name); + end + end + return true; +end + +function hashring_methods:remove_node(node_name) + self.ring = nil; + if self.nodes[node_name] then + for i, stored_node_name in ipairs(self.nodes) do + if node_name == stored_node_name then + self.nodes[node_name] = nil; + table.remove(self.nodes, i); + return true; + end + end + end + return false; +end + +function hashring_methods:remove_nodes(nodes) + self.ring = nil; + for _, node_name in ipairs(nodes) do + self:remove_node(node_name); + end +end + +function hashring_methods:clone() + local clone_hashring = new(self.num_replicas, self.hash); + clone_hashring:add_nodes(self.nodes); + return clone_hashring; +end + +function hashring_methods:get_node(key) + local key_hash = self.hash(key); + for _, replica_hash in ipairs(self.ring) do + if key_hash < replica_hash then + return self.ring[replica_hash]; + end + end + return self.ring[self.ring[1]]; +end + +return { + new = new; +} diff --git a/util/helpers.lua b/util/helpers.lua index 02257ffa..139b62ec 100644 --- a/util/helpers.lua +++ b/util/helpers.lua @@ -23,12 +23,27 @@ local function log_events(events, name, logger) logger("debug", "%s firing event: %s", name, event); return f(event, ...); end + + local function event_handler_hook(handler, event_name, event_data) + logger("debug", "calling handler for %s: %s", event_name, handler); + local ok, ret = pcall(handler, event_data); + if not ok then + logger("error", "error in event handler %s: %s", handler, ret); + error(ret); + end + if ret ~= nil then + logger("debug", "event chain ended for %s by %s with result: %s", event_name, handler, ret); + end + return ret; + end + events.set_debug_hook(event_handler_hook); events[events.fire_event] = f; return events; end local function revert_log_events(events) events.fire_event, events[events.fire_event] = events[events.fire_event], nil; -- :)) + events.set_debug_hook(nil); end local function log_host_events(host) diff --git a/util/hex.lua b/util/hex.lua index 4cc28d33..6202620f 100644 --- a/util/hex.lua +++ b/util/hex.lua @@ -23,4 +23,8 @@ local function from(s) return (s_gsub(s_lower(s), "%X*(%x%x)%X*", hex_to_char)); end -return { to = to, from = from } +return { + encode = to, decode = from; + -- COMPAT w/pre-0.12: + to = to, from = from; +}; diff --git a/util/hmac.lua b/util/hmac.lua index 2c4cc6ef..4cad17cc 100644 --- a/util/hmac.lua +++ b/util/hmac.lua @@ -10,6 +10,9 @@ local hashes = require "util.hashes" -return { md5 = hashes.hmac_md5, - sha1 = hashes.hmac_sha1, - sha256 = hashes.hmac_sha256 }; +return { + md5 = hashes.hmac_md5, + sha1 = hashes.hmac_sha1, + sha256 = hashes.hmac_sha256, + sha512 = hashes.hmac_sha512, +}; diff --git a/util/http.lua b/util/http.lua index cfb89193..3852f91c 100644 --- a/util/http.lua +++ b/util/http.lua @@ -6,24 +6,26 @@ -- local format, char = string.format, string.char; -local pairs, ipairs, tonumber = pairs, ipairs, tonumber; +local pairs, ipairs = pairs, ipairs; local t_insert, t_concat = table.insert, table.concat; +local url_codes = {}; +for i = 0, 255 do + local c = char(i); + local u = format("%%%02x", i); + url_codes[c] = u; + url_codes[u] = c; + url_codes[u:upper()] = c; +end local function urlencode(s) - return s and (s:gsub("[^a-zA-Z0-9.~_-]", function (c) return format("%%%02x", c:byte()); end)); + return s and (s:gsub("[^a-zA-Z0-9.~_-]", url_codes)); end local function urldecode(s) - return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end)); + return s and (s:gsub("%%%x%x", url_codes)); end local function _formencodepart(s) - return s and (s:gsub("%W", function (c) - if c ~= " " then - return format("%%%02x", c:byte()); - else - return "+"; - end - end)); + return s and (urlencode(s):gsub("%%20", "+")); end local function formencode(form) diff --git a/util/human/io.lua b/util/human/io.lua new file mode 100644 index 00000000..7d7dea97 --- /dev/null +++ b/util/human/io.lua @@ -0,0 +1,192 @@ +local array = require "util.array"; +local utf8 = rawget(_G, "utf8") or require"util.encodings".utf8; +local len = utf8.len or function(s) + local _, count = s:gsub("[%z\001-\127\194-\253][\128-\191]*", ""); + return count; +end; + +local function getchar(n) + local stty_ret = os.execute("stty raw -echo 2>/dev/null"); + local ok, char; + if stty_ret == true or stty_ret == 0 then + ok, char = pcall(io.read, n or 1); + os.execute("stty sane"); + else + ok, char = pcall(io.read, "*l"); + if ok then + char = char:sub(1, n or 1); + end + end + if ok then + return char; + end +end + +local function getline() + local ok, line = pcall(io.read, "*l"); + if ok then + return line; + end +end + +local function getpass() + local stty_ret, _, status_code = os.execute("stty -echo 2>/dev/null"); + if status_code then -- COMPAT w/ Lua 5.1 + stty_ret = status_code; + end + if stty_ret ~= 0 then + io.write("\027[08m"); -- ANSI 'hidden' text attribute + end + local ok, pass = pcall(io.read, "*l"); + if stty_ret == 0 then + os.execute("stty sane"); + else + io.write("\027[00m"); + end + io.write("\n"); + if ok then + return pass; + end +end + +local function show_yesno(prompt) + io.write(prompt, " "); + local choice = getchar():lower(); + io.write("\n"); + if not choice:match("%a") then + choice = prompt:match("%[.-(%U).-%]$"); + if not choice then return nil; end + end + return (choice == "y"); +end + +local function read_password() + local password; + while true do + io.write("Enter new password: "); + password = getpass(); + if not password then + print("No password - cancelled"); + return; + end + io.write("Retype new password: "); + if getpass() ~= password then + if not show_yesno [=[Passwords did not match, try again? [Y/n]]=] then + return; + end + else + break; + end + end + return password; +end + +local function show_prompt(prompt) + io.write(prompt, " "); + local line = getline(); + line = line and line:gsub("\n$",""); + return (line and #line > 0) and line or nil; +end + +local function printf(fmt, ...) + print(fmt:format(...)); +end + +local function padright(s, width) + return s..string.rep(" ", width-len(s)); +end + +local function padleft(s, width) + return string.rep(" ", width-len(s))..s; +end + +local pat = "[%z\001-\127\194-\253][\128-\191]*"; +local function utf8_cut(s, pos) + return s:match("^"..pat:rep(pos)) or s; +end + +if utf8.len and utf8.offset then + function utf8_cut(s, pos) + return s:sub(1, utf8.offset(s, pos+1)-1); + end +end + +local function ellipsis(s, width) + if len(s) <= width then return s; end + if width == 1 then return "…"; end + return utf8_cut(s, width - 1) .. "…"; +end + +local function new_table(col_specs, max_width) + max_width = max_width or tonumber(os.getenv("COLUMNS")) or 80; + local separator = " | "; + + local widths = {}; + local total_width = max_width - #separator * (#col_specs-1); + local free_width = total_width; + -- Calculate width of fixed-size columns + for i = 1, #col_specs do + local width = col_specs[i].width or "0"; + if not(type(width) == "string" and width:sub(-1) == "%") then + local title = col_specs[i].title; + width = math.max(tonumber(width), title and (#title+1) or 0); + widths[i] = width; + free_width = free_width - width; + if i > 1 then + free_width = free_width - #separator; + end + end + end + -- Calculate width of %-based columns + for i = 1, #col_specs do + if not widths[i] then + local pc_width = tonumber((col_specs[i].width:gsub("%%$", ""))); + widths[i] = math.floor(free_width*(pc_width/100)); + end + end + + return function (row) + local titles; + if not row then + titles, row = true, array.pluck(col_specs, "title", ""); + end + local output = {}; + for i, column in ipairs(col_specs) do + local width = widths[i]; + local v = row[not titles and column.key or i]; + if not titles and column.mapper then + v = column.mapper(v, row); + end + if v == nil then + v = column.default or ""; + else + v = tostring(v); + end + if len(v) < width then + if column.align == "right" then + v = padleft(v, width); + else + v = padright(v, width); + end + elseif len(v) > width then + v = ellipsis(v, width); + end + table.insert(output, v); + end + return table.concat(output, separator); + end; +end + +return { + getchar = getchar; + getline = getline; + getpass = getpass; + show_yesno = show_yesno; + read_password = read_password; + show_prompt = show_prompt; + printf = printf; + padleft = padleft; + padright = padright; + ellipsis = ellipsis; + table = new_table; +}; diff --git a/util/human/units.lua b/util/human/units.lua new file mode 100644 index 00000000..af233e98 --- /dev/null +++ b/util/human/units.lua @@ -0,0 +1,80 @@ +local math_abs = math.abs; +local math_ceil = math.ceil; +local math_floor = math.floor; +local math_log = math.log; +local math_max = math.max; +local math_min = math.min; +local unpack = table.unpack or unpack; --luacheck: ignore 113 + +if math_log(10, 10) ~= 1 then + -- Lua 5.1 COMPAT + local log10 = math.log10; + function math_log(n, base) + return log10(n) / log10(base); + end +end + +local large = { + "k", 1000, + "M", 1000000, + "G", 1000000000, + "T", 1000000000000, + "P", 1000000000000000, + "E", 1000000000000000000, + "Z", 1000000000000000000000, + "Y", 1000000000000000000000000, +} +local small = { + "m", 0.001, + "μ", 0.000001, + "n", 0.000000001, + "p", 0.000000000001, + "f", 0.000000000000001, + "a", 0.000000000000000001, + "z", 0.000000000000000000001, + "y", 0.000000000000000000000001, +} + +local binary = { + "Ki", 2^10, + "Mi", 2^20, + "Gi", 2^30, + "Ti", 2^40, + "Pi", 2^50, + "Ei", 2^60, + "Zi", 2^70, + "Yi", 2^80, +} + +local function adjusted_unit(n, b) + local round = math_floor; + local prefixes = large; + local logbase = 1000; + if b == 'b' then + prefixes = binary; + logbase = 1024; + elseif n < 1 then + prefixes = small; + round = math_ceil; + end + local m = math_max(0, math_min(8, round(math_abs(math_log(math_abs(n), logbase))))); + local prefix, multiplier = unpack(prefixes, m * 2-1, m*2); + return multiplier or 1, prefix; +end + +-- n: number, the number to format +-- unit: string, the base unit +-- b: optional enum 'b', thousands base +local function format(n, unit, b) --> string + local fmt = "%.3g %s%s"; + if n == 0 then + return fmt:format(n, "", unit); + end + local multiplier, prefix = adjusted_unit(n, b); + return fmt:format(n / multiplier, prefix or "", unit); +end + +return { + adjust = adjusted_unit; + format = format; +}; diff --git a/util/id.lua b/util/id.lua index 731355fa..ff4e919d 100644 --- a/util/id.lua +++ b/util/id.lua @@ -17,9 +17,23 @@ local function b64url_random(len) end return { - short = function () return b64url_random(6); end; - medium = function () return b64url_random(12); end; - long = function () return b64url_random(24); end; + -- sizes divisible by 3 fit nicely into base64 without padding== + + -- for short lived things with low risk of collisions + tiny = function() return b64url_random(3); end; + + -- close to 8 bytes, should be good enough for relatively short lived or uses + -- scoped by host or users, half the size of an uuid + short = function() return b64url_random(9); end; + + -- more entropy than uuid at 2/3 the size + -- should be okay for globally scoped ids or security token + medium = function() return b64url_random(18); end; + + -- as long as an uuid but MOAR entropy + long = function() return b64url_random(27); end; + + -- pick your own adventure custom = function (size) return function () return b64url_random(size); end; end; diff --git a/util/import.lua b/util/import.lua index 8ecfe43c..1007bc0a 100644 --- a/util/import.lua +++ b/util/import.lua @@ -8,7 +8,7 @@ -local unpack = table.unpack or unpack; --luacheck: ignore 113 143 +local unpack = table.unpack or unpack; --luacheck: ignore 113 local t_insert = table.insert; function _G.import(module, ...) local m = package.loaded[module] or require(module); diff --git a/util/interpolation.lua b/util/interpolation.lua index 3e1f8c4a..acae901c 100644 --- a/util/interpolation.lua +++ b/util/interpolation.lua @@ -64,6 +64,9 @@ local function new_render(pat, escape, funcs) elseif opt == '&' then if not value then return ""; end return render(s_sub(block, e), values); + elseif opt == '~' then + if value then return ""; end + return render(s_sub(block, e), values); elseif opt == '?' and not value then return render(s_sub(block, e), values); elseif value ~= nil then diff --git a/util/ip.lua b/util/ip.lua index 05c4ca14..4b450934 100644 --- a/util/ip.lua +++ b/util/ip.lua @@ -19,8 +19,14 @@ local ip_mt = { return ret; end, __tostring = function (ip) return ip.addr; end, - __eq = function (ipA, ipB) return ipA.packed == ipB.packed; end }; +ip_mt.__eq = function (ipA, ipB) + if getmetatable(ipA) ~= ip_mt or getmetatable(ipB) ~= ip_mt then + -- Lua 5.3+ calls this if both operands are tables, even if metatables differ + return false; + end + return ipA.packed == ipB.packed; +end local hex2bits = { ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", @@ -61,7 +67,7 @@ function ip_methods:normal() end function ip_methods.bits(ip) - return hex.to(ip.packed):upper():gsub(".", hex2bits); + return hex.encode(ip.packed):upper():gsub(".", hex2bits); end function ip_methods.bits_full(ip) diff --git a/util/iterators.lua b/util/iterators.lua index 302cca36..c03c2fd6 100644 --- a/util/iterators.lua +++ b/util/iterators.lua @@ -11,9 +11,9 @@ local it = {}; local t_insert = table.insert; -local select, next = select, next; -local unpack = table.unpack or unpack; --luacheck: ignore 113 143 -local pack = table.pack or function (...) return { n = select("#", ...), ... }; end -- luacheck: ignore 143 +local next = next; +local unpack = table.unpack or unpack; --luacheck: ignore 113 +local pack = table.pack or require "util.table".pack; local type = type; local table, setmetatable = table, setmetatable; diff --git a/util/jid.lua b/util/jid.lua index ec31f180..694a6b1f 100644 --- a/util/jid.lua +++ b/util/jid.lua @@ -22,63 +22,67 @@ local escapes = { ["@"] = "\\40"; ["\\"] = "\\5c"; }; local unescapes = {}; -for k,v in pairs(escapes) do unescapes[v] = k; end +local backslash_escapes = {}; +for k,v in pairs(escapes) do + unescapes[v] = k; + backslash_escapes[v] = v:gsub("\\", escapes) +end local _ENV = nil; -- luacheck: std none local function split(jid) - if not jid then return; end + if jid == nil then return; end local node, nodepos = match(jid, "^([^@/]+)@()"); local host, hostpos = match(jid, "^([^@/]+)()", nodepos); - if node and not host then return nil, nil, nil; end + if node ~= nil and host == nil then return nil, nil, nil; end local resource = match(jid, "^/(.+)$", hostpos); - if (not host) or ((not resource) and #jid >= hostpos) then return nil, nil, nil; end + if (host == nil) or ((resource == nil) and #jid >= hostpos) then return nil, nil, nil; end return node, host, resource; end local function bare(jid) local node, host = split(jid); - if node and host then + if node ~= nil and host ~= nil then return node.."@"..host; end return host; end -local function prepped_split(jid) +local function prepped_split(jid, strict) local node, host, resource = split(jid); - if host and host ~= "." then + if host ~= nil and host ~= "." then if sub(host, -1, -1) == "." then -- Strip empty root label host = sub(host, 1, -2); end - host = nameprep(host); - if not host then return; end - if node then - node = nodeprep(node); - if not node then return; end + host = nameprep(host, strict); + if host == nil then return; end + if node ~= nil then + node = nodeprep(node, strict); + if node == nil then return; end end - if resource then - resource = resourceprep(resource); - if not resource then return; end + if resource ~= nil then + resource = resourceprep(resource, strict); + if resource == nil then return; end end return node, host, resource; end end local function join(node, host, resource) - if not host then return end - if node and resource then + if host == nil then return end + if node ~= nil and resource ~= nil then return node.."@"..host.."/"..resource; - elseif node then + elseif node ~= nil then return node.."@"..host; - elseif resource then + elseif resource ~= nil then return host.."/"..resource; end return host; end -local function prep(jid) - local node, host, resource = prepped_split(jid); +local function prep(jid, strict) + local node, host, resource = prepped_split(jid, strict); return join(node, host, resource); end @@ -107,7 +111,7 @@ local function resource(jid) return (select(3, split(jid))); end -local function escape(s) return s and (s:gsub(".", escapes)); end +local function escape(s) return s and (s:gsub("\\%x%x", backslash_escapes):gsub("[\"&'/:<>@ ]", escapes)); end local function unescape(s) return s and (s:gsub("\\%x%x", unescapes)); end return { diff --git a/util/json.lua b/util/json.lua index a750da2e..e6704b7e 100644 --- a/util/json.lua +++ b/util/json.lua @@ -217,12 +217,19 @@ local function _readobject(json, index) end local function _readarray(json, index) local a = {}; - local oindex = index; while true do - local val; - val, index = _readvalue(json, index + 1); + local val, terminated; + val, index, terminated = _readvalue(json, index + 1, 0x5d); if val == nil then - if json:byte(oindex + 1) == 0x5d then return setmetatable(a, array_mt), oindex + 2; end -- "]" + if terminated then -- "]" found instead of value + if #a ~= 0 then + -- A non-empty array here means we processed a comma, + -- but it wasn't followed by a value. JSON doesn't allow + -- trailing commas. + return nil, "value expected"; + end + val, index = setmetatable(a, array_mt), index+1; + end return val, index; end t_insert(a, val); @@ -294,7 +301,7 @@ local function _readfalse(json, index) end return nil, "false parse failed"; end -function _readvalue(json, index) +function _readvalue(json, index, terminator) index = _skip_whitespace(json, index); local b = json:byte(index); -- TODO try table lookup instead of if-else? @@ -312,6 +319,8 @@ function _readvalue(json, index) return _readtrue(json, index); elseif b == 0x66 then -- "f" return _readfalse(json, index); + elseif b == terminator then + return nil, index, true; else return nil, "value expected"; end diff --git a/util/jsonpointer.lua b/util/jsonpointer.lua new file mode 100644 index 00000000..9b871ae7 --- /dev/null +++ b/util/jsonpointer.lua @@ -0,0 +1,44 @@ +local m_type = math.type or function (n) + return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float"; +end; + +local function unescape_token(escaped_token) + local unescaped = escaped_token:gsub("~1", "/"):gsub("~0", "~") + return unescaped +end + +local function resolve_json_pointer(ref, path) + local ptr_len = #path + 1 + for part, pos in path:gmatch("/([^/]*)()") do + local token = unescape_token(part) + if not (type(ref) == "table") then + return nil + end + local idx = next(ref) + local new_ref + + if type(idx) == "string" then + new_ref = ref[token] + elseif m_type(idx) == "integer" then + local i = tonumber(token) + if token == "-" then + i = #ref + 1 + end + new_ref = ref[i + 1] + else + return nil, "invalid-table" + end + + if pos == ptr_len then + return new_ref + elseif type(new_ref) == "table" then + ref = new_ref + elseif not (type(ref) == "table") then + return nil, "invalid-path" + end + + end + return ref +end + +return { resolve = resolve_json_pointer } diff --git a/util/jsonschema.lua b/util/jsonschema.lua new file mode 100644 index 00000000..eafa8b7c --- /dev/null +++ b/util/jsonschema.lua @@ -0,0 +1,286 @@ +-- This file is generated from teal-src/util/jsonschema.lua + +local m_type = function(n) + return type(n) == "number" and n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float"; +end; +local json = require("util.json") +local null = json.null; + +local pointer = require("util.jsonpointer") + +local json_type_name = json.json_type_name + +local schema_t = {} + +local json_schema_object = { xml_t = {} } + +local function simple_validate(schema, data) + if schema == nil then + return true + elseif schema == "object" and type(data) == "table" then + return type(data) == "table" and (next(data) == nil or type((next(data, nil))) == "string") + elseif schema == "array" and type(data) == "table" then + return type(data) == "table" and (next(data) == nil or type((next(data, nil))) == "number") + elseif schema == "integer" then + return m_type(data) == schema + elseif schema == "null" then + return data == null + elseif type(schema) == "table" then + for _, one in ipairs(schema) do + if simple_validate(one, data) then + return true + end + end + return false + else + return type(data) == schema + end +end + +local complex_validate + +local function validate(schema, data, root) + if type(schema) == "boolean" then + return schema + else + return complex_validate(schema, data, root) + end +end + +function complex_validate(schema, data, root) + + if root == nil then + root = schema + end + + if schema["$ref"] and schema["$ref"]:sub(1, 1) == "#" then + local referenced = pointer.resolve(root, schema["$ref"]:sub(2)) + if referenced ~= nil and referenced ~= root and referenced ~= schema then + if not validate(referenced, data, root) then + return false + end + end + end + + if not simple_validate(schema.type, data) then + return false + end + + if schema.type == "object" then + if type(data) == "table" then + + for k in pairs(data) do + if not (type(k) == "string") then + return false + end + end + end + end + + if schema.type == "array" then + if type(data) == "table" then + + for i in pairs(data) do + if not (m_type(i) == "integer") then + return false + end + end + end + end + + if schema["enum"] ~= nil then + local match = false + for _, v in ipairs(schema["enum"]) do + if v == data then + + match = true + break + end + end + if not match then + return false + end + end + + if type(data) == "string" then + if schema.maxLength and #data > schema.maxLength then + return false + end + if schema.minLength and #data < schema.minLength then + return false + end + end + + if type(data) == "number" then + if schema.multipleOf and (data == 0 or data % schema.multipleOf ~= 0) then + return false + end + + if schema.maximum and not (data <= schema.maximum) then + return false + end + + if schema.exclusiveMaximum and not (data < schema.exclusiveMaximum) then + return false + end + + if schema.minimum and not (data >= schema.minimum) then + return false + end + + if schema.exclusiveMinimum and not (data > schema.exclusiveMinimum) then + return false + end + end + + if schema.allOf then + for _, sub in ipairs(schema.allOf) do + if not validate(sub, data, root) then + return false + end + end + end + + if schema.oneOf then + local valid = 0 + for _, sub in ipairs(schema.oneOf) do + if validate(sub, data, root) then + valid = valid + 1 + end + end + if valid ~= 1 then + return false + end + end + + if schema.anyOf then + local match = false + for _, sub in ipairs(schema.anyOf) do + if validate(sub, data, root) then + match = true + break + end + end + if not match then + return false + end + end + + if schema["not"] then + if validate(schema["not"], data, root) then + return false + end + end + + if schema["if"] ~= nil then + if validate(schema["if"], data, root) then + if schema["then"] then + return validate(schema["then"], data, root) + end + else + if schema["else"] then + return validate(schema["else"], data, root) + end + end + end + + if schema.const ~= nil and schema.const ~= data then + return false + end + + if type(data) == "table" then + + if schema.maxItems and #data > schema.maxItems then + return false + end + + if schema.minItems and #data < schema.minItems then + return false + end + + if schema.required then + for _, k in ipairs(schema.required) do + if data[k] == nil then + return false + end + end + end + + if schema.propertyNames ~= nil then + for k in pairs(data) do + if not validate(schema.propertyNames, k, root) then + return false + end + end + end + + if schema.properties then + for k, sub in pairs(schema.properties) do + if data[k] ~= nil and not validate(sub, data[k], root) then + return false + end + end + end + + if schema.additionalProperties ~= nil then + for k, v in pairs(data) do + if schema.properties == nil or schema.properties[k] == nil then + if not validate(schema.additionalProperties, v, root) then + return false + end + end + end + end + + if schema.uniqueItems then + + local values = {} + for _, v in pairs(data) do + if values[v] then + return false + end + values[v] = true + end + end + + local p = 0 + if schema.prefixItems ~= nil then + for i, s in ipairs(schema.prefixItems) do + if data[i] == nil then + break + elseif validate(s, data[i], root) then + p = i + else + return false + end + end + end + + if schema.items ~= nil then + for i = p + 1, #data do + if not validate(schema.items, data[i], root) then + return false + end + end + end + + if schema.contains ~= nil then + local found = false + for i = 1, #data do + if validate(schema.contains, data[i], root) then + found = true + break + end + end + if not found then + return false + end + end + end + + return true +end + +json_schema_object.validate = validate; + +return json_schema_object diff --git a/util/jwt.lua b/util/jwt.lua new file mode 100644 index 00000000..bf106dfa --- /dev/null +++ b/util/jwt.lua @@ -0,0 +1,51 @@ +local s_gsub = string.gsub; +local json = require "util.json"; +local hashes = require "util.hashes"; +local base64_encode = require "util.encodings".base64.encode; +local base64_decode = require "util.encodings".base64.decode; +local secure_equals = require "util.hashes".equals; + +local b64url_rep = { ["+"] = "-", ["/"] = "_", ["="] = "", ["-"] = "+", ["_"] = "/" }; +local function b64url(data) + return (s_gsub(base64_encode(data), "[+/=]", b64url_rep)); +end +local function unb64url(data) + return base64_decode(s_gsub(data, "[-_]", b64url_rep).."=="); +end + +local static_header = b64url('{"alg":"HS256","typ":"JWT"}') .. '.'; + +local function sign(key, payload) + local encoded_payload = json.encode(payload); + local signed = static_header .. b64url(encoded_payload); + local signature = hashes.hmac_sha256(key, signed); + return signed .. "." .. b64url(signature); +end + +local jwt_pattern = "^(([A-Za-z0-9-_]+)%.([A-Za-z0-9-_]+))%.([A-Za-z0-9-_]+)$" +local function verify(key, blob) + local signed, bheader, bpayload, signature = string.match(blob, jwt_pattern); + if not signed then + return nil, "invalid-encoding"; + end + local header = json.decode(unb64url(bheader)); + if not header or type(header) ~= "table" then + return nil, "invalid-header"; + elseif header.alg ~= "HS256" then + return nil, "unsupported-algorithm"; + end + if not secure_equals(b64url(hashes.hmac_sha256(key, signed)), signature) then + return false, "signature-mismatch"; + end + local payload, err = json.decode(unb64url(bpayload)); + if err ~= nil then + return nil, "json-decode-error"; + end + return true, payload; +end + +return { + sign = sign; + verify = verify; +}; + diff --git a/util/mercurial.lua b/util/mercurial.lua index 3f75c4c1..0f2b1d04 100644 --- a/util/mercurial.lua +++ b/util/mercurial.lua @@ -19,7 +19,7 @@ function hg.check_id(path) hg_changelog:close(); end else - local hg_archival,e = io.open(path.."/.hg_archival.txt"); + local hg_archival,e = io.open(path.."/.hg_archival.txt"); -- luacheck: ignore 211/e if hg_archival then local repo = hg_archival:read("*l"); local node = hg_archival:read("*l"); diff --git a/util/multitable.lua b/util/multitable.lua index 8d32ed8a..4f2cd972 100644 --- a/util/multitable.lua +++ b/util/multitable.lua @@ -9,7 +9,7 @@ local select = select; local t_insert = table.insert; local pairs, next, type = pairs, next, type; -local unpack = table.unpack or unpack; --luacheck: ignore 113 143 +local unpack = table.unpack or unpack; --luacheck: ignore 113 local _ENV = nil; -- luacheck: std none diff --git a/util/openmetrics.lua b/util/openmetrics.lua new file mode 100644 index 00000000..c18e63e9 --- /dev/null +++ b/util/openmetrics.lua @@ -0,0 +1,388 @@ +--[[ +This module implements a subset of the OpenMetrics Internet Draft version 00. + +URL: https://tools.ietf.org/html/draft-richih-opsawg-openmetrics-00 + +The following metric types are supported: + +- Counter +- Gauge +- Histogram +- Summary + +It is used by util.statsd and util.statistics to provide the OpenMetrics API. + +To understand what this module is about, it is useful to familiarize oneself +with the terms MetricFamily, Metric, LabelSet, Label and MetricPoint as +defined in the I-D linked above. +--]] +-- metric constructor interface: +-- metric_ctor(..., family_name, labels, extra) + +local time = require "util.time".now; +local select = select; +local array = require "util.array"; +local log = require "util.logger".init("util.openmetrics"); +local new_multitable = require "util.multitable".new; +local iter_multitable = require "util.multitable".iter; +local t_concat, t_insert = table.concat, table.insert; +local t_pack, t_unpack = require "util.table".pack, table.unpack or unpack; --luacheck: ignore 113/unpack + +-- BEGIN of Utility: "metric proxy" +-- This allows to wrap a MetricFamily in a proxy which only provides the +-- `with_labels` and `with_partial_label` methods. This allows to pre-set one +-- or more labels on a metric family. This is used in particular via +-- `with_partial_label` by the moduleapi in order to pre-set the `host` label +-- on metrics created in non-global modules. +local metric_proxy_mt = {} +metric_proxy_mt.__index = metric_proxy_mt + +local function new_metric_proxy(metric_family, with_labels_proxy_fun) + return setmetatable({ + _family = metric_family, + with_labels = function(self, ...) + return with_labels_proxy_fun(self._family, ...) + end; + with_partial_label = function(self, label) + return new_metric_proxy(self._family, function(family, ...) + return family:with_labels(label, ...) + end) + end + }, metric_proxy_mt); +end + +-- END of Utility: "metric proxy" + +-- BEGIN Rendering helper functions (internal) + +local function escape(text) + return text:gsub("\\", "\\\\"):gsub("\"", "\\\""):gsub("\n", "\\n"); +end + +local function escape_name(name) + return name:gsub("/", "__"):gsub("[^A-Za-z0-9_]", "_"):gsub("^[^A-Za-z_]", "_%1"); +end + +local function repr_help(metric, docstring) + docstring = docstring:gsub("\\", "\\\\"):gsub("\n", "\\n"); + return "# HELP "..escape_name(metric).." "..docstring.."\n"; +end + +local function repr_unit(metric, unit) + if not unit then + unit = "" + else + unit = unit:gsub("\\", "\\\\"):gsub("\n", "\\n"); + end + return "# UNIT "..escape_name(metric).." "..unit.."\n"; +end + +-- local allowed_types = { counter = true, gauge = true, histogram = true, summary = true, untyped = true }; +-- local allowed_types = { "counter", "gauge", "histogram", "summary", "untyped" }; +local function repr_type(metric, type_) + -- if not allowed_types:contains(type_) then + -- return; + -- end + return "# TYPE "..escape_name(metric).." "..type_.."\n"; +end + +local function repr_label(key, value) + return key.."=\""..escape(value).."\""; +end + +local function repr_labels(labelkeys, labelvalues, extra_labels) + local values = {} + if labelkeys then + for i, key in ipairs(labelkeys) do + local value = labelvalues[i] + t_insert(values, repr_label(escape_name(key), escape(value))); + end + end + if extra_labels then + for key, value in pairs(extra_labels) do + t_insert(values, repr_label(escape_name(key), escape(value))); + end + end + if #values == 0 then + return ""; + end + return "{"..t_concat(values, ",").."}"; +end + +local function repr_sample(metric, labelkeys, labelvalues, extra_labels, value) + return escape_name(metric)..repr_labels(labelkeys, labelvalues, extra_labels).." "..string.format("%.17g", value).."\n"; +end + +-- END Rendering helper functions (internal) + +local function render_histogram_le(v) + if v == 1/0 then + -- I-D-00: 4.1.2.2.1: + -- Exposers MUST produce output for positive infinity as +Inf. + return "+Inf" + end + + return string.format("%.14g", v) +end + +-- BEGIN of generic MetricFamily implementation + +local metric_family_mt = {} +metric_family_mt.__index = metric_family_mt + +local function histogram_metric_ctor(orig_ctor, buckets) + return function(family_name, labels, extra) + return orig_ctor(buckets, family_name, labels, extra) + end +end + +local function new_metric_family(backend, type_, family_name, unit, description, label_keys, extra) + local metric_ctor = assert(backend[type_], "statistics backend does not support "..type_.." metrics families") + local labels = label_keys or {} + local user_labels = #labels + if type_ == "histogram" then + local buckets = extra and extra.buckets + if not buckets then + error("no buckets given for histogram metric") + end + buckets = array(buckets) + buckets:push(1/0) -- must have +inf bucket + + metric_ctor = histogram_metric_ctor(metric_ctor, buckets) + end + + local data + if #labels == 0 then + data = metric_ctor(family_name, nil, extra) + else + data = new_multitable() + end + + local mf = { + family_name = family_name, + data = data, + type_ = type_, + unit = unit, + description = description, + user_labels = user_labels, + label_keys = labels, + extra = extra, + _metric_ctor = metric_ctor, + } + setmetatable(mf, metric_family_mt); + return mf +end + +function metric_family_mt:new_metric(labels) + return self._metric_ctor(self.family_name, labels, self.extra) +end + +function metric_family_mt:clear() + for _, metric in self:iter_metrics() do + metric:reset() + end +end + +function metric_family_mt:with_labels(...) + local count = select('#', ...) + if count ~= self.user_labels then + error("number of labels passed to with_labels does not match number of label keys") + end + if count == 0 then + return self.data + end + local metric = self.data:get(...) + if not metric then + local values = t_pack(...) + metric = self:new_metric(values) + values[values.n+1] = metric + self.data:set(t_unpack(values, 1, values.n+1)) + end + return metric +end + +function metric_family_mt:with_partial_label(label) + return new_metric_proxy(self, function (family, ...) + return family:with_labels(label, ...) + end) +end + +function metric_family_mt:iter_metrics() + if #self.label_keys == 0 then + local done = false + return function() + if done then + return nil + end + done = true + return {}, self.data + end + end + local searchkeys = {}; + local nlabels = #self.label_keys + for i=1,nlabels do + searchkeys[i] = nil; + end + local it, state = iter_multitable(self.data, t_unpack(searchkeys, 1, nlabels)) + return function(_s) + local label_values = t_pack(it(_s)) + if label_values.n == 0 then + return nil, nil + end + local metric = label_values[label_values.n] + label_values[label_values.n] = nil + label_values.n = label_values.n - 1 + return label_values, metric + end, state +end + +-- END of generic MetricFamily implementation + +-- BEGIN of MetricRegistry implementation + + +-- Helper to test whether two metrics are "equal". +local function equal_metric_family(mf1, mf2) + if mf1.type_ ~= mf2.type_ then + return false + end + if #mf1.label_keys ~= #mf2.label_keys then + return false + end + -- Ignoring unit here because in general it'll be part of the name anyway + -- So either the unit was moved into/out of the name (which is a valid) + -- thing to do on an upgrade or we would expect not to see any conflicts + -- anyway. + --[[ + if mf1.unit ~= mf2.unit then + return false + end + ]] + for i, key in ipairs(mf1.label_keys) do + if key ~= mf2.label_keys[i] then + return false + end + end + return true +end + +-- If the unit is not empty, add it to the full name as per the I-D spec. +local function compose_name(name, unit) + local full_name = name + if unit and unit ~= "" then + full_name = full_name .. "_" .. unit + end + -- TODO: prohibit certain suffixes used by metrics if where they may cause + -- conflicts + return full_name +end + +local metric_registry_mt = {} +metric_registry_mt.__index = metric_registry_mt + +local function new_metric_registry(backend) + local reg = { + families = {}, + backend = backend, + } + setmetatable(reg, metric_registry_mt) + return reg +end + +function metric_registry_mt:register_metric_family(name, metric_family) + local existing = self.families[name]; + if existing then + if not equal_metric_family(metric_family, existing) then + -- We could either be strict about this, or replace the + -- existing metric family with the new one. + -- Being strict is nice to avoid programming errors / + -- conflicts, but causes issues when a new version of a module + -- is loaded. + -- + -- We will thus assume that the new metric is the correct one; + -- That is probably OK because unless you're reaching down into + -- the util.openmetrics or core.statsmanager API, your metric + -- name is going to be scoped to `prosody_mod_$modulename` + -- anyway and the damage is thus controlled. + -- + -- To make debugging such issues easier, we still log. + log("debug", "replacing incompatible existing metric family %s", name) + -- Below is the code to be strict. + --error("conflicting declarations for metric family "..name) + else + return existing + end + end + self.families[name] = metric_family + return metric_family +end + +function metric_registry_mt:gauge(name, unit, description, labels, extra) + name = compose_name(name, unit) + local mf = new_metric_family(self.backend, "gauge", name, unit, description, labels, extra) + mf = self:register_metric_family(name, mf) + return mf +end + +function metric_registry_mt:counter(name, unit, description, labels, extra) + name = compose_name(name, unit) + local mf = new_metric_family(self.backend, "counter", name, unit, description, labels, extra) + mf = self:register_metric_family(name, mf) + return mf +end + +function metric_registry_mt:histogram(name, unit, description, labels, extra) + name = compose_name(name, unit) + local mf = new_metric_family(self.backend, "histogram", name, unit, description, labels, extra) + mf = self:register_metric_family(name, mf) + return mf +end + +function metric_registry_mt:summary(name, unit, description, labels, extra) + name = compose_name(name, unit) + local mf = new_metric_family(self.backend, "summary", name, unit, description, labels, extra) + mf = self:register_metric_family(name, mf) + return mf +end + +function metric_registry_mt:get_metric_families() + return self.families +end + +function metric_registry_mt:render() + local answer = {}; + for metric_family_name, metric_family in pairs(self:get_metric_families()) do + t_insert(answer, repr_help(metric_family_name, metric_family.description)) + t_insert(answer, repr_unit(metric_family_name, metric_family.unit)) + t_insert(answer, repr_type(metric_family_name, metric_family.type_)) + for labelset, metric in metric_family:iter_metrics() do + for suffix, extra_labels, value in metric:iter_samples() do + t_insert(answer, repr_sample(metric_family_name..suffix, metric_family.label_keys, labelset, extra_labels, value)) + end + end + end + t_insert(answer, "# EOF\n") + return t_concat(answer, ""); +end + +-- END of MetricRegistry implementation + +-- BEGIN of general helpers for implementing high-level APIs on top of OpenMetrics + +local function timed(metric) + local t0 = time() + local submitter = assert(metric.sample or metric.set, "metric type cannot be used with timed()") + return function() + local t1 = time() + submitter(metric, t1-t0) + end +end + +-- END of general helpers + +return { + new_metric_proxy = new_metric_proxy; + new_metric_registry = new_metric_registry; + render_histogram_le = render_histogram_le; + timed = timed; +} diff --git a/util/paths.lua b/util/paths.lua index 89f4cad9..b75c35e5 100644 --- a/util/paths.lua +++ b/util/paths.lua @@ -37,8 +37,34 @@ function path_util.glob_to_pattern(glob) end).."$"; end -function path_util.join(...) - return t_concat({...}, path_sep); +function path_util.join(a, b, c, ...) -- (... : string) --> string + -- Optimization: Avoid creating table for most uses + if b then + if c then + if ... then + return t_concat({a,b,c,...}, path_sep); + end + return a..path_sep..b..path_sep..c; + end + return a..path_sep..b; + end + return a; +end + +function path_util.complement_lua_path(installer_plugin_path) + -- Checking for duplicates + -- The commands using luarocks need the path to the directory that has the /share and /lib folders. + local lua_version = _VERSION:match(" (.+)$"); + local lua_path_sep = package.config:sub(3,3); + local dir_sep = package.config:sub(1,1); + local sub_path = dir_sep.."lua"..dir_sep..lua_version..dir_sep; + if not string.find(package.path, installer_plugin_path, 1, true) then + package.path = package.path..lua_path_sep..installer_plugin_path..dir_sep.."share"..sub_path.."?.lua"; + package.path = package.path..lua_path_sep..installer_plugin_path..dir_sep.."share"..sub_path.."?"..dir_sep.."init.lua"; + end + if not string.find(package.path, installer_plugin_path, 1, true) then + package.cpath = package.cpath..lua_path_sep..installer_plugin_path..dir_sep.."lib"..sub_path.."?.so"; + end end return path_util; diff --git a/util/pluginloader.lua b/util/pluginloader.lua index 9ab8f245..f2ccb4cb 100644 --- a/util/pluginloader.lua +++ b/util/pluginloader.lua @@ -8,18 +8,23 @@ -- luacheck: ignore 113/CFG_PLUGINDIR local dir_sep, path_sep = package.config:match("^(%S+)%s(%S+)"); +local lua_version = _VERSION:match(" (.+)$"); local plugin_dir = {}; for path in (CFG_PLUGINDIR or "./plugins/"):gsub("[/\\]", dir_sep):gmatch("[^"..path_sep.."]+") do path = path..dir_sep; -- add path separator to path end - path = path:gsub(dir_sep..dir_sep.."+", dir_sep); -- coalesce multiple separaters + path = path:gsub(dir_sep..dir_sep.."+", dir_sep); -- coalesce multiple separators plugin_dir[#plugin_dir + 1] = path; end local io_open = io.open; local envload = require "util.envload".envload; -local function load_file(names) +local pluginloader_methods = {}; +local pluginloader_mt = { __index = pluginloader_methods }; + +function pluginloader_methods:load_file(names) local file, err, path; + local load_filter_cb = self._options.load_filter_cb; for i=1,#plugin_dir do for j=1,#names do path = plugin_dir[i]..names[j]; @@ -27,39 +32,49 @@ local function load_file(names) if file then local content = file:read("*a"); file:close(); - return content, path; + local metadata; + if load_filter_cb then + path, content, metadata = load_filter_cb(path, content); + end + if content and path then + return content, path, metadata; + end end end end return file, err; end -local function load_resource(plugin, resource) +function pluginloader_methods:load_resource(plugin, resource) resource = resource or "mod_"..plugin..".lua"; - local names = { "mod_"..plugin..dir_sep..plugin..dir_sep..resource; -- mod_hello/hello/mod_hello.lua "mod_"..plugin..dir_sep..resource; -- mod_hello/mod_hello.lua plugin..dir_sep..resource; -- hello/mod_hello.lua resource; -- mod_hello.lua + "share"..dir_sep.."lua"..dir_sep..lua_version..dir_sep..resource; + "share"..dir_sep.."lua"..dir_sep..lua_version..dir_sep.."mod_"..plugin..dir_sep..resource; }; - return load_file(names); + return self:load_file(names); end -local function load_code(plugin, resource, env) - local content, err = load_resource(plugin, resource); +function pluginloader_methods:load_code(plugin, resource, env) + local content, err, metadata = self:load_resource(plugin, resource); if not content then return content, err; end local path = err; local f, err = envload(content, "@"..path, env); if not f then return f, err; end - return f, path; + return f, path, metadata; end -local function load_code_ext(plugin, resource, extension, env) - local content, err = load_resource(plugin, resource.."."..extension); +function pluginloader_methods:load_code_ext(plugin, resource, extension, env) + local content, err, metadata = self:load_resource(plugin, resource.."."..extension); + if not content and extension == "lib.lua" then + content, err, metadata = self:load_resource(plugin, resource..".lua"); + end if not content then - content, err = load_resource(resource, resource.."."..extension); + content, err, metadata = self:load_resource(resource, resource.."."..extension); if not content then return content, err; end @@ -67,12 +82,28 @@ local function load_code_ext(plugin, resource, extension, env) local path = err; local f, err = envload(content, "@"..path, env); if not f then return f, err; end - return f, path; + return f, path, metadata; +end + +local function init(options) + return setmetatable({ + _options = options or {}; + }, pluginloader_mt); end +local function bind(self, method) + return function (...) + return method(self, ...); + end; +end + +local default_loader = init(); + return { - load_file = load_file; - load_resource = load_resource; - load_code = load_code; - load_code_ext = load_code_ext; + load_file = bind(default_loader, default_loader.load_file); + load_resource = bind(default_loader, default_loader.load_resource); + load_code = bind(default_loader, default_loader.load_code); + load_code_ext = bind(default_loader, default_loader.load_code_ext); + + init = init; }; diff --git a/util/promise.lua b/util/promise.lua index 75c8697b..c4e166ed 100644 --- a/util/promise.lua +++ b/util/promise.lua @@ -2,6 +2,7 @@ local promise_methods = {}; local promise_mt = { __name = "promise", __index = promise_methods }; local xpcall = require "util.xpcall".xpcall; +local unpack = table.unpack or unpack; --luacheck: ignore 113 function promise_mt:__tostring() return "promise (" .. (self._state or "invalid") .. ")"; @@ -49,6 +50,9 @@ local function promise_settle(promise, new_state, new_next, cbs, value) for _, cb in ipairs(cbs) do cb(value); end + -- No need to keep references to callbacks + promise._pending_on_fulfilled = nil; + promise._pending_on_rejected = nil; return true; end @@ -74,30 +78,84 @@ local function new_resolve_functions(p) return _resolve, _reject; end +local next_tick = function (f) + f(); +end + local function new(f) local p = setmetatable({ _state = "pending", _next = next_pending, _pending_on_fulfilled = {}, _pending_on_rejected = {} }, promise_mt); if f then - local resolve, reject = new_resolve_functions(p); - local ok, ret = xpcall(f, debug.traceback, resolve, reject); - if not ok and p._state == "pending" then - reject(ret); - end + next_tick(function() + local resolve, reject = new_resolve_functions(p); + local ok, ret = xpcall(f, debug.traceback, resolve, reject); + if not ok and p._state == "pending" then + reject(ret); + end + end); end return p; end local function all(promises) return new(function (resolve, reject) - local count, total, results = 0, #promises, {}; - for i = 1, total do - promises[i]:next(function (v) - results[i] = v; - count = count + 1; - if count == total then - resolve(results); - end - end, reject); + local settled, results, loop_finished = 0, {}, false; + local total = 0; + for k, v in pairs(promises) do + if is_promise(v) then + total = total + 1; + v:next(function (value) + results[k] = value; + settled = settled + 1; + if settled == total and loop_finished then + resolve(results); + end + end, reject); + else + results[k] = v; + end end + loop_finished = true; + if settled == total then + resolve(results); + end + end); +end + +local function all_settled(promises) + return new(function (resolve) + local settled, results, loop_finished = 0, {}, false; + local total = 0; + for k, v in pairs(promises) do + if is_promise(v) then + total = total + 1; + v:next(function (value) + results[k] = { status = "fulfilled", value = value }; + settled = settled + 1; + if settled == total and loop_finished then + resolve(results); + end + end, function (e) + results[k] = { status = "rejected", reason = e }; + settled = settled + 1; + if settled == total and loop_finished then + resolve(results); + end + end); + else + results[k] = v; + end + end + loop_finished = true; + if settled == total then + resolve(results); + end + end); +end + +local function join(handler, ...) + local promises, n = { ... }, select("#", ...); + return all(promises):next(function (results) + return handler(unpack(results, 1, n)); end); end @@ -144,9 +202,12 @@ end return { new = new; resolve = resolve; + join = join; reject = reject; all = all; + all_settled = all_settled; race = race; try = try; is_promise = is_promise; + set_nexttick = function(new_next_tick) next_tick = new_next_tick; end; } diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua index 6c84ab6e..4d49cd16 100644 --- a/util/prosodyctl.lua +++ b/util/prosodyctl.lua @@ -12,10 +12,10 @@ local encodings = require "util.encodings"; local stringprep = encodings.stringprep; local storagemanager = require "core.storagemanager"; local usermanager = require "core.usermanager"; +local interpolation = require "util.interpolation"; local signal = require "util.signal"; local set = require "util.set"; local lfs = require "lfs"; -local pcall = pcall; local type = type; local nodeprep, nameprep = stringprep.nodeprep, stringprep.nameprep; @@ -27,10 +27,22 @@ local tonumber = tonumber; local _G = _G; local prosody = prosody; +local error_messages = setmetatable({ + ["invalid-username"] = "The given username is invalid in a Jabber ID"; + ["invalid-hostname"] = "The given hostname is invalid"; + ["no-password"] = "No password was supplied"; + ["no-such-user"] = "The given user does not exist on the server"; + ["no-such-host"] = "The given hostname does not exist in the config"; + ["unable-to-save-data"] = "Unable to store, perhaps you don't have permission?"; + ["no-pidfile"] = "There is no 'pidfile' option in the configuration file, see https://prosody.im/doc/prosodyctl#pidfile for help"; + ["invalid-pidfile"] = "The 'pidfile' option in the configuration file is not a string, see https://prosody.im/doc/prosodyctl#pidfile for help"; + ["no-posix"] = "The mod_posix module is not enabled in the Prosody config file, see https://prosody.im/doc/prosodyctl for more info"; + ["no-such-method"] = "This module has no commands"; + ["not-running"] = "Prosody is not running"; + }, { __index = function (_,k) return "Error: "..(tostring(k):gsub("%-", " "):gsub("^.", string.upper)); end }); + -- UI helpers -local function show_message(msg, ...) - print(msg:format(...)); -end +local show_message = require "util.human.io".printf; local function show_usage(usage, desc) print("Usage: ".._G.arg[0].." "..usage); @@ -39,92 +51,19 @@ local function show_usage(usage, desc) end end -local function getchar(n) - local stty_ret = os.execute("stty raw -echo 2>/dev/null"); - local ok, char; - if stty_ret == true or stty_ret == 0 then - ok, char = pcall(io.read, n or 1); - os.execute("stty sane"); - else - ok, char = pcall(io.read, "*l"); - if ok then - char = char:sub(1, n or 1); - end - end - if ok then - return char; - end -end - -local function getline() - local ok, line = pcall(io.read, "*l"); - if ok then - return line; - end -end - -local function getpass() - local stty_ret, _, status_code = os.execute("stty -echo 2>/dev/null"); - if status_code then -- COMPAT w/ Lua 5.1 - stty_ret = status_code; - end - if stty_ret ~= 0 then - io.write("\027[08m"); -- ANSI 'hidden' text attribute - end - local ok, pass = pcall(io.read, "*l"); - if stty_ret == 0 then - os.execute("stty sane"); - else - io.write("\027[00m"); - end - io.write("\n"); - if ok then - return pass; - end -end - -local function show_yesno(prompt) - io.write(prompt, " "); - local choice = getchar():lower(); - io.write("\n"); - if not choice:match("%a") then - choice = prompt:match("%[.-(%U).-%]$"); - if not choice then return nil; end - end - return (choice == "y"); -end - -local function read_password() - local password; - while true do - io.write("Enter new password: "); - password = getpass(); - if not password then - show_message("No password - cancelled"); - return; - end - io.write("Retype new password: "); - if getpass() ~= password then - if not show_yesno [=[Passwords did not match, try again? [Y/n]]=] then - return; - end - else - break; - end - end - return password; -end - -local function show_prompt(prompt) - io.write(prompt, " "); - local line = getline(); - line = line and line:gsub("\n$",""); - return (line and #line > 0) and line or nil; +local function show_module_configuration_help(mod_name) + print("Done.") + print("If you installed a prosody plugin, don't forget to add its name under the 'modules_enabled' section inside your configuration file.") + print("Depending on the module, there might be further configuration steps required.") + print("") + print("More info about: ") + print(" modules_enabled: https://prosody.im/doc/modules_enabled") + print(" "..mod_name..": https://modules.prosody.im/"..mod_name..".html") end -- Server control local function adduser(params) - local user, host, password = nodeprep(params.user), nameprep(params.host), params.password; + local user, host, password = nodeprep(params.user, true), nameprep(params.host), params.password; if not user then return false, "invalid-username"; elseif not host then @@ -200,7 +139,7 @@ local function getpid() return false, "pidfile-read-failed", err; end - local locked, err = lfs.lock(file, "w"); + local locked, err = lfs.lock(file, "w"); -- luacheck: ignore 211/err if locked then file:close(); return false, "pidfile-not-locked"; @@ -217,7 +156,7 @@ local function getpid() end local function isrunning() - local ok, pid, err = getpid(); + local ok, pid, err = getpid(); -- luacheck: ignore 211/err if not ok then if pid == "pidfile-read-failed" or pid == "pidfile-not-locked" then -- Report as not running, since we can't open the pidfile @@ -229,7 +168,8 @@ local function isrunning() return true, signal.kill(pid, 0) == 0; end -local function start(source_dir) +local function start(source_dir, lua) + lua = lua and lua .. " " or ""; local ok, ret = isrunning(); if not ok then return ok, ret; @@ -238,9 +178,9 @@ local function start(source_dir) return false, "already-running"; end if not source_dir then - os.execute("./prosody -D"); + os.execute(lua .. "./prosody -D"); else - os.execute(source_dir.."/../../bin/prosody -D"); + os.execute(lua .. source_dir.."/../../bin/prosody -D"); end return true; end @@ -277,16 +217,22 @@ local function reload() return true; end +local render_cli = interpolation.new("%b{}", function (s) return "'"..s:gsub("'","'\\''").."'" end) + +local function call_luarocks(operation, mod, server) + local dir = prosody.paths.installer; + local ok, _, code = os.execute(render_cli("luarocks --lua-version={luav} {op} --tree={dir} {server&--server={server}} {mod?}", { + dir = dir; op = operation; mod = mod; server = server; luav = _VERSION:match("5%.%d"); + })); + if type(ok) == "number" then code = ok; end + return code; +end + return { show_message = show_message; show_warning = show_message; show_usage = show_usage; - getchar = getchar; - getline = getline; - getpass = getpass; - show_yesno = show_yesno; - read_password = read_password; - show_prompt = show_prompt; + show_module_configuration_help = show_module_configuration_help; adduser = adduser; user_exists = user_exists; passwd = passwd; @@ -296,4 +242,6 @@ return { start = start; stop = stop; reload = reload; + call_luarocks = call_luarocks; + error_messages = error_messages; }; diff --git a/util/prosodyctl/cert.lua b/util/prosodyctl/cert.lua new file mode 100644 index 00000000..02c81585 --- /dev/null +++ b/util/prosodyctl/cert.lua @@ -0,0 +1,318 @@ +local lfs = require "lfs"; + +local pctl = require "util.prosodyctl"; +local hi = require "util.human.io"; +local configmanager = require "core.configmanager"; + +local openssl; + +local cert_commands = {}; + +-- If a file already exists, ask if the user wants to use it or replace it +-- Backups the old file if replaced +local function use_existing(filename) + local attrs = lfs.attributes(filename); + if attrs then + if hi.show_yesno(filename .. " exists, do you want to replace it? [y/n]") then + local backup = filename..".bkp~"..os.date("%FT%T", attrs.change); + os.rename(filename, backup); + pctl.show_message("%s backed up to %s", filename, backup); + else + -- Use the existing file + return true; + end + end +end + +local have_pposix, pposix = pcall(require, "util.pposix"); +local cert_basedir = prosody.paths.data == "." and "./certs" or prosody.paths.data; +if have_pposix and pposix.getuid() == 0 then + -- FIXME should be enough to check if this directory is writable + local cert_dir = configmanager.get("*", "certificates") or "certs"; + cert_basedir = configmanager.resolve_relative_path(prosody.paths.config, cert_dir); +end + +function cert_commands.config(arg) + if #arg >= 1 and arg[1] ~= "--help" then + local conf_filename = cert_basedir .. "/" .. arg[1] .. ".cnf"; + if use_existing(conf_filename) then + return nil, conf_filename; + end + local distinguished_name; + if arg[#arg]:find("^/") then + distinguished_name = table.remove(arg); + end + local conf = openssl.config.new(); + conf:from_prosody(prosody.hosts, configmanager, arg); + if distinguished_name then + local dn = {}; + for k, v in distinguished_name:gmatch("/([^=/]+)=([^/]+)") do + table.insert(dn, k); + dn[k] = v; + end + conf.distinguished_name = dn; + else + pctl.show_message("Please provide details to include in the certificate config file."); + pctl.show_message("Leave the field empty to use the default value or '.' to exclude the field.") + for _, k in ipairs(openssl._DN_order) do + local v = conf.distinguished_name[k]; + if v then + local nv = nil; + if k == "commonName" then + v = arg[1] + elseif k == "emailAddress" then + v = "xmpp@" .. arg[1]; + elseif k == "countryName" then + local tld = arg[1]:match"%.([a-z]+)$"; + if tld and #tld == 2 and tld ~= "uk" then + v = tld:upper(); + end + end + nv = hi.show_prompt(("%s (%s):"):format(k, nv or v)); + nv = (not nv or nv == "") and v or nv; + if nv:find"[\192-\252][\128-\191]+" then + conf.req.string_mask = "utf8only" + end + conf.distinguished_name[k] = nv ~= "." and nv or nil; + end + end + end + local conf_file, err = io.open(conf_filename, "w"); + if not conf_file then + pctl.show_warning("Could not open OpenSSL config file for writing"); + pctl.show_warning("%s", err); + os.exit(1); + end + conf_file:write(conf:serialize()); + conf_file:close(); + print(""); + pctl.show_message("Config written to %s", conf_filename); + return nil, conf_filename; + else + pctl.show_usage("cert config HOSTNAME [HOSTNAME+]", "Builds a certificate config file covering the supplied hostname(s)") + end +end + +function cert_commands.key(arg) + if #arg >= 1 and arg[1] ~= "--help" then + local key_filename = cert_basedir .. "/" .. arg[1] .. ".key"; + if use_existing(key_filename) then + return nil, key_filename; + end + os.remove(key_filename); -- This file, if it exists is unlikely to have write permissions + local key_size = tonumber(arg[2] or hi.show_prompt("Choose key size (2048):") or 2048); + local old_umask = pposix.umask("0377"); + if openssl.genrsa{out=key_filename, key_size} then + os.execute(("chmod 400 '%s'"):format(key_filename)); + pctl.show_message("Key written to %s", key_filename); + pposix.umask(old_umask); + return nil, key_filename; + end + pctl.show_message("There was a problem, see OpenSSL output"); + else + pctl.show_usage("cert key HOSTNAME <bits>", "Generates a RSA key named HOSTNAME.key\n " + .."Prompts for a key size if none given") + end +end + +function cert_commands.request(arg) + if #arg >= 1 and arg[1] ~= "--help" then + local req_filename = cert_basedir .. "/" .. arg[1] .. ".req"; + if use_existing(req_filename) then + return nil, req_filename; + end + local _, key_filename = cert_commands.key({arg[1]}); + local _, conf_filename = cert_commands.config(arg); + if openssl.req{new=true, key=key_filename, utf8=true, sha256=true, config=conf_filename, out=req_filename} then + pctl.show_message("Certificate request written to %s", req_filename); + else + pctl.show_message("There was a problem, see OpenSSL output"); + end + else + pctl.show_usage("cert request HOSTNAME [HOSTNAME+]", "Generates a certificate request for the supplied hostname(s)") + end +end + +function cert_commands.generate(arg) + if #arg >= 1 and arg[1] ~= "--help" then + local cert_filename = cert_basedir .. "/" .. arg[1] .. ".crt"; + if use_existing(cert_filename) then + return nil, cert_filename; + end + local _, key_filename = cert_commands.key({arg[1]}); + local _, conf_filename = cert_commands.config(arg); + if key_filename and conf_filename and cert_filename + and openssl.req{new=true, x509=true, nodes=true, key=key_filename, + days=365, sha256=true, utf8=true, config=conf_filename, out=cert_filename} then + pctl.show_message("Certificate written to %s", cert_filename); + print(); + else + pctl.show_message("There was a problem, see OpenSSL output"); + end + else + pctl.show_usage("cert generate HOSTNAME [HOSTNAME+]", "Generates a self-signed certificate for the current hostname(s)") + end +end + +local function sh_esc(s) + return "'" .. s:gsub("'", "'\\''") .. "'"; +end + +local function copy(from, to, umask, owner, group) + local old_umask = umask and pposix.umask(umask); + local attrs = lfs.attributes(to); + if attrs then -- Move old file out of the way + local backup = to..".bkp~"..os.date("%FT%T", attrs.change); + os.rename(to, backup); + end + -- FIXME friendlier error handling, maybe move above backup back? + local input = assert(io.open(from)); + local output = assert(io.open(to, "w")); + local data = input:read(2^11); + while data and output:write(data) do + data = input:read(2^11); + end + assert(input:close()); + assert(output:close()); + if not prosody.installed then + -- FIXME this is possibly specific to GNU chown + os.execute(("chown -c --reference=%s %s"):format(sh_esc(cert_basedir), sh_esc(to))); + elseif owner and group then + local ok = os.execute(("chown %s:%s %s"):format(sh_esc(owner), sh_esc(group), sh_esc(to))); + assert(ok == true or ok == 0, "Failed to change ownership of "..to); + end + if old_umask then pposix.umask(old_umask); end + return true; +end + +function cert_commands.import(arg) + local hostnames = {}; + -- Move hostname arguments out of arg, the rest should be a list of paths + while arg[1] and prosody.hosts[ arg[1] ] do + table.insert(hostnames, table.remove(arg, 1)); + end + if hostnames[1] == nil then + local domains = os.getenv"RENEWED_DOMAINS"; -- Set if invoked via certbot + if domains then + for host in domains:gmatch("%S+") do + table.insert(hostnames, host); + end + else + for host in pairs(prosody.hosts) do + if host ~= "*" and configmanager.get(host, "enabled") ~= false then + table.insert(hostnames, host); + local http_host = configmanager.get(host, "http_host") or host; + if http_host ~= host then + table.insert(hostnames, http_host); + end + end + end + end + end + if not arg[1] or arg[1] == "--help" then -- Probably forgot the path + pctl.show_usage("cert import [HOSTNAME+] /path/to/certs [/other/paths/]+", + "Copies certificates to "..cert_basedir); + return 1; + end + local owner, group; + if pposix.getuid() == 0 then -- We need root to change ownership + owner = configmanager.get("*", "prosody_user") or "prosody"; + group = configmanager.get("*", "prosody_group") or owner; + end + local cm = require "core.certmanager"; + local files_by_name = {} + for _, dir in ipairs(arg) do + cm.index_certs(dir, files_by_name); + end + local imported = {}; + table.sort(hostnames, function (a, b) + -- Try to find base domain name before sub-domains, then alphabetically, so + -- that the order and choice of file name is deterministic. + if #a == #b then + return a < b; + else + return #a < #b; + end + end); + for _, host in ipairs(hostnames) do + local paths = cm.find_cert_in_index(files_by_name, host); + if paths and imported[paths.certificate] then + -- One certificate, many names! + table.insert(imported, host); + elseif paths then + local c = copy(paths.certificate, cert_basedir .. "/" .. host .. ".crt", nil, owner, group); + local k = copy(paths.key, cert_basedir .. "/" .. host .. ".key", "0377", owner, group); + if c and k then + table.insert(imported, host); + imported[paths.certificate] = true; + else + if not c then pctl.show_warning("Could not copy certificate '%s'", paths.certificate); end + if not k then pctl.show_warning("Could not copy key '%s'", paths.key); end + end + else + -- TODO Say where we looked + pctl.show_warning("No certificate for host %s found :(", host); + end + -- TODO Additional checks + -- Certificate names matches the hostname + -- Private key matches public key in certificate + end + if imported[1] then + pctl.show_message("Imported certificate and key for hosts %s", table.concat(imported, ", ")); + local ok, err = pctl.reload(); + if not ok and err ~= "not-running" then + pctl.show_message(pctl.error_messages[err]); + end + else + pctl.show_warning("No certificates imported :("); + return 1; + end +end + +local function cert(arg) + if #arg >= 1 and arg[1] ~= "--help" then + openssl = require "util.openssl"; + lfs = require "lfs"; + local cert_dir_attrs = lfs.attributes(cert_basedir); + if not cert_dir_attrs then + pctl.show_warning("The directory %s does not exist", cert_basedir); + return 1; -- TODO Should we create it? + end + local uid = pposix.getuid(); + if uid ~= 0 and uid ~= cert_dir_attrs.uid then + pctl.show_warning("The directory %s is not owned by the current user, won't be able to write files to it", cert_basedir); + return 1; + elseif not cert_dir_attrs.permissions then -- COMPAT with LuaFilesystem < 1.6.2 (hey CentOS!) + pctl.show_message("Unable to check permissions on %s (LuaFilesystem 1.6.2+ required)", cert_basedir); + pctl.show_message("Please confirm that Prosody (and only Prosody) can write to this directory)"); + elseif cert_dir_attrs.permissions:match("^%.w..%-..%-.$") then + pctl.show_warning("The directory %s not only writable by its owner", cert_basedir); + return 1; + end + local subcmd = table.remove(arg, 1); + if type(cert_commands[subcmd]) == "function" then + if subcmd ~= "import" then -- hostnames are optional for import + if not arg[1] then + pctl.show_message"You need to supply at least one hostname" + arg = { "--help" }; + end + if arg[1] ~= "--help" and not prosody.hosts[arg[1]] then + pctl.show_message(pctl.error_messages["no-such-host"]); + return 1; + end + end + return cert_commands[subcmd](arg); + elseif subcmd == "check" then + return require "util.prosodyctl.check".check({"certs"}); + end + end + pctl.show_usage("cert config|request|generate|key|import", "Helpers for generating X.509 certificates and keys.") + for _, cmd in pairs(cert_commands) do + print() + cmd{ "--help" } + end +end + +return { + cert = cert; +}; diff --git a/util/prosodyctl/check.lua b/util/prosodyctl/check.lua new file mode 100644 index 00000000..42d73f29 --- /dev/null +++ b/util/prosodyctl/check.lua @@ -0,0 +1,1326 @@ +local configmanager = require "core.configmanager"; +local show_usage = require "util.prosodyctl".show_usage; +local show_warning = require "util.prosodyctl".show_warning; +local is_prosody_running = require "util.prosodyctl".isrunning; +local parse_args = require "util.argparse".parse; +local dependencies = require "util.dependencies"; +local socket = require "socket"; +local socket_url = require "socket.url"; +local jid_split = require "util.jid".prepped_split; +local modulemanager = require "core.modulemanager"; +local async = require "util.async"; +local httputil = require "util.http"; + +local function check_ojn(check_type, target_host) + local http = require "net.http"; -- .new({}); + local json = require "util.json"; + + local response, err = async.wait_for(http.request( + ("https://observe.jabber.network/api/v1/check/%s"):format(httputil.urlencode(check_type)), + { + method="POST", + headers={["Accept"] = "application/json"; ["Content-Type"] = "application/json"}, + body=json.encode({target=target_host}), + })); + + if not response then + return false, err; + end + + if response.code ~= 200 then + return false, ("API replied with non-200 code: %d"):format(response.code); + end + + local decoded_body, err = json.decode(response.body); + if decoded_body == nil then + return false, ("Failed to parse API JSON: %s"):format(err) + end + + local success = decoded_body["success"]; + return success == true, nil; +end + +local function check_probe(base_url, probe_module, target) + local http = require "net.http"; -- .new({}); + local params = httputil.formencode({ module = probe_module; target = target }) + local response, err = async.wait_for(http.request(base_url .. "?" .. params)); + + if not response then return false, err; end + + if response.code ~= 200 then return false, ("API replied with non-200 code: %d"):format(response.code); end + + for line in response.body:gmatch("[^\r\n]+") do + local probe_success = line:match("^probe_success%s+(%d+)"); + + if probe_success == "1" then + return true; + elseif probe_success == "0" then + return false; + end + end + return false, "Probe endpoint did not return a success status"; +end + +local function check_turn_service(turn_service, ping_service) + local ip = require "util.ip"; + local stun = require "net.stun"; + + -- Create UDP socket for communication with the server + local sock = assert(require "socket".udp()); + sock:setsockname("*", 0); + sock:setpeername(turn_service.host, turn_service.port); + sock:settimeout(10); + + -- Helper function to receive a packet + local function receive_packet() + local raw_packet, err = sock:receive(); + if not raw_packet then + return nil, err; + end + return stun.new_packet():deserialize(raw_packet); + end + + local result = { warnings = {} }; + + -- Send a "binding" query, i.e. a request for our external IP/port + local bind_query = stun.new_packet("binding", "request"); + bind_query:add_attribute("software", "prosodyctl check turn"); + sock:send(bind_query:serialize()); + + local bind_result, err = receive_packet(); + if not bind_result then + result.error = "No STUN response: "..err; + return result; + elseif bind_result:is_err_resp() then + result.error = ("STUN server returned error: %d (%s)"):format(bind_result:get_error()); + return result; + elseif not bind_result:is_success_resp() then + result.error = ("Unexpected STUN response: %d (%s)"):format(bind_result:get_type()); + return result; + end + + result.external_ip = bind_result:get_xor_mapped_address(); + if not result.external_ip then + result.error = "STUN server did not return an address"; + return result; + end + if ip.new_ip(result.external_ip.address).private then + table.insert(result.warnings, "STUN returned a private IP! Is the TURN server behind a NAT and misconfigured?"); + end + + -- Send a TURN "allocate" request. Expected to fail due to auth, but + -- necessary to obtain a valid realm/nonce from the server. + local pre_request = stun.new_packet("allocate", "request"); + sock:send(pre_request:serialize()); + + local pre_result, err = receive_packet(); + if not pre_result then + result.error = "No initial TURN response: "..err; + return result; + elseif pre_result:is_success_resp() then + result.error = "TURN server does not have authentication enabled"; + return result; + end + + local realm = pre_result:get_attribute("realm"); + local nonce = pre_result:get_attribute("nonce"); + + if not realm then + table.insert(result.warnings, "TURN server did not return an authentication realm. Is authentication enabled?"); + end + if not nonce then + table.insert(result.warnings, "TURN server did not return a nonce"); + end + + -- Use the configured secret to obtain temporary user/pass credentials + local turn_user, turn_pass = stun.get_user_pass_from_secret(turn_service.secret); + + -- Send a TURN allocate request, will fail if auth is wrong + local alloc_request = stun.new_packet("allocate", "request"); + alloc_request:add_requested_transport("udp"); + alloc_request:add_attribute("username", turn_user); + if realm then + alloc_request:add_attribute("realm", realm); + end + if nonce then + alloc_request:add_attribute("nonce", nonce); + end + local key = stun.get_long_term_auth_key(realm or turn_service.host, turn_user, turn_pass); + alloc_request:add_message_integrity(key); + sock:send(alloc_request:serialize()); + + -- Check the response + local alloc_response, err = receive_packet(); + if not alloc_response then + result.error = "TURN server did not response to allocation request: "..err; + return result; + elseif alloc_response:is_err_resp() then + result.error = ("TURN allocation failed: %d (%s)"):format(alloc_response:get_error()); + return result; + elseif not alloc_response:is_success_resp() then + result.error = ("Unexpected TURN response: %d (%s)"):format(alloc_response:get_type()); + return result; + end + + result.relayed_addresses = alloc_response:get_xor_relayed_addresses(); + + if not ping_service then + -- Success! We won't be running the relay test. + return result; + end + + -- Run the relay test - i.e. send a binding request to ping_service + -- and receive a response. + + -- Resolve the IP of the ping service + local ping_host, ping_port = ping_service:match("^([^:]+):(%d+)$"); + if ping_host then + ping_port = tonumber(ping_port); + else + -- Only a hostname specified, use default STUN port + ping_host, ping_port = ping_service, 3478; + end + + if ping_host == turn_service.host then + result.error = ("Unable to perform ping test: please supply an external STUN server address. See https://prosody.im/doc/turn#prosodyctl-check"); + return result; + end + + local ping_service_ip, err = socket.dns.toip(ping_host); + if not ping_service_ip then + result.error = "Unable to resolve ping service hostname: "..err; + return result; + end + + -- Ask the TURN server to allow packets from the ping service IP + local perm_request = stun.new_packet("create-permission"); + perm_request:add_xor_peer_address(ping_service_ip); + perm_request:add_attribute("username", turn_user); + if realm then + perm_request:add_attribute("realm", realm); + end + if nonce then + perm_request:add_attribute("nonce", nonce); + end + perm_request:add_message_integrity(key); + sock:send(perm_request:serialize()); + + local perm_response, err = receive_packet(); + if not perm_response then + result.error = "No response from TURN server when requesting peer permission: "..err; + return result; + elseif perm_response:is_err_resp() then + result.error = ("TURN permission request failed: %d (%s)"):format(perm_response:get_error()); + return result; + elseif not perm_response:is_success_resp() then + result.error = ("Unexpected TURN response: %d (%s)"):format(perm_response:get_type()); + return result; + end + + -- Ask the TURN server to relay a STUN binding request to the ping server + local ping_data = stun.new_packet("binding"):serialize(); + + local ping_request = stun.new_packet("send", "indication"); + ping_request:add_xor_peer_address(ping_service_ip, ping_port); + ping_request:add_attribute("data", ping_data); + ping_request:add_attribute("username", turn_user); + if realm then + ping_request:add_attribute("realm", realm); + end + if nonce then + ping_request:add_attribute("nonce", nonce); + end + ping_request:add_message_integrity(key); + sock:send(ping_request:serialize()); + + local ping_response, err = receive_packet(); + if not ping_response then + result.error = "No response from ping server ("..ping_service_ip.."): "..err; + return result; + elseif not ping_response:is_indication() or select(2, ping_response:get_method()) ~= "data" then + result.error = ("Unexpected TURN response: %s %s"):format(select(2, ping_response:get_method()), select(2, ping_response:get_type())); + return result; + end + + local pong_data = ping_response:get_attribute("data"); + if not pong_data then + result.error = "No data relayed from remote server"; + return result; + end + local pong = stun.new_packet():deserialize(pong_data); + + result.external_ip_pong = pong:get_xor_mapped_address(); + if not result.external_ip_pong then + result.error = "Ping server did not return an address"; + return result; + end + + local relay_address_found, relay_port_matches; + for _, relayed_address in ipairs(result.relayed_addresses) do + if relayed_address.address == result.external_ip_pong.address then + relay_address_found = true; + relay_port_matches = result.external_ip_pong.port == relayed_address.port; + end + end + if not relay_address_found then + table.insert(result.warnings, "TURN external IP vs relay address mismatch! Is the TURN server behind a NAT and misconfigured?"); + elseif not relay_port_matches then + table.insert(result.warnings, "External port does not match reported relay port! This is probably caused by a NAT in front of the TURN server."); + end + + -- + + return result; +end + +local function skip_bare_jid_hosts(host) + if jid_split(host) then + -- See issue #779 + return false; + end + return true; +end + +local check_opts = { + short_params = { + h = "help", v = "verbose"; + }; + value_params = { + ping = true; + }; +}; + +local function check(arg) + if arg[1] == "help" or arg[1] == "--help" then + show_usage([[check]], [[Perform basic checks on your Prosody installation]]); + return 1; + end + local what = table.remove(arg, 1); + local opts, opts_err, opts_info = parse_args(arg, check_opts); + if opts_err == "missing-value" then + print("Error: Expected a value after '"..opts_info.."'"); + return 1; + elseif opts_err == "param-not-found" then + print("Error: Unknown parameter: "..opts_info); + return 1; + end + local array = require "util.array"; + local set = require "util.set"; + local it = require "util.iterators"; + local ok = true; + local function disabled_hosts(host, conf) return host ~= "*" and conf.enabled ~= false; end + local function enabled_hosts() return it.filter(disabled_hosts, pairs(configmanager.getconfig())); end + if not (what == nil or what == "disabled" or what == "config" or what == "dns" or what == "certs" or what == "connectivity" or what == "turn") then + show_warning("Don't know how to check '%s'. Try one of 'config', 'dns', 'certs', 'disabled', 'turn' or 'connectivity'.", what); + show_warning("Note: The connectivity check will connect to a remote server."); + return 1; + end + if not what or what == "disabled" then + local disabled_hosts_set = set.new(); + for host, host_options in it.filter("*", pairs(configmanager.getconfig())) do + if host_options.enabled == false then + disabled_hosts_set:add(host); + end + end + if not disabled_hosts_set:empty() then + local msg = "Checks will be skipped for these disabled hosts: %s"; + if what then msg = "These hosts are disabled: %s"; end + show_warning(msg, tostring(disabled_hosts_set)); + if what then return 0; end + print"" + end + end + if not what or what == "config" then + print("Checking config..."); + + if what == "config" then + local files = configmanager.files(); + print(" The following configuration files have been loaded:"); + print(" - "..table.concat(files, "\n - ")); + end + + local obsolete = set.new({ --> remove + "archive_cleanup_interval", + "cross_domain_bosh", + "cross_domain_websocket", + "dns_timeout", + "muc_log_cleanup_interval", + "s2s_dns_resolvers", + "setgid", + "setuid", + }); + local function instead_use(kind, name, value) + if kind == "option" then + if value then + return string.format("instead, use '%s = %q'", name, value); + else + return string.format("instead, use '%s'", name); + end + elseif kind == "module" then + return string.format("instead, add %q to '%s'", name, value or "modules_enabled"); + elseif kind == "community" then + return string.format("instead, add %q from %s", name, value or "prosody-modules"); + end + return kind + end + local deprecated_replacements = { + anonymous_login = instead_use("option", "authentication", "anonymous"); + daemonize = "instead, use the --daemonize/-D or --foreground/-F command line flags"; + disallow_s2s = instead_use("module", "s2s"); + no_daemonize = "instead, use the --daemonize/-D or --foreground/-F command line flags"; + require_encryption = "instead, use 'c2s_require_encryption' and 's2s_require_encryption'"; + vcard_compatibility = instead_use("community", "mod_compat_vcard"); + use_libevent = instead_use("option", "network_backend", "event"); + whitelist_registration_only = instead_use("option", "allowlist_registration_only"); + registration_whitelist = instead_use("option", "registration_allowlist"); + registration_blacklist = instead_use("option", "registration_blocklist"); + blacklist_on_registration_throttle_overload = instead_use("blocklist_on_registration_throttle_overload"); + }; + -- FIXME all the singular _port and _interface options are supposed to be deprecated too + local deprecated_ports = { bosh = "http", legacy_ssl = "c2s_direct_tls" }; + local port_suffixes = set.new({ "port", "ports", "interface", "interfaces", "ssl" }); + for port, replacement in pairs(deprecated_ports) do + for suffix in port_suffixes do + local rsuffix = (suffix == "port" or suffix == "interface") and suffix.."s" or suffix; + deprecated_replacements[port.."_"..suffix] = "instead, use '"..replacement.."_"..rsuffix.."'" + end + end + local deprecated = set.new(array.collect(it.keys(deprecated_replacements))); + local known_global_options = set.new({ + "access_control_allow_credentials", + "access_control_allow_headers", + "access_control_allow_methods", + "access_control_max_age", + "admin_socket", + "body_size_limit", + "bosh_max_inactivity", + "bosh_max_polling", + "bosh_max_wait", + "buffer_size_limit", + "c2s_close_timeout", + "c2s_stanza_size_limit", + "c2s_tcp_keepalives", + "c2s_timeout", + "component_stanza_size_limit", + "component_tcp_keepalives", + "consider_bosh_secure", + "consider_websocket_secure", + "console_banner", + "console_prettyprint_settings", + "daemonize", + "gc", + "http_default_host", + "http_errors_always_show", + "http_errors_default_message", + "http_errors_detailed", + "http_errors_messages", + "http_max_buffer_size", + "http_max_content_size", + "installer_plugin_path", + "limits", + "limits_resolution", + "log", + "multiplex_buffer_size", + "network_backend", + "network_default_read_size", + "network_settings", + "openmetrics_allow_cidr", + "openmetrics_allow_ips", + "pidfile", + "plugin_paths", + "plugin_server", + "prosodyctl_timeout", + "prosody_group", + "prosody_user", + "run_as_root", + "s2s_close_timeout", + "s2s_insecure_domains", + "s2s_require_encryption", + "s2s_secure_auth", + "s2s_secure_domains", + "s2s_stanza_size_limit", + "s2s_tcp_keepalives", + "s2s_timeout", + "statistics", + "statistics_config", + "statistics_interval", + "tcp_keepalives", + "tls_profile", + "trusted_proxies", + "umask", + "use_dane", + "use_ipv4", + "use_ipv6", + "websocket_frame_buffer_limit", + "websocket_frame_fragment_limit", + "websocket_get_response_body", + "websocket_get_response_text", + }); + local config = configmanager.getconfig(); + -- Check that we have any global options (caused by putting a host at the top) + if it.count(it.filter("log", pairs(config["*"]))) == 0 then + ok = false; + print(""); + print(" No global options defined. Perhaps you have put a host definition at the top") + print(" of the config file? They should be at the bottom, see https://prosody.im/doc/configure#overview"); + end + if it.count(enabled_hosts()) == 0 then + ok = false; + print(""); + if it.count(it.filter("*", pairs(config))) == 0 then + print(" No hosts are defined, please add at least one VirtualHost section") + elseif config["*"]["enabled"] == false then + print(" No hosts are enabled. Remove enabled = false from the global section or put enabled = true under at least one VirtualHost section") + else + print(" All hosts are disabled. Remove enabled = false from at least one VirtualHost section") + end + end + if not config["*"].modules_enabled then + print(" No global modules_enabled is set?"); + local suggested_global_modules; + for host, options in enabled_hosts() do --luacheck: ignore 213/host + if not options.component_module and options.modules_enabled then + suggested_global_modules = set.intersection(suggested_global_modules or set.new(options.modules_enabled), set.new(options.modules_enabled)); + end + end + if suggested_global_modules and not suggested_global_modules:empty() then + print(" Consider moving these modules into modules_enabled in the global section:") + print(" "..tostring(suggested_global_modules / function (x) return ("%q"):format(x) end)); + end + print(); + end + + do -- Check for modules enabled both normally and as components + local modules = set.new(config["*"]["modules_enabled"]); + for host, options in enabled_hosts() do + local component_module = options.component_module; + if component_module and modules:contains(component_module) then + print((" mod_%s is enabled both in modules_enabled and as Component %q %q"):format(component_module, host, component_module)); + print(" This means the service is enabled on all VirtualHosts as well as the Component."); + print(" Are you sure this what you want? It may cause unexpected behaviour."); + end + end + end + + -- Check for global options under hosts + local global_options = set.new(it.to_array(it.keys(config["*"]))); + local obsolete_global_options = set.intersection(global_options, obsolete); + if not obsolete_global_options:empty() then + print(""); + print(" You have some obsolete options you can remove from the global section:"); + print(" "..tostring(obsolete_global_options)) + ok = false; + end + local deprecated_global_options = set.intersection(global_options, deprecated); + if not deprecated_global_options:empty() then + print(""); + print(" You have some deprecated options in the global section:"); + for option in deprecated_global_options do + print((" '%s' -- %s"):format(option, deprecated_replacements[option])); + end + ok = false; + end + for host, options in it.filter(function (h) return h ~= "*" end, pairs(configmanager.getconfig())) do + local host_options = set.new(it.to_array(it.keys(options))); + local misplaced_options = set.intersection(host_options, known_global_options); + for name in pairs(options) do + if name:match("^interfaces?") + or name:match("_ports?$") or name:match("_interfaces?$") + or (name:match("_ssl$") and not name:match("^[cs]2s_ssl$")) then + misplaced_options:add(name); + end + end + -- FIXME These _could_ be misplaced, but we would have to check where the corresponding module is loaded to be sure + misplaced_options:exclude(set.new({ "external_service_port", "turn_external_port" })); + if not misplaced_options:empty() then + ok = false; + print(""); + local n = it.count(misplaced_options); + print(" You have "..n.." option"..(n>1 and "s " or " ").."set under "..host.." that should be"); + print(" in the global section of the config file, above any VirtualHost or Component definitions,") + print(" see https://prosody.im/doc/configure#overview for more information.") + print(""); + print(" You need to move the following option"..(n>1 and "s" or "")..": "..table.concat(it.to_array(misplaced_options), ", ")); + end + end + for host, options in enabled_hosts() do + local host_options = set.new(it.to_array(it.keys(options))); + local subdomain = host:match("^[^.]+"); + if not(host_options:contains("component_module")) and (subdomain == "jabber" or subdomain == "xmpp" + or subdomain == "chat" or subdomain == "im") then + print(""); + print(" Suggestion: If "..host.. " is a new host with no real users yet, consider renaming it now to"); + print(" "..host:gsub("^[^.]+%.", "")..". You can use SRV records to redirect XMPP clients and servers to "..host.."."); + print(" For more information see: https://prosody.im/doc/dns"); + end + end + local all_modules = set.new(config["*"].modules_enabled); + local all_options = set.new(it.to_array(it.keys(config["*"]))); + for host in enabled_hosts() do + all_options:include(set.new(it.to_array(it.keys(config[host])))); + all_modules:include(set.new(config[host].modules_enabled)); + end + for mod in all_modules do + if mod:match("^mod_") then + print(""); + print(" Modules in modules_enabled should not have the 'mod_' prefix included."); + print(" Change '"..mod.."' to '"..mod:match("^mod_(.*)").."'."); + elseif mod:match("^auth_") then + print(""); + print(" Authentication modules should not be added to modules_enabled,"); + print(" but be specified in the 'authentication' option."); + print(" Remove '"..mod.."' from modules_enabled and instead add"); + print(" authentication = '"..mod:match("^auth_(.*)").."'"); + print(" For more information see https://prosody.im/doc/authentication"); + elseif mod:match("^storage_") then + print(""); + print(" storage modules should not be added to modules_enabled,"); + print(" but be specified in the 'storage' option."); + print(" Remove '"..mod.."' from modules_enabled and instead add"); + print(" storage = '"..mod:match("^storage_(.*)").."'"); + print(" For more information see https://prosody.im/doc/storage"); + end + end + if all_modules:contains("vcard") and all_modules:contains("vcard_legacy") then + print(""); + print(" Both mod_vcard_legacy and mod_vcard are enabled but they conflict"); + print(" with each other. Remove one."); + end + if all_modules:contains("pep") and all_modules:contains("pep_simple") then + print(""); + print(" Both mod_pep_simple and mod_pep are enabled but they conflict"); + print(" with each other. Remove one."); + end + for host, host_config in pairs(config) do --luacheck: ignore 213/host + if type(rawget(host_config, "storage")) == "string" and rawget(host_config, "default_storage") then + print(""); + print(" The 'default_storage' option is not needed if 'storage' is set to a string."); + break; + end + end + local require_encryption = set.intersection(all_options, set.new({ + "require_encryption", "c2s_require_encryption", "s2s_require_encryption" + })):empty(); + local ssl = dependencies.softreq"ssl"; + if not ssl then + if not require_encryption then + print(""); + print(" You require encryption but LuaSec is not available."); + print(" Connections will fail."); + ok = false; + end + elseif not ssl.loadcertificate then + if all_options:contains("s2s_secure_auth") then + print(""); + print(" You have set s2s_secure_auth but your version of LuaSec does "); + print(" not support certificate validation, so all s2s connections will"); + print(" fail."); + ok = false; + elseif all_options:contains("s2s_secure_domains") then + local secure_domains = set.new(); + for host in enabled_hosts() do + if config[host].s2s_secure_auth == true then + secure_domains:add("*"); + else + secure_domains:include(set.new(config[host].s2s_secure_domains)); + end + end + if not secure_domains:empty() then + print(""); + print(" You have set s2s_secure_domains but your version of LuaSec does "); + print(" not support certificate validation, so s2s connections to/from "); + print(" these domains will fail."); + ok = false; + end + end + elseif require_encryption and not all_modules:contains("tls") then + print(""); + print(" You require encryption but mod_tls is not enabled."); + print(" Connections will fail."); + ok = false; + end + + do + local global_modules = set.new(config["*"].modules_enabled); + local registration_enabled_hosts = {}; + for host in enabled_hosts() do + local host_modules = set.new(config[host].modules_enabled) + global_modules; + local allow_registration = config[host].allow_registration; + local mod_register = host_modules:contains("register"); + local mod_register_ibr = host_modules:contains("register_ibr"); + local mod_invites_register = host_modules:contains("invites_register"); + local registration_invite_only = config[host].registration_invite_only; + local is_vhost = not config[host].component_module; + if is_vhost and (mod_register_ibr or (mod_register and allow_registration)) + and not (mod_invites_register and registration_invite_only) then + table.insert(registration_enabled_hosts, host); + end + end + if #registration_enabled_hosts > 0 then + table.sort(registration_enabled_hosts); + print(""); + print(" Public registration is enabled on:"); + print(" "..table.concat(registration_enabled_hosts, ", ")); + print(""); + print(" If this is intentional, review our guidelines on running a public server"); + print(" at https://prosody.im/doc/public_servers - otherwise, consider switching to"); + print(" invite-based registration, which is more secure."); + end + end + + do + local orphan_components = {}; + local referenced_components = set.new(); + local enabled_hosts_set = set.new(); + for host, host_options in it.filter("*", pairs(configmanager.getconfig())) do + if host_options.enabled ~= false then + enabled_hosts_set:add(host); + for _, disco_item in ipairs(host_options.disco_items or {}) do + referenced_components:add(disco_item[1]); + end + end + end + for host, host_config in it.filter(skip_bare_jid_hosts, enabled_hosts()) do + local is_component = not not host_config.component_module; + if is_component then + local parent_domain = host:match("^[^.]+%.(.+)$"); + local is_orphan = not (enabled_hosts_set:contains(parent_domain) or referenced_components:contains(host)); + if is_orphan then + table.insert(orphan_components, host); + end + end + end + if #orphan_components > 0 then + table.sort(orphan_components); + print(""); + print(" Your configuration contains the following unreferenced components:\n"); + print(" "..table.concat(orphan_components, "\n ")); + print(""); + print(" Clients may not be able to discover these services because they are not linked to"); + print(" any VirtualHost. They are automatically linked if they are direct subdomains of a"); + print(" VirtualHost. Alternatively, you can explicitly link them using the disco_items option."); + print(" For more information see https://prosody.im/doc/modules/mod_disco#items"); + end + end + + print("Done.\n"); + end + if not what or what == "dns" then + local dns = require "net.dns"; + pcall(function () + local unbound = require"net.unbound"; + dns = unbound.dns; + end) + local idna = require "util.encodings".idna; + local ip = require "util.ip"; + local c2s_ports = set.new(configmanager.get("*", "c2s_ports") or {5222}); + local s2s_ports = set.new(configmanager.get("*", "s2s_ports") or {5269}); + local c2s_tls_ports = set.new(configmanager.get("*", "c2s_direct_tls_ports") or {}); + local s2s_tls_ports = set.new(configmanager.get("*", "s2s_direct_tls_ports") or {}); + + if set.new(configmanager.get("*", "modules_enabled")):contains("net_multiplex") then + local multiplex_ports = set.new(configmanager.get("*", "ports") or {}); + local multiplex_tls_ports = set.new(configmanager.get("*", "ssl_ports") or {}); + if not multiplex_ports:empty() then + c2s_ports = c2s_ports + multiplex_ports; + s2s_ports = s2s_ports + multiplex_ports; + end + if not multiplex_tls_ports:empty() then + c2s_tls_ports = c2s_tls_ports + multiplex_tls_ports; + s2s_tls_ports = s2s_tls_ports + multiplex_tls_ports; + end + end + + local c2s_srv_required, s2s_srv_required, c2s_tls_srv_required, s2s_tls_srv_required; + if not c2s_ports:contains(5222) then + c2s_srv_required = true; + end + if not s2s_ports:contains(5269) then + s2s_srv_required = true; + end + if not c2s_tls_ports:empty() then + c2s_tls_srv_required = true; + end + if not s2s_tls_ports:empty() then + s2s_tls_srv_required = true; + end + + local problem_hosts = set.new(); + + local external_addresses, internal_addresses = set.new(), set.new(); + + local fqdn = socket.dns.tohostname(socket.dns.gethostname()); + if fqdn then + do + local res = dns.lookup(idna.to_ascii(fqdn), "A"); + if res then + for _, record in ipairs(res) do + external_addresses:add(record.a); + end + end + end + do + local res = dns.lookup(idna.to_ascii(fqdn), "AAAA"); + if res then + for _, record in ipairs(res) do + external_addresses:add(record.aaaa); + end + end + end + end + + local local_addresses = require"util.net".local_addresses() or {}; + + for addr in it.values(local_addresses) do + if not ip.new_ip(addr).private then + external_addresses:add(addr); + else + internal_addresses:add(addr); + end + end + + -- Allow admin to specify additional (e.g. undiscoverable) IP addresses in the config + for _, address in ipairs(configmanager.get("*", "external_addresses") or {}) do + external_addresses:add(address); + end + + if external_addresses:empty() then + print(""); + print(" Failed to determine the external addresses of this server. Checks may be inaccurate."); + c2s_srv_required, s2s_srv_required = true, true; + end + + local v6_supported = not not socket.tcp6; + local use_ipv4 = configmanager.get("*", "use_ipv4") ~= false; + local use_ipv6 = v6_supported and configmanager.get("*", "use_ipv6") ~= false; + + local function trim_dns_name(n) + return (n:gsub("%.$", "")); + end + + local unknown_addresses = set.new(); + + for jid, host_options in enabled_hosts() do + local all_targets_ok, some_targets_ok = true, false; + local node, host = jid_split(jid); + + local modules, component_module = modulemanager.get_modules_for_host(host); + if component_module then + modules:add(component_module); + end + + local is_component = not not host_options.component_module; + print("Checking DNS for "..(is_component and "component" or "host").." "..jid.."..."); + if node then + print("Only the domain part ("..host..") is used in DNS.") + end + local target_hosts = set.new(); + if modules:contains("c2s") then + local res = dns.lookup("_xmpp-client._tcp."..idna.to_ascii(host)..".", "SRV"); + if res and #res > 0 then + for _, record in ipairs(res) do + if record.srv.target == "." then -- TODO is this an error if mod_c2s is enabled? + print(" 'xmpp-client' service disabled by pointing to '.'"); -- FIXME Explain better what this is + break; + end + local target = trim_dns_name(record.srv.target); + target_hosts:add(target); + if not c2s_ports:contains(record.srv.port) then + print(" SRV target "..target.." contains unknown client port: "..record.srv.port); + end + end + else + if c2s_srv_required then + print(" No _xmpp-client SRV record found for "..host..", but it looks like you need one."); + all_targets_ok = false; + else + target_hosts:add(host); + end + end + end + if modules:contains("c2s") and c2s_tls_srv_required then + local res = dns.lookup("_xmpps-client._tcp."..idna.to_ascii(host)..".", "SRV"); + if res and #res > 0 then + for _, record in ipairs(res) do + if record.srv.target == "." then -- TODO is this an error if mod_c2s is enabled? + print(" 'xmpps-client' service disabled by pointing to '.'"); -- FIXME Explain better what this is + break; + end + local target = trim_dns_name(record.srv.target); + target_hosts:add(target); + if not c2s_tls_ports:contains(record.srv.port) then + print(" SRV target "..target.." contains unknown Direct TLS client port: "..record.srv.port); + end + end + else + print(" No _xmpps-client SRV record found for "..host..", but it looks like you need one."); + all_targets_ok = false; + end + end + if modules:contains("s2s") then + local res = dns.lookup("_xmpp-server._tcp."..idna.to_ascii(host)..".", "SRV"); + if res and #res > 0 then + for _, record in ipairs(res) do + if record.srv.target == "." then -- TODO Is this an error if mod_s2s is enabled? + print(" 'xmpp-server' service disabled by pointing to '.'"); -- FIXME Explain better what this is + break; + end + local target = trim_dns_name(record.srv.target); + target_hosts:add(target); + if not s2s_ports:contains(record.srv.port) then + print(" SRV target "..target.." contains unknown server port: "..record.srv.port); + end + end + else + if s2s_srv_required then + print(" No _xmpp-server SRV record found for "..host..", but it looks like you need one."); + all_targets_ok = false; + else + target_hosts:add(host); + end + end + end + if modules:contains("s2s") and s2s_tls_srv_required then + local res = dns.lookup("_xmpps-server._tcp."..idna.to_ascii(host)..".", "SRV"); + if res and #res > 0 then + for _, record in ipairs(res) do + if record.srv.target == "." then -- TODO is this an error if mod_s2s is enabled? + print(" 'xmpps-server' service disabled by pointing to '.'"); -- FIXME Explain better what this is + break; + end + local target = trim_dns_name(record.srv.target); + target_hosts:add(target); + if not s2s_tls_ports:contains(record.srv.port) then + print(" SRV target "..target.." contains unknown Direct TLS server port: "..record.srv.port); + end + end + else + print(" No _xmpps-server SRV record found for "..host..", but it looks like you need one."); + all_targets_ok = false; + end + end + if target_hosts:empty() then + target_hosts:add(host); + end + + if target_hosts:contains("localhost") then + print(" Target 'localhost' cannot be accessed from other servers"); + target_hosts:remove("localhost"); + end + + local function check_address(target) + local A, AAAA = dns.lookup(idna.to_ascii(target), "A"), dns.lookup(idna.to_ascii(target), "AAAA"); + local prob = {}; + if use_ipv4 and not (A and #A > 0) then table.insert(prob, "A"); end + if use_ipv6 and not (AAAA and #AAAA > 0) then table.insert(prob, "AAAA"); end + return prob; + end + + if modules:contains("proxy65") then + local proxy65_target = configmanager.get(host, "proxy65_address") or host; + if type(proxy65_target) == "string" then + local prob = check_address(proxy65_target); + if #prob > 0 then + print(" File transfer proxy "..proxy65_target.." has no "..table.concat(prob, "/") + .." record. Create one or set 'proxy65_address' to the correct host/IP."); + end + else + print(" proxy65_address for "..host.." should be set to a string, unable to perform DNS check"); + end + end + + local known_http_modules = set.new { "bosh"; "http_files"; "http_file_share"; "http_openmetrics"; "websocket" }; + local function contains_match(hayset, needle) + for member in hayset do if member:find(needle) then return true end end + end + + if modules:contains("http") or not set.intersection(modules, known_http_modules):empty() + or contains_match(modules, "^http_") or contains_match(modules, "_web$") then + + local http_host = configmanager.get(host, "http_host") or host; + local http_internal_host = http_host; + local http_url = configmanager.get(host, "http_external_url"); + if http_url then + local url_parse = require "socket.url".parse; + local external_url_parts = url_parse(http_url); + if external_url_parts then + http_host = external_url_parts.host; + else + print(" The 'http_external_url' setting is not a valid URL"); + end + end + + local prob = check_address(http_host); + if #prob > 1 then + print(" HTTP service " .. http_host .. " has no " .. table.concat(prob, "/") .. " record. Create one or change " + .. (http_url and "'http_external_url'" or "'http_host'").." to the correct host."); + end + + if http_host ~= http_internal_host then + print(" Ensure the reverse proxy sets the HTTP Host header to '" .. http_internal_host .. "'"); + end + end + + if not use_ipv4 and not use_ipv6 then + print(" Both IPv6 and IPv4 are disabled, Prosody will not listen on any ports"); + print(" nor be able to connect to any remote servers."); + all_targets_ok = false; + end + + for target_host in target_hosts do + local host_ok_v4, host_ok_v6; + do + local res = dns.lookup(idna.to_ascii(target_host), "A"); + if res then + for _, record in ipairs(res) do + if external_addresses:contains(record.a) then + some_targets_ok = true; + host_ok_v4 = true; + elseif internal_addresses:contains(record.a) then + host_ok_v4 = true; + some_targets_ok = true; + print(" "..target_host.." A record points to internal address, external connections might fail"); + else + print(" "..target_host.." A record points to unknown address "..record.a); + unknown_addresses:add(record.a); + all_targets_ok = false; + end + end + end + end + do + local res = dns.lookup(idna.to_ascii(target_host), "AAAA"); + if res then + for _, record in ipairs(res) do + if external_addresses:contains(record.aaaa) then + some_targets_ok = true; + host_ok_v6 = true; + elseif internal_addresses:contains(record.aaaa) then + host_ok_v6 = true; + some_targets_ok = true; + print(" "..target_host.." AAAA record points to internal address, external connections might fail"); + else + print(" "..target_host.." AAAA record points to unknown address "..record.aaaa); + unknown_addresses:add(record.aaaa); + all_targets_ok = false; + end + end + end + end + + if host_ok_v4 and not use_ipv4 then + print(" Host "..target_host.." does seem to resolve to this server but IPv4 has been disabled"); + all_targets_ok = false; + end + + if host_ok_v6 and not use_ipv6 then + print(" Host "..target_host.." does seem to resolve to this server but IPv6 has been disabled"); + all_targets_ok = false; + end + + local bad_protos = {} + if use_ipv4 and not host_ok_v4 then + table.insert(bad_protos, "IPv4"); + end + if use_ipv6 and not host_ok_v6 then + table.insert(bad_protos, "IPv6"); + end + if #bad_protos > 0 then + print(" Host "..target_host.." does not seem to resolve to this server ("..table.concat(bad_protos, "/")..")"); + end + if host_ok_v6 and not v6_supported then + print(" Host "..target_host.." has AAAA records, but your version of LuaSocket does not support IPv6."); + print(" Please see https://prosody.im/doc/ipv6 for more information."); + elseif host_ok_v6 and not use_ipv6 then + print(" Host "..target_host.." has AAAA records, but IPv6 is disabled."); + -- TODO Tell them to drop the AAAA records or enable IPv6? + print(" Please see https://prosody.im/doc/ipv6 for more information."); + end + end + if not all_targets_ok then + print(" "..(some_targets_ok and "Only some" or "No").." targets for "..host.." appear to resolve to this server."); + if is_component then + print(" DNS records are necessary if you want users on other servers to access this component."); + end + problem_hosts:add(host); + end + print(""); + end + if not problem_hosts:empty() then + if not unknown_addresses:empty() then + print(""); + print("Some of your DNS records point to unknown IP addresses. This may be expected if your server"); + print("is behind a NAT or proxy. The unrecognized addresses were:"); + print(""); + print(" Unrecognized: "..tostring(unknown_addresses)); + print(""); + print("The addresses we found on this system are:"); + print(""); + print(" Internal: "..tostring(internal_addresses)); + print(" External: "..tostring(external_addresses)); + end + print(""); + print("For more information about DNS configuration please see https://prosody.im/doc/dns"); + print(""); + ok = false; + end + end + if not what or what == "certs" then + local cert_ok; + print"Checking certificates..." + local x509_verify_identity = require"util.x509".verify_identity; + local create_context = require "core.certmanager".create_context; + local ssl = dependencies.softreq"ssl"; + -- local datetime_parse = require"util.datetime".parse_x509; + local load_cert = ssl and ssl.loadcertificate; + -- or ssl.cert_from_pem + if not ssl then + print("LuaSec not available, can't perform certificate checks") + if what == "certs" then cert_ok = false end + elseif not load_cert then + print("This version of LuaSec (" .. ssl._VERSION .. ") does not support certificate checking"); + cert_ok = false + else + for host in it.filter(skip_bare_jid_hosts, enabled_hosts()) do + print("Checking certificate for "..host); + -- First, let's find out what certificate this host uses. + local host_ssl_config = configmanager.rawget(host, "ssl") + or configmanager.rawget(host:match("%.(.*)"), "ssl"); + local global_ssl_config = configmanager.rawget("*", "ssl"); + local ok, err, ssl_config = create_context(host, "server", host_ssl_config, global_ssl_config); + if not ok then + print(" Error: "..err); + cert_ok = false + elseif not ssl_config.certificate then + print(" No 'certificate' found for "..host) + cert_ok = false + elseif not ssl_config.key then + print(" No 'key' found for "..host) + cert_ok = false + else + local key, err = io.open(ssl_config.key); -- Permissions check only + if not key then + print(" Could not open "..ssl_config.key..": "..err); + cert_ok = false + else + key:close(); + end + local cert_fh, err = io.open(ssl_config.certificate); -- Load the file. + if not cert_fh then + print(" Could not open "..ssl_config.certificate..": "..err); + cert_ok = false + else + print(" Certificate: "..ssl_config.certificate) + local cert = load_cert(cert_fh:read"*a"); cert_fh:close(); + if not cert:validat(os.time()) then + print(" Certificate has expired.") + cert_ok = false + elseif not cert:validat(os.time() + 86400) then + print(" Certificate expires within one day.") + cert_ok = false + elseif not cert:validat(os.time() + 86400*7) then + print(" Certificate expires within one week.") + elseif not cert:validat(os.time() + 86400*31) then + print(" Certificate expires within one month.") + end + if configmanager.get(host, "component_module") == nil + and not x509_verify_identity(host, "_xmpp-client", cert) then + print(" Not valid for client connections to "..host..".") + cert_ok = false + end + if (not (configmanager.get(host, "anonymous_login") + or configmanager.get(host, "authentication") == "anonymous")) + and not x509_verify_identity(host, "_xmpp-server", cert) then + print(" Not valid for server-to-server connections to "..host..".") + cert_ok = false + end + end + end + end + end + if cert_ok == false then + print("") + print("For more information about certificates please see https://prosody.im/doc/certificates"); + ok = false + end + print("") + end + -- intentionally not doing this by default + if what == "connectivity" then + local _, prosody_is_running = is_prosody_running(); + if configmanager.get("*", "pidfile") and not prosody_is_running then + print("Prosody does not appear to be running, which is required for this test."); + print("Start it and then try again."); + return 1; + end + + local checker = "observe.jabber.network"; + local probe_instance; + local probe_modules = { + ["xmpp-client"] = "c2s_normal_auth"; + ["xmpp-server"] = "s2s_normal"; + ["xmpps-client"] = nil; -- TODO + ["xmpps-server"] = nil; -- TODO + }; + local probe_settings = configmanager.get("*", "connectivity_probe"); + if type(probe_settings) == "string" then + probe_instance = probe_settings; + elseif type(probe_settings) == "table" and type(probe_settings.url) == "string" then + probe_instance = probe_settings.url; + if type(probe_settings.modules) == "table" then + probe_modules = probe_settings.modules; + end + elseif probe_settings ~= nil then + print("The 'connectivity_probe' setting not understood."); + print("Expected an URL or a table with 'url' and 'modules' fields"); + print("See https://prosody.im/doc/prosodyctl#check for more information."); -- FIXME + return 1; + end + + local check_api; + if probe_instance then + local parsed_url = socket_url.parse(probe_instance); + if not parsed_url then + print(("'connectivity_probe' is not a valid URL: %q"):format(probe_instance)); + print("Set it to the URL of an XMPP Blackbox Exporter instance and try again"); + return 1; + end + checker = parsed_url.host; + + function check_api(protocol, host) + local target = socket_url.build({scheme="xmpp",path=host}); + local probe_module = probe_modules[protocol]; + if not probe_module then + return nil, "Checking protocol '"..protocol.."' is currently unsupported"; + end + return check_probe(probe_instance, probe_module, target); + end + else + check_api = check_ojn; + end + + for host in it.filter(skip_bare_jid_hosts, enabled_hosts()) do + local modules, component_module = modulemanager.get_modules_for_host(host); + if component_module then + modules:add(component_module) + end + + print("Checking external connectivity for "..host.." via "..checker) + local function check_connectivity(protocol) + local success, err = check_api(protocol, host); + if not success and err ~= nil then + print((" %s: Failed to request check at API: %s"):format(protocol, err)) + elseif success then + print((" %s: Works"):format(protocol)) + else + print((" %s: Check service failed to establish (secure) connection"):format(protocol)) + ok = false + end + end + + if modules:contains("c2s") then + check_connectivity("xmpp-client") + if configmanager.get("*", "c2s_direct_tls_ports") then + check_connectivity("xmpps-client"); + end + end + + if modules:contains("s2s") then + check_connectivity("xmpp-server") + if configmanager.get("*", "s2s_direct_tls_ports") then + check_connectivity("xmpps-server"); + end + end + + print() + end + print("Note: The connectivity check only checks the reachability of the domain.") + print("Note: It does not ensure that the check actually reaches this specific prosody instance.") + end + + if not what or what == "turn" then + local turn_enabled_hosts = {}; + local turn_services = {}; + + for host in enabled_hosts() do + local has_external_turn = modulemanager.get_modules_for_host(host):contains("turn_external"); + if has_external_turn then + table.insert(turn_enabled_hosts, host); + local turn_host = configmanager.get(host, "turn_external_host") or host; + local turn_port = configmanager.get(host, "turn_external_port") or 3478; + local turn_secret = configmanager.get(host, "turn_external_secret"); + if not turn_secret then + print("Error: Your configuration is missing a turn_external_secret for "..host); + print("Error: TURN will not be advertised for this host."); + ok = false; + else + local turn_id = ("%s:%d"):format(turn_host, turn_port); + if turn_services[turn_id] and turn_services[turn_id].secret ~= turn_secret then + print("Error: Your configuration contains multiple differing secrets"); + print(" for the TURN service at "..turn_id.." - we will only test one."); + elseif not turn_services[turn_id] then + turn_services[turn_id] = { + host = turn_host; + port = turn_port; + secret = turn_secret; + }; + end + end + end + end + + if what == "turn" then + local count = it.count(pairs(turn_services)); + if count == 0 then + print("Error: Unable to find any TURN services configured. Enable mod_turn_external!"); + ok = false; + else + print("Identified "..tostring(count).." TURN services."); + print(""); + end + end + + for turn_id, turn_service in pairs(turn_services) do + print("Testing TURN service "..turn_id.."..."); + + local result = check_turn_service(turn_service, opts.ping); + if #result.warnings > 0 then + print(("%d warnings:\n"):format(#result.warnings)); + print(" "..table.concat(result.warnings, "\n ")); + print(""); + end + + if opts.verbose then + if result.external_ip then + print(("External IP: %s"):format(result.external_ip.address)); + end + if result.relayed_addresses then + for i, relayed_address in ipairs(result.relayed_addresses) do + print(("Relayed address %d: %s:%d"):format(i, relayed_address.address, relayed_address.port)); + end + end + if result.external_ip_pong then + print(("TURN external address: %s:%d"):format(result.external_ip_pong.address, result.external_ip_pong.port)); + end + end + + if result.error then + print("Error: "..result.error.."\n"); + ok = false; + else + print("Success!\n"); + end + end + end + + if not ok then + print("Problems found, see above."); + else + print("All checks passed, congratulations!"); + end + return ok and 0 or 2; +end + +return { + check = check; +}; diff --git a/util/prosodyctl/shell.lua b/util/prosodyctl/shell.lua new file mode 100644 index 00000000..bce27b94 --- /dev/null +++ b/util/prosodyctl/shell.lua @@ -0,0 +1,148 @@ +local config = require "core.configmanager"; +local server = require "net.server"; +local st = require "util.stanza"; +local path = require "util.paths"; +local parse_args = require "util.argparse".parse; +local unpack = table.unpack or _G.unpack; + +local have_readline, readline = pcall(require, "readline"); + +local adminstream = require "util.adminstream"; + +if have_readline then + readline.set_readline_name("prosody"); + readline.set_options({ + histfile = path.join(prosody.paths.data, ".shell_history"); + ignoredups = true; + }); +end + +local function read_line(prompt_string) + if have_readline then + return readline.readline(prompt_string); + else + io.write(prompt_string); + return io.read("*line"); + end +end + +local function send_line(client, line) + client.send(st.stanza("repl-input"):text(line)); +end + +local function repl(client) + local line = read_line(client.prompt_string or "prosody> "); + if not line or line == "quit" or line == "exit" or line == "bye" then + if not line then + print(""); + end + if have_readline then + readline.save_history(); + end + os.exit(); + end + send_line(client, line); +end + +local function printbanner() + local banner = config.get("*", "console_banner"); + if banner then return print(banner); end + print([[ + ____ \ / _ + | _ \ _ __ ___ ___ _-_ __| |_ _ + | |_) | '__/ _ \/ __|/ _ \ / _` | | | | + | __/| | | (_) \__ \ |_| | (_| | |_| | + |_| |_| \___/|___/\___/ \__,_|\__, | + A study in simplicity |___/ + +]]); + print("Welcome to the Prosody administration console. For a list of commands, type: help"); + print("You may find more help on using this console in our online documentation at "); + print("https://prosody.im/doc/console\n"); +end + +local function start(arg) --luacheck: ignore 212/arg + local client = adminstream.client(); + local opts, err, where = parse_args(arg); + + if not opts then + if err == "param-not-found" then + print("Unknown command-line option: "..tostring(where)); + elseif err == "missing-value" then + print("Expected a value to follow command-line option: "..where); + end + os.exit(1); + end + + if arg[1] then + if arg[2] then + -- prosodyctl shell module reload foo bar.com --> module:reload("foo", "bar.com") + -- COMPAT Lua 5.1 doesn't have the separator argument to string.rep + arg[1] = string.format("%s:%s("..string.rep("%q, ", #arg-2):sub(1, -3)..")", unpack(arg)); + end + + client.events.add_handler("connected", function() + client.send(st.stanza("repl-input"):text(arg[1])); + return true; + end, 1); + + local errors = 0; -- TODO This is weird, but works for now. + client.events.add_handler("received", function(stanza) + if stanza.name == "repl-output" or stanza.name == "repl-result" then + if stanza.attr.type == "error" then + errors = errors + 1; + io.stderr:write(stanza:get_text(), "\n"); + else + print(stanza:get_text()); + end + end + if stanza.name == "repl-result" then + os.exit(errors); + end + return true; + end, 1); + end + + client.events.add_handler("connected", function () + if not opts.quiet then + printbanner(); + end + repl(client); + end); + + client.events.add_handler("disconnected", function () + print("--- session closed ---"); + os.exit(); + end); + + client.events.add_handler("received", function (stanza) + if stanza.name == "repl-output" or stanza.name == "repl-result" then + local result_prefix = stanza.attr.type == "error" and "!" or "|"; + print(result_prefix.." "..stanza:get_text()); + end + if stanza.name == "repl-result" then + repl(client); + end + end); + + client.prompt_string = config.get("*", "admin_shell_prompt"); + + local socket_path = path.resolve_relative_path(prosody.paths.data, opts.socket or config.get("*", "admin_socket") or "prosody.sock"); + local conn = adminstream.connection(socket_path, client.listeners); + local ok, err = conn:connect(); + if not ok then + if err == "no unix socket support" then + print("** LuaSocket unix socket support not available or incompatible, ensure your"); + print("** version is up to date."); + else + print("** Unable to connect to server - is it running? Is mod_admin_shell enabled?"); + print("** Connection error: "..err); + end + os.exit(1); + end + server.loop(); +end + +return { + shell = start; +}; diff --git a/util/pubsub.lua b/util/pubsub.lua index 7ccc817f..acb34db9 100644 --- a/util/pubsub.lua +++ b/util/pubsub.lua @@ -1,9 +1,11 @@ local events = require "util.events"; local cache = require "util.cache"; +local errors = require "util.error"; local service_mt = {}; local default_config = { + max_items = 256; itemstore = function (config, _) return cache.new(config["max_items"]) end; broadcaster = function () end; subscriber_filter = function (subs) return subs end; @@ -131,10 +133,11 @@ local default_config = { local default_config_mt = { __index = default_config }; local default_node_config = { - ["persist_items"] = false; + ["persist_items"] = true; ["max_items"] = 20; ["access_model"] = "open"; ["publish_model"] = "publishers"; + ["send_last_published_item"] = "never"; }; local default_node_config_mt = { __index = default_node_config }; @@ -176,8 +179,11 @@ local function new(config) -- Load nodes from storage, if we have a store and it supports iterating over stored items if config.nodestore and config.nodestore.users then for node_name in config.nodestore:users() do - service.nodes[node_name] = load_node_from_store(service, node_name); - service.data[node_name] = config.itemstore(service.nodes[node_name].config, node_name); + local node = load_node_from_store(service, node_name); + service.nodes[node_name] = node; + if node.config.persist_items then + service.data[node_name] = config.itemstore(service.nodes[node_name].config, node_name); + end for jid in pairs(service.nodes[node_name].subscribers) do local normal_jid = service.config.normalize_jid(jid); @@ -280,7 +286,8 @@ function service:set_affiliation(node, actor, jid, affiliation) --> ok, err node_obj.affiliations[jid] = affiliation; if self.config.nodestore then - local ok, err = save_node_to_store(self, node_obj); + -- TODO pass the error from storage to caller eg wrapped in an util.error + local ok, err = save_node_to_store(self, node_obj); -- luacheck: ignore 211/err if not ok then node_obj.affiliations[jid] = old_affiliation; return ok, "internal-server-error"; @@ -344,7 +351,8 @@ function service:add_subscription(node, actor, jid, options) --> ok, err end if self.config.nodestore then - local ok, err = save_node_to_store(self, node_obj); + -- TODO pass the error from storage to caller eg wrapped in an util.error + local ok, err = save_node_to_store(self, node_obj); -- luacheck: ignore 211/err if not ok then node_obj.subscribers[jid] = old_subscription; self.subscriptions[normal_jid][jid][node] = old_subscription and true or nil; @@ -396,7 +404,8 @@ function service:remove_subscription(node, actor, jid) --> ok, err end if self.config.nodestore then - local ok, err = save_node_to_store(self, node_obj); + -- TODO pass the error from storage to caller eg wrapped in an util.error + local ok, err = save_node_to_store(self, node_obj); -- luacheck: ignore 211/err if not ok then node_obj.subscribers[jid] = old_subscription; self.subscriptions[normal_jid][jid][node] = old_subscription and true or nil; @@ -454,14 +463,18 @@ function service:create(node, actor, options) --> ok, err }; if self.config.nodestore then - local ok, err = save_node_to_store(self, self.nodes[node]); + -- TODO pass the error from storage to caller eg wrapped in an util.error + local ok, err = save_node_to_store(self, self.nodes[node]); -- luacheck: ignore 211/err if not ok then self.nodes[node] = nil; return ok, "internal-server-error"; end end - self.data[node] = self.config.itemstore(self.nodes[node].config, node); + if config.persist_items then + self.data[node] = self.config.itemstore(self.nodes[node].config, node); + end + self.events.fire_event("node-created", { service = self, node = node, actor = actor }); if actor ~= true then local ok, err = self:set_affiliation(node, true, actor, "owner"); @@ -511,7 +524,7 @@ local function check_preconditions(node_config, required_config) end for config_field, value in pairs(required_config) do if node_config[config_field] ~= value then - return false; + return false, config_field; end end return true; @@ -547,23 +560,28 @@ function service:publish(node, actor, id, item, requested_config) --> ok, err node_obj = self.nodes[node]; elseif requested_config and not requested_config._defaults_only then -- Check that node has the requested config before we publish - if not check_preconditions(node_obj.config, requested_config) then - return false, "precondition-not-met"; + local ok, field = check_preconditions(node_obj.config, requested_config); + if not ok then + local err = errors.new({ + type = "cancel", condition = "conflict", text = "Field does not match: "..field; + }); + err.pubsub_condition = "precondition-not-met"; + return false, err; end end if not self.config.itemcheck(item) then return nil, "invalid-item"; end - local node_data = self.data[node]; - if not node_data then - -- FIXME how is this possible? #1657 - return nil, "internal-server-error"; - end - local ok = node_data:set(id, item); - if not ok then - return nil, "internal-server-error"; + if node_obj.config.persist_items then + if not self.data[node] then + self.data[node] = self.config.itemstore(self.nodes[node].config, node); + end + local ok = self.data[node]:set(id, item); + if not ok then + return nil, "internal-server-error"; + end + if type(ok) == "string" then id = ok; end end - if type(ok) == "string" then id = ok; end local event_data = { service = self, node = node, actor = actor, id = id, item = item }; self.events.fire_event("item-published/"..node, event_data); self.events.fire_event("item-published", event_data); @@ -583,12 +601,17 @@ function service:retract(node, actor, id, retract) --> ok, err end -- local node_obj = self.nodes[node]; - if (not node_obj) or (not self.data[node]:get(id)) then + if not node_obj then return false, "item-not-found"; end - local ok = self.data[node]:set(id, nil); - if not ok then - return nil, "internal-server-error"; + if self.data[node] then + if not self.data[node]:get(id) then + return false, "item-not-found"; + end + local ok = self.data[node]:set(id, nil); + if not ok then + return nil, "internal-server-error"; + end end self.events.fire_event("item-retracted", { service = self, node = node, actor = actor, id = id }); if retract then @@ -607,10 +630,12 @@ function service:purge(node, actor, notify) --> ok, err if not node_obj then return false, "item-not-found"; end - if self.data[node] and self.data[node].clear then - self.data[node]:clear() - else - self.data[node] = self.config.itemstore(self.nodes[node].config, node); + if self.data[node] then + if self.data[node].clear then + self.data[node]:clear() + else + self.data[node] = self.config.itemstore(self.nodes[node].config, node); + end end self.events.fire_event("node-purged", { service = self, node = node, actor = actor }); if notify then @@ -619,7 +644,7 @@ function service:purge(node, actor, notify) --> ok, err return true end -function service:get_items(node, actor, ids) --> (true, { id, [id] = node }) or (false, err) +function service:get_items(node, actor, ids, resultspec) --> (true, { id, [id] = node }) or (false, err) -- Access checking if not self:may(node, actor, "get_items") then return false, "forbidden"; @@ -629,22 +654,31 @@ function service:get_items(node, actor, ids) --> (true, { id, [id] = node }) or if not node_obj then return false, "item-not-found"; end + if not self.data[node] then + -- Disabled rather than unsupported, but close enough. + return false, "persistent-items-unsupported"; + end if type(ids) == "string" then -- COMPAT see #1305 ids = { ids }; end local data = {}; + local limit = resultspec and resultspec.max; if ids then for _, key in ipairs(ids) do local value = self.data[node]:get(key); if value then data[#data+1] = key; data[key] = value; + -- Limits and ids seem like a problematic combination. + if limit and #data >= limit then break end end end else for key, value in self.data[node]:items() do data[#data+1] = key; data[key] = value; + if limit and #data >= limit then break + end end end return true, data; @@ -662,6 +696,11 @@ function service:get_last_item(node, actor) --> (true, id, node) or (false, err) return false, "item-not-found"; end + if not self.data[node] then + -- FIXME Should this be a success or failure? + return true, nil; + end + -- Returns success, id, item return true, self.data[node]:head(); end @@ -772,7 +811,8 @@ function service:set_node_config(node, actor, new_config) --> ok, err node_obj.config = new_config; if self.config.nodestore then - local ok, err = save_node_to_store(self, node_obj); + -- TODO pass the error from storage to caller eg wrapped in an util.error + local ok, err = save_node_to_store(self, node_obj); -- luacheck: ignore 211/err if not ok then node_obj.config = old_config; return ok, "internal-server-error"; @@ -792,9 +832,22 @@ function service:set_node_config(node, actor, new_config) --> ok, err end if old_config["persist_items"] ~= node_obj.config["persist_items"] then - self.data[node] = self.config.itemstore(self.nodes[node].config, node); + if node_obj.config["persist_items"] then + self.data[node] = self.config.itemstore(self.nodes[node].config, node); + elseif self.data[node] then + if self.data[node].clear then + self.data[node]:clear() + end + self.data[node] = nil; + end elseif old_config["max_items"] ~= node_obj.config["max_items"] then - self.data[node]:resize(self.nodes[node].config["max_items"]); + if self.data[node] then + local max_items = self.nodes[node].config["max_items"]; + if max_items == "max" then + max_items = self.config.max_items; + end + self.data[node]:resize(max_items); + end end return true; diff --git a/util/queue.lua b/util/queue.lua index c8e71514..c94c62ae 100644 --- a/util/queue.lua +++ b/util/queue.lua @@ -59,18 +59,20 @@ local function new(size, allow_wrapping) return true; end; items = function (self) - --luacheck: ignore 431/t - return function (t, pos) - if pos >= t:count() then + return function (_, pos) + if pos >= items then return nil; end local read_pos = tail + pos; - if read_pos > t.size then + if read_pos > self.size then read_pos = (read_pos%size); end - return pos+1, t._items[read_pos]; + return pos+1, t[read_pos]; end, self, 0; end; + consume = function (self) + return self.pop, self; + end; }; end diff --git a/util/random.lua b/util/random.lua index 6782d7fa..3305172f 100644 --- a/util/random.lua +++ b/util/random.lua @@ -7,7 +7,7 @@ -- local ok, crand = pcall(require, "util.crand"); -if ok then return crand; end +if ok and pcall(crand.bytes, 1) then return crand; end local urandom, urandom_err = io.open("/dev/urandom", "r"); diff --git a/util/rsm.lua b/util/rsm.lua index 40a78fb5..e6060af8 100644 --- a/util/rsm.lua +++ b/util/rsm.lua @@ -10,10 +10,15 @@ -- local stanza = require"util.stanza".stanza; -local tostring, tonumber = tostring, tonumber; +local tonumber = tonumber; +local s_format = string.format; local type = type; local pairs = pairs; +local function inttostr(n) + return s_format("%d", n); +end + local xmlns_rsm = 'http://jabber.org/protocol/rsm'; local element_parsers = {}; @@ -45,22 +50,31 @@ end local element_generators = setmetatable({ first = function(st, data) if type(data) == "table" then - st:tag("first", { index = data.index }):text(data[1]):up(); + st:tag("first", { index = inttostr(data.index) }):text(data[1]):up(); else - st:tag("first"):text(tostring(data)):up(); + st:text_tag("first", data); end end; before = function(st, data) if data == true then st:tag("before"):up(); else - st:tag("before"):text(tostring(data)):up(); + st:text_tag("before", data); end - end + end; + max = function (st, data) + st:text_tag("max", inttostr(data)); + end; + index = function (st, data) + st:text_tag("index", inttostr(data)); + end; + count = function (st, data) + st:text_tag("count", inttostr(data)); + end; }, { __index = function(_, name) return function(st, data) - st:tag(name):text(tostring(data)):up(); + st:text_tag(name, data); end end; }); diff --git a/util/sasl.lua b/util/sasl.lua index 50851405..528743d1 100644 --- a/util/sasl.lua +++ b/util/sasl.lua @@ -27,7 +27,7 @@ Authentication Backend Prototypes: state = false : disabled state = true : enabled -state = nil : non-existant +state = nil : non-existent Channel Binding: @@ -47,7 +47,7 @@ local registered_mechanisms = {}; local backend_mechanism = {}; local mechanism_channelbindings = {}; --- register a new SASL mechanims +-- register a new SASL mechanisms local function registerMechanism(name, backends, f, cb_backends) assert(type(name) == "string", "Parameter name MUST be a string."); assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table."); @@ -97,7 +97,7 @@ function method:clean_clone() return new(self.realm, self.profile) end --- get a list of possible SASL mechanims to use +-- get a list of possible SASL mechanisms to use function method:mechanisms() local current_mechs = {}; for mech, _ in pairs(self.mechs) do @@ -134,7 +134,6 @@ end -- load the mechanisms require "util.sasl.plain" .init(registerMechanism); -require "util.sasl.digest-md5".init(registerMechanism); require "util.sasl.anonymous" .init(registerMechanism); require "util.sasl.scram" .init(registerMechanism); require "util.sasl.external" .init(registerMechanism); diff --git a/util/sasl/digest-md5.lua b/util/sasl/digest-md5.lua deleted file mode 100644 index 7542a037..00000000 --- a/util/sasl/digest-md5.lua +++ /dev/null @@ -1,251 +0,0 @@ --- sasl.lua v0.4 --- Copyright (C) 2008-2010 Tobias Markmann --- --- All rights reserved. --- --- Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: --- --- * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. --- * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. --- * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. --- --- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -local tostring = tostring; -local type = type; - -local s_gmatch = string.gmatch; -local s_match = string.match; -local t_concat = table.concat; -local t_insert = table.insert; -local to_byte, to_char = string.byte, string.char; - -local md5 = require "util.hashes".md5; -local log = require "util.logger".init("sasl"); -local generate_uuid = require "util.uuid".generate; -local nodeprep = require "util.encodings".stringprep.nodeprep; - -local _ENV = nil; --- luacheck: std none - ---========================= ---SASL DIGEST-MD5 according to RFC 2831 - ---[[ -Supported Authentication Backends - -digest_md5: - function(username, domain, realm, encoding) -- domain and realm are usually the same; for some broken - -- implementations it's not - return digesthash, state; - end - -digest_md5_test: - function(username, domain, realm, encoding, digesthash) - return true or false, state; - end -]] - -local function digest(self, message) - --TODO complete support for authzid - - local function serialize(message) - local data = "" - - -- testing all possible values - if message["realm"] then data = data..[[realm="]]..message.realm..[[",]] end - if message["nonce"] then data = data..[[nonce="]]..message.nonce..[[",]] end - if message["qop"] then data = data..[[qop="]]..message.qop..[[",]] end - if message["charset"] then data = data..[[charset=]]..message.charset.."," end - if message["algorithm"] then data = data..[[algorithm=]]..message.algorithm.."," end - if message["rspauth"] then data = data..[[rspauth=]]..message.rspauth.."," end - data = data:gsub(",$", "") - return data - end - - local function utf8tolatin1ifpossible(passwd) - local i = 1; - while i <= #passwd do - local passwd_i = to_byte(passwd:sub(i, i)); - if passwd_i > 0x7F then - if passwd_i < 0xC0 or passwd_i > 0xC3 then - return passwd; - end - i = i + 1; - passwd_i = to_byte(passwd:sub(i, i)); - if passwd_i < 0x80 or passwd_i > 0xBF then - return passwd; - end - end - i = i + 1; - end - - local p = {}; - local j = 0; - i = 1; - while (i <= #passwd) do - local passwd_i = to_byte(passwd:sub(i, i)); - if passwd_i > 0x7F then - i = i + 1; - local passwd_i_1 = to_byte(passwd:sub(i, i)); - t_insert(p, to_char(passwd_i%4*64 + passwd_i_1%64)); -- I'm so clever - else - t_insert(p, to_char(passwd_i)); - end - i = i + 1; - end - return t_concat(p); - end - local function latin1toutf8(str) - local p = {}; - for ch in s_gmatch(str, ".") do - ch = to_byte(ch); - if (ch < 0x80) then - t_insert(p, to_char(ch)); - elseif (ch < 0xC0) then - t_insert(p, to_char(0xC2, ch)); - else - t_insert(p, to_char(0xC3, ch - 64)); - end - end - return t_concat(p); - end - local function parse(data) - local message = {} - -- COMPAT: %z in the pattern to work around jwchat bug (sends "charset=utf-8\0") - for k, v in s_gmatch(data, [[([%w%-]+)="?([^",%z]*)"?,?]]) do -- FIXME The hacky regex makes me shudder - message[k] = v; - end - return message; - end - - if not self.nonce then - self.nonce = generate_uuid(); - self.step = 0; - self.nonce_count = {}; - end - - self.step = self.step + 1; - if (self.step == 1) then - local challenge = serialize({ nonce = self.nonce, - qop = "auth", - charset = "utf-8", - algorithm = "md5-sess", - realm = self.realm}); - return "challenge", challenge; - elseif (self.step == 2) then - local response = parse(message); - -- check for replay attack - if response["nc"] then - if self.nonce_count[response["nc"]] then return "failure", "not-authorized" end - end - - -- check for username, it's REQUIRED by RFC 2831 - local username = response["username"]; - local _nodeprep = self.profile.nodeprep; - if username and _nodeprep ~= false then - username = (_nodeprep or nodeprep)(username); -- FIXME charset - end - if not username or username == "" then - return "failure", "malformed-request"; - end - self.username = username; - - -- check for nonce, ... - if not response["nonce"] then - return "failure", "malformed-request"; - else - -- check if it's the right nonce - if response["nonce"] ~= tostring(self.nonce) then return "failure", "malformed-request" end - end - - if not response["cnonce"] then return "failure", "malformed-request", "Missing entry for cnonce in SASL message." end - if not response["qop"] then response["qop"] = "auth" end - - if response["realm"] == nil or response["realm"] == "" then - response["realm"] = ""; - elseif response["realm"] ~= self.realm then - return "failure", "not-authorized", "Incorrect realm value"; - end - - local decoder; - if response["charset"] == nil then - decoder = utf8tolatin1ifpossible; - elseif response["charset"] ~= "utf-8" then - return "failure", "incorrect-encoding", "The client's response uses "..response["charset"].." for encoding with isn't supported by sasl.lua. Supported encodings are latin or utf-8."; - end - - local domain = ""; - local protocol = ""; - if response["digest-uri"] then - protocol, domain = response["digest-uri"]:match("(%w+)/(.*)$"); - if protocol == nil or domain == nil then return "failure", "malformed-request" end - else - return "failure", "malformed-request", "Missing entry for digest-uri in SASL message." - end - - --TODO maybe realm support - local Y, state; - if self.profile.plain then - local password, state = self.profile.plain(self, response["username"], self.realm) - if state == nil then return "failure", "not-authorized" - elseif state == false then return "failure", "account-disabled" end - Y = md5(response["username"]..":"..response["realm"]..":"..password); - elseif self.profile["digest-md5"] then - Y, state = self.profile["digest-md5"](self, response["username"], self.realm, response["realm"], response["charset"]) - if state == nil then return "failure", "not-authorized" - elseif state == false then return "failure", "account-disabled" end - elseif self.profile["digest-md5-test"] then - -- TODO - end - --local password_encoding, Y = self.credentials_handler("DIGEST-MD5", response["username"], self.realm, response["realm"], decoder); - --if Y == nil then return "failure", "not-authorized" - --elseif Y == false then return "failure", "account-disabled" end - local A1 = ""; - if response.authzid then - if response.authzid == self.username or response.authzid == self.username.."@"..self.realm then - -- COMPAT - log("warn", "Client is violating RFC 3920 (section 6.1, point 7)."); - A1 = Y..":"..response["nonce"]..":"..response["cnonce"]..":"..response.authzid; - else - return "failure", "invalid-authzid"; - end - else - A1 = Y..":"..response["nonce"]..":"..response["cnonce"]; - end - local A2 = "AUTHENTICATE:"..protocol.."/"..domain; - - local HA1 = md5(A1, true); - local HA2 = md5(A2, true); - - local KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2; - local response_value = md5(KD, true); - - if response_value == response["response"] then - -- calculate rspauth - A2 = ":"..protocol.."/"..domain; - - HA1 = md5(A1, true); - HA2 = md5(A2, true); - - KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2 - local rspauth = md5(KD, true); - self.authenticated = true; - --TODO: considering sending the rspauth in a success node for saving one roundtrip; allowed according to http://tools.ietf.org/html/draft-saintandre-rfc3920bis-09#section-7.3.6 - return "challenge", serialize({rspauth = rspauth}); - else - return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated." - end - elseif self.step == 3 then - if self.authenticated ~= nil then return "success" - else return "failure", "malformed-request" end - end -end - -local function init(registerMechanism) - registerMechanism("DIGEST-MD5", {"plain"}, digest); -end - -return { - init = init; -} diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua index f64feb8b..37abf4a4 100644 --- a/util/sasl/scram.lua +++ b/util/sasl/scram.lua @@ -14,16 +14,12 @@ local s_match = string.match; local type = type local base64 = require "util.encodings".base64; -local hmac_sha1 = require "util.hashes".hmac_sha1; -local sha1 = require "util.hashes".sha1; -local Hi = require "util.hashes".scram_Hi_sha1; +local hashes = require "util.hashes"; local generate_uuid = require "util.uuid".generate; local saslprep = require "util.encodings".stringprep.saslprep; local nodeprep = require "util.encodings".stringprep.nodeprep; local log = require "util.logger".init("sasl"); -local t_concat = table.concat; -local char = string.char; -local byte = string.byte; +local binaryXOR = require "util.strbitop".sxor; local _ENV = nil; -- luacheck: std none @@ -45,33 +41,7 @@ Supported Channel Binding Backends 'tls-unique' according to RFC 5929 ]] -local default_i = 4096 - -local xor_map = { - 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,1,0,3,2,5,4,7,6,9,8,11,10, - 13,12,15,14,2,3,0,1,6,7,4,5,10,11,8,9,14,15,12,13,3,2,1,0,7,6,5, - 4,11,10,9,8,15,14,13,12,4,5,6,7,0,1,2,3,12,13,14,15,8,9,10,11,5, - 4,7,6,1,0,3,2,13,12,15,14,9,8,11,10,6,7,4,5,2,3,0,1,14,15,12,13, - 10,11,8,9,7,6,5,4,3,2,1,0,15,14,13,12,11,10,9,8,8,9,10,11,12,13, - 14,15,0,1,2,3,4,5,6,7,9,8,11,10,13,12,15,14,1,0,3,2,5,4,7,6,10, - 11,8,9,14,15,12,13,2,3,0,1,6,7,4,5,11,10,9,8,15,14,13,12,3,2,1, - 0,7,6,5,4,12,13,14,15,8,9,10,11,4,5,6,7,0,1,2,3,13,12,15,14,9,8, - 11,10,5,4,7,6,1,0,3,2,14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1,15, - 14,13,12,11,10,9,8,7,6,5,4,3,2,1,0, -}; - -local result = {}; -local function binaryXOR( a, b ) - for i=1, #a do - local x, y = byte(a, i), byte(b, i); - local lowx, lowy = x % 16, y % 16; - local hix, hiy = (x - lowx) / 16, (y - lowy) / 16; - local lowr, hir = xor_map[lowx * 16 + lowy + 1], xor_map[hix * 16 + hiy + 1]; - local r = hir * 16 + lowr; - result[i] = char(r) - end - return t_concat(result); -end +local default_i = 10000 local function validate_username(username, _nodeprep) -- check for forbidden char sequences @@ -99,24 +69,26 @@ local function hashprep(hashname) return hashname:lower():gsub("-", "_"); end -local function getAuthenticationDatabaseSHA1(password, salt, iteration_count) - if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then - return false, "inappropriate argument types" - end - if iteration_count < 4096 then - log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.") - end - password = saslprep(password); - if not password then - return false, "password fails SASLprep"; +local function get_scram_hasher(H, HMAC, Hi) + return function (password, salt, iteration_count) + if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then + return false, "inappropriate argument types" + end + if iteration_count < 4096 then + log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.") + end + password = saslprep(password); + if not password then + return false, "password fails SASLprep"; + end + local salted_password = Hi(password, salt, iteration_count); + local stored_key = H(HMAC(salted_password, "Client Key")) + local server_key = HMAC(salted_password, "Server Key"); + return true, stored_key, server_key end - local salted_password = Hi(password, salt, iteration_count); - local stored_key = sha1(hmac_sha1(salted_password, "Client Key")) - local server_key = hmac_sha1(salted_password, "Server Key"); - return true, stored_key, server_key end -local function scram_gen(hash_name, H_f, HMAC_f) +local function scram_gen(hash_name, H_f, HMAC_f, get_auth_db, expect_cb) local profile_name = "scram_" .. hashprep(hash_name); local function scram_hash(self, message) local support_channel_binding = false; @@ -129,6 +101,7 @@ local function scram_gen(hash_name, H_f, HMAC_f) local client_first_message = message; -- TODO: fail if authzid is provided, since we don't support them yet + -- luacheck: ignore 211/authzid local gs2_header, gs2_cbind_flag, gs2_cbind_name, authzid, client_first_message_bare, username, clientnonce = s_match(client_first_message, "^(([pny])=?([^,]*),([^,]*),)(m?=?[^,]*,?n=([^,]*),r=([^,]*),?.*)$"); @@ -144,6 +117,10 @@ local function scram_gen(hash_name, H_f, HMAC_f) if gs2_cbind_flag == "n" then -- "n" -> client doesn't support channel binding. + if expect_cb then + log("debug", "Client unexpectedly doesn't support channel binding"); + -- XXX Is it sensible to abort if the client starts -PLUS but doesn't use channel binding? + end support_channel_binding = false; end @@ -181,7 +158,7 @@ local function scram_gen(hash_name, H_f, HMAC_f) iteration_count = default_i; local succ; - succ, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count); + succ, stored_key, server_key = get_auth_db(password, salt, iteration_count); if not succ then log("error", "Generating authentication database failed. Reason: %s", stored_key); return "failure", "temporary-auth-failure"; @@ -194,11 +171,11 @@ local function scram_gen(hash_name, H_f, HMAC_f) end local nonce = clientnonce .. generate_uuid(); - local server_first_message = "r="..nonce..",s="..base64.encode(salt)..",i="..iteration_count; + local server_first_message = ("r=%s,s=%s,i=%d"):format(nonce, base64.encode(salt), iteration_count); self.state = { gs2_header = gs2_header; gs2_cbind_name = gs2_cbind_name; - username = username; + username = self.username; -- Reference property instead of local, in case it was modified by the profile nonce = nonce; server_key = server_key; @@ -251,22 +228,28 @@ local function scram_gen(hash_name, H_f, HMAC_f) return scram_hash; end +local auth_db_getters = {} local function init(registerMechanism) - local function registerSCRAMMechanism(hash_name, hash, hmac_hash) + local function registerSCRAMMechanism(hash_name, hash, hmac_hash, pbkdf2) + local get_auth_db = get_scram_hasher(hash, hmac_hash, pbkdf2); + auth_db_getters[hash_name] = get_auth_db; registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, - scram_gen(hash_name:lower(), hash, hmac_hash)); + scram_gen(hash_name:lower(), hash, hmac_hash, get_auth_db)); -- register channel binding equivalent registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, - scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"}); + scram_gen(hash_name:lower(), hash, hmac_hash, get_auth_db, true), {"tls-unique"}); end - registerSCRAMMechanism("SHA-1", sha1, hmac_sha1); + registerSCRAMMechanism("SHA-1", hashes.sha1, hashes.hmac_sha1, hashes.pbkdf2_hmac_sha1); + registerSCRAMMechanism("SHA-256", hashes.sha256, hashes.hmac_sha256, hashes.pbkdf2_hmac_sha256); end return { - getAuthenticationDatabaseSHA1 = getAuthenticationDatabaseSHA1; + get_hash = get_scram_hasher; + hashers = auth_db_getters; + getAuthenticationDatabaseSHA1 = get_scram_hasher(hashes.sha1, hashes.hmac_sha1, hashes.pbkdf2_hmac_sha1); -- COMPAT init = init; } diff --git a/util/sasl_cyrus.lua b/util/sasl_cyrus.lua deleted file mode 100644 index a6bd0628..00000000 --- a/util/sasl_cyrus.lua +++ /dev/null @@ -1,169 +0,0 @@ --- sasl.lua v0.4 --- Copyright (C) 2008-2009 Tobias Markmann --- --- All rights reserved. --- --- Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: --- --- * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. --- * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. --- * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. --- --- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -local cyrussasl = require "cyrussasl"; -local log = require "util.logger".init("sasl_cyrus"); - -local setmetatable = setmetatable - -local pcall = pcall -local s_match, s_gmatch = string.match, string.gmatch - -local sasl_errstring = { - -- SASL result codes -- - [1] = "another step is needed in authentication"; - [0] = "successful result"; - [-1] = "generic failure"; - [-2] = "memory shortage failure"; - [-3] = "overflowed buffer"; - [-4] = "mechanism not supported"; - [-5] = "bad protocol / cancel"; - [-6] = "can't request info until later in exchange"; - [-7] = "invalid parameter supplied"; - [-8] = "transient failure (e.g., weak key)"; - [-9] = "integrity check failed"; - [-12] = "SASL library not initialized"; - - -- client only codes -- - [2] = "needs user interaction"; - [-10] = "server failed mutual authentication step"; - [-11] = "mechanism doesn't support requested feature"; - - -- server only codes -- - [-13] = "authentication failure"; - [-14] = "authorization failure"; - [-15] = "mechanism too weak for this user"; - [-16] = "encryption needed to use mechanism"; - [-17] = "One time use of a plaintext password will enable requested mechanism for user"; - [-18] = "passphrase expired, has to be reset"; - [-19] = "account disabled"; - [-20] = "user not found"; - [-23] = "version mismatch with plug-in"; - [-24] = "remote authentication server unavailable"; - [-26] = "user exists, but no verifier for user"; - - -- codes for password setting -- - [-21] = "passphrase locked"; - [-22] = "requested change was not needed"; - [-27] = "passphrase is too weak for security policy"; - [-28] = "user supplied passwords not permitted"; -}; -setmetatable(sasl_errstring, { __index = function() return "undefined error!" end }); - -local _ENV = nil; --- luacheck: std none - -local method = {}; -method.__index = method; -local initialized = false; - -local function init(service_name) - if not initialized then - local st, errmsg = pcall(cyrussasl.server_init, service_name); - if st then - initialized = true; - else - log("error", "Failed to initialize Cyrus SASL: %s", errmsg); - end - end -end - --- create a new SASL object which can be used to authenticate clients --- host_fqdn may be nil in which case gethostname() gives the value. --- For GSSAPI, this determines the hostname in the service ticket (after --- reverse DNS canonicalization, only if [libdefaults] rdns = true which --- is the default). -local function new(realm, service_name, app_name, host_fqdn) - - init(app_name or service_name); - - local st, ret = pcall(cyrussasl.server_new, service_name, host_fqdn, realm, nil, nil) - if not st then - log("error", "Creating SASL server connection failed: %s", ret); - return nil; - end - - local sasl_i = { realm = realm, service_name = service_name, cyrus = ret }; - - if cyrussasl.set_canon_cb then - local c14n_cb = function (user) - local node = s_match(user, "^([^@]+)"); - log("debug", "Canonicalizing username %s to %s", user, node) - return node - end - cyrussasl.set_canon_cb(sasl_i.cyrus, c14n_cb); - end - - cyrussasl.setssf(sasl_i.cyrus, 0, 0xffffffff) - local mechanisms = {}; - local cyrus_mechs = cyrussasl.listmech(sasl_i.cyrus, nil, "", " ", ""); - for w in s_gmatch(cyrus_mechs, "[^ ]+") do - mechanisms[w] = true; - end - sasl_i.mechs = mechanisms; - return setmetatable(sasl_i, method); -end - --- get a fresh clone with the same realm and service name -function method:clean_clone() - return new(self.realm, self.service_name) -end - --- get a list of possible SASL mechanims to use -function method:mechanisms() - return self.mechs; -end - --- select a mechanism to use -function method:select(mechanism) - if not self.selected and self.mechs[mechanism] then - self.selected = mechanism; - return true; - end -end - --- feed new messages to process into the library -function method:process(message) - local err; - local data; - - if not self.first_step_done then - err, data = cyrussasl.server_start(self.cyrus, self.selected, message or "") - self.first_step_done = true; - else - err, data = cyrussasl.server_step(self.cyrus, message or "") - end - - self.username = cyrussasl.get_username(self.cyrus) - - if (err == 0) then -- SASL_OK - if self.require_provisioning and not self.require_provisioning(self.username) then - return "failure", "not-authorized", "User authenticated successfully, but not provisioned for XMPP"; - end - return "success", data - elseif (err == 1) then -- SASL_CONTINUE - return "challenge", data - elseif (err == -4) then -- SASL_NOMECH - log("debug", "SASL mechanism not available from remote end") - return "failure", "invalid-mechanism", "SASL mechanism not available" - elseif (err == -13) then -- SASL_BADAUTH - return "failure", "not-authorized", sasl_errstring[err]; - else - log("debug", "Got SASL error condition %d: %s", err, sasl_errstring[err]); - return "failure", "undefined-condition", sasl_errstring[err]; - end -end - -return { - new = new; -}; diff --git a/util/serialization.lua b/util/serialization.lua index 5121a9f9..d310a3e8 100644 --- a/util/serialization.lua +++ b/util/serialization.lua @@ -16,22 +16,18 @@ local s_char = string.char; local s_match = string.match; local t_concat = table.concat; +local to_hex = require "util.hex".to; + local pcall = pcall; local envload = require"util.envload".envload; local pos_inf, neg_inf = math.huge, -math.huge; --- luacheck: ignore 143/math local m_type = math.type or function (n) return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float"; end; -local char_to_hex = {}; -for i = 0,255 do - char_to_hex[s_char(i)] = s_format("%02x", i); -end - -local function to_hex(s) - return (s_gsub(s, ".", char_to_hex)); +local function rawpairs(t) + return next, t, nil; end local function fatal_error(obj, why) @@ -123,6 +119,7 @@ local function new(opt) local freeze = opt.freeze; local maxdepth = opt.maxdepth or 127; local multirefs = opt.multiref; + local table_pairs = opt.table_iterator or rawpairs; -- serialize one table, recursively -- t - table being serialized @@ -153,6 +150,10 @@ local function new(opt) if type(fr) == "function" then t = fr(t); + if type(t) == "string" then + o[l], l = t, l + 1; + return l; + end if type(tag) == "string" then o[l], l = tag, l + 1; end @@ -164,7 +165,9 @@ local function new(opt) local indent = s_rep(indentwith, d); local numkey = 1; local ktyp, vtyp; - for k,v in next,t do + local had_items = false; + for k,v in table_pairs(t) do + had_items = true; o[l], l = itemstart, l + 1; o[l], l = indent, l + 1; ktyp, vtyp = type(k), type(v); @@ -195,14 +198,10 @@ local function new(opt) else o[l], l = ser(v), l + 1; end - -- last item? - if next(t, k) ~= nil then - o[l], l = itemsep, l + 1; - else - o[l], l = itemlast, l + 1; - end + o[l], l = itemsep, l + 1; end - if next(t) ~= nil then + if had_items then + o[l - 1] = itemlast; o[l], l = s_rep(indentwith, d-1), l + 1; end o[l], l = tend, l +1; diff --git a/util/session.lua b/util/session.lua index b2a726ce..25b22faf 100644 --- a/util/session.lua +++ b/util/session.lua @@ -4,12 +4,13 @@ local logger = require "util.logger"; local function new_session(typ) local session = { type = typ .. "_unauthed"; + base_type = typ; }; return session; end local function set_id(session) - local id = session.type .. tostring(session):match("%x+$"):lower(); + local id = session.base_type .. tostring(session):match("%x+$"):lower(); session.id = id; return session; end @@ -30,7 +31,7 @@ local function set_send(session) local conn = session.conn; if not conn then function session.send(data) - session.log("debug", "Discarding data sent to unconnected session: %s", tostring(data)); + session.log("debug", "Discarding data sent to unconnected session: %s", data); return false; end return session; @@ -46,7 +47,7 @@ local function set_send(session) if t then local ret, err = w(conn, t); if not ret then - session.log("debug", "Error writing to connection: %s", tostring(err)); + session.log("debug", "Error writing to connection: %s", err); return false, err; end end diff --git a/util/set.lua b/util/set.lua index b7345e7e..69dfef5d 100644 --- a/util/set.lua +++ b/util/set.lua @@ -6,8 +6,9 @@ -- COPYING file in the source package for more information. -- -local ipairs, pairs, getmetatable, setmetatable, next, tostring = - ipairs, pairs, getmetatable, setmetatable, next, tostring; +local ipairs, pairs, setmetatable, next, tostring = + ipairs, pairs, setmetatable, next, tostring; +local getmetatable = getmetatable; local t_concat = table.concat; local _ENV = nil; @@ -51,6 +52,15 @@ local function new(list) return items[item]; end + function set:contains_set(other_set) + for item in other_set do + if not self:contains(item) then + return false; + end + end + return true; + end + function set:items() return next, items; end @@ -151,6 +161,11 @@ function set_mt.__div(set, func) return new_set; end function set_mt.__eq(set1, set2) + if getmetatable(set1) ~= set_mt or getmetatable(set2) ~= set_mt then + -- Lua 5.3+ calls this if both operands are tables, even if metatables differ + return false; + end + set1, set2 = set1._items, set2._items; for item in pairs(set1) do if not set2[item] then diff --git a/util/smqueue.lua b/util/smqueue.lua new file mode 100644 index 00000000..6d8348d4 --- /dev/null +++ b/util/smqueue.lua @@ -0,0 +1,56 @@ +local queue = require("util.queue"); + +local lib = { smqueue = {} } + +local smqueue = lib.smqueue; + +function smqueue:push(v) + self._head = self._head + 1; + + assert(self._queue:push(v)); +end + +function smqueue:ack(h) + if h < self._tail then + return nil, "tail" + elseif h > self._head then + return nil, "head" + end + + local acked = {}; + self._tail = h; + local expect = self._head - self._tail; + while expect < self._queue:count() do + local v = self._queue:pop(); + if not v then return nil, "pop" end + table.insert(acked, v); + end + return acked +end + +function smqueue:count_unacked() return self._head - self._tail end + +function smqueue:count_acked() return self._tail end + +function smqueue:resumable() return self._queue:count() >= (self._head - self._tail) end + +function smqueue:resume() return self._queue:items() end + +function smqueue:consume() return self._queue:consume() end + +function smqueue:table() + local t = {}; + for i, v in self:resume() do t[i] = v; end + return t +end + +local function freeze(q) return { head = q._head; tail = q._tail } end + +local queue_mt = { __name = "smqueue"; __index = smqueue; __len = smqueue.count_unacked; __freeze = freeze } + +function lib.new(size) + assert(size > 0); + return setmetatable({ _head = 0; _tail = 0; _queue = queue.new(size, true) }, queue_mt) +end + +return lib diff --git a/util/sql.lua b/util/sql.lua index 00c7b57f..9d1c86ca 100644 --- a/util/sql.lua +++ b/util/sql.lua @@ -201,31 +201,31 @@ function engine:_transaction(func, ...) if not ok then return ok, err; end end --assert(not self.__transaction, "Recursive transactions not allowed"); - log("debug", "SQL transaction begin [%s]", tostring(func)); + log("debug", "SQL transaction begin [%s]", func); self.__transaction = true; local success, a, b, c = xpcall(func, handleerr, ...); self.__transaction = nil; if success then - log("debug", "SQL transaction success [%s]", tostring(func)); + log("debug", "SQL transaction success [%s]", func); local ok, err = self.conn:commit(); -- LuaDBI doesn't actually return an error message here, just a boolean if not ok then return ok, err or "commit failed"; end return success, a, b, c; else - log("debug", "SQL transaction failure [%s]: %s", tostring(func), a.err); + log("debug", "SQL transaction failure [%s]: %s", func, a.err); if self.conn then self.conn:rollback(); end return success, a.err; end end function engine:transaction(...) - local ok, ret = self:_transaction(...); + local ok, ret, b, c = self:_transaction(...); if not ok then local conn = self.conn; if not conn or not conn:ping() then log("debug", "Database connection was closed. Will reconnect and retry."); self.conn = nil; - log("debug", "Retrying SQL transaction [%s]", tostring((...))); - ok, ret = self:_transaction(...); + log("debug", "Retrying SQL transaction [%s]", (...)); + ok, ret, b, c = self:_transaction(...); log("debug", "SQL transaction retry %s", ok and "succeeded" or "failed"); else log("debug", "SQL connection is up, so not retrying"); @@ -234,7 +234,7 @@ function engine:transaction(...) log("error", "Error in SQL transaction: %s", ret); end end - return ok, ret; + return ok, ret, b, c; end function engine:_create_index(index) local sql = "CREATE INDEX \""..index.name.."\" ON \""..index.table.."\" ("; @@ -335,6 +335,9 @@ function engine:set_encoding() -- to UTF-8 local ok, actual_charset = self:transaction(function () return self:select"SHOW SESSION VARIABLES LIKE 'character_set_client'"; end); + if not ok then + return false, "Failed to detect connection encoding"; + end local charset_ok = true; for row in actual_charset do if row[2] ~= charset then diff --git a/util/sslconfig.lua b/util/sslconfig.lua index a5827a76..6074a1fb 100644 --- a/util/sslconfig.lua +++ b/util/sslconfig.lua @@ -67,6 +67,9 @@ end -- Curve list too finalisers.curveslist = finalisers.ciphers; +-- TLS 1.3 ciphers +finalisers.ciphersuites = finalisers.ciphers; + -- protocol = "x" should enable only that protocol -- protocol = "x+" should enable x and later versions diff --git a/util/stanza.lua b/util/stanza.lua index 3cd56c5f..86b88169 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -11,7 +11,6 @@ local error = error; local t_insert = table.insert; local t_remove = table.remove; local t_concat = table.concat; -local s_format = string.format; local s_match = string.match; local tostring = tostring; local setmetatable = setmetatable; @@ -22,20 +21,10 @@ local type = type; local s_gsub = string.gsub; local s_sub = string.sub; local s_find = string.find; -local os = os; local valid_utf8 = require "util.encodings".utf8.valid; -local do_pretty_printing = not os.getenv("WINDIR"); -local getstyle, getstring; -if do_pretty_printing then - local ok, termcolours = pcall(require, "util.termcolours"); - if ok then - getstyle, getstring = termcolours.getstyle, termcolours.getstring; - else - do_pretty_printing = nil; - end -end +local do_pretty_printing, termcolours = pcall(require, "util.termcolours"); local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas"; @@ -80,16 +69,11 @@ end local function check_attr(attr) if attr ~= nil then if type(attr) ~= "table" then - error("invalid attributes, expected table got "..type(attr)); + error("invalid attributes: expected table, got "..type(attr)); end for k, v in pairs(attr) do check_name(k, "attribute"); check_text(v, "attribute"); - if type(v) ~= "string" then - error("invalid attribute value for '"..k.."': expected string, got "..type(v)); - elseif not valid_utf8(v) then - error("invalid attribute value for '"..k.."': contains invalid utf8"); - end end end end @@ -110,7 +94,7 @@ function stanza_mt:query(xmlns) end function stanza_mt:body(text, attr) - return self:tag("body", attr):text(text); + return self:text_tag("body", text, attr); end function stanza_mt:text_tag(name, text, attr, namespaces) @@ -140,6 +124,10 @@ function stanza_mt:up() return self; end +function stanza_mt:at_top() + return self.last_add == nil or #self.last_add == 0 +end + function stanza_mt:reset() self.last_add = nil; return self; @@ -180,6 +168,7 @@ function stanza_mt:get_child(name, xmlns) return child; end end + return nil; end function stanza_mt:get_child_text(name, xmlns) @@ -194,12 +183,23 @@ function stanza_mt:child_with_name(name) for _, child in ipairs(self.tags) do if child.name == name then return child; end end + return nil; end function stanza_mt:child_with_ns(ns) for _, child in ipairs(self.tags) do if child.attr.xmlns == ns then return child; end end + return nil; +end + +function stanza_mt:get_child_with_attr(name, xmlns, attr_name, attr_value, normalize) + for tag in self:childtags(name, xmlns) do + if (normalize and normalize(tag.attr[attr_name]) or tag.attr[attr_name]) == attr_value then + return tag; + end + end + return nil; end function stanza_mt:children() @@ -282,6 +282,34 @@ function stanza_mt:find(path) until not self end +local function _clone(stanza, only_top) + local attr, tags = {}, {}; + for k,v in pairs(stanza.attr) do attr[k] = v; end + local old_namespaces, namespaces = stanza.namespaces; + if old_namespaces then + namespaces = {}; + for k,v in pairs(old_namespaces) do namespaces[k] = v; end + end + local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags }; + if not only_top then + for i=1,#stanza do + local child = stanza[i]; + if child.name then + child = _clone(child); + t_insert(tags, child); + end + t_insert(new, child); + end + end + return setmetatable(new, stanza_mt); +end + +local function clone(stanza, only_top) + if not is_stanza(stanza) then + error("bad argument to clone: expected stanza, got "..type(stanza)); + end + return _clone(stanza, only_top); +end local escape_table = { ["'"] = "'", ["\""] = """, ["<"] = "<", [">"] = ">", ["&"] = "&" }; local function xml_escape(str) return (s_gsub(str, "['&<>\"]", escape_table)); end @@ -322,25 +350,23 @@ function stanza_mt.__tostring(t) end function stanza_mt.top_tag(t) - local attr_string = ""; - if t.attr then - for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(" %s='%s'", k, xml_escape(tostring(v))); end end - end - return s_format("<%s%s>", t.name, attr_string); + local top_tag_clone = clone(t, true); + return tostring(top_tag_clone):sub(1,-3)..">"; end function stanza_mt.get_text(t) if #t.tags == 0 then return t_concat(t); end + return nil; end function stanza_mt.get_error(stanza) - local error_type, condition, text; + local error_type, condition, text, extra_tag; local error_tag = stanza:get_child("error"); if not error_tag then - return nil, nil, nil; + return nil, nil, nil, nil; end error_type = error_tag.attr.type; @@ -351,12 +377,14 @@ function stanza_mt.get_error(stanza) elseif not condition then condition = child.name; end - if condition and text then - break; - end + else + extra_tag = child; + end + if condition and text and extra_tag then + break; end end - return error_type, condition or "undefined-condition", text; + return error_type, condition or "undefined-condition", text, extra_tag; end local function preserialize(stanza) @@ -400,50 +428,32 @@ local function deserialize(serialized) end end -local function _clone(stanza) - local attr, tags = {}, {}; - for k,v in pairs(stanza.attr) do attr[k] = v; end - local old_namespaces, namespaces = stanza.namespaces; - if old_namespaces then - namespaces = {}; - for k,v in pairs(old_namespaces) do namespaces[k] = v; end - end - local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags }; - for i=1,#stanza do - local child = stanza[i]; - if child.name then - child = _clone(child); - t_insert(tags, child); - end - t_insert(new, child); - end - return setmetatable(new, stanza_mt); -end - -local function clone(stanza) - if not is_stanza(stanza) then - error("bad argument to clone: expected stanza, got "..type(stanza)); - end - return _clone(stanza); -end - local function message(attr, body) if not body then return new_stanza("message", attr); else - return new_stanza("message", attr):tag("body"):text(body):up(); + return new_stanza("message", attr):text_tag("body", body); end end local function iq(attr) - if not (attr and attr.id) then + if not attr then + error("iq stanzas require id and type attributes"); + end + if not attr.id then error("iq stanzas require an id attribute"); end + if not attr.type then + error("iq stanzas require a type attribute"); + end return new_stanza("iq", attr); end local function reply(orig) + if not is_stanza(orig) then + error("bad argument to reply: expected stanza, got "..type(orig)); + end return new_stanza(orig.name, - orig.attr and { + { to = orig.attr.from, from = orig.attr.to, id = orig.attr.id, @@ -452,12 +462,37 @@ local function reply(orig) end local xmpp_stanzas_attr = { xmlns = xmlns_stanzas }; -local function error_reply(orig, error_type, condition, error_message) +local function error_reply(orig, error_type, condition, error_message, error_by) + if not is_stanza(orig) then + error("bad argument to error_reply: expected stanza, got "..type(orig)); + elseif orig.attr.type == "error" then + error("bad argument to error_reply: got stanza of type error which must not be replied to"); + end local t = reply(orig); t.attr.type = "error"; - t:tag("error", {type = error_type}) --COMPAT: Some day xmlns:stanzas goes here - :tag(condition, xmpp_stanzas_attr):up(); - if error_message then t:tag("text", xmpp_stanzas_attr):text(error_message):up(); end + local extra; + if type(error_type) == "table" then -- an util.error or similar object + if type(error_type.extra) == "table" then + extra = error_type.extra; + end + if type(error_type.context) == "table" and type(error_type.context.by) == "string" then error_by = error_type.context.by; end + error_type, condition, error_message = error_type.type, error_type.condition, error_type.text; + end + if t.attr.from == error_by then + error_by = nil; + end + t:tag("error", {type = error_type, by = error_by}) --COMPAT: Some day xmlns:stanzas goes here + :tag(condition, xmpp_stanzas_attr); + if extra and condition == "gone" and type(extra.uri) == "string" then + t:text(extra.uri); + end + t:up(); + if error_message then t:text_tag("text", error_message, xmpp_stanzas_attr); end + if extra and is_stanza(extra.tag) then + t:add_child(extra.tag); + elseif extra and extra.namespace and extra.condition then + t:tag(extra.condition, { xmlns = extra.namespace }):up(); + end return t; -- stanza ready for adding app-specific errors end @@ -465,39 +500,50 @@ local function presence(attr) return new_stanza("presence", attr); end +local pretty; if do_pretty_printing then - local style_attrk = getstyle("yellow"); - local style_attrv = getstyle("red"); - local style_tagname = getstyle("red"); - local style_punc = getstyle("magenta"); - - local attr_format = " "..getstring(style_attrk, "%s")..getstring(style_punc, "=")..getstring(style_attrv, "'%s'"); - local top_tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">"); - --local tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">").."%s"..getstring(style_punc, "</")..getstring(style_tagname, "%s")..getstring(style_punc, ">"); - local tag_format = top_tag_format.."%s"..getstring(style_punc, "</")..getstring(style_tagname, "%s")..getstring(style_punc, ">"); - function stanza_mt.pretty_print(t) - local children_text = ""; - for _, child in ipairs(t) do - if type(child) == "string" then - children_text = children_text .. xml_escape(child); - else - children_text = children_text .. child:pretty_print(); - end - end + local getstyle, getstring = termcolours.getstyle, termcolours.getstring; + + local blue1 = getstyle("1b3967"); + local blue2 = getstyle("13b5ea"); + local green1 = getstyle("439639"); + local green2 = getstyle("a0ce67"); + local orange1 = getstyle("d9541e"); + local orange2 = getstyle("e96d1f"); + + local attr_replace = ( + getstring(green2, "%1") .. -- attr name + getstring(green1, "%2") .. -- equal + getstring(orange1, "%3") .. -- quote + getstring(orange2, "%4") .. -- attr value + getstring(orange1, "%5") -- quote + ); + + local text_replace = ( + getstring(green1, "%1") .. -- & + getstring(green2, "%2") .. -- amp + getstring(green1, "%3") -- ; + ); + + function pretty(s) + -- Tag soup color + -- Outer gsub call takes each <tag>, applies colour to the brackets, the + -- tag name, then applies one inner gsub call to colour the attributes and + -- another for any text content. + return (s:gsub("(<[?/]?)([^ >/?]*)(.-)([?/]?>)([^<]*)", function(opening_bracket, tag_name, attrs, closing_bracket, content) + return getstring(blue1, opening_bracket)..getstring(blue2, tag_name).. + attrs:gsub("([^=]+)(=)([\"'])(.-)([\"'])", attr_replace) .. + getstring(blue1, closing_bracket) .. + content:gsub("(&#?)(%w+)(;)", text_replace); + end, 100)); + end - local attr_string = ""; - if t.attr then - for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(attr_format, k, tostring(v)); end end - end - return s_format(tag_format, t.name, attr_string, children_text, t.name); + function stanza_mt.pretty_print(t) + return pretty(tostring(t)); end function stanza_mt.pretty_top_tag(t) - local attr_string = ""; - if t.attr then - for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(attr_format, k, tostring(v)); end end - end - return s_format(top_tag_format, t.name, attr_string); + return pretty(t:top_tag()); end else -- Sorry, fresh out of colours for you guys ;) @@ -505,6 +551,36 @@ else stanza_mt.pretty_top_tag = stanza_mt.top_tag; end +function stanza_mt.indent(t, level, indent) + if #t == 0 or (#t == 1 and type(t[1]) == "string") then + -- Empty nodes wouldn't have any indentation + -- Text-only nodes are preserved as to not alter the text content + -- Optimization: Skip clone of these since we don't alter them + return t; + end + + indent = indent or "\t"; + level = level or 1; + local tag = clone(t, true); + + for child in t:children() do + if type(child) == "string" then + -- Already indented text would look weird but let's ignore that for now. + if child:find("%S") then + tag:text("\n" .. indent:rep(level)); + tag:text(child); + end + elseif is_stanza(child) then + tag:text("\n" .. indent:rep(level)); + tag:add_direct_child(child:indent(level+1, indent)); + end + end + -- before the closing tag + tag:text("\n" .. indent:rep((level-1))); + + return tag; +end + return { stanza_mt = stanza_mt; stanza = new_stanza; @@ -518,4 +594,5 @@ return { error_reply = error_reply; presence = presence; xml_escape = xml_escape; + pretty_print = pretty; }; diff --git a/util/startup.lua b/util/startup.lua index 602dfe5e..10ff1875 100644 --- a/util/startup.lua +++ b/util/startup.lua @@ -5,8 +5,10 @@ local startup = {}; local prosody = { events = require "util.events".new() }; local logger = require "util.logger"; local log = logger.init("startup"); +local parse_args = require "util.argparse".parse; local config = require "core.configmanager"; +local config_warnings; local dependencies = require "util.dependencies"; @@ -20,59 +22,45 @@ local default_gc_params = { minor_threshold = 20, major_threshold = 50; }; -local short_params = { D = "daemonize", F = "no-daemonize" }; -local value_params = { config = true }; - -function startup.parse_args() - local parsed_opts = {}; - prosody.opts = parsed_opts; - - if #arg == 0 then - return; - end - while true do - local raw_param = arg[1]; - if not raw_param then - break; - end - - local prefix = raw_param:match("^%-%-?"); - if not prefix then - break; - elseif prefix == "--" and raw_param == "--" then - table.remove(arg, 1); - break; - end - local param = table.remove(arg, 1):sub(#prefix+1); - if #param == 1 then - param = short_params[param]; +local arg_settigs = { + prosody = { + short_params = { D = "daemonize"; F = "no-daemonize", h = "help", ["?"] = "help" }; + value_params = { config = true }; + }; + prosodyctl = { + short_params = { v = "verbose", h = "help", ["?"] = "help" }; + value_params = { config = true }; + }; +} + +function startup.parse_args(profile) + local opts, err, where = parse_args(arg, arg_settigs[profile or prosody.process_type] or profile); + if not opts then + if err == "param-not-found" then + print("Unknown command-line option: "..tostring(where)); + if prosody.process_type == "prosody" then + print("Perhaps you meant to use prosodyctl instead?"); + end + elseif err == "missing-value" then + print("Expected a value to follow command-line option: "..where); end - - if not param then - print("Unknown command-line option: "..tostring(raw_param)); - print("Perhaps you meant to use prosodyctl instead?"); - os.exit(1); + os.exit(1); + end + if prosody.process_type == "prosody" then + if #arg > 0 then + print("Unrecognized option: "..arg[1]); + print("(Did you mean 'prosodyctl "..arg[1].."'?)"); + print(""); end - - local param_k, param_v; - if value_params[param] then - param_k, param_v = param, table.remove(arg, 1); - if not param_v then - print("Expected a value to follow command-line option: "..raw_param); - os.exit(1); - end - else - param_k, param_v = param:match("^([^=]+)=(.+)$"); - if not param_k then - if param:match("^no%-") then - param_k, param_v = param:sub(4), false; - else - param_k, param_v = param, true; - end - end + if opts.help or #arg > 0 then + print("prosody [ -D | -F ] [ --config /path/to/prosody.cfg.lua ]"); + print(" -D, --daemonize Run in the background") + print(" -F, --no-daemonize Run in the foreground") + print(" --config FILE Specify config file") + os.exit(0); end - parsed_opts[param_k] = param_v; end + prosody.opts = opts; end function startup.read_config() @@ -127,6 +115,8 @@ function startup.read_config() print("**************************"); print(""); os.exit(1); + elseif err and #err > 0 then + config_warnings = err; end prosody.config_loaded = true; end @@ -142,6 +132,7 @@ end function startup.load_libraries() -- Load socket framework -- luacheck: ignore 111/server 111/socket + require "util.import" socket = require "socket"; server = require "net.server" end @@ -159,8 +150,13 @@ function startup.init_logging() end); end -function startup.log_dependency_warnings() +function startup.log_startup_warnings() dependencies.log_warnings(); + if config_warnings then + for _, warning in ipairs(config_warnings) do + log("warn", "Configuration warning: %s", warning); + end + end end function startup.sanity_check() @@ -229,8 +225,15 @@ function startup.set_function_metatable() end end function mt.__tostring(f) - local info = debug.getinfo(f); - return ("function(%s:%d)"):format(info.short_src:match("[^\\/]*$"), info.linedefined); + local info = debug.getinfo(f, "Su"); + local n_params = info.nparams or 0; + for i = 1, n_params do + info[i] = debug.getlocal(f, i); + end + if info.isvararg then + info[n_params+1] = "..."; + end + return ("function<%s:%d>(%s)"):format(info.short_src:match("[^\\/]*$"), info.linedefined, table.concat(info, ", ")); end debug.setmetatable(function() end, mt); end @@ -282,8 +285,8 @@ end function startup.setup_plugindir() local custom_plugin_paths = config.get("*", "plugin_paths"); + local path_sep = package.config:sub(3,3); if custom_plugin_paths then - local path_sep = package.config:sub(3,3); -- path1;path2;path3;defaultpath... -- luacheck: ignore 111 CFG_PLUGINDIR = table.concat(custom_plugin_paths, path_sep)..path_sep..(CFG_PLUGINDIR or "plugins"); @@ -291,6 +294,17 @@ function startup.setup_plugindir() end end +function startup.setup_plugin_install_path() + local installer_plugin_path = config.get("*", "installer_plugin_path") or "custom_plugins"; + local path_sep = package.config:sub(3,3); + installer_plugin_path = config.resolve_relative_path(CFG_DATADIR or "data", installer_plugin_path); + require"util.paths".complement_lua_path(installer_plugin_path); + -- luacheck: ignore 111 + CFG_PLUGINDIR = installer_plugin_path..path_sep..(CFG_PLUGINDIR or "plugins"); + prosody.paths.installer = installer_plugin_path; + prosody.paths.plugins = CFG_PLUGINDIR; +end + function startup.chdir() if prosody.installed then local lfs = require "lfs"; @@ -312,9 +326,9 @@ function startup.add_global_prosody_functions() local ok, level, err = config.load(prosody.config_file); if not ok then if level == "parser" then - log("error", "There was an error parsing the configuration file: %s", tostring(err)); + log("error", "There was an error parsing the configuration file: %s", err); elseif level == "file" then - log("error", "Couldn't read the config file when trying to reload: %s", tostring(err)); + log("error", "Couldn't read the config file when trying to reload: %s", err); end else prosody.events.fire_event("config-reloaded", { @@ -340,13 +354,12 @@ function startup.add_global_prosody_functions() reason = reason; code = code; }); - server.setquitting(true); + prosody.main_thread:run(startup.shutdown); end end function startup.load_secondary_libraries() --- Load and initialise core modules - require "util.import" require "util.xmppstream" require "core.stanza_router" require "core.statsmanager" @@ -387,6 +400,22 @@ function startup.init_http_client() local https_client = config.get("*", "client_https_ssl") http.default.options.sslctx = require "core.certmanager".create_context("client_https port 0", "client", { capath = config_ssl.capath, cafile = config_ssl.cafile, verify = "peer", }, https_client); + http.default.options.use_dane = config.get("*", "use_dane") +end + +function startup.init_promise() + local promise = require "util.promise"; + + local timer = require "util.timer"; + promise.set_nexttick(function(f) return timer.add_task(0, f); end); +end + +function startup.init_async() + local async = require "util.async"; + + local timer = require "util.timer"; + async.set_nexttick(function(f) return timer.add_task(0, f); end); + async.set_schedule_function(timer.add_task); end function startup.init_data_store() @@ -448,7 +477,18 @@ end -- Override logging config (used by prosodyctl) function startup.force_console_logging() original_logging_config = config.get("*", "log"); - config.set("*", "log", { { levels = { min = os.getenv("PROSODYCTL_LOG_LEVEL") or "info" }, to = "console" } }); + local log_level = os.getenv("PROSODYCTL_LOG_LEVEL"); + if not log_level then + if prosody.opts.verbose then + log_level = "debug"; + elseif prosody.opts.quiet then + log_level = "error"; + elseif prosody.opts.silent then + config.set("*", "log", {}); -- ssssshush! + return + end + end + config.set("*", "log", { { levels = { min = log_level or "info" }, to = "console" } }); end function startup.switch_user() @@ -486,9 +526,9 @@ function startup.switch_user() if not prosody.switched_user then -- Boo! print("Warning: Couldn't switch to Prosody user/group '"..tostring(desired_user).."'/'"..tostring(desired_group).."': "..tostring(err)); - else + elseif prosody.config_file then -- Make sure the Prosody user can read the config - local conf, err, errno = io.open(prosody.config_file); + local conf, err, errno = io.open(prosody.config_file); --luacheck: ignore 211/errno if conf then conf:close(); else @@ -565,6 +605,10 @@ function startup.init_gc() return true; end +function startup.init_errors() + require "util.error".configure(config.get("*", "error_library") or {}); +end + function startup.make_host(hostname) return { type = "local", @@ -587,21 +631,44 @@ function startup.make_dummy_hosts() end end +function startup.cleanup() + prosody.log("info", "Shutdown status: Cleaning up"); + prosody.events.fire_event("server-cleanup"); +end + +function startup.shutdown() + prosody.log("info", "Shutting down..."); + startup.cleanup(); + prosody.events.fire_event("server-stopped"); + prosody.log("info", "Shutdown complete"); + + prosody.log("debug", "Shutdown reason was: %s", prosody.shutdown_reason or "not specified"); + prosody.log("debug", "Exiting with status code: %d", prosody.shutdown_code or 0); + server.setquitting(true); +end + +function startup.exit() + os.exit(prosody.shutdown_code); +end + -- prosodyctl only function startup.prosodyctl() + prosody.process_type = "prosodyctl"; startup.parse_args(); startup.init_global_state(); startup.read_config(); startup.force_console_logging(); startup.init_logging(); startup.init_gc(); + startup.init_errors(); startup.setup_plugindir(); + startup.setup_plugin_install_path(); startup.setup_datadir(); startup.chdir(); startup.read_version(); startup.switch_user(); startup.check_dependencies(); - startup.log_dependency_warnings(); + startup.log_startup_warnings(); startup.check_unwriteable(); startup.load_libraries(); startup.init_http_client(); @@ -611,24 +678,29 @@ end function startup.prosody() -- These actions are in a strict order, as many depend on -- previous steps to have already been performed + prosody.process_type = "prosody"; startup.parse_args(); startup.init_global_state(); startup.read_config(); startup.init_logging(); startup.init_gc(); + startup.init_errors(); startup.sanity_check(); startup.sandbox_require(); startup.set_function_metatable(); startup.check_dependencies(); startup.load_libraries(); startup.setup_plugindir(); + startup.setup_plugin_install_path(); startup.setup_datadir(); startup.chdir(); startup.add_global_prosody_functions(); startup.read_version(); startup.log_greeting(); - startup.log_dependency_warnings(); + startup.log_startup_warnings(); startup.load_secondary_libraries(); + startup.init_promise(); + startup.init_async(); startup.init_http_client(); startup.init_data_store(); startup.init_global_protection(); diff --git a/util/statistics.lua b/util/statistics.lua index 39954652..cb6481c5 100644 --- a/util/statistics.lua +++ b/util/statistics.lua @@ -1,160 +1,191 @@ -local t_sort = table.sort -local m_floor = math.floor; local time = require "util.time".now; +local new_metric_registry = require "util.openmetrics".new_metric_registry; +local render_histogram_le = require "util.openmetrics".render_histogram_le; -local function nop_function() end +-- BEGIN of Metric implementations -local function percentile(arr, length, pc) - local n = pc/100 * (length + 1); - local k, d = m_floor(n), n%1; - if k == 0 then - return arr[1] or 0; - elseif k >= length then - return arr[length]; - end - return arr[k] + d*(arr[k+1] - arr[k]); +-- Gauges +local gauge_metric_mt = {} +gauge_metric_mt.__index = gauge_metric_mt + +local function new_gauge_metric() + local metric = { value = 0 } + setmetatable(metric, gauge_metric_mt) + return metric +end + +function gauge_metric_mt:set(value) + self.value = value +end + +function gauge_metric_mt:add(delta) + self.value = self.value + delta end -local function new_registry(config) - config = config or {}; - local duration_sample_interval = config.duration_sample_interval or 5; - local duration_max_samples = config.duration_max_stored_samples or 5000; +function gauge_metric_mt:reset() + self.value = 0 +end - local function get_distribution_stats(events, n_actual_events, since, new_time, units) - local n_stored_events = #events; - t_sort(events); - local sum = 0; - for i = 1, n_stored_events do - sum = sum + events[i]; +function gauge_metric_mt:iter_samples() + local done = false + return function(_s) + if done then + return nil, true end + done = true + return "", nil, _s.value + end, self +end - return { - samples = events; - sample_count = n_stored_events; - count = n_actual_events, - rate = n_actual_events/(new_time-since); - average = n_stored_events > 0 and sum/n_stored_events or 0, - min = events[1] or 0, - max = events[n_stored_events] or 0, - units = units, - }; - end +-- Counters +local counter_metric_mt = {} +counter_metric_mt.__index = counter_metric_mt + +local function new_counter_metric() + local metric = { + _created = time(), + value = 0, + } + setmetatable(metric, counter_metric_mt) + return metric +end +function counter_metric_mt:set(value) + self.value = value +end - local registry = {}; - local methods; - methods = { - amount = function (name, initial) - local v = initial or 0; - registry[name..":amount"] = function () return "amount", v; end - return function (new_v) v = new_v; end - end; - counter = function (name, initial) - local v = initial or 0; - registry[name..":amount"] = function () return "amount", v; end - return function (delta) - v = v + delta; - end; - end; - rate = function (name) - local since, n = time(), 0; - registry[name..":rate"] = function () - local t = time(); - local stats = { - rate = n/(t-since); - count = n; - }; - since, n = t, 0; - return "rate", stats.rate, stats; - end; - return function () - n = n + 1; - end; - end; - distribution = function (name, unit, type) - type = type or "distribution"; - local events, last_event = {}, 0; - local n_actual_events = 0; - local since = time(); - - registry[name..":"..type] = function () - local new_time = time(); - local stats = get_distribution_stats(events, n_actual_events, since, new_time, unit); - events, last_event = {}, 0; - n_actual_events = 0; - since = new_time; - return type, stats.average, stats; - end; - - return function (value) - n_actual_events = n_actual_events + 1; - if n_actual_events%duration_sample_interval == 1 then - last_event = (last_event%duration_max_samples) + 1; - events[last_event] = value; - end - end; - end; - sizes = function (name) - return methods.distribution(name, "bytes", "size"); - end; - times = function (name) - local events, last_event = {}, 0; - local n_actual_events = 0; - local since = time(); - - registry[name..":duration"] = function () - local new_time = time(); - local stats = get_distribution_stats(events, n_actual_events, since, new_time, "seconds"); - events, last_event = {}, 0; - n_actual_events = 0; - since = new_time; - return "duration", stats.average, stats; - end; - - return function () - n_actual_events = n_actual_events + 1; - if n_actual_events%duration_sample_interval ~= 1 then - return nop_function; - end - - local start_time = time(); - return function () - local end_time = time(); - local duration = end_time - start_time; - last_event = (last_event%duration_max_samples) + 1; - events[last_event] = duration; - end - end; - end; - - get_stats = function () - return registry; - end; - }; - return methods; +function counter_metric_mt:add(value) + self.value = (self.value or 0) + value end -return { - new = new_registry; - get_histogram = function (duration, n_buckets) - n_buckets = n_buckets or 100; - local events, n_events = duration.samples, duration.sample_count; - if not (events and n_events) then - return nil, "not a valid distribution stat"; +function counter_metric_mt:iter_samples() + local step = 0 + return function(_s) + step = step + 1 + if step == 1 then + return "_created", nil, _s._created + elseif step == 2 then + return "_total", nil, _s.value + else + return nil, nil, true + end + end, self +end + +function counter_metric_mt:reset() + self.value = 0 +end + +-- Histograms +local histogram_metric_mt = {} +histogram_metric_mt.__index = histogram_metric_mt + +local function new_histogram_metric(buckets) + local metric = { + _created = time(), + _sum = 0, + _count = 0, + } + -- the order of buckets matters unfortunately, so we cannot directly use + -- the thresholds as table keys + for i, threshold in ipairs(buckets) do + metric[i] = { + threshold = threshold, + threshold_s = render_histogram_le(threshold), + count = 0 + } + end + setmetatable(metric, histogram_metric_mt) + return metric +end + +function histogram_metric_mt:sample(value) + -- According to the I-D, values must be part of all buckets + for i, bucket in pairs(self) do + if "number" == type(i) and value <= bucket.threshold then + bucket.count = bucket.count + 1 end - local histogram = {}; + end + self._sum = self._sum + value + self._count = self._count + 1 +end - for i = 1, 100, 100/n_buckets do - histogram[i] = percentile(events, n_events, i); +function histogram_metric_mt:iter_samples() + local key = nil + return function (_s) + local data + key, data = next(_s, key) + if key == "_created" or key == "_sum" or key == "_count" then + return key, nil, data + elseif key ~= nil then + return "_bucket", {["le"] = data.threshold_s}, data.count + else + return nil, nil, nil end - return histogram; - end; + end, self +end - get_percentile = function (duration, pc) - local events, n_events = duration.samples, duration.sample_count; - if not (events and n_events) then - return nil, "not a valid distribution stat"; +function histogram_metric_mt:reset() + self._created = time() + self._count = 0 + self._sum = 0 + for i, bucket in pairs(self) do + if "number" == type(i) then + bucket.count = 0 end - return percentile(events, n_events, pc); - end; + end +end + +-- Summary +local summary_metric_mt = {} +summary_metric_mt.__index = summary_metric_mt + +local function new_summary_metric() + -- quantiles are not supported yet + local metric = { + _created = time(), + _sum = 0, + _count = 0, + } + setmetatable(metric, summary_metric_mt) + return metric +end + +function summary_metric_mt:sample(value) + self._sum = self._sum + value + self._count = self._count + 1 +end + +function summary_metric_mt:iter_samples() + local key = nil + return function (_s) + local data + key, data = next(_s, key) + return key, nil, data + end, self +end + +function summary_metric_mt:reset() + self._created = time() + self._count = 0 + self._sum = 0 +end + +local pull_backend = { + gauge = new_gauge_metric, + counter = new_counter_metric, + histogram = new_histogram_metric, + summary = new_summary_metric, +} + +-- END of Metric implementations + +local function new() + return { + metric_registry = new_metric_registry(pull_backend), + } +end + +return { + new = new; } diff --git a/util/statsd.lua b/util/statsd.lua index 67481c36..581f945a 100644 --- a/util/statsd.lua +++ b/util/statsd.lua @@ -1,82 +1,267 @@ local socket = require "socket"; +local time = require "util.time".now; +local array = require "util.array"; +local t_concat = table.concat; -local time = require "util.time".now +local new_metric_registry = require "util.openmetrics".new_metric_registry; +local render_histogram_le = require "util.openmetrics".render_histogram_le; -local function new(config) - if not config or not config.statsd_server then - return nil, "No statsd server specified in the config, please see https://prosody.im/doc/statistics"; +-- BEGIN of Metric implementations + +-- Gauges +local gauge_metric_mt = {} +gauge_metric_mt.__index = gauge_metric_mt + +local function new_gauge_metric(full_name, impl) + local metric = { + _full_name = full_name; + _impl = impl; + value = 0; + } + setmetatable(metric, gauge_metric_mt) + return metric +end + +function gauge_metric_mt:set(value) + self.value = value + self._impl:push_gauge(self._full_name, value) +end + +function gauge_metric_mt:add(delta) + self.value = self.value + delta + self._impl:push_gauge(self._full_name, self.value) +end + +function gauge_metric_mt:reset() + self.value = 0 + self._impl:push_gauge(self._full_name, 0) +end + +function gauge_metric_mt.iter_samples() + -- statsd backend does not support iteration. + return function() + return nil end +end - local sock = socket.udp(); - sock:setpeername(config.statsd_server, config.statsd_port or 8125); +-- Counters +local counter_metric_mt = {} +counter_metric_mt.__index = counter_metric_mt - local prefix = (config.prefix or "prosody").."."; +local function new_counter_metric(full_name, impl) + local metric = { + _full_name = full_name, + _impl = impl, + value = 0, + } + setmetatable(metric, counter_metric_mt) + return metric +end + +function counter_metric_mt:set(value) + local delta = value - self.value + self.value = value + self._impl:push_counter_delta(self._full_name, delta) +end - local function send_metric(s) - return sock:send(prefix..s); +function counter_metric_mt:add(value) + self.value = (self.value or 0) + value + self._impl:push_counter_delta(self._full_name, value) +end + +function counter_metric_mt.iter_samples() + -- statsd backend does not support iteration. + return function() + return nil + end +end + +function counter_metric_mt:reset() + self.value = 0 +end + +-- Histograms +local histogram_metric_mt = {} +histogram_metric_mt.__index = histogram_metric_mt + +local function new_histogram_metric(buckets, full_name, impl) + -- NOTE: even though the more or less proprietary dogstatsd has Its own + -- histogram implementation, we push the individual buckets in this statsd + -- backend for both consistency and compatibility across statsd + -- implementations. + local metric = { + _sum_name = full_name..".sum", + _count_name = full_name..".count", + _impl = impl, + _created = time(), + _sum = 0, + _count = 0, + } + -- the order of buckets matters unfortunately, so we cannot directly use + -- the thresholds as table keys + for i, threshold in ipairs(buckets) do + local threshold_s = render_histogram_le(threshold) + metric[i] = { + threshold = threshold, + threshold_s = threshold_s, + count = 0, + _full_name = full_name..".bucket."..(threshold_s:gsub("%.", "_")), + } end + setmetatable(metric, histogram_metric_mt) + return metric +end - local function send_gauge(name, amount, relative) - local s_amount = tostring(amount); - if relative and amount > 0 then - s_amount = "+"..s_amount; +function histogram_metric_mt:sample(value) + -- According to the I-D, values must be part of all buckets + for i, bucket in pairs(self) do + if "number" == type(i) and value <= bucket.threshold then + bucket.count = bucket.count + 1 + self._impl:push_counter_delta(bucket._full_name, 1) end - return send_metric(name..":"..s_amount.."|g"); end + self._sum = self._sum + value + self._count = self._count + 1 + self._impl:push_gauge(self._sum_name, self._sum) + self._impl:push_counter_delta(self._count_name, 1) +end - local function send_counter(name, amount) - return send_metric(name..":"..tostring(amount).."|c"); +function histogram_metric_mt.iter_samples() + -- statsd backend does not support iteration. + return function() + return nil end +end - local function send_duration(name, duration) - return send_metric(name..":"..tostring(duration).."|ms"); +function histogram_metric_mt:reset() + self._created = time() + self._count = 0 + self._sum = 0 + for i, bucket in pairs(self) do + if "number" == type(i) then + bucket.count = 0 + end end + self._impl:push_gauge(self._sum_name, self._sum) +end + +-- Summaries +local summary_metric_mt = {} +summary_metric_mt.__index = summary_metric_mt + +local function new_summary_metric(full_name, impl) + local metric = { + _sum_name = full_name..".sum", + _count_name = full_name..".count", + _impl = impl, + } + setmetatable(metric, summary_metric_mt) + return metric +end + +function summary_metric_mt:sample(value) + self._impl:push_counter_delta(self._sum_name, value) + self._impl:push_counter_delta(self._count_name, 1) +end - local function send_histogram_sample(name, sample) - return send_metric(name..":"..tostring(sample).."|h"); +function summary_metric_mt.iter_samples() + -- statsd backend does not support iteration. + return function() + return nil end +end - local methods; - methods = { - amount = function (name, initial) - if initial then - send_gauge(name, initial); - end - return function (new_v) send_gauge(name, new_v); end - end; - counter = function (name, initial) --luacheck: ignore 212/initial - return function (delta) - send_gauge(name, delta, true); - end; - end; - rate = function (name) - return function () - send_counter(name, 1); - end; +function summary_metric_mt.reset() +end + +-- BEGIN of statsd client implementation + +local statsd_mt = {} +statsd_mt.__index = statsd_mt + +function statsd_mt:cork() + self.corked = true + self.cork_buffer = self.cork_buffer or {} +end + +function statsd_mt:uncork() + self.corked = false + self:_flush_cork_buffer() +end + +function statsd_mt:_flush_cork_buffer() + local buffer = self.cork_buffer + for metric_name, value in pairs(buffer) do + self:_send_gauge(metric_name, value) + buffer[metric_name] = nil + end +end + +function statsd_mt:push_gauge(metric_name, value) + if self.corked then + self.cork_buffer[metric_name] = value + else + self:_send_gauge(metric_name, value) + end +end + +function statsd_mt:_send_gauge(metric_name, value) + self:_send(self.prefix..metric_name..":"..tostring(value).."|g") +end + +function statsd_mt:push_counter_delta(metric_name, delta) + self:_send(self.prefix..metric_name..":"..tostring(delta).."|c") +end + +function statsd_mt:_send(s) + return self.sock:send(s) +end + +-- END of statsd client implementation + +local function build_metric_name(family_name, labels) + local parts = array { family_name } + if labels then + parts:append(labels) + end + return t_concat(parts, "/"):gsub("%.", "_"):gsub("/", ".") +end + +local function new(config) + if not config or not config.statsd_server then + return nil, "No statsd server specified in the config, please see https://prosody.im/doc/statistics"; + end + + local sock = socket.udp(); + sock:setpeername(config.statsd_server, config.statsd_port or 8125); + + local prefix = (config.prefix or "prosody").."."; + + local impl = { + metric_registry = nil; + sock = sock; + prefix = prefix; + }; + setmetatable(impl, statsd_mt) + + local backend = { + gauge = function(family_name, labels) + return new_gauge_metric(build_metric_name(family_name, labels), impl) end; - distribution = function (name, unit, type) --luacheck: ignore 212/unit 212/type - return function (value) - send_histogram_sample(name, value); - end; + counter = function(family_name, labels) + return new_counter_metric(build_metric_name(family_name, labels), impl) end; - sizes = function (name) - name = name.."_size"; - return function (value) - send_histogram_sample(name, value); - end; + histogram = function(buckets, family_name, labels) + return new_histogram_metric(buckets, build_metric_name(family_name, labels), impl) end; - times = function (name) - return function () - local start_time = time(); - return function () - local end_time = time(); - local duration = end_time - start_time; - send_duration(name, duration*1000); - end - end; + summary = function(family_name, labels, extra) + return new_summary_metric(build_metric_name(family_name, labels), impl, extra) end; }; - return methods; + + impl.metric_registry = new_metric_registry(backend); + + return impl; end return { diff --git a/util/termcolours.lua b/util/termcolours.lua index 829d84af..2c13d0ff 100644 --- a/util/termcolours.lua +++ b/util/termcolours.lua @@ -83,7 +83,7 @@ end setmetatable(stylemap, { __index = function(_, style) if type(style) == "string" and style:find("%x%x%x%x%x%x") == 1 then local g = style:sub(7) == " background" and "48;5;" or "38;5;"; - return g .. color(hex2rgb(style)); + return format("%s%d", g, color(hex2rgb(style))); end end } ); diff --git a/util/timer.lua b/util/timer.lua index bc3836be..84da02cf 100644 --- a/util/timer.lua +++ b/util/timer.lua @@ -17,6 +17,11 @@ local xpcall = require "util.xpcall".xpcall; local math_max = math.max; local pairs = pairs; +if server.timer then + -- The selected net.server implements this API, so defer to that + return server.timer; +end + local _ENV = nil; -- luacheck: std none diff --git a/util/uuid.lua b/util/uuid.lua index f4fd21f6..54ea99b4 100644 --- a/util/uuid.lua +++ b/util/uuid.lua @@ -8,7 +8,7 @@ local random = require "util.random"; local random_bytes = random.bytes; -local hex = require "util.hex".to; +local hex = require "util.hex".encode; local m_ceil = math.ceil; local function get_nibbles(n) diff --git a/util/vcard.lua b/util/vcard.lua index bb299fab..e311f73f 100644 --- a/util/vcard.lua +++ b/util/vcard.lua @@ -29,7 +29,7 @@ local function vCard_unesc(s) ["\\n"] = "\n", ["\\r"] = "\r", ["\\t"] = "\t", - ["\\:"] = ":", -- FIXME Shouldn't need to espace : in values, just params + ["\\:"] = ":", -- FIXME Shouldn't need to escape : in values, just params ["\\;"] = ";", ["\\,"] = ",", [":"] = "\29", diff --git a/util/x509.lua b/util/x509.lua index 15cc4d3c..76b50076 100644 --- a/util/x509.lua +++ b/util/x509.lua @@ -20,9 +20,12 @@ local nameprep = require "util.encodings".stringprep.nameprep; local idna_to_ascii = require "util.encodings".idna.to_ascii; +local idna_to_unicode = require "util.encodings".idna.to_unicode; local base64 = require "util.encodings".base64; local log = require "util.logger".init("x509"); +local mt = require "util.multitable"; local s_format = string.format; +local ipairs = ipairs; local _ENV = nil; -- luacheck: std none @@ -216,6 +219,63 @@ local function verify_identity(host, service, cert) return false end +-- TODO Support other SANs +local function get_identities(cert) --> map of names to sets of services + if cert.setencode then + cert:setencode("utf8"); + end + + local names = mt.new(); + + local ext = cert:extensions(); + local sans = ext[oid_subjectaltname]; + if sans then + if sans["dNSName"] then -- Valid for any service + for _, name in ipairs(sans["dNSName"]) do + local is_wildcard = name:sub(1, 2) == "*."; + if is_wildcard then name = name:sub(3); end + name = idna_to_unicode(nameprep(name)); + if name then + if is_wildcard then name = "*." .. name; end + names:set(name, "*", true); + end + end + end + if sans[oid_xmppaddr] then + for _, name in ipairs(sans[oid_xmppaddr]) do + name = nameprep(name); + if name then + names:set(name, "xmpp-client", true); + names:set(name, "xmpp-server", true); + end + end + end + if sans[oid_dnssrv] then + for _, srvname in ipairs(sans[oid_dnssrv]) do + local srv, name = srvname:match("^_([^.]+)%.(.*)"); + if srv then + name = nameprep(name); + if name then + names:set(name, srv, true); + end + end + end + end + end + + local subject = cert:subject(); + for i = 1, #subject do + local dn = subject[i]; + if dn.oid == oid_commonname then + local name = nameprep(dn.value); + if name and idna_to_ascii(name) then + names:set(name, "*", true); + end + end + end + return names.data; +end + local pat = "%-%-%-%-%-BEGIN ([A-Z ]+)%-%-%-%-%-\r?\n".. "([0-9A-Za-z+/=\r\n]*)\r?\n%-%-%-%-%-END %1%-%-%-%-%-"; @@ -237,6 +297,7 @@ end return { verify_identity = verify_identity; + get_identities = get_identities; pem2der = pem2der; der2pem = der2pem; }; diff --git a/util/xml.lua b/util/xml.lua index 4327dfba..2bf1ff4e 100644 --- a/util/xml.lua +++ b/util/xml.lua @@ -72,11 +72,14 @@ local parse_xml = (function() end end handler.StartDoctypeDecl = restricted_handler; - handler.ProcessingInstruction = restricted_handler; if not options or not options.allow_comments then -- NOTE: comments are generally harmless and can be useful when parsing configuration files or other data, even user-provided data handler.Comment = restricted_handler; end + if not options or not options.allow_processing_instructions then + -- Processing instructions should generally be safe to just ignore + handler.ProcessingInstruction = restricted_handler; + end local parser = lxp.new(handler, ns_separator); local ok, err, line, col = parser:parse(xml); if ok then ok, err, line, col = parser:parse(); end @@ -84,7 +87,7 @@ local parse_xml = (function() if ok then return stanza.tags[1]; else - return ok, err.." (line "..line..", col "..col..")"; + return ok, ("%s (line %d, col %d))"):format(err, line, col); end end; end)(); diff --git a/util/xmppstream.lua b/util/xmppstream.lua index 82a9820f..be113396 100644 --- a/util/xmppstream.lua +++ b/util/xmppstream.lua @@ -64,6 +64,8 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress) local stream_default_ns = stream_callbacks.default_ns; + local stream_lang = "en"; + local stack = {}; local chardata, stanza = {}; local stanza_size = 0; @@ -101,6 +103,7 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress) if session.notopen then if tagname == stream_tag then non_streamns_depth = 0; + stream_lang = attr["xml:lang"] or stream_lang; if cb_streamopened then if lxp_supports_bytecount then cb_handleprogress(stanza_size); @@ -178,6 +181,9 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress) cb_handleprogress(stanza_size); end stanza_size = 0; + if stanza.attr["xml:lang"] == nil then + stanza.attr["xml:lang"] = stream_lang; + end if tagname ~= stream_error_tag then cb_handlestanza(session, stanza); else @@ -259,14 +265,13 @@ local function new(session, stream_callbacks, stanza_size_limit) ["xml:lang"] = "en", xmlns = stream_callbacks.default_ns, version = session.version and (session.version > 0 and "1.0" or nil), - id = session.streamid, + id = session.streamid or "", from = from or session.host, to = to, }; if session.stream_attrs then session:stream_attrs(from, to, attr) end - send("<?xml version='1.0'?>"); - send(st.stanza("stream:stream", attr):top_tag()); + send("<?xml version='1.0'?>"..st.stanza("stream:stream", attr):top_tag()); return true; end diff --git a/util/xtemplate.lua b/util/xtemplate.lua new file mode 100644 index 00000000..254c8af0 --- /dev/null +++ b/util/xtemplate.lua @@ -0,0 +1,86 @@ +local s_gsub = string.gsub; +local s_match = string.match; +local s_sub = string.sub; +local t_concat = table.concat; + +local st = require("util.stanza"); + +local function render(template, root, escape, filters) + escape = escape or st.xml_escape; + + return (s_gsub(template, "%b{}", function(block) + local inner = s_sub(block, 2, -2); + local path, pipe, pos = s_match(inner, "^([^|]+)(|?)()"); + if not (type(path) == "string") then return end + local value + if path == "." then + value = root; + elseif path == "#" then + value = root:get_text(); + else + value = root:find(path); + end + local is_escaped = false; + + while pipe == "|" do + local func, args, tmpl, p = s_match(inner, "^(%w+)(%b())(%b{})()", pos); + if not func then func, args, p = s_match(inner, "^(%w+)(%b())()", pos); end + if not func then func, tmpl, p = s_match(inner, "^(%w+)(%b{})()", pos); end + if not func then func, p = s_match(inner, "^(%w+)()", pos); end + if not func then break end + if tmpl then tmpl = s_sub(tmpl, 2, -2); end + if args then args = s_sub(args, 2, -2); end + + if func == "each" and tmpl and st.is_stanza(value) then + if not args then value, args = root, path; end + local ns, name = s_match(args, "^(%b{})(.*)$"); + if ns then + ns = s_sub(ns, 2, -2); + else + name, ns = args, nil; + end + if ns == "" then ns = nil; end + if name == "" then name = nil; end + local out, i = {}, 1; + for c in (value):childtags(name, ns) do out[i], i = render(tmpl, c, escape, filters), i + 1; end + value = t_concat(out); + is_escaped = true; + elseif func == "and" and tmpl then + local condition = value; + if args then condition = root:find(args); end + if condition then + value = render(tmpl, root, escape, filters); + is_escaped = true; + end + elseif func == "or" and tmpl then + local condition = value; + if args then condition = root:find(args); end + if not condition then + value = render(tmpl, root, escape, filters); + is_escaped = true; + end + elseif filters and filters[func] then + local f = filters[func]; + if args == nil then + value, is_escaped = f(value, tmpl); + else + value, is_escaped = f(args, value, tmpl); + end + else + error("No such filter function: " .. func); + end + pipe, pos = s_match(inner, "^(|?)()", p); + end + + if type(value) == "string" then + if not is_escaped then value = escape(value); end + return value + elseif st.is_stanza(value) then + value = value:get_text(); + if value then return escape(value) end + end + return "" + end)) +end + +return { render = render } |