diff options
191 files changed, 7222 insertions, 2018 deletions
diff --git a/.luacheckrc b/.luacheckrc index 85035b86..9e0d1fdf 100644 --- a/.luacheckrc +++ b/.luacheckrc @@ -2,7 +2,7 @@ cache = true codes = true ignore = { "411/err", "421/err", "411/ok", "421/ok", "211/_ENV", "431/log", "214", "581" } -std = "lua53c" +std = "lua54c" max_line_length = 150 read_globals = { @@ -62,6 +62,8 @@ files["plugins/"] = { "module.broadcast", "module.context", "module.depends", + "module.default_permission", + "module.default_permissions", "module.fire_event", "module.get_directory", "module.get_host", @@ -86,6 +88,7 @@ files["plugins/"] = { "module.load_resource", "module.log", "module.log_status", + "module.may", "module.measure", "module.metric", "module.open_store", @@ -149,10 +152,6 @@ if os.getenv("PROSODY_STRICT_LINT") ~= "1" then "net/dns.lua"; "net/server_select.lua"; - "util/vcard.lua"; - - "plugins/mod_storage_sql1.lua"; - "spec/core_moduleapi_spec.lua"; "spec/util_http_spec.lua"; "spec/util_ip_spec.lua"; @@ -171,6 +170,7 @@ if os.getenv("PROSODY_STRICT_LINT") ~= "1" then "tools/migration/migrator/prosody_sql.lua"; "tools/migration/prosody-migrator.lua"; "tools/openfire2prosody.lua"; + "tools/test_mutants.sh.lua"; "tools/xep227toprosody.lua"; } for _, file in ipairs(exclude_files) do diff --git a/.semgrep.yml b/.semgrep.yml index de1ef89e..22bfcfea 100644 --- a/.semgrep.yml +++ b/.semgrep.yml @@ -22,3 +22,9 @@ rules: message: Non-string default from :get_option_string severity: ERROR languages: [lua] +- id: stanza-empty-text-constructor + patterns: + - pattern: $A:text() + message: Use :get_text() to read text, or pass a value here to add text + severity: WARNING + languages: [lua] @@ -1,3 +1,47 @@ +TRUNK +===== + +## New + +### Administration + +- Add 'watch log' command to follow live debug logs at runtime (even if disabled) + +### Networking + +- Honour 'weight' parameter during SRV record selection +- Support for RFC 8305 "Happy Eyeballs" to improve IPv4/IPv6 connectivity +- Support for TCP Fast Open in server_epoll (pending LuaSocket support) +- Support for deferred accept in server_epoll (pending LuaSocket support) + +### MUC + +- Permissions updates: + - Room creation restricted to local users (of the parent host) by default + - restrict_room_creation = true restricts to admins, false disables all restrictions + - Persistent rooms can only be created by local users (parent host) by default + - muc_room_allow_persistent = false restricts to admins + - Public rooms can only be created by local users (parent host) by default + - muc_room_allow_public = false restricts to admins +- Commands to show occupants and affiliations in the Shell + +### Security and authentication + +- Advertise supported SASL Channel-Binding types (XEP-0440) +- Implement RFC 9266 'tls-exporter' channel binding with TLS 1.3 +- New role and permissions framework and API + +## Changes + +- Support sub-second precision timestamps +- mod_blocklist: New option 'migrate_legacy_blocking' to disable migration from mod_privacy +- Ability to use SQLite3 storage using LuaSQLite3 instead of LuaDBI + +## Removed + +- Lua 5.1 support +- XEP-0090 support removed from mod_time + 0.12.0 ====== diff --git a/GNUmakefile b/GNUmakefile index e9ec78c4..c8d2d3dd 100644 --- a/GNUmakefile +++ b/GNUmakefile @@ -71,12 +71,13 @@ install-util: util/encodings.so util/encodings.so util/pposix.so util/signal.so install-plugins: $(MKDIR) $(MODULES) - $(MKDIR) $(MODULES)/mod_pubsub $(MODULES)/adhoc $(MODULES)/muc $(MODULES)/mod_mam + $(MKDIR) $(MODULES)/mod_pubsub $(MODULES)/adhoc $(MODULES)/muc $(MODULES)/mod_mam $(MODULES)/mod_debug_stanzas $(INSTALL_DATA) plugins/*.lua $(MODULES) $(INSTALL_DATA) plugins/mod_pubsub/*.lua $(MODULES)/mod_pubsub $(INSTALL_DATA) plugins/adhoc/*.lua $(MODULES)/adhoc $(INSTALL_DATA) plugins/muc/*.lua $(MODULES)/muc $(INSTALL_DATA) plugins/mod_mam/*.lua $(MODULES)/mod_mam + $(INSTALL_DATA) plugins/mod_debug_stanzas/*.lua $(MODULES)/mod_debug_stanzas install-man: $(MKDIR) $(MAN)/man1 diff --git a/certs/openssl.cnf b/certs/openssl.cnf index ee17b1cf..8a05ecc3 100644 --- a/certs/openssl.cnf +++ b/certs/openssl.cnf @@ -46,7 +46,7 @@ subjectAltName = @subject_alternative_name [ subject_alternative_name ] -# See http://tools.ietf.org/html/rfc6120#section-13.7.1.2 for more info. +# See https://www.rfc-editor.org/rfc/rfc6120.html#section-13.7.1.2 for more info. DNS.0 = example.com otherName.0 = xmppAddr;FORMAT:UTF8,UTF8:example.com @@ -45,7 +45,7 @@ Configure $APP_NAME prior to building. Default is \$PREFIX/lib --datadir=DIR Location where the server data should be stored. Default is \$PREFIX/var/lib/$APP_DIRNAME ---lua-version=VERSION Use specific Lua version: 5.1, 5.2, or 5.3 +--lua-version=VERSION Use specific Lua version: 5.2, 5.3, or 5.4 Default is auto-detected. --lua-suffix=SUFFIX Versioning suffix to use in Lua filenames. Default is "$LUA_SUFFIX" (lua$LUA_SUFFIX...) @@ -173,7 +173,8 @@ do --lua-version|--with-lua-version) [ -n "$value" ] || die "Missing value in flag $key." LUA_VERSION="$value" - [ "$LUA_VERSION" = "5.1" ] || [ "$LUA_VERSION" = "5.2" ] || [ "$LUA_VERSION" = "5.3" ] || [ "$LUA_VERSION" = "5.4" ] || die "Invalid Lua version in flag $key." + [ "$LUA_VERSION" != "5.1" ] || die "Lua 5.1 is no longer supported" + [ "$LUA_VERSION" = "5.2" ] || [ "$LUA_VERSION" = "5.3" ] || [ "$LUA_VERSION" = "5.4" ] || die "Invalid Lua version in flag $key." LUA_VERSION_SET=yes ;; --with-lua) @@ -275,11 +276,11 @@ if [ "$OSPRESET_SET" = "yes" ]; then CFLAGS="$CFLAGS -ggdb" fi if [ "$OSPRESET" = "freebsd" ] || [ "$OSPRESET" = "openbsd" ]; then - LUA_INCDIR="/usr/local/include/lua51" + LUA_INCDIR="/usr/local/include/lua52" LUA_INCDIR_SET=yes CFLAGS="-Wall -fPIC -I/usr/local/include" LDFLAGS="-I/usr/local/include -L/usr/local/lib -shared" - LUA_SUFFIX="51" + LUA_SUFFIX="52" LUA_SUFFIX_SET=yes LUA_DIR=/usr/local LUA_DIR_SET=yes @@ -291,16 +292,16 @@ if [ "$OSPRESET_SET" = "yes" ]; then LUA_INCDIR_SET="yes" fi if [ "$OSPRESET" = "netbsd" ]; then - LUA_INCDIR="/usr/pkg/include/lua-5.1" + LUA_INCDIR="/usr/pkg/include/lua-5.2" LUA_INCDIR_SET=yes - LUA_LIBDIR="/usr/pkg/lib/lua/5.1" + LUA_LIBDIR="/usr/pkg/lib/lua/5.2" LUA_LIBDIR_SET=yes CFLAGS="-Wall -fPIC -I/usr/pkg/include" LDFLAGS="-L/usr/pkg/lib -Wl,-rpath,/usr/pkg/lib -shared" fi if [ "$OSPRESET" = "pkg-config" ]; then if [ "$LUA_SUFFIX_SET" != "yes" ]; then - LUA_SUFFIX="5.1"; + LUA_SUFFIX="5.4"; LUA_SUFFIX_SET=yes fi LUA_CF="$(pkg-config --cflags-only-I lua"$LUA_SUFFIX")" @@ -335,7 +336,7 @@ then fi detect_lua_version() { - detected_lua=$("$1" -e 'print(_VERSION:match(" (5%.[1234])$"))' 2> /dev/null) + detected_lua=$("$1" -e 'print(_VERSION:match(" (5%.[234])$"))' 2> /dev/null) if [ "$detected_lua" != "nil" ] then if [ "$LUA_VERSION_SET" != "yes" ] @@ -389,10 +390,7 @@ search_interpreter() { lua_interp_found=no if [ "$LUA_SUFFIX_SET" != "yes" ] then - if [ "$LUA_VERSION_SET" = "yes" ] && [ "$LUA_VERSION" = "5.1" ] - then - suffixes="5.1 51 -5.1 -51" - elif [ "$LUA_VERSION_SET" = "yes" ] && [ "$LUA_VERSION" = "5.2" ] + if [ "$LUA_VERSION_SET" = "yes" ] && [ "$LUA_VERSION" = "5.2" ] then suffixes="5.2 52 -5.2 -52" elif [ "$LUA_VERSION_SET" = "yes" ] && [ "$LUA_VERSION" = "5.3" ] @@ -402,8 +400,7 @@ then then suffixes="5.4 54 -5.4 -54" else - suffixes="5.1 51 -5.1 -51" - suffixes="$suffixes 5.2 52 -5.2 -52" + suffixes="5.2 52 -5.2 -52" suffixes="$suffixes 5.3 53 -5.3 -53" suffixes="$suffixes 5.4 54 -5.4 -54" fi diff --git a/core/certmanager.lua b/core/certmanager.lua index 7a82c786..0c71e448 100644 --- a/core/certmanager.lua +++ b/core/certmanager.lua @@ -9,9 +9,8 @@ local ssl = require "ssl"; local configmanager = require "core.configmanager"; local log = require "util.logger".init("certmanager"); -local ssl_context = ssl.context or require "ssl.context"; local ssl_newcontext = ssl.newcontext; -local new_config = require"util.sslconfig".new; +local new_config = require"net.server".tls_builder; local stat = require "lfs".attributes; local x509 = require "util.x509"; @@ -313,10 +312,6 @@ else core_defaults.curveslist = nil; end -local path_options = { -- These we pass through resolve_path() - key = true, certificate = true, cafile = true, capath = true, dhparam = true -} - local function create_context(host, mode, ...) local cfg = new_config(); cfg:apply(core_defaults); @@ -352,34 +347,7 @@ local function create_context(host, mode, ...) if user_ssl_config.certificate and not user_ssl_config.key then return nil, "No key present in SSL/TLS configuration for "..host; end end - for option in pairs(path_options) do - if type(user_ssl_config[option]) == "string" then - user_ssl_config[option] = resolve_path(config_path, user_ssl_config[option]); - else - user_ssl_config[option] = nil; - end - end - - -- LuaSec expects dhparam to be a callback that takes two arguments. - -- We ignore those because it is mostly used for having a separate - -- set of params for EXPORT ciphers, which we don't have by default. - if type(user_ssl_config.dhparam) == "string" then - local f, err = io_open(user_ssl_config.dhparam); - if not f then return nil, "Could not open DH parameters: "..err end - local dhparam = f:read("*a"); - f:close(); - user_ssl_config.dhparam = function() return dhparam; end - end - - local ctx, err = ssl_newcontext(user_ssl_config); - - -- COMPAT Older LuaSec ignores the cipher list from the config, so we have to take care - -- of it ourselves (W/A for #x) - if ctx and user_ssl_config.ciphers then - local success; - success, err = ssl_context.setcipher(ctx, user_ssl_config.ciphers); - if not success then ctx = nil; end - end + local ctx, err = cfg:build(); if not ctx then err = err or "invalid ssl config" diff --git a/core/configmanager.lua b/core/configmanager.lua index 092b3946..4b8df96e 100644 --- a/core/configmanager.lua +++ b/core/configmanager.lua @@ -40,16 +40,10 @@ function _M.getconfig() return config; end -function _M.get(host, key, _oldkey) - if key == "core" then - key = _oldkey; -- COMPAT with code that still uses "core" - end +function _M.get(host, key) return config[host][key]; end -function _M.rawget(host, key, _oldkey) - if key == "core" then - key = _oldkey; -- COMPAT with code that still uses "core" - end +function _M.rawget(host, key) local hostconfig = rawget(config, host); if hostconfig then return rawget(hostconfig, key); @@ -68,10 +62,7 @@ local function set(config_table, host, key, value) return false; end -function _M.set(host, key, value, _oldvalue) - if key == "core" then - key, value = value, _oldvalue; --COMPAT with code that still uses "core" - end +function _M.set(host, key, value) return set(config, host, key, value); end diff --git a/core/features.lua b/core/features.lua index 7248f881..96023b09 100644 --- a/core/features.lua +++ b/core/features.lua @@ -4,5 +4,7 @@ return { available = set.new{ -- mod_bookmarks bundled "mod_bookmarks"; + -- Roles, module.may and per-session authz + "permissions"; }; }; diff --git a/core/moduleapi.lua b/core/moduleapi.lua index 870a6a50..5f9ee3c7 100644 --- a/core/moduleapi.lua +++ b/core/moduleapi.lua @@ -19,6 +19,7 @@ local promise = require "util.promise"; local time_now = require "util.time".now; local format = require "util.format".format; local jid_node = require "util.jid".node; +local jid_split = require "util.jid".split; local jid_resource = require "util.jid".resource; local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat; @@ -26,8 +27,8 @@ local error, setmetatable, type = error, setmetatable, type; local ipairs, pairs, select = ipairs, pairs, select; local tonumber, tostring = tonumber, tostring; local require = require; -local pack = table.pack or require "util.table".pack; -- table.pack is only in 5.2 -local unpack = table.unpack or unpack; --luacheck: ignore 113 -- renamed in 5.2 +local pack = table.pack; +local unpack = table.unpack; local prosody = prosody; local hosts = prosody.hosts; @@ -537,6 +538,7 @@ function api:load_resource(path, mode) end function api:open_store(name, store_type) + if self.host == "*" then return nil, "global-storage-not-supported"; end return require"core.storagemanager".open(self.host, name or self.name, store_type); end @@ -578,7 +580,7 @@ local status_priorities = { error = 3, warn = 2, info = 1, core = 0 }; function api:set_status(status_type, status_message, override) local priority = status_priorities[status_type]; if not priority then - self:log("error", "set_status: Invalid status type '%s', assuming 'info'"); + self:log("error", "set_status: Invalid status type '%s', assuming 'info'", status_type); status_type, priority = "info", status_priorities.info; end local current_priority = status_priorities[self.status_type] or 0; @@ -601,4 +603,73 @@ function api:get_status() return self.status_type, self.status_message, self.status_time; end +function api:default_permission(role_name, permission) + permission = permission:gsub("^:", self.name..":"); + if self.host == "*" then + for _, host in pairs(hosts) do + if host.authz then + host.authz.add_default_permission(role_name, permission); + end + end + return + end + hosts[self.host].authz.add_default_permission(role_name, permission); +end + +function api:default_permissions(role_name, permissions) + for _, permission in ipairs(permissions) do + self:default_permission(role_name, permission); + end +end + +function api:may(action, context) + if action:byte(1) == 58 then -- action begins with ':' + action = self.name..action; -- prepend module name + end + if type(context) == "string" then -- check JID permissions + local role; + local node, host = jid_split(context); + if host == self.host then + role = hosts[host].authz.get_user_role(node); + else + role = hosts[self.host].authz.get_jid_role(context); + end + if not role then + self:log("debug", "Access denied: JID <%s> may not %s (no role found)", context, action); + return false; + end + local permit = role:may(action); + if not permit then + self:log("debug", "Access denied: JID <%s> may not %s (not permitted by role %s)", context, action, role.name); + end + return permit; + end + + local session = context.origin or context.session; + if type(session) ~= "table" then + error("Unable to identify actor session from context"); + end + if session.role and session.type == "c2s" and session.host == self.host then + local permit = session.role:may(action, context); + if not permit then + self:log("debug", "Access denied: session %s (%s) may not %s (not permitted by role %s)", + session.id, session.full_jid, action, session.role.name + ); + end + return permit; + else + local actor_jid = context.stanza.attr.from; + local role = hosts[self.host].authz.get_jid_role(actor_jid); + if not role then + self:log("debug", "Access denied: JID <%s> may not %s (no role found)", actor_jid, action); + return false; + end + local permit = role:may(action, context); + if not permit then + self:log("debug", "Access denied: JID <%s> may not %s (not permitted by role %s)", actor_jid, action, role.name); + end + return permit; + end +end + return api; diff --git a/core/portmanager.lua b/core/portmanager.lua index 38c74b66..8c7dfddb 100644 --- a/core/portmanager.lua +++ b/core/portmanager.lua @@ -240,21 +240,22 @@ local function add_sni_host(host, service) log("debug", "Gathering certificates for SNI for host %s, %s service", host, service or "default"); for name, interface, port, n, active_service --luacheck: ignore 213 in active_services:iter(service, nil, nil, nil) do - if active_service.server.hosts and active_service.tls_cfg then - local config_prefix = (active_service.config_prefix or name).."_"; - if config_prefix == "_" then config_prefix = ""; end - local prefix_ssl_config = config.get(host, config_prefix.."ssl"); + if active_service.server and active_service.tls_cfg then local alternate_host = name and config.get(host, name.."_host"); if not alternate_host and name == "https" then -- TODO should this be some generic thing? e.g. in the service definition alternate_host = config.get(host, "http_host"); end local autocert = certmanager.find_host_cert(alternate_host or host); - -- luacheck: ignore 211/cfg - local ssl, err, cfg = certmanager.create_context(host, "server", prefix_ssl_config, autocert, active_service.tls_cfg); - if ssl then - active_service.server.hosts[alternate_host or host] = ssl; - else + local manualcert = active_service.tls_cfg; + local certificate = (autocert and autocert.certificate) or manualcert.certificate; + local key = (autocert and autocert.key) or manualcert.key; + local ok, err = active_service.server:sslctx():set_sni_host( + host, + certificate, + key + ); + if not ok then log("error", "Error creating TLS context for SNI host %s: %s", host, err); end end @@ -277,7 +278,7 @@ prosody.events.add_handler("host-deactivated", function (host) for name, interface, port, n, active_service --luacheck: ignore 213 in active_services:iter(nil, nil, nil, nil) do if active_service.tls_cfg then - active_service.server.hosts[host] = nil; + active_service.server:sslctx():remove_sni_host(host) end end end); diff --git a/core/sessionmanager.lua b/core/sessionmanager.lua index f8279bb4..509568c6 100644 --- a/core/sessionmanager.lua +++ b/core/sessionmanager.lua @@ -10,7 +10,7 @@ local tostring, setmetatable = tostring, setmetatable; local pairs, next= pairs, next; -local hosts = prosody.hosts; +local prosody, hosts = prosody, prosody.hosts; local full_sessions = prosody.full_sessions; local bare_sessions = prosody.bare_sessions; @@ -92,6 +92,51 @@ local function retire_session(session) return setmetatable(session, resting_session); end +-- Update a session with a new one (transplanting connection, filters, etc.) +-- new_session should be discarded after this call returns +local function update_session(to_session, from_session) + to_session.log("debug", "Updating with parameters from session %s", from_session.id); + from_session.log("debug", "Session absorbed into %s", to_session.id); + + local replaced_conn = to_session.conn; + if replaced_conn then + to_session.log("debug", "closing a replaced connection for this session"); + replaced_conn:close(); + end + + to_session.ip = from_session.ip; + to_session.conn = from_session.conn; + to_session.rawsend = from_session.rawsend; + to_session.rawsend.session = to_session; + to_session.rawsend.conn = to_session.conn; + to_session.send = from_session.send; + to_session.send.session = to_session; + to_session.close = from_session.close; + to_session.filter = from_session.filter; + to_session.filter.session = to_session; + to_session.filters = from_session.filters; + to_session.send.filter = to_session.filter; + to_session.sasl_handler = from_session.sasl_handler; + to_session.stream = from_session.stream; + to_session.secure = from_session.secure; + to_session.hibernating = nil; + to_session.resumption_counter = (to_session.resumption_counter or 0) + 1; + from_session.log = to_session.log; + from_session.type = to_session.type; + -- Inform xmppstream of the new session (passed to its callbacks) + to_session.stream:set_session(to_session); + + -- Notify modules, allowing them to copy further fields or update state + prosody.events.fire_event("c2s-session-updated", { + session = to_session; + from_session = from_session; + replaced_conn = replaced_conn; + }); + + -- Retire the session we've pulled from, to avoid two sessions on the same connection + retire_session(from_session); +end + local function destroy_session(session, err) if session.destroyed then return; end @@ -130,15 +175,24 @@ local function destroy_session(session, err) retire_session(session); end -local function make_authenticated(session, username, scope) +local function make_authenticated(session, username, role_name) username = nodeprep(username); if not username or #username == 0 then return nil, "Invalid username"; end session.username = username; if session.type == "c2s_unauthed" then session.type = "c2s_unbound"; end - session.auth_scope = scope; - session.log("info", "Authenticated as %s@%s", username, session.host or "(unknown)"); + + local role; + if role_name then + role = hosts[session.host].authz.get_role_by_name(role_name); + else + role = hosts[session.host].authz.get_user_role(username); + end + if role then + sessionlib.set_role(session, role); + end + session.log("info", "Authenticated as %s@%s [%s]", username, session.host or "(unknown)", role and role.name or "no role"); return true; end @@ -265,6 +319,7 @@ end return { new_session = new_session; retire_session = retire_session; + update_session = update_session; destroy_session = destroy_session; make_authenticated = make_authenticated; bind_resource = bind_resource; diff --git a/core/stanza_router.lua b/core/stanza_router.lua index b54ea1ab..89a02c02 100644 --- a/core/stanza_router.lua +++ b/core/stanza_router.lua @@ -127,7 +127,7 @@ function core_process_stanza(origin, stanza) end core_post_stanza(origin, stanza, origin.full_jid); else - local h = hosts[stanza.attr.to or origin.host or origin.to_host]; + local h = hosts[stanza.attr.to or origin.host]; if h then local event; if xmlns == nil then @@ -143,7 +143,7 @@ function core_process_stanza(origin, stanza) if h.events.fire_event(event, {origin = origin, stanza = stanza}) then return; end end if host and not hosts[host] then host = nil; end -- COMPAT: workaround for a Pidgin bug which sets 'to' to the SRV result - handle_unhandled_stanza(host or origin.host or origin.to_host, origin, stanza); + handle_unhandled_stanza(host or origin.host, origin, stanza); end end diff --git a/core/usermanager.lua b/core/usermanager.lua index 45f104fa..fcb1fa9b 100644 --- a/core/usermanager.lua +++ b/core/usermanager.lua @@ -9,14 +9,10 @@ local modulemanager = require "core.modulemanager"; local log = require "util.logger".init("usermanager"); local type = type; -local it = require "util.iterators"; -local jid_bare = require "util.jid".bare; local jid_split = require "util.jid".split; -local jid_prep = require "util.jid".prep; local config = require "core.configmanager"; local sasl_new = require "util.sasl".new; local storagemanager = require "core.storagemanager"; -local set = require "util.set"; local prosody = _G.prosody; local hosts = prosody.hosts; @@ -25,6 +21,8 @@ local setmetatable = setmetatable; local default_provider = "internal_hashed"; +local debug = debug; + local _ENV = nil; -- luacheck: std none @@ -36,26 +34,25 @@ local function new_null_provider() }); end -local global_admins_config = config.get("*", "admins"); -if type(global_admins_config) ~= "table" then - global_admins_config = nil; -- TODO: factor out moduleapi magic config handling and use it here -end -local global_admins = set.new(global_admins_config) / jid_prep; +local fallback_authz_provider = { + -- luacheck: ignore 212 + get_jids_with_role = function (role) end; -local admin_role = { ["prosody:admin"] = true }; -local global_authz_provider = { - get_user_roles = function (user) end; --luacheck: ignore 212/user - get_jid_roles = function (jid) - if global_admins:contains(jid) then - return admin_role; - end - end; - get_jids_with_role = function (role) - if role ~= "prosody:admin" then return {}; end - return it.to_array(global_admins); - end; - set_user_roles = function (user, roles) end; -- luacheck: ignore 212 - set_jid_roles = function (jid, roles) end; -- luacheck: ignore 212 + get_user_role = function (user) end; + set_user_role = function (user, role_name) end; + + get_user_secondary_roles = function (user) end; + add_user_secondary_role = function (user, host, role_name) end; + remove_user_secondary_role = function (user, host, role_name) end; + + user_can_assume_role = function(user, role_name) end; + + get_jid_role = function (jid) end; + set_jid_role = function (jid, role) end; + + get_users_with_role = function (role_name) end; + add_default_permission = function (role_name, action, policy) end; + get_role_by_name = function (role_name) end; }; local provider_mt = { __index = new_null_provider() }; @@ -66,7 +63,7 @@ local function initialize_host(host) local authz_provider_name = config.get(host, "authorization") or "internal"; local authz_mod = modulemanager.load(host, "authz_"..authz_provider_name); - host_session.authz = authz_mod or global_authz_provider; + host_session.authz = authz_mod or fallback_authz_provider; if host_session.type ~= "local" then return; end @@ -116,6 +113,12 @@ local function set_password(username, password, host, resource) return ok, err; end +local function get_account_info(username, host) + local method = hosts[host].users.get_account_info; + if not method then return nil, "method-not-supported"; end + return method(username); +end + local function user_exists(username, host) if hosts[host].sessions[username] then return true; end return hosts[host].users.user_exists(username); @@ -144,70 +147,113 @@ local function get_provider(host) return hosts[host].users; end -local function get_roles(jid, host) +local function get_user_role(user, host) if host and not hosts[host] then return false; end - if type(jid) ~= "string" then return false; end + if type(user) ~= "string" then return false; end - jid = jid_bare(jid); - host = host or "*"; + return hosts[host].authz.get_user_role(user); +end - local actor_user, actor_host = jid_split(jid); - local roles; +local function set_user_role(user, host, role_name) + if host and not hosts[host] then return false; end + if type(user) ~= "string" then return false; end - local authz_provider = (host ~= "*" and hosts[host].authz) or global_authz_provider; + local role, err = hosts[host].authz.set_user_role(user, role_name); + if role then + prosody.events.fire_event("user-role-changed", { + username = user, host = host, role = role; + }); + end + return role, err; +end + +local function user_can_assume_role(user, host, role_name) + if host and not hosts[host] then return false; end + if type(user) ~= "string" then return false; end + + return hosts[host].authz.user_can_assume_role(user, role_name); +end + +local function add_user_secondary_role(user, host, role_name) + if host and not hosts[host] then return false; end + if type(user) ~= "string" then return false; end - if actor_user and actor_host == host then -- Local user - roles = authz_provider.get_user_roles(actor_user); - else -- Remote user/JID - roles = authz_provider.get_jid_roles(jid); + local role, err = hosts[host].authz.add_user_secondary_role(user, role_name); + if role then + prosody.events.fire_event("user-role-added", { + username = user, host = host, role = role; + }); end + return role, err; +end + +local function remove_user_secondary_role(user, host, role_name) + if host and not hosts[host] then return false; end + if type(user) ~= "string" then return false; end - return roles; + local ok, err = hosts[host].authz.remove_user_secondary_role(user, role_name); + if ok then + prosody.events.fire_event("user-role-removed", { + username = user, host = host, role_name = role_name; + }); + end + return ok, err; end -local function set_roles(jid, host, roles) +local function get_user_secondary_roles(user, host) if host and not hosts[host] then return false; end - if type(jid) ~= "string" then return false; end + if type(user) ~= "string" then return false; end - jid = jid_bare(jid); - host = host or "*"; + return hosts[host].authz.get_user_secondary_roles(user); +end - local actor_user, actor_host = jid_split(jid); +local function get_jid_role(jid, host) + local jid_node, jid_host = jid_split(jid); + if host == jid_host and jid_node then + return hosts[host].authz.get_user_role(jid_node); + end + return hosts[host].authz.get_jid_role(jid); +end - local authz_provider = (host ~= "*" and hosts[host].authz) or global_authz_provider; - if actor_user and actor_host == host then -- Local user - local ok, err = authz_provider.set_user_roles(actor_user, roles); - if ok then - prosody.events.fire_event("user-roles-changed", { - username = actor_user, host = actor_host - }); - end - return ok, err; - else -- Remote entity - return authz_provider.set_jid_roles(jid, roles) +local function set_jid_role(jid, host, role_name) + local _, jid_host = jid_split(jid); + if host == jid_host then + return nil, "unexpected-local-jid"; end + return hosts[host].authz.set_jid_role(jid, role_name) end +local strict_deprecate_is_admin; +local legacy_admin_roles = { ["prosody:admin"] = true, ["prosody:operator"] = true }; local function is_admin(jid, host) - local roles = get_roles(jid, host); - return roles and roles["prosody:admin"]; + if strict_deprecate_is_admin == nil then + strict_deprecate_is_admin = (config.get("*", "strict_deprecate_is_admin") == true); + end + if strict_deprecate_is_admin then + log("error", "Attempt to use deprecated is_admin() API: %s", debug.traceback()); + return false; + end + log("warn", "Usage of legacy is_admin() API, which will be disabled in a future build: %s", debug.traceback()); + log("warn", "See https://prosody.im/doc/developers/permissions about the new permissions API"); + return legacy_admin_roles[get_jid_role(jid, host)] or false; end local function get_users_with_role(role, host) if not hosts[host] then return false; end if type(role) ~= "string" then return false; end - return hosts[host].authz.get_users_with_role(role); end local function get_jids_with_role(role, host) if host and not hosts[host] then return false; end if type(role) ~= "string" then return false; end + return hosts[host].authz.get_jids_with_role(role); +end - host = host or "*"; - - local authz_provider = (host ~= "*" and hosts[host].authz) or global_authz_provider; - return authz_provider.get_jids_with_role(role); +local function get_role_by_name(role_name, host) + if host and not hosts[host] then return false; end + if type(role_name) ~= "string" then return false; end + return hosts[host].authz.get_role_by_name(role_name); end return { @@ -216,15 +262,25 @@ return { test_password = test_password; get_password = get_password; set_password = set_password; + get_account_info = get_account_info; user_exists = user_exists; create_user = create_user; delete_user = delete_user; users = users; get_sasl_handler = get_sasl_handler; get_provider = get_provider; - get_roles = get_roles; - set_roles = set_roles; - is_admin = is_admin; + get_user_role = get_user_role; + set_user_role = set_user_role; + user_can_assume_role = user_can_assume_role; + add_user_secondary_role = add_user_secondary_role; + remove_user_secondary_role = remove_user_secondary_role; + get_user_secondary_roles = get_user_secondary_roles; get_users_with_role = get_users_with_role; + get_jid_role = get_jid_role; + set_jid_role = set_jid_role; get_jids_with_role = get_jids_with_role; + get_role_by_name = get_role_by_name; + + -- Deprecated + is_admin = is_admin; }; diff --git a/doc/doap.xml b/doc/doap.xml index fa3893f8..4e938163 100644 --- a/doc/doap.xml +++ b/doc/doap.xml @@ -60,6 +60,8 @@ <implements rdf:resource="https://www.rfc-editor.org/info/rfc7395"/> <implements rdf:resource="https://www.rfc-editor.org/info/rfc7590"/> <implements rdf:resource="https://www.rfc-editor.org/info/rfc7673"/> + <implements rdf:resource="https://www.rfc-editor.org/info/rfc8305"/> + <implements rdf:resource="https://www.rfc-editor.org/info/rfc9266"/> <implements rdf:resource="https://datatracker.ietf.org/doc/draft-cridland-xmpp-session/"> <!-- since=0.6.0 note=Added in hg:0bbbc9042361 --> </implements> @@ -67,7 +69,7 @@ <implements> <xmpp:SupportedXep> <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0004.html"/> - <xmpp:version>2.12.1</xmpp:version> + <xmpp:version>2.13.0</xmpp:version> <xmpp:since>0.4.0</xmpp:since> <xmpp:status>partial</xmpp:status> <xmpp:note>no support for multiple items (reported tag)</xmpp:note> @@ -119,7 +121,7 @@ <implements> <xmpp:SupportedXep> <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0045.html"/> - <xmpp:version>1.34.1</xmpp:version> + <xmpp:version>1.34.3</xmpp:version> <xmpp:since>0.3.0</xmpp:since> <xmpp:status>partial</xmpp:status> </xmpp:SupportedXep> @@ -172,7 +174,7 @@ <implements> <xmpp:SupportedXep> <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0060.html"/> - <xmpp:version>1.22.0</xmpp:version> + <xmpp:version>1.24.1</xmpp:version> <xmpp:since>0.9.0</xmpp:since> <xmpp:status>partial</xmpp:status> <xmpp:note>mod_pubsub</xmpp:note> @@ -240,7 +242,8 @@ <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0090.html"/> <xmpp:version>1.2</xmpp:version> <xmpp:since>0.1.0</xmpp:since> - <xmpp:status>complete</xmpp:status> + <xmpp:until>trunk</xmpp:until> + <xmpp:status>removed</xmpp:status> <xmpp:note>mod_time</xmpp:note> </xmpp:SupportedXep> </implements> @@ -268,6 +271,7 @@ <xmpp:version>1.0</xmpp:version> <xmpp:since>0.9.0</xmpp:since> <xmpp:status>complete</xmpp:status> + <xmpp:note>util.jid.(un)escape, missing rejection of \20 at start or end per xep version 1.1</xmpp:note> </xmpp:SupportedXep> </implements> <implements> @@ -297,7 +301,7 @@ <implements> <xmpp:SupportedXep> <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0115.html"/> - <xmpp:version>1.5.2</xmpp:version> + <xmpp:version>1.6.0</xmpp:version> <xmpp:since>0.8.0</xmpp:since> <xmpp:status>complete</xmpp:status> </xmpp:SupportedXep> @@ -355,7 +359,7 @@ <implements> <xmpp:SupportedXep> <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0138.html"/> - <xmpp:version>2.0</xmpp:version> + <xmpp:version>2.1</xmpp:version> <xmpp:since>0.6.0</xmpp:since> <xmpp:until>0.10.0</xmpp:until> <xmpp:status>removed</xmpp:status> @@ -390,7 +394,7 @@ <implements> <xmpp:SupportedXep> <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0163.html"/> - <xmpp:version>1.2.1</xmpp:version> + <xmpp:version>1.2.2</xmpp:version> <xmpp:since>0.5.0</xmpp:since> <xmpp:status>complete</xmpp:status> </xmpp:SupportedXep> @@ -561,7 +565,7 @@ <implements> <xmpp:SupportedXep> <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0215.html"/> - <xmpp:version>0.7.1</xmpp:version> + <xmpp:version>1.0.0</xmpp:version> <xmpp:status>complete</xmpp:status> <xmpp:since>0.12.0</xmpp:since> <xmpp:note>mod_external_services</xmpp:note> @@ -623,7 +627,7 @@ <implements> <xmpp:SupportedXep> <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0280.html"/> - <xmpp:version>1.0.0</xmpp:version> + <xmpp:version>1.0.1</xmpp:version> <xmpp:status>complete</xmpp:status> <xmpp:since>0.10.0</xmpp:since> </xmpp:SupportedXep> @@ -657,7 +661,7 @@ <implements> <xmpp:SupportedXep> <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0297.html"/> - <xmpp:version>1.0.0</xmpp:version> + <xmpp:version>1.0</xmpp:version> <xmpp:since>0.11.0</xmpp:since> <xmpp:status>complete</xmpp:status> <xmpp:note>Used by XEP-0280, XEP-0313</xmpp:note> @@ -683,7 +687,7 @@ <implements> <xmpp:SupportedXep> <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0313.html"/> - <xmpp:version>1.0.0</xmpp:version> + <xmpp:version>1.0.1</xmpp:version> <xmpp:status>complete</xmpp:status> <xmpp:since>0.10.0</xmpp:since> <xmpp:note>mod_mam, mod_muc_mam</xmpp:note> @@ -737,7 +741,7 @@ <implements> <xmpp:SupportedXep> <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0363.html"/> - <xmpp:version>1.0.0</xmpp:version> + <xmpp:version>1.1.0</xmpp:version> <xmpp:status>complete</xmpp:status> <xmpp:since>0.12.0</xmpp:since> <xmpp:note>mod_http_file_share</xmpp:note> @@ -763,7 +767,7 @@ <implements> <xmpp:SupportedXep> <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0380.html"/> - <xmpp:version>0.3.0</xmpp:version> + <xmpp:version>0.4.0</xmpp:version> <xmpp:since>0.11.0</xmpp:since> <xmpp:status>complete</xmpp:status> <xmpp:note>Used in context of XEP-0352</xmpp:note> @@ -772,7 +776,7 @@ <implements> <xmpp:SupportedXep> <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0384.html"/> - <xmpp:version>0.8.1</xmpp:version> + <xmpp:version>0.8.3</xmpp:version> <xmpp:status>complete</xmpp:status> <xmpp:note>via XEP-0163, XEP-0222</xmpp:note> </xmpp:SupportedXep> @@ -789,7 +793,7 @@ <implements> <xmpp:SupportedXep> <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0401.html"/> - <xmpp:version>0.3.0</xmpp:version> + <xmpp:version>0.5.0</xmpp:version> <xmpp:since>0.12.0</xmpp:since> <xmpp:status>partial</xmpp:status> </xmpp:SupportedXep> @@ -845,5 +849,27 @@ <xmpp:note>Broken out of XEP-0313</xmpp:note> </xmpp:SupportedXep> </implements> + <implements> + <xmpp:SupportedXep> + <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0440.html"/> + <xmpp:version>0.4.0</xmpp:version> + <xmpp:since>trunk</xmpp:since> + <xmpp:status>complete</xmpp:status> + </xmpp:SupportedXep> + </implements> + <implements> + <xmpp:SupportedXep> + <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0445.html"/> + <xmpp:version>0.2.0</xmpp:version> + <xmpp:since>0.12.0</xmpp:since> + <xmpp:status>complete</xmpp:status> + </xmpp:SupportedXep> + </implements> + <implements> + <xmpp:SupportedXep> + <xmpp:xep rdf:resources="https://xmpp.org/extensions/inbox/xep-sla.xml"/> + <xmpp:since>trunk</xmpp:since> + </xmpp:SupportedXep> + </implements> </Project> </rdf:RDF> diff --git a/doc/storage.tld b/doc/storage.tld deleted file mode 100644 index 0fc8ff54..00000000 --- a/doc/storage.tld +++ /dev/null @@ -1,68 +0,0 @@ --- Storage Interface API Description --- --- This is written as a TypedLua description - --- Key-Value stores (the default) - -interface keyval_store - get : ( self, string? ) -> (any) | (nil, string) - set : ( self, string?, any ) -> (boolean) | (nil, string) -end - --- Map stores (key-key-value stores) - -interface map_store - get : ( self, string?, any ) -> (any) | (nil, string) - set : ( self, string?, any, any ) -> (boolean) | (nil, string) - set_keys : ( self, string?, { any : any }) -> (boolean) | (nil, string) - remove : {} -end - --- Archive stores - -typealias archive_query = { - "start" : number?, -- timestamp - "end" : number?, -- timestamp - "with" : string?, - "after" : string?, -- archive id - "before" : string?, -- archive id - "total" : boolean?, -} - -interface archive_store - -- Optional set of capabilities - caps : { - -- Optional total count of matching items returned as second return value from :find() - "total" : boolean?, - }? - - -- Add to the archive - append : ( self, string?, string?, any, number?, string? ) -> (string) | (nil, string) - - -- Iterate over archive - find : ( self, string?, archive_query? ) -> ( () -> ( string, any, number?, string? ), integer? ) - - -- Removal of items. API like find. Optional? - delete : ( self, string?, archive_query? ) -> (boolean) | (number) | (nil, string) - - -- Array of dates which do have messages (Optional?) - dates : ( self, string? ) -> ({ string }) | (nil, string) - - -- Map of counts per "with" field - summary : ( self, string?, archive_query? ) -> ( { string : integer } ) | (nil, string) - - -- Map-store API - get : ( self, string, string ) -> (stanza, number?, string?) | (nil, string) - set : ( self, string, string, stanza, number?, string? ) -> (boolean) | (nil, string) -end - --- This represents moduleapi -interface module - -- If the first string is omitted then the name of the module is used - -- The second string is one of "keyval" (default), "map" or "archive" - open_store : (self, string?, string?) -> (keyval_store) | (map_store) | (archive_store) | (nil, string) - - -- Other module methods omitted -end - -module : module @@ -73,12 +73,13 @@ install-util: util/encodings.so util/encodings.so util/pposix.so util/signal.so install-plugins: $(MKDIR) $(MODULES) - $(MKDIR) $(MODULES)/mod_pubsub $(MODULES)/adhoc $(MODULES)/muc $(MODULES)/mod_mam + $(MKDIR) $(MODULES)/mod_pubsub $(MODULES)/adhoc $(MODULES)/muc $(MODULES)/mod_mam $(MODULES)/mod_debug_stanzas $(INSTALL_DATA) plugins/*.lua $(MODULES) $(INSTALL_DATA) plugins/mod_pubsub/*.lua $(MODULES)/mod_pubsub $(INSTALL_DATA) plugins/adhoc/*.lua $(MODULES)/adhoc $(INSTALL_DATA) plugins/muc/*.lua $(MODULES)/muc $(INSTALL_DATA) plugins/mod_mam/*.lua $(MODULES)/mod_mam + $(INSTALL_DATA) plugins/mod_debug_stanzas/*.lua $(MODULES)/mod_debug_stanzas install-man: $(MKDIR) $(MAN)/man1 diff --git a/net/connect.lua b/net/connect.lua index 4b602be4..3cb407a1 100644 --- a/net/connect.lua +++ b/net/connect.lua @@ -1,8 +1,8 @@ local server = require "net.server"; local log = require "util.logger".init("net.connect"); local new_id = require "util.id".short; +local timer = require "util.timer"; --- TODO #1246 Happy Eyeballs -- FIXME RFC 6724 -- FIXME Error propagation from resolvers doesn't work -- FIXME #1428 Reuse DNS resolver object between service and basic resolver @@ -28,16 +28,17 @@ local pending_connection_listeners = {}; local function attempt_connection(p) p:log("debug", "Checking for targets..."); - if p.conn then - pending_connections_map[p.conn] = nil; - p.conn = nil; - end - p.target_resolver:next(function (conn_type, ip, port, extra) + p.target_resolver:next(function (conn_type, ip, port, extra, more_targets_available) if not conn_type then -- No more targets to try p:log("debug", "No more connection targets to try", p.target_resolver.last_error); - if p.listeners.onfail then - p.listeners.onfail(p.data, p.last_error or p.target_resolver.last_error or "unable to resolve service"); + if next(p.conns) == nil then + p:log("debug", "No more targets, no pending connections. Connection failed."); + if p.listeners.onfail then + p.listeners.onfail(p.data, p.last_error or p.target_resolver.last_error or "unable to resolve service"); + end + else + p:log("debug", "One or more connection attempts are still pending. Waiting for now."); end return; end @@ -49,8 +50,16 @@ local function attempt_connection(p) p.last_error = err or "unknown reason"; return attempt_connection(p); end - p.conn = conn; + p.conns[conn] = true; pending_connections_map[conn] = p; + if more_targets_available then + timer.add_task(0.250, function () + if not p.connected then + p:log("debug", "Still not connected, making parallel connection attempt..."); + attempt_connection(p); + end + end); + end end); end @@ -62,6 +71,13 @@ function pending_connection_listeners.onconnect(conn) return; end pending_connections_map[conn] = nil; + if p.connected then + -- We already succeeded in connecting + p.conns[conn] = nil; + conn:close(); + return; + end + p.connected = true; p:log("debug", "Successfully connected"); conn:setlistener(p.listeners, p.data); return p.listeners.onconnect(conn); @@ -73,9 +89,18 @@ function pending_connection_listeners.ondisconnect(conn, reason) log("warn", "Failed connection, but unexpected!"); return; end + p.conns[conn] = nil; + pending_connections_map[conn] = nil; p.last_error = reason or "unknown reason"; p:log("debug", "Connection attempt failed: %s", p.last_error); - attempt_connection(p); + if p.connected then + p:log("debug", "Connection already established, ignoring failure"); + elseif next(p.conns) == nil then + p:log("debug", "No pending connection attempts, and not yet connected"); + attempt_connection(p); + else + p:log("debug", "Other attempts are still pending, ignoring failure"); + end end local function connect(target_resolver, listeners, options, data) @@ -85,6 +110,7 @@ local function connect(target_resolver, listeners, options, data) listeners = assert(listeners); options = options or {}; data = data; + conns = {}; }, pending_connection_mt); p:log("debug", "Starting connection process"); diff --git a/net/dns.lua b/net/dns.lua index a9846e86..e6179637 100644 --- a/net/dns.lua +++ b/net/dns.lua @@ -8,8 +8,8 @@ -- todo: cache results of encodeName --- reference: http://tools.ietf.org/html/rfc1035 --- reference: http://tools.ietf.org/html/rfc1876 (LOC) +-- reference: https://www.rfc-editor.org/rfc/rfc1035.html +-- reference: https://www.rfc-editor.org/rfc/rfc1876.html (LOC) local socket = require "socket"; diff --git a/net/http/codes.lua b/net/http/codes.lua index 4327f151..b2949286 100644 --- a/net/http/codes.lua +++ b/net/http/codes.lua @@ -2,62 +2,62 @@ local response_codes = { -- Source: http://www.iana.org/assignments/http-status-codes - [100] = "Continue"; -- RFC7231, Section 6.2.1 - [101] = "Switching Protocols"; -- RFC7231, Section 6.2.2 + [100] = "Continue"; -- RFC9110, Section 15.2.1 + [101] = "Switching Protocols"; -- RFC9110, Section 15.2.2 [102] = "Processing"; [103] = "Early Hints"; -- [104-199] = "Unassigned"; - [200] = "OK"; -- RFC7231, Section 6.3.1 - [201] = "Created"; -- RFC7231, Section 6.3.2 - [202] = "Accepted"; -- RFC7231, Section 6.3.3 - [203] = "Non-Authoritative Information"; -- RFC7231, Section 6.3.4 - [204] = "No Content"; -- RFC7231, Section 6.3.5 - [205] = "Reset Content"; -- RFC7231, Section 6.3.6 - [206] = "Partial Content"; -- RFC7233, Section 4.1 + [200] = "OK"; -- RFC9110, Section 15.3.1 + [201] = "Created"; -- RFC9110, Section 15.3.2 + [202] = "Accepted"; -- RFC9110, Section 15.3.3 + [203] = "Non-Authoritative Information"; -- RFC9110, Section 15.3.4 + [204] = "No Content"; -- RFC9110, Section 15.3.5 + [205] = "Reset Content"; -- RFC9110, Section 15.3.6 + [206] = "Partial Content"; -- RFC9110, Section 15.3.7 [207] = "Multi-Status"; [208] = "Already Reported"; -- [209-225] = "Unassigned"; [226] = "IM Used"; -- [227-299] = "Unassigned"; - [300] = "Multiple Choices"; -- RFC7231, Section 6.4.1 - [301] = "Moved Permanently"; -- RFC7231, Section 6.4.2 - [302] = "Found"; -- RFC7231, Section 6.4.3 - [303] = "See Other"; -- RFC7231, Section 6.4.4 - [304] = "Not Modified"; -- RFC7232, Section 4.1 - [305] = "Use Proxy"; -- RFC7231, Section 6.4.5 - -- [306] = "(Unused)"; -- RFC7231, Section 6.4.6 - [307] = "Temporary Redirect"; -- RFC7231, Section 6.4.7 - [308] = "Permanent Redirect"; + [300] = "Multiple Choices"; -- RFC9110, Section 15.4.1 + [301] = "Moved Permanently"; -- RFC9110, Section 15.4.2 + [302] = "Found"; -- RFC9110, Section 15.4.3 + [303] = "See Other"; -- RFC9110, Section 15.4.4 + [304] = "Not Modified"; -- RFC9110, Section 15.4.5 + [305] = "Use Proxy"; -- RFC9110, Section 15.4.6 + -- [306] = "(Unused)"; -- RFC9110, Section 15.4.7 + [307] = "Temporary Redirect"; -- RFC9110, Section 15.4.8 + [308] = "Permanent Redirect"; -- RFC9110, Section 15.4.9 -- [309-399] = "Unassigned"; - [400] = "Bad Request"; -- RFC7231, Section 6.5.1 - [401] = "Unauthorized"; -- RFC7235, Section 3.1 - [402] = "Payment Required"; -- RFC7231, Section 6.5.2 - [403] = "Forbidden"; -- RFC7231, Section 6.5.3 - [404] = "Not Found"; -- RFC7231, Section 6.5.4 - [405] = "Method Not Allowed"; -- RFC7231, Section 6.5.5 - [406] = "Not Acceptable"; -- RFC7231, Section 6.5.6 - [407] = "Proxy Authentication Required"; -- RFC7235, Section 3.2 - [408] = "Request Timeout"; -- RFC7231, Section 6.5.7 - [409] = "Conflict"; -- RFC7231, Section 6.5.8 - [410] = "Gone"; -- RFC7231, Section 6.5.9 - [411] = "Length Required"; -- RFC7231, Section 6.5.10 - [412] = "Precondition Failed"; -- RFC7232, Section 4.2 - [413] = "Payload Too Large"; -- RFC7231, Section 6.5.11 - [414] = "URI Too Long"; -- RFC7231, Section 6.5.12 - [415] = "Unsupported Media Type"; -- RFC7231, Section 6.5.13 - [416] = "Range Not Satisfiable"; -- RFC7233, Section 4.4 - [417] = "Expectation Failed"; -- RFC7231, Section 6.5.14 + [400] = "Bad Request"; -- RFC9110, Section 15.5.1 + [401] = "Unauthorized"; -- RFC9110, Section 15.5.2 + [402] = "Payment Required"; -- RFC9110, Section 15.5.3 + [403] = "Forbidden"; -- RFC9110, Section 15.5.4 + [404] = "Not Found"; -- RFC9110, Section 15.5.5 + [405] = "Method Not Allowed"; -- RFC9110, Section 15.5.6 + [406] = "Not Acceptable"; -- RFC9110, Section 15.5.7 + [407] = "Proxy Authentication Required"; -- RFC9110, Section 15.5.8 + [408] = "Request Timeout"; -- RFC9110, Section 15.5.9 + [409] = "Conflict"; -- RFC9110, Section 15.5.10 + [410] = "Gone"; -- RFC9110, Section 15.5.11 + [411] = "Length Required"; -- RFC9110, Section 15.5.12 + [412] = "Precondition Failed"; -- RFC9110, Section 15.5.13 + [413] = "Content Too Large"; -- RFC9110, Section 15.5.14 + [414] = "URI Too Long"; -- RFC9110, Section 15.5.15 + [415] = "Unsupported Media Type"; -- RFC9110, Section 15.5.16 + [416] = "Range Not Satisfiable"; -- RFC9110, Section 15.5.17 + [417] = "Expectation Failed"; -- RFC9110, Section 15.5.18 [418] = "I'm a teapot"; -- RFC2324, Section 2.3.2 -- [419-420] = "Unassigned"; - [421] = "Misdirected Request"; -- RFC7540, Section 9.1.2 - [422] = "Unprocessable Entity"; + [421] = "Misdirected Request"; -- RFC9110, Section 15.5.20 + [422] = "Unprocessable Content"; -- RFC9110, Section 15.5.21 [423] = "Locked"; [424] = "Failed Dependency"; [425] = "Too Early"; - [426] = "Upgrade Required"; -- RFC7231, Section 6.5.15 + [426] = "Upgrade Required"; -- RFC9110, Section 15.5.22 -- [427] = "Unassigned"; [428] = "Precondition Required"; [429] = "Too Many Requests"; @@ -67,17 +67,17 @@ local response_codes = { [451] = "Unavailable For Legal Reasons"; -- [452-499] = "Unassigned"; - [500] = "Internal Server Error"; -- RFC7231, Section 6.6.1 - [501] = "Not Implemented"; -- RFC7231, Section 6.6.2 - [502] = "Bad Gateway"; -- RFC7231, Section 6.6.3 - [503] = "Service Unavailable"; -- RFC7231, Section 6.6.4 - [504] = "Gateway Timeout"; -- RFC7231, Section 6.6.5 - [505] = "HTTP Version Not Supported"; -- RFC7231, Section 6.6.6 + [500] = "Internal Server Error"; -- RFC9110, Section 15.6.1 + [501] = "Not Implemented"; -- RFC9110, Section 15.6.2 + [502] = "Bad Gateway"; -- RFC9110, Section 15.6.3 + [503] = "Service Unavailable"; -- RFC9110, Section 15.6.4 + [504] = "Gateway Timeout"; -- RFC9110, Section 15.6.5 + [505] = "HTTP Version Not Supported"; -- RFC9110, Section 15.6.6 [506] = "Variant Also Negotiates"; [507] = "Insufficient Storage"; [508] = "Loop Detected"; -- [509] = "Unassigned"; - [510] = "Not Extended"; + [510] = "Not Extended"; -- (OBSOLETED) [511] = "Network Authentication Required"; -- [512-599] = "Unassigned"; }; diff --git a/net/http/server.lua b/net/http/server.lua index 43e6bc9f..e15165c2 100644 --- a/net/http/server.lua +++ b/net/http/server.lua @@ -378,11 +378,11 @@ function _M.send_file(response, f) response.conn:write(chunk); else incomplete[response.conn] = nil; + if f.close then f:close(); end if chunked then response.conn:write("0\r\n\r\n"); end -- io.write("\n"); - if f.close then f:close(); end return response:done(); end end diff --git a/net/resolvers/basic.lua b/net/resolvers/basic.lua index 305bce76..3ef916a2 100644 --- a/net/resolvers/basic.lua +++ b/net/resolvers/basic.lua @@ -2,13 +2,61 @@ local adns = require "net.adns"; local inet_pton = require "util.net".pton; local inet_ntop = require "util.net".ntop; local idna_to_ascii = require "util.encodings".idna.to_ascii; -local unpack = table.unpack or unpack; -- luacheck: ignore 113 +local promise = require "util.promise"; +local t_move = require "util.table".move; local methods = {}; local resolver_mt = { __index = methods }; -- FIXME RFC 6724 +local function do_dns_lookup(self, dns_resolver, record_type, name, allow_insecure) + return promise.new(function (resolve, reject) + local ipv = (record_type == "A" and "4") or (record_type == "AAAA" and "6") or nil; + if ipv and self.extra["use_ipv"..ipv] == false then + return reject(("IPv%s disabled - %s lookup skipped"):format(ipv, record_type)); + elseif record_type == "TLSA" and self.extra.use_dane ~= true then + return reject("DANE disabled - TLSA lookup skipped"); + end + dns_resolver:lookup(function (answer, err) + if not answer then + return reject(err); + elseif answer.bogus then + return reject(("Validation error in %s lookup"):format(record_type)); + elseif not (answer.secure or allow_insecure) then + return reject(("Insecure response in %s lookup"):format(record_type)); + elseif answer.status and #answer == 0 then + return reject(("%s in %s lookup"):format(answer.status, record_type)); + end + + local targets = { secure = answer.secure }; + for _, record in ipairs(answer) do + if ipv then + table.insert(targets, { self.conn_type..ipv, record[record_type:lower()], self.port, self.extra }); + else + table.insert(targets, record[record_type:lower()]); + end + end + return resolve(targets); + end, name, record_type, "IN"); + end); +end + +local function merge_targets(ipv4_targets, ipv6_targets) + local result = { secure = ipv4_targets.secure and ipv6_targets.secure }; + local common_length = math.min(#ipv4_targets, #ipv6_targets); + for i = 1, common_length do + table.insert(result, ipv6_targets[i]); + table.insert(result, ipv4_targets[i]); + end + if common_length < #ipv4_targets then + t_move(ipv4_targets, common_length+1, #ipv4_targets, common_length+1, result); + elseif common_length < #ipv6_targets then + t_move(ipv6_targets, common_length+1, #ipv6_targets, common_length+1, result); + end + return result; +end + -- Find the next target to connect to, and -- pass it to cb() function methods:next(cb) @@ -18,7 +66,7 @@ function methods:next(cb) return; end local next_target = table.remove(self.targets, 1); - cb(unpack(next_target, 1, 4)); + cb(next_target[1], next_target[2], next_target[3], next_target[4], not not self.targets[1]); return; end @@ -28,91 +76,47 @@ function methods:next(cb) return; end - local secure = true; - local tlsa = {}; - local targets = {}; - local n = 3; - local function ready() - n = n - 1; - if n > 0 then return; end - self.targets = targets; + -- Resolve DNS to target list + local dns_resolver = adns.resolver(); + + local dns_lookups = { + ipv4 = do_dns_lookup(self, dns_resolver, "A", self.hostname, true); + ipv6 = do_dns_lookup(self, dns_resolver, "AAAA", self.hostname, true); + tlsa = do_dns_lookup(self, dns_resolver, "TLSA", ("_%d._%s.%s"):format(self.port, self.conn_type, self.hostname)); + }; + + promise.all_settled(dns_lookups):next(function (dns_results) + -- Combine targets, assign to self.targets, self:next(cb) + local have_ipv4 = dns_results.ipv4.status == "fulfilled"; + local have_ipv6 = dns_results.ipv6.status == "fulfilled"; + + if have_ipv4 and have_ipv6 then + self.targets = merge_targets(dns_results.ipv4.value, dns_results.ipv6.value); + elseif have_ipv4 then + self.targets = dns_results.ipv4.value; + elseif have_ipv6 then + self.targets = dns_results.ipv6.value; + else + self.targets = {}; + end + if self.extra and self.extra.use_dane then - if secure and tlsa[1] then - self.extra.tlsa = tlsa; + if self.targets.secure and dns_results.tlsa.status == "fulfilled" then + self.extra.tlsa = dns_results.tlsa.value; self.extra.dane_hostname = self.hostname; else self.extra.tlsa = nil; self.extra.dane_hostname = nil; end + elseif self.extra and self.extra.srv_secure then + self.extra.secure_hostname = self.hostname; end - self:next(cb); - end - -- Resolve DNS to target list - local dns_resolver = adns.resolver(); - - if not self.extra or self.extra.use_ipv4 ~= false then - dns_resolver:lookup(function (answer, err) - if answer then - secure = secure and answer.secure; - for _, record in ipairs(answer) do - table.insert(targets, { self.conn_type.."4", record.a, self.port, self.extra }); - end - if answer.bogus then - self.last_error = "Validation error in A lookup"; - elseif answer.status then - self.last_error = answer.status .. " in A lookup"; - end - else - self.last_error = err; - end - ready(); - end, self.hostname, "A", "IN"); - else - ready(); - end - - if not self.extra or self.extra.use_ipv6 ~= false then - dns_resolver:lookup(function (answer, err) - if answer then - secure = secure and answer.secure; - for _, record in ipairs(answer) do - table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra }); - end - if answer.bogus then - self.last_error = "Validation error in AAAA lookup"; - elseif answer.status then - self.last_error = answer.status .. " in AAAA lookup"; - end - else - self.last_error = err; - end - ready(); - end, self.hostname, "AAAA", "IN"); - else - ready(); - end - - if self.extra and self.extra.use_dane == true then - dns_resolver:lookup(function (answer, err) - if answer then - secure = secure and answer.secure; - for _, record in ipairs(answer) do - table.insert(tlsa, record.tlsa); - end - if answer.bogus then - self.last_error = "Validation error in TLSA lookup"; - elseif answer.status then - self.last_error = answer.status .. " in TLSA lookup"; - end - else - self.last_error = err; - end - ready(); - end, ("_%d._tcp.%s"):format(self.port, self.hostname), "TLSA", "IN"); - else - ready(); - end + self:next(cb); + end):catch(function (err) + self.last_error = err; + self.targets = {}; + end); end local function new(hostname, port, conn_type, extra) @@ -137,7 +141,7 @@ local function new(hostname, port, conn_type, extra) hostname = ascii_host; port = port; conn_type = conn_type; - extra = extra; + extra = extra or {}; targets = targets; }, resolver_mt); end diff --git a/net/resolvers/manual.lua b/net/resolvers/manual.lua index dbc40256..c766a11f 100644 --- a/net/resolvers/manual.lua +++ b/net/resolvers/manual.lua @@ -1,6 +1,6 @@ local methods = {}; local resolver_mt = { __index = methods }; -local unpack = table.unpack or unpack; -- luacheck: ignore 113 +local unpack = table.unpack; -- Find the next target to connect to, and -- pass it to cb() diff --git a/net/resolvers/service.lua b/net/resolvers/service.lua index 3810cac8..104a45b2 100644 --- a/net/resolvers/service.lua +++ b/net/resolvers/service.lua @@ -2,23 +2,78 @@ local adns = require "net.adns"; local basic = require "net.resolvers.basic"; local inet_pton = require "util.net".pton; local idna_to_ascii = require "util.encodings".idna.to_ascii; -local unpack = table.unpack or unpack; -- luacheck: ignore 113 local methods = {}; local resolver_mt = { __index = methods }; +local function new_target_selector(rrset) + local rr_count = rrset and #rrset; + if not rr_count or rr_count == 0 then + rrset = nil; + else + table.sort(rrset, function (a, b) return a.srv.priority < b.srv.priority end); + end + local rrset_pos = 1; + local priority_bucket, bucket_total_weight, bucket_len, bucket_used; + return function () + if not rrset then return; end + + if not priority_bucket or bucket_used >= bucket_len then + if rrset_pos > rr_count then return; end -- Used up all records + + -- Going to start on a new priority now. Gather up all the next + -- records with the same priority and add them to priority_bucket + priority_bucket, bucket_total_weight, bucket_len, bucket_used = {}, 0, 0, 0; + local current_priority; + repeat + local curr_record = rrset[rrset_pos].srv; + if not current_priority then + current_priority = curr_record.priority; + elseif current_priority ~= curr_record.priority then + break; + end + table.insert(priority_bucket, curr_record); + bucket_total_weight = bucket_total_weight + curr_record.weight; + bucket_len = bucket_len + 1; + rrset_pos = rrset_pos + 1; + until rrset_pos > rr_count; + end + + bucket_used = bucket_used + 1; + local n, running_total = math.random(0, bucket_total_weight), 0; + local target_record; + for i = 1, bucket_len do + local candidate = priority_bucket[i]; + if candidate then + running_total = running_total + candidate.weight; + if running_total >= n then + target_record = candidate; + bucket_total_weight = bucket_total_weight - candidate.weight; + priority_bucket[i] = nil; + break; + end + end + end + return target_record; + end; +end + -- Find the next target to connect to, and -- pass it to cb() function methods:next(cb) - if self.targets then - if not self.resolver then - if #self.targets == 0 then + if self.resolver or self._get_next_target then + if not self.resolver then -- Do we have a basic resolver currently? + -- We don't, so fetch a new SRV target, create a new basic resolver for it + local next_srv_target = self._get_next_target and self._get_next_target(); + if not next_srv_target then + -- No more SRV targets left cb(nil); return; end - local next_target = table.remove(self.targets, 1); - self.resolver = basic.new(unpack(next_target, 1, 4)); + -- Create a new basic resolver for this SRV target + self.resolver = basic.new(next_srv_target.target, next_srv_target.port, self.conn_type, self.extra); end + -- Look up the next (basic) target from the current target's resolver self.resolver:next(function (...) if self.resolver then self.last_error = self.resolver.last_error; @@ -31,6 +86,9 @@ function methods:next(cb) end end); return; + elseif self.in_progress then + cb(nil); + return; end if not self.hostname then @@ -39,9 +97,9 @@ function methods:next(cb) return; end - local targets = {}; + self.in_progress = true; + local function ready() - self.targets = targets; self:next(cb); end @@ -53,17 +111,23 @@ function methods:next(cb) answer = {}; end if answer then - if self.extra and not answer.secure then - self.extra.use_dane = false; - elseif answer.bogus then + if answer.bogus then self.last_error = "Validation error in SRV lookup"; ready(); return; + elseif not answer.secure then + if self.extra then + -- Insecure results, so no DANE + self.extra.use_dane = false; + end + end + if self.extra then + self.extra.srv_secure = answer.secure; end if #answer == 0 then if self.extra and self.extra.default_port then - table.insert(targets, { self.hostname, self.extra.default_port, self.conn_type, self.extra }); + self.resolver = basic.new(self.hostname, self.extra.default_port, self.conn_type, self.extra); else self.last_error = "zero SRV records found"; end @@ -77,10 +141,7 @@ function methods:next(cb) return; end - table.sort(answer, function (a, b) return a.srv.priority < b.srv.priority end); - for _, record in ipairs(answer) do - table.insert(targets, { record.srv.target, record.srv.port, self.conn_type, self.extra }); - end + self._get_next_target = new_target_selector(answer); else self.last_error = err; end diff --git a/net/server.lua b/net/server.lua index 0696fd52..72272bef 100644 --- a/net/server.lua +++ b/net/server.lua @@ -118,6 +118,13 @@ if prosody and set_config then prosody.events.add_handler("config-reloaded", load_config); end +local tls_builder = server.tls_builder; +-- resolving the basedir here avoids util.sslconfig depending on +-- prosody.paths.config +function server.tls_builder() + return tls_builder(prosody.paths.config or "") +end + -- require "net.server" shall now forever return this, -- ie. server_select or server_event as chosen above. return server; diff --git a/net/server_epoll.lua b/net/server_epoll.lua index fa275d71..c7ab0c13 100644 --- a/net/server_epoll.lua +++ b/net/server_epoll.lua @@ -18,7 +18,6 @@ local traceback = debug.traceback; local logger = require "util.logger"; local log = logger.init("server_epoll"); local socket = require "socket"; -local luasec = require "ssl"; local realtime = require "util.time".now; local monotonic = require "util.time".monotonic; local indexedbheap = require "util.indexedbheap"; @@ -28,6 +27,8 @@ local inet_pton = inet.pton; local _SOCKETINVALID = socket._SOCKETINVALID or -1; local new_id = require "util.id".short; local xpcall = require "util.xpcall".xpcall; +local sslconfig = require "util.sslconfig"; +local tls_impl = require "net.tls_luasec"; local poller = require "util.poll" local EEXIST = poller.EEXIST; @@ -91,6 +92,12 @@ local default_config = { __index = { --- How long to wait after getting the shutdown signal before forcefully tearing down every socket shutdown_deadline = 5; + + -- TCP Fast Open + tcp_fastopen = false; + + -- Defer accept until incoming data is available + tcp_defer_accept = false; }}; local cfg = default_config.__index; @@ -614,6 +621,42 @@ function interface:set_sslctx(sslctx) self._sslctx = sslctx; end +function interface:sslctx() + return self.tls_ctx +end + +function interface:ssl_info() + local sock = self.conn; + if not sock.info then return nil, "not-implemented"; end + return sock:info(); +end + +function interface:ssl_peercertificate() + local sock = self.conn; + if not sock.getpeercertificate then return nil, "not-implemented"; end + return sock:getpeercertificate(); +end + +function interface:ssl_peerverification() + local sock = self.conn; + if not sock.getpeerverification then return nil, { { "Chain verification not supported" } }; end + return sock:getpeerverification(); +end + +function interface:ssl_peerfinished() + local sock = self.conn; + if not sock.getpeerfinished then return nil, "not-implemented"; end + return sock:getpeerfinished(); +end + +function interface:ssl_exportkeyingmaterial(label, len, context) + local sock = self.conn; + if sock.exportkeyingmaterial then + return sock:exportkeyingmaterial(label, len, context); + end +end + + function interface:starttls(tls_ctx) if tls_ctx then self.tls_ctx = tls_ctx; end self.starttls = false; @@ -641,11 +684,7 @@ function interface:inittls(tls_ctx, now) self.starttls = false; self:debug("Starting TLS now"); self:updatenames(); -- Can't getpeer/sockname after wrap() - local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx); - if not ok then - conn, err = ok, conn; - self:debug("Failed to initialize TLS: %s", err); - end + local conn, err = self.tls_ctx:wrap(self.conn); if not conn then self:on("disconnect", err); self:destroy(); @@ -656,8 +695,8 @@ function interface:inittls(tls_ctx, now) if conn.sni then if self.servername then conn:sni(self.servername); - elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then - conn:sni(self._server.hosts, true); + elseif next(self.tls_ctx._sni_contexts) ~= nil then + conn:sni(self.tls_ctx._sni_contexts, true); end end if self.extra and self.extra.tlsa and conn.settlsa then @@ -741,7 +780,6 @@ local function wrapsocket(client, server, read_size, listeners, tls_ctx, extra) end end - conn:updatenames(); return conn; end @@ -767,6 +805,7 @@ function interface:onacceptable() return; end local client = wrapsocket(conn, self, nil, self.listeners); + client:updatenames(); client:debug("New connection %s on server %s", client, self); client:defaultoptions(); client._writable = cfg.opportunistic_writes; @@ -885,6 +924,12 @@ local function wrapserver(conn, addr, port, listeners, config) log = logger.init(("serv%s"):format(new_id())); }, interface_mt); server:debug("Server %s created", server); + if cfg.tcp_fastopen then + server:setoption("tcp-fastopen", cfg.tcp_fastopen); + end + if type(cfg.tcp_defer_accept) == "number" then + server:setoption("tcp-defer-accept", cfg.tcp_defer_accept); + end server:add(true, false); return server; end @@ -908,6 +953,7 @@ end -- COMPAT local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx, extra) local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra); + client:updatenames(); if not client.peername then client.peername, client.peerport = addr, port; end @@ -941,9 +987,13 @@ local function addclient(addr, port, listeners, read_size, tls_ctx, typ, extra) if not conn then return conn, err; end local ok, err = conn:settimeout(0); if not ok then return ok, err; end + local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra) + if cfg.tcp_fastopen then + client:setoption("tcp-fastopen-connect", 1); + end local ok, err = conn:setpeername(addr, port); if not ok and err ~= "timeout" then return ok, err; end - local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra) + client:updatenames(); local ok, err = client:init(); if not client.peername then -- otherwise not set until connected @@ -1032,12 +1082,38 @@ local function setquitting(quit) end end +local function loop_once() + runtimers(); -- Ignore return value because we only do this once + local fd, r, w = poll:wait(0); + if fd then + local conn = fds[fd]; + if conn then + if r then + conn:onreadable(); + end + if w then + conn:onwritable(); + end + else + log("debug", "Removing unknown fd %d", fd); + poll:del(fd); + end + else + return fd, r; + end +end + -- Main loop local function loop(once) - repeat - local t = runtimers(cfg.max_wait, cfg.min_wait); + if once then + return loop_once(); + end + + local t = 0; + while not quitting do local fd, r, w = poll:wait(t); - while fd do + if fd then + t = 0; local conn = fds[fd]; if conn then if r then @@ -1050,12 +1126,12 @@ local function loop(once) log("debug", "Removing unknown fd %d", fd); poll:del(fd); end - fd, r, w = poll:wait(0); - end - if r ~= "timeout" and r ~= "signal" then + elseif r == "timeout" then + t = runtimers(cfg.max_wait, cfg.min_wait); + elseif r ~= "signal" then log("debug", "epoll_wait error: %s[%d]", r, w); end - until once or (quitting and next(fds) == nil); + end return quitting; end @@ -1085,6 +1161,10 @@ return { cfg = setmetatable(newconfig, default_config); end; + tls_builder = function(basedir) + return sslconfig._new(tls_impl.new_context, basedir) + end, + -- libevent emulation event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 }; addevent = function (fd, mode, callback) diff --git a/net/server_event.lua b/net/server_event.lua index c30181b8..d8f08c8d 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -47,11 +47,13 @@ local s_sub = string.sub local coroutine_wrap = coroutine.wrap local coroutine_yield = coroutine.yield -local has_luasec, ssl = pcall ( require , "ssl" ) +local has_luasec = pcall ( require , "ssl" ) local socket = require "socket" local levent = require "luaevent.core" local inet = require "util.net"; local inet_pton = inet.pton; +local sslconfig = require "util.sslconfig"; +local tls_impl = require "net.tls_luasec"; local socket_gettime = socket.gettime @@ -153,7 +155,7 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed _ = self.eventwrite and self.eventwrite:close( ) self.eventread, self.eventwrite = nil, nil local err - self.conn, err = ssl.wrap( self.conn, self._sslctx ) + self.conn, err = self._sslctx:wrap(self.conn) if err then self.fatalerror = err self.conn = nil -- cannot be used anymore @@ -168,8 +170,8 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed if self.conn.sni then if self.servername then self.conn:sni(self.servername); - elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then - self.conn:sni(self._server.hosts, true); + elseif next(self._sslctx._sni_contexts) ~= nil then + self.conn:sni(self._sslctx._sni_contexts, true); end end @@ -274,6 +276,34 @@ function interface_mt:pause() return self:_lock(self.nointerface, true, self.nowriting); end +function interface_mt:sslctx() + return self._sslctx +end + +function interface_mt:ssl_info() + local sock = self.conn; + if not sock.info then return nil, "not-implemented"; end + return sock:info(); +end + +function interface_mt:ssl_peercertificate() + local sock = self.conn; + if not sock.getpeercertificate then return nil, "not-implemented"; end + return sock:getpeercertificate(); +end + +function interface_mt:ssl_peerverification() + local sock = self.conn; + if not sock.getpeerverification then return nil, { { "Chain verification not supported" } }; end + return sock:getpeerverification(); +end + +function interface_mt:ssl_peerfinished() + local sock = self.conn; + if not sock.getpeerfinished then return nil, "not-implemented"; end + return sock:getpeerfinished(); +end + function interface_mt:resume() self:_lock(self.nointerface, false, self.nowriting); if self.readcallback and not self.eventread then @@ -924,6 +954,10 @@ return { add_task = add_task, watchfd = watchfd, + tls_builder = function(basedir) + return sslconfig._new(tls_impl.new_context, basedir) + end, + __NAME = SCRIPT_NAME, __DATE = LAST_MODIFIED, __AUTHOR = SCRIPT_AUTHOR, diff --git a/net/server_select.lua b/net/server_select.lua index eea850ce..651bdfde 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -47,15 +47,15 @@ local coroutine_yield = coroutine.yield --// extern libs //-- -local has_luasec, luasec = pcall ( require , "ssl" ) local luasocket = use "socket" or require "socket" local luasocket_gettime = luasocket.gettime local inet = require "util.net"; local inet_pton = inet.pton; +local sslconfig = require "util.sslconfig"; +local has_luasec, tls_impl = pcall(require, "net.tls_luasec"); --// extern lib methods //-- -local ssl_wrap = ( has_luasec and luasec.wrap ) local socket_bind = luasocket.bind local socket_select = luasocket.select @@ -359,6 +359,21 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.sslctx = function ( ) return sslctx end + handler.ssl_info = function( ) + return socket.info and socket:info() + end + handler.ssl_peercertificate = function( ) + if not socket.getpeercertificate then return nil, "not-implemented"; end + return socket:getpeercertificate() + end + handler.ssl_peerverification = function( ) + if not socket.getpeerverification then return nil, { { "Chain verification not supported" } }; end + return socket:getpeerverification(); + end + handler.ssl_peerfinished = function( ) + if not socket.getpeerfinished then return nil, "not-implemented"; end + return socket:getpeerfinished(); + end handler.send = function( _, data, i, j ) return send( socket, data, i, j ) end @@ -652,7 +667,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) local oldsocket, err = socket - socket, err = ssl_wrap( socket, sslctx ) -- wrap socket + socket, err = sslctx:wrap(socket) -- wrap socket if not socket then out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) @@ -662,8 +677,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport if socket.sni then if self.servername then socket:sni(self.servername); - elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then - socket:sni(self.server().hosts, true); + elseif next(sslctx._sni_contexts) ~= nil then + socket:sni(sslctx._sni_contexts, true); end end @@ -1169,4 +1184,8 @@ return { removeserver = removeserver, get_backend = get_backend, changesettings = changesettings, + + tls_builder = function(basedir) + return sslconfig._new(tls_impl.new_context, basedir) + end, } diff --git a/net/tls_luasec.lua b/net/tls_luasec.lua new file mode 100644 index 00000000..2bedb5ab --- /dev/null +++ b/net/tls_luasec.lua @@ -0,0 +1,89 @@ +-- Prosody IM +-- Copyright (C) 2021 Prosody folks +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +--[[ +This file provides a shim abstraction over LuaSec, consolidating some code +which was previously spread between net.server backends, portmanager and +certmanager. + +The goal is to provide a more or less well-defined API on top of LuaSec which +abstracts away some of the things which are not needed and simplifies usage of +commonly used things (such as SNI contexts). Eventually, network backends +which do not rely on LuaSocket+LuaSec should be able to provide *this* API +instead of having to mimic LuaSec. +]] +local ssl = require "ssl"; +local ssl_newcontext = ssl.newcontext; +local ssl_context = ssl.context or require "ssl.context"; +local io_open = io.open; + +local context_api = {}; +local context_mt = {__index = context_api}; + +function context_api:set_sni_host(host, cert, key) + local ctx, err = self._builder:clone():apply({ + certificate = cert, + key = key, + }):build(); + if not ctx then + return false, err + end + + self._sni_contexts[host] = ctx._inner + + return true, nil +end + +function context_api:remove_sni_host(host) + self._sni_contexts[host] = nil +end + +function context_api:wrap(sock) + local ok, conn, err = pcall(ssl.wrap, sock, self._inner); + if not ok then + return nil, err + end + return conn, nil +end + +local function new_context(cfg, builder) + -- LuaSec expects dhparam to be a callback that takes two arguments. + -- We ignore those because it is mostly used for having a separate + -- set of params for EXPORT ciphers, which we don't have by default. + if type(cfg.dhparam) == "string" then + local f, err = io_open(cfg.dhparam); + if not f then return nil, "Could not open DH parameters: "..err end + local dhparam = f:read("*a"); + f:close(); + cfg.dhparam = function() return dhparam; end + end + + local inner, err = ssl_newcontext(cfg); + if not inner then + return nil, err + end + + -- COMPAT Older LuaSec ignores the cipher list from the config, so we have to take care + -- of it ourselves (W/A for #x) + if inner and cfg.ciphers then + local success; + success, err = ssl_context.setcipher(inner, cfg.ciphers); + if not success then + return nil, err + end + end + + return setmetatable({ + _inner = inner, + _builder = builder, + _sni_contexts = {}, + }, context_mt), nil +end + +return { + new_context = new_context, +}; diff --git a/plugins/adhoc/adhoc.lib.lua b/plugins/adhoc/adhoc.lib.lua index 4cf6911d..9f091e3b 100644 --- a/plugins/adhoc/adhoc.lib.lua +++ b/plugins/adhoc/adhoc.lib.lua @@ -23,10 +23,16 @@ end function _M.new(name, node, handler, permission) if not permission then error "adhoc.new() expects a permission argument, none given" - end - if permission == "user" then + elseif permission == "user" then error "the permission mode 'user' has been renamed 'any', please update your code" end + if permission == "admin" then + module:default_permission("prosody:admin", "mod_adhoc:"..node); + permission = "check"; + elseif permission == "global_admin" then + module:default_permission("prosody:operator", "mod_adhoc:"..node); + permission = "check"; + end return { name = name, node = node, handler = handler, cmdtag = _cmdtag, permission = permission }; end @@ -34,6 +40,8 @@ function _M.handle_cmd(command, origin, stanza) local cmdtag = stanza.tags[1] local sessionid = cmdtag.attr.sessionid or uuid.generate(); local dataIn = { + origin = origin; + stanza = stanza; to = stanza.attr.to; from = stanza.attr.from; action = cmdtag.attr.action or "execute"; diff --git a/plugins/adhoc/mod_adhoc.lua b/plugins/adhoc/mod_adhoc.lua index 09a72075..c94ff24f 100644 --- a/plugins/adhoc/mod_adhoc.lua +++ b/plugins/adhoc/mod_adhoc.lua @@ -7,7 +7,6 @@ local it = require "util.iterators"; local st = require "util.stanza"; -local is_admin = require "core.usermanager".is_admin; local jid_host = require "util.jid".host; local adhoc_handle_cmd = module:require "adhoc".handle_cmd; local xmlns_cmd = "http://jabber.org/protocol/commands"; @@ -15,18 +14,17 @@ local commands = {}; module:add_feature(xmlns_cmd); +local function check_permissions(event, node, command) + return (command.permission == "check" and module:may("mod_adhoc:"..node, event)) + or (command.permission == "local_user" and jid_host(event.stanza.attr.from) == module.host) + or (command.permission == "any"); +end + module:hook("host-disco-info-node", function (event) local stanza, origin, reply, node = event.stanza, event.origin, event.reply, event.node; if commands[node] then - local from = stanza.attr.from; - local privileged = is_admin(from, stanza.attr.to); - local global_admin = is_admin(from); - local hostname = jid_host(from); local command = commands[node]; - if (command.permission == "admin" and privileged) - or (command.permission == "global_admin" and global_admin) - or (command.permission == "local_user" and hostname == module.host) - or (command.permission == "any") then + if check_permissions(event, node, command) then reply:tag("identity", { name = command.name, category = "automation", type = "command-node" }):up(); reply:tag("feature", { var = xmlns_cmd }):up(); @@ -44,20 +42,13 @@ module:hook("host-disco-info-node", function (event) end); module:hook("host-disco-items-node", function (event) - local stanza, reply, disco_node = event.stanza, event.reply, event.node; + local reply, disco_node = event.reply, event.node; if disco_node ~= xmlns_cmd then return; end - local from = stanza.attr.from; - local admin = is_admin(from, stanza.attr.to); - local global_admin = is_admin(from); - local hostname = jid_host(from); for node, command in it.sorted_pairs(commands) do - if (command.permission == "admin" and admin) - or (command.permission == "global_admin" and global_admin) - or (command.permission == "local_user" and hostname == module.host) - or (command.permission == "any") then + if check_permissions(event, node, command) then reply:tag("item", { name = command.name, node = node, jid = module:get_host() }); reply:up(); @@ -71,20 +62,14 @@ module:hook("iq-set/host/"..xmlns_cmd..":command", function (event) local node = stanza.tags[1].attr.node local command = commands[node]; if command then - local from = stanza.attr.from; - local admin = is_admin(from, stanza.attr.to); - local global_admin = is_admin(from); - local hostname = jid_host(from); - if (command.permission == "admin" and not admin) - or (command.permission == "global_admin" and not global_admin) - or (command.permission == "local_user" and hostname ~= module.host) then + if not check_permissions(event, node, command) then origin.send(st.error_reply(stanza, "auth", "forbidden", "You don't have permission to execute this command"):up() - :add_child(commands[node]:cmdtag("canceled") + :add_child(command:cmdtag("canceled") :tag("note", {type="error"}):text("You don't have permission to execute this command"))); return true end -- User has permission now execute the command - adhoc_handle_cmd(commands[node], origin, stanza); + adhoc_handle_cmd(command, origin, stanza); return true; end end, 500); diff --git a/plugins/mod_admin_shell.lua b/plugins/mod_admin_shell.lua index e1016425..c03a74b8 100644 --- a/plugins/mod_admin_shell.lua +++ b/plugins/mod_admin_shell.lua @@ -22,10 +22,10 @@ local _G = _G; local prosody = _G.prosody; -local unpack = table.unpack or unpack; -- luacheck: ignore 113 +local unpack = table.unpack; local iterators = require "util.iterators"; local keys, values = iterators.keys, iterators.values; -local jid_bare, jid_split, jid_join = import("util.jid", "bare", "prepped_split", "join"); +local jid_bare, jid_split, jid_join, jid_resource = import("util.jid", "bare", "prepped_split", "join", "resource"); local set, array = require "util.set", require "util.array"; local cert_verify_identity = require "util.x509".verify_identity; local envload = require "util.envload".envload; @@ -36,6 +36,7 @@ local serialization = require "util.serialization"; local serialize_config = serialization.new ({ fatal = false, unquoted = true}); local time = require "util.time"; local promise = require "util.promise"; +local logger = require "util.logger"; local t_insert = table.insert; local t_concat = table.concat; @@ -83,8 +84,8 @@ function runner_callbacks:error(err) self.data.print("Error: "..tostring(err)); end -local function send_repl_output(session, line) - return session.send(st.stanza("repl-output"):text(tostring(line))); +local function send_repl_output(session, line, attr) + return session.send(st.stanza("repl-output", attr):text(tostring(line))); end function console:new_session(admin_session) @@ -99,8 +100,14 @@ function console:new_session(admin_session) end return send_repl_output(admin_session, table.concat(t, "\t")); end; + write = function (t) + return send_repl_output(admin_session, t, { eol = "0" }); + end; serialize = tostring; disconnect = function () admin_session:close(); end; + is_connected = function () + return not not admin_session.conn; + end }; session.env = setmetatable({}, default_env_mt); @@ -126,6 +133,11 @@ local function handle_line(event) session = console:new_session(event.origin); event.origin.shell_session = session; end + + local default_width = 132; -- The common default of 80 is a bit too narrow for e.g. s2s:show(), 132 was another common width for hardware terminals + local margin = 2; -- To account for '| ' when lines are printed + session.width = (tonumber(event.stanza.attr.width) or default_width)-margin; + local line = event.stanza:get_text(); local useglobalenv; @@ -213,7 +225,7 @@ function commands.help(session, data) print [[Commands are divided into multiple sections. For help on a particular section, ]] print [[type: help SECTION (for example, 'help c2s'). Sections are: ]] print [[]] - local row = format_table({ { title = "Section"; width = 7 }; { title = "Description"; width = "100%" } }) + local row = format_table({ { title = "Section", width = 7 }, { title = "Description", width = "100%" } }, session.width) print(row()) print(row { "c2s"; "Commands to manage local client-to-server sessions" }) print(row { "s2s"; "Commands to manage sessions between this server and others" }) @@ -229,6 +241,7 @@ function commands.help(session, data) print(row { "dns"; "Commands to manage and inspect the internal DNS resolver" }) print(row { "xmpp"; "Commands for sending XMPP stanzas" }) print(row { "debug"; "Commands for debugging the server" }) + print(row { "watch"; "Commands for watching live logs from the server" }) print(row { "config"; "Reloading the configuration, etc." }) print(row { "columns"; "Information about customizing session listings" }) print(row { "console"; "Help regarding the console itself" }) @@ -256,28 +269,29 @@ function commands.help(session, data) print [[host:deactivate(hostname) - Disconnects all clients on this host and deactivates]] print [[host:list() - List the currently-activated hosts]] elseif section == "user" then - print [[user:create(jid, password, roles) - Create the specified user account]] + print [[user:create(jid, password, role) - Create the specified user account]] print [[user:password(jid, password) - Set the password for the specified user account]] print [[user:roles(jid, host) - Show current roles for an user]] - print [[user:setroles(jid, host, roles) - Set roles for an user (see 'help roles')]] + print [[user:setrole(jid, host, role) - Set primary role of a user (see 'help roles')]] + print [[user:addrole(jid, host, role) - Add a secondary role to a user]] + print [[user:delrole(jid, host, role) - Remove a secondary role from a user]] print [[user:delete(jid) - Permanently remove the specified user account]] print [[user:list(hostname, pattern) - List users on the specified host, optionally filtering with a pattern]] elseif section == "roles" then print [[Roles may grant access or restrict users from certain operations]] print [[Built-in roles are:]] - print [[ prosody:admin - Administrator]] - print [[ (empty set) - Normal user]] + print [[ prosody:user - Normal user (default)]] + print [[ prosody:admin - Host administrator]] + print [[ prosody:operator - Server administrator]] print [[]] - print [[The canonical role format looks like: { ["example:role"] = true }]] - print [[For convenience, the following formats are also accepted:]] - print [["admin" - short for "prosody:admin", the normal admin status (like the admins config option)]] - print [["example:role" - short for {["example:role"]=true}]] - print [[{"example:role"} - short for {["example:role"]=true}]] + print [[Roles can be assigned using the user management commands (see 'help user').]] elseif section == "muc" then -- TODO `muc:room():foo()` commands print [[muc:create(roomjid, { config }) - Create the specified MUC room with the given config]] print [[muc:list(host) - List rooms on the specified MUC component]] print [[muc:room(roomjid) - Reference the specified MUC room to access MUC API methods]] + print [[muc:occupants(roomjid, filter) - List room occupants, optionally filtered on substring or role]] + print [[muc:affiliations(roomjid, filter) - List affiliated members of the room, optionally filtered on substring or affiliation]] elseif section == "server" then print [[server:version() - Show the server's version number]] print [[server:uptime() - Show how long the server has been running]] @@ -305,6 +319,9 @@ function commands.help(session, data) print [[debug:logevents(host) - Enable logging of fired events on host]] print [[debug:events(host, event) - Show registered event handlers]] print [[debug:timers() - Show information about scheduled timers]] + elseif section == "watch" then + print [[watch:log() - Follow debug logs]] + print [[watch:stanzas(target, filter) - Watch live stanzas matching the specified target and filter]] elseif section == "console" then print [[Hey! Welcome to Prosody's admin console.]] print [[First thing, if you're ever wondering how to get out, simply type 'quit'.]] @@ -335,7 +352,7 @@ function commands.help(session, data) meta_columns[2].width = math.max(meta_columns[2].width or 0, #(spec.title or "")); meta_columns[3].width = math.max(meta_columns[3].width or 0, #(spec.description or "")); end - local row = format_table(meta_columns, 120) + local row = format_table(meta_columns, session.width) print(row()); for column, spec in iterators.sorted_pairs(available_columns) do print(row({ column, spec.title, spec.description })); @@ -481,6 +498,16 @@ function def_env.module:info(name, hosts) local function item_name(item) return item.name; end + local function task_timefmt(t) + if not t then + return "no last run time" + elseif os.difftime(os.time(), t) < 86400 then + return os.date("last run today at %H:%M", t); + else + return os.date("last run %A at %H:%M", t); + end + end + local friendly_descriptions = { ["adhoc-provider"] = "Ad-hoc commands", ["auth-provider"] = "Authentication provider", @@ -498,12 +525,22 @@ function def_env.module:info(name, hosts) ["auth-provider"] = item_name, ["storage-provider"] = item_name, ["http-provider"] = function(item, mod) return mod:http_url(item.name, item.default_path); end, - ["net-provider"] = item_name, + ["net-provider"] = function(item) + local service_name = item.name; + local ports_list = {}; + for _, interface, port in portmanager.get_active_services():iter(service_name, nil, nil) do + table.insert(ports_list, "["..interface.."]:"..port); + end + if not ports_list[1] then + return service_name..": not listening on any ports"; + end + return service_name..": "..table.concat(ports_list, ", "); + end, ["measure"] = function(item) return item.name .. " (" .. suf(item.conf and item.conf.unit, " ") .. item.type .. ")"; end, ["metric"] = function(item) return ("%s (%s%s)%s"):format(item.name, suf(item.mf.unit, " "), item.mf.type_, pre(": ", item.mf.description)); end, - ["task"] = function (item) return string.format("%s (%s)", item.name or item.id, item.when); end + ["task"] = function (item) return string.format("%s (%s, %s)", item.name or item.id, item.when, task_timefmt(item.last)); end }; for host in hosts do @@ -540,14 +577,14 @@ function def_env.module:info(name, hosts) return true; end -function def_env.module:load(name, hosts, config) +function def_env.module:load(name, hosts) hosts = get_hosts_with_module(hosts); -- Load the module for each host local ok, err, count, mod = true, nil, 0; for host in hosts do if (not modulemanager.is_loaded(host, name)) then - mod, err = modulemanager.load(host, name, config); + mod, err = modulemanager.load(host, name); if not mod then ok = false; if err == "global-module-already-loaded" then @@ -805,9 +842,7 @@ available_columns = { mapper = function(conn, session) if not session.secure then return "insecure"; end if not conn or not conn:ssl() then return "secure" end - local sock = conn and conn:socket(); - if not sock then return "secure"; end - local tls_info = sock.info and sock:info(); + local tls_info = conn.ssl_info and conn:ssl_info(); return tls_info and tls_info.protocol or "secure"; end; }; @@ -817,8 +852,7 @@ available_columns = { width = 30; key = "conn"; mapper = function(conn) - local sock = conn and conn:socket(); - local info = sock and sock.info and sock:info(); + local info = conn and conn.ssl_info and conn:ssl_info(); if info then return info.cipher end end; }; @@ -915,6 +949,15 @@ available_columns = { end end }; + role = { + title = "Role"; + description = "Session role"; + width = 20; + key = "role"; + mapper = function(role) + return role and role.name; + end; + } }; local function get_colspec(colspec, default) @@ -935,8 +978,8 @@ end function def_env.c2s:show(match_jid, colspec) local print = self.session.print; - local columns = get_colspec(colspec, { "id"; "jid"; "ipv"; "status"; "secure"; "smacks"; "csi" }); - local row = format_table(columns, 120); + local columns = get_colspec(colspec, { "id"; "jid"; "role"; "ipv"; "status"; "secure"; "smacks"; "csi" }); + local row = format_table(columns, self.session.width); local function match(session) local jid = get_jid(session) @@ -1019,7 +1062,7 @@ end function def_env.s2s:show(match_jid, colspec) local print = self.session.print; local columns = get_colspec(colspec, { "id"; "host"; "dir"; "remote"; "ipv"; "secure"; "s2s_sasl"; "dialback" }); - local row = format_table(columns, 132); + local row = format_table(columns, self.session.width); local function match(session) local host, remote = get_s2s_hosts(session); @@ -1229,18 +1272,18 @@ end function def_env.host:list() local print = self.session.print; local i = 0; - local type; + local host_type; for host, host_session in iterators.sorted_pairs(prosody.hosts, _sort_hosts) do i = i + 1; - type = host_session.type; - if type == "local" then + host_type = host_session.type; + if host_type == "local" then print(host); else - type = module:context(host):get_option_string("component_module", type); - if type ~= "component" then - type = type .. " component"; + host_type = module:context(host):get_option_string("component_module", host_type); + if host_type ~= "component" then + host_type = host_type .. " component"; end - print(("%s (%s)"):format(host, type)); + print(("%s (%s)"):format(host, host_type)); end end return true, i.." hosts"; @@ -1307,6 +1350,20 @@ local function check_muc(jid) return room_name, host; end +local function get_muc(room_jid) + local room_name, host = check_muc(room_jid); + if not room_name then + return room_name, host; + end + local room_obj = prosody.hosts[host].modules.muc.get_room_from_jid(room_jid); + if not room_obj then + return nil, "No such room: "..room_jid; + end + return room_obj; +end + +local muc_util = module:require"muc/util"; + function def_env.muc:create(room_jid, config) local room_name, host = check_muc(room_jid); if not room_name then @@ -1319,13 +1376,9 @@ function def_env.muc:create(room_jid, config) end function def_env.muc:room(room_jid) - local room_name, host = check_muc(room_jid); - if not room_name then - return room_name, host; - end - local room_obj = prosody.hosts[host].modules.muc.get_room_from_jid(room_jid); + local room_obj, err = get_muc(room_jid); if not room_obj then - return nil, "No such room: "..room_jid; + return room_obj, err; end return setmetatable({ room = room_obj }, console_room_mt); end @@ -1344,34 +1397,128 @@ function def_env.muc:list(host) return true, c.." rooms"; end -local um = require"core.usermanager"; +function def_env.muc:occupants(room_jid, filter) + local room_obj, err = get_muc(room_jid); + if not room_obj then + return room_obj, err; + end + + local print = self.session.print; + local row = format_table({ + { title = "Role"; width = 12; key = "role" }; -- longest role name + { title = "JID"; width = "75%"; key = "bare_jid" }; + { title = "Nickname"; width = "25%"; key = "nick"; mapper = jid_resource }; + }, self.session.width); + local occupants = array.collect(iterators.select(2, room_obj:each_occupant())); + local total = #occupants; + if filter then + occupants:filter(function(occupant) + return occupant.role == filter or jid_resource(occupant.nick):find(filter, 1, true); + end); + end + local displayed = #occupants; + occupants:sort(function(a, b) + if a.role ~= b.role then + return muc_util.valid_roles[a.role] > muc_util.valid_roles[b.role]; + else + return a.bare_jid < b.bare_jid; + end + end); + + if displayed == 0 then + return true, ("%d out of %d occupant%s listed"):format(displayed, total, total ~= 1 and "s" or "") + end + + print(row()); + for _, occupant in ipairs(occupants) do + print(row(occupant)); + end + + if total == displayed then + return true, ("%d occupant%s listed"):format(total, total ~= 1 and "s" or "") + else + return true, ("%d out of %d occupant%s listed"):format(displayed, total, total ~= 1 and "s" or "") + end +end -local function coerce_roles(roles) - if roles == "admin" then roles = "prosody:admin"; end - if type(roles) == "string" then roles = { [roles] = true }; end - if roles[1] then for i, role in ipairs(roles) do roles[role], roles[i] = true, nil; end end - return roles; +function def_env.muc:affiliations(room_jid, filter) + local room_obj, err = get_muc(room_jid); + if not room_obj then + return room_obj, err; + end + + local print = self.session.print; + local row = format_table({ + { title = "Affiliation"; width = 12 }; -- longest affiliation name + { title = "JID"; width = "75%" }; + { title = "Nickname"; width = "25%"; key = "reserved_nickname" }; + }, self.session.width); + local affiliated = array(); + for affiliated_jid, affiliation, affiliation_data in room_obj:each_affiliation() do + affiliated:push(setmetatable({ affiliation; affiliated_jid }, { __index = affiliation_data })); + end + + local total = #affiliated; + if filter then + affiliated:filter(function(affiliation) + return filter == affiliation[1] or affiliation[2]:find(filter, 1, true); + end); + end + local displayed = #affiliated; + local aff_ranking = muc_util.valid_affiliations; + affiliated:sort(function(a, b) + if a[1] ~= b[1] then + return aff_ranking[a[1]] > aff_ranking[b[1]]; + else + return a[2] < b[2]; + end + end); + + if displayed == 0 then + return true, ("%d out of %d affiliations%s listed"):format(displayed, total, total ~= 1 and "s" or "") + end + + print(row()); + for _, affiliation in ipairs(affiliated) do + print(row(affiliation)); + end + + + if total == displayed then + return true, ("%d affiliation%s listed"):format(total, total ~= 1 and "s" or "") + else + return true, ("%d out of %d affiliation%s listed"):format(displayed, total, total ~= 1 and "s" or "") + end end +local um = require"core.usermanager"; + def_env.user = {}; -function def_env.user:create(jid, password, roles) +function def_env.user:create(jid, password, role) local username, host = jid_split(jid); if not prosody.hosts[host] then return nil, "No such host: "..host; elseif um.user_exists(username, host) then return nil, "User exists"; end - local ok, err = um.create_user(username, password, host); - if ok then - if ok and roles then - roles = coerce_roles(roles); - local roles_ok, rerr = um.set_roles(jid, host, roles); - if not roles_ok then return nil, "User created, but could not set roles: " .. tostring(rerr); end - end - return true, "User created"; - else + local ok, err = um.create_user(username, nil, host); + if not ok then return nil, "Could not create user: "..err; end + + if role then + local role_ok, rerr = um.set_user_role(jid, host, role); + if not role_ok then + return nil, "Could not set role: " .. tostring(rerr); + end + end + + local ok, err = um.set_password(username, password, host, nil); + if not ok then + return nil, "Could not set password for user: "..err; + end + + return true, "User created"; end function def_env.user:delete(jid) @@ -1404,41 +1551,64 @@ function def_env.user:password(jid, password) end end -function def_env.user:roles(jid, host, new_roles) - if new_roles or type(host) == "table" then - return nil, "Use user:setroles(jid, host, roles) to change user roles"; - end +function def_env.user:role(jid, host) + local print = self.session.print; local username, userhost = jid_split(jid); if host == nil then host = userhost; end - if host ~= "*" and not prosody.hosts[host] then + if not prosody.hosts[host] then return nil, "No such host: "..host; elseif prosody.hosts[userhost] and not um.user_exists(username, userhost) then return nil, "No such user"; end - local roles = um.get_roles(jid, host); - if not roles then return true, "No roles"; end - local count = 0; - local print = self.session.print; - for role in pairs(roles) do + + local primary_role = um.get_user_role(username, host); + local secondary_roles = um.get_user_secondary_roles(username, host); + + print(primary_role and primary_role.name or "<none>"); + + local count = primary_role and 1 or 0; + for role_name in pairs(secondary_roles or {}) do count = count + 1; - print(role); + print(role_name.." (secondary)"); end + return true, count == 1 and "1 role" or count.." roles"; end -def_env.user.showroles = def_env.user.roles; -- COMPAT +def_env.user.roles = def_env.user.role; --- user:roles("someone@example.com", "example.com", {"prosody:admin"}) --- user:roles("someone@example.com", {"prosody:admin"}) -function def_env.user:setroles(jid, host, new_roles) +-- user:setrole("someone@example.com", "example.com", "prosody:admin") +-- user:setrole("someone@example.com", "prosody:admin") +function def_env.user:setrole(jid, host, new_role) local username, userhost = jid_split(jid); - if new_roles == nil then host, new_roles = userhost, host; end - if host ~= "*" and not prosody.hosts[host] then + if new_role == nil then host, new_role = userhost, host; end + if not prosody.hosts[host] then return nil, "No such host: "..host; elseif prosody.hosts[userhost] and not um.user_exists(username, userhost) then return nil, "No such user"; end - if host == "*" then host = nil; end - return um.set_roles(jid, host, coerce_roles(new_roles)); + return um.set_user_role(username, host, new_role); +end + +function def_env.user:addrole(jid, host, new_role) + local username, userhost = jid_split(jid); + if new_role == nil then host, new_role = userhost, host; end + if not prosody.hosts[host] then + return nil, "No such host: "..host; + elseif prosody.hosts[userhost] and not um.user_exists(username, userhost) then + return nil, "No such user"; + end + return um.add_user_secondary_role(username, host, new_role); +end + +function def_env.user:delrole(jid, host, role_name) + local username, userhost = jid_split(jid); + if role_name == nil then host, role_name = userhost, host; end + if not prosody.hosts[host] then + return nil, "No such host: "..host; + elseif prosody.hosts[userhost] and not um.user_exists(username, userhost) then + return nil, "No such user"; + end + return um.remove_user_secondary_role(username, host, role_name); end -- TODO switch to table view, include roles @@ -1509,7 +1679,7 @@ function def_env.xmpp:ping(localhost, remotehost, timeout) module:unhook("s2sin-established", onestablished); module:unhook("s2s-destroyed", ondestroyed); end):next(function(pong) - return ("pong from %s in %gs"):format(pong.stanza.attr.from, time.now() - time_start); + return ("pong from %s on %s in %gs"):format(pong.stanza.attr.from, pong.origin.id, time.now() - time_start); end); end @@ -1561,7 +1731,7 @@ function def_env.http:list(hosts) local output = format_table({ { title = "Module", width = "20%" }, { title = "URL", width = "80%" }, - }, 132); + }, self.session.width); for _, host in ipairs(hosts) do local http_apps = modulemanager.get_items("http-provider", host); @@ -1592,6 +1762,60 @@ function def_env.http:list(hosts) return true; end +def_env.watch = {}; + +function def_env.watch:log() + local writing = false; + local sink = logger.add_simple_sink(function (source, level, message) + if writing then return; end + writing = true; + self.session.print(source, level, message); + writing = false; + end); + + while self.session.is_connected() do + async.sleep(3); + end + if not logger.remove_sink(sink) then + module:log("warn", "Unable to remove watch:log() sink"); + end +end + +local stanza_watchers = module:require("mod_debug_stanzas/watcher"); +function def_env.watch:stanzas(target_spec, filter_spec) + local function handler(event_type, stanza, session) + if stanza then + if event_type == "sent" then + self.session.print(("\n<!-- sent to %s -->"):format(session.id)); + elseif event_type == "received" then + self.session.print(("\n<!-- received from %s -->"):format(session.id)); + else + self.session.print(("\n<!-- %s (%s) -->"):format(event_type, session.id)); + end + self.session.print(stanza); + elseif session then + self.session.print("\n<!-- session "..session.id.." "..event_type.." -->"); + elseif event_type then + self.session.print("\n<!-- "..event_type.." -->"); + end + end + + stanza_watchers.add({ + target_spec = { + jid = target_spec; + }; + filter_spec = filter_spec and { + with_jid = filter_spec; + }; + }, handler); + + while self.session.is_connected() do + async.sleep(3); + end + + stanza_watchers.remove(handler); +end + def_env.debug = {}; function def_env.debug:logevents(host) @@ -1935,6 +2159,10 @@ function def_env.stats:show(name_filter) end +function module.unload() + stanza_watchers.cleanup(); +end + ------------- diff --git a/plugins/mod_admin_socket.lua b/plugins/mod_admin_socket.lua index 157e746c..fdf099db 100644 --- a/plugins/mod_admin_socket.lua +++ b/plugins/mod_admin_socket.lua @@ -8,7 +8,7 @@ if have_unix and type(unix) == "function" then -- constructor was exported instead of a module table. Due to the lack of a -- proper release of LuaSocket, distros have settled on shipping either the -- last RC tag or some commit since then. - -- Here we accomodate both variants. + -- Here we accommodate both variants. unix = { stream = unix }; end if not have_unix or type(unix) ~= "table" then diff --git a/plugins/mod_announce.lua b/plugins/mod_announce.lua index c742ebb8..8161d4ba 100644 --- a/plugins/mod_announce.lua +++ b/plugins/mod_announce.lua @@ -9,7 +9,6 @@ local st, jid = require "util.stanza", require "util.jid"; local hosts = prosody.hosts; -local is_admin = require "core.usermanager".is_admin; function send_to_online(message, host) local sessions; @@ -34,6 +33,7 @@ function send_to_online(message, host) return c; end +module:default_permission("prosody:admin", ":send-announcement"); -- Old <message>-based jabberd-style announcement sending function handle_announcement(event) @@ -45,8 +45,8 @@ function handle_announcement(event) return; -- Not an announcement end - if not is_admin(stanza.attr.from, host) then - -- Not an admin? Not allowed! + if not module:may(":send-announcement", event) then + -- Not allowed! module:log("warn", "Non-admin '%s' tried to send server announcement", stanza.attr.from); return; end diff --git a/plugins/mod_auth_insecure.lua b/plugins/mod_auth_insecure.lua index dc5ee616..5428d1fa 100644 --- a/plugins/mod_auth_insecure.lua +++ b/plugins/mod_auth_insecure.lua @@ -27,6 +27,7 @@ function provider.set_password(username, password) return nil, "Password fails SASLprep."; end if account then + account.updated = os.time(); account.password = password; return datamanager.store(username, host, "accounts", account); end @@ -38,7 +39,8 @@ function provider.user_exists(username) end function provider.create_user(username, password) - return datamanager.store(username, host, "accounts", {password = password}); + local now = os.time(); + return datamanager.store(username, host, "accounts", { created = now; updated = now; password = password }); end function provider.delete_user(username) diff --git a/plugins/mod_auth_internal_hashed.lua b/plugins/mod_auth_internal_hashed.lua index cf851eef..ddff31e9 100644 --- a/plugins/mod_auth_internal_hashed.lua +++ b/plugins/mod_auth_internal_hashed.lua @@ -86,11 +86,21 @@ function provider.set_password(username, password) account.server_key = server_key_hex account.password = nil; + account.updated = os.time(); return accounts:set(username, account); end return nil, "Account not available."; end +function provider.get_account_info(username) + local account = accounts:get(username); + if not account then return nil, "Account not available"; end + return { + created = account.created; + password_updated = account.updated; + }; +end + function provider.user_exists(username) local account = accounts:get(username); if not account then @@ -105,8 +115,9 @@ function provider.users() end function provider.create_user(username, password) + local now = os.time(); if password == nil then - return accounts:set(username, {}); + return accounts:set(username, { created = now; updated = now; disabled = true }); end local salt = generate_uuid(); local valid, stored_key, server_key = get_auth_db(password, salt, default_iteration_count); @@ -117,7 +128,8 @@ function provider.create_user(username, password) local server_key_hex = to_hex(server_key); return accounts:set(username, { stored_key = stored_key_hex, server_key = server_key_hex, - salt = salt, iteration_count = default_iteration_count + salt = salt, iteration_count = default_iteration_count, + created = now, updated = now; }); end diff --git a/plugins/mod_auth_internal_plain.lua b/plugins/mod_auth_internal_plain.lua index 8a50e820..0f65323c 100644 --- a/plugins/mod_auth_internal_plain.lua +++ b/plugins/mod_auth_internal_plain.lua @@ -48,11 +48,21 @@ function provider.set_password(username, password) local account = accounts:get(username); if account then account.password = password; + account.updated = os.time(); return accounts:set(username, account); end return nil, "Account not available."; end +function provider.get_account_info(username) + local account = accounts:get(username); + if not account then return nil, "Account not available"; end + return { + created = account.created; + password_updated = account.updated; + }; +end + function provider.user_exists(username) local account = accounts:get(username); if not account then @@ -71,7 +81,11 @@ function provider.create_user(username, password) if not password then return nil, "Password fails SASLprep."; end - return accounts:set(username, {password = password}); + local now = os.time(); + return accounts:set(username, { + password = password; + created = now, updated = now; + }); end function provider.delete_user(username) diff --git a/plugins/mod_auth_ldap.lua b/plugins/mod_auth_ldap.lua index 4d484aaa..a3ea880c 100644 --- a/plugins/mod_auth_ldap.lua +++ b/plugins/mod_auth_ldap.lua @@ -1,6 +1,5 @@ -- mod_auth_ldap -local jid_split = require "util.jid".split; local new_sasl = require "util.sasl".new; local lualdap = require "lualdap"; @@ -21,6 +20,13 @@ local ldap_admins = module:get_option_string("ldap_admin_filter", module:get_option_string("ldap_admins")); -- COMPAT with mistake in documentation local host = ldap_filter_escape(module:get_option_string("realm", module.host)); +if ldap_admins then + module:log("error", "The 'ldap_admin_filter' option has been deprecated, ".. + "and will be ignored. Equivalent functionality may be added in ".. + "the future if there is demand." + ); +end + -- Initiate connection local ld = nil; module.unload = function() if ld then pcall(ld, ld.close); end end @@ -133,22 +139,4 @@ else module:log("error", "Unsupported ldap_mode %s", tostring(ldap_mode)); end -if ldap_admins then - function provider.is_admin(jid) - local username, user_host = jid_split(jid); - if user_host ~= module.host then - return false; - end - return ldap_do("search", 2, { - base = ldap_base; - scope = ldap_scope; - sizelimit = 1; - filter = ldap_admins:gsub("%$(%a+)", { - user = ldap_filter_escape(username); - host = host; - }); - }); - end -end - module:provides("auth", provider); diff --git a/plugins/mod_authz_internal.lua b/plugins/mod_authz_internal.lua index 17687959..c2895613 100644 --- a/plugins/mod_authz_internal.lua +++ b/plugins/mod_authz_internal.lua @@ -1,59 +1,330 @@ local array = require "util.array"; local it = require "util.iterators"; local set = require "util.set"; -local jid_split = require "util.jid".split; +local jid_split, jid_bare, jid_host = import("util.jid", "split", "bare", "host"); local normalize = require "util.jid".prep; +local roles = require "util.roles"; + +local config_global_admin_jids = module:context("*"):get_option_set("admins", {}) / normalize; local config_admin_jids = module:get_option_inherited_set("admins", {}) / normalize; local host = module.host; -local role_store = module:open_store("roles"); -local role_map_store = module:open_store("roles", "map"); +local host_suffix = host:gsub("^[^%.]+%.", ""); + +local hosts = prosody.hosts; +local is_component = hosts[host].type == "component"; +local host_user_role, server_user_role, public_user_role; +if is_component then + host_user_role = module:get_option_string("host_user_role", "prosody:user"); + server_user_role = module:get_option_string("server_user_role"); + public_user_role = module:get_option_string("public_user_role"); +end + +local role_store = module:open_store("account_roles"); +local role_map_store = module:open_store("account_roles", "map"); + +local role_registry = {}; + +function register_role(role) + if role_registry[role.name] ~= nil then + return error("A role '"..role.name.."' is already registered"); + end + if not roles.is_role(role) then + -- Convert table syntax to real role object + for i, inherited_role in ipairs(role.inherits or {}) do + if type(inherited_role) == "string" then + role.inherits[i] = assert(role_registry[inherited_role], "The named role '"..inherited_role.."' is not registered"); + end + end + if not role.permissions then role.permissions = {}; end + for _, allow_permission in ipairs(role.allow or {}) do + role.permissions[allow_permission] = true; + end + for _, deny_permission in ipairs(role.deny or {}) do + role.permissions[deny_permission] = false; + end + role = roles.new(role); + end + role_registry[role.name] = role; +end + +-- Default roles +register_role { + name = "prosody:restricted"; + priority = 15; +}; + +register_role { + name = "prosody:user"; + priority = 25; + inherits = { "prosody:restricted" }; +}; + +register_role { + name = "prosody:admin"; + priority = 50; + inherits = { "prosody:user" }; +}; + +register_role { + name = "prosody:operator"; + priority = 75; + inherits = { "prosody:admin" }; +}; + + +-- Process custom roles from config + +local custom_roles = module:get_option("custom_roles", {}); +for n, role_config in ipairs(custom_roles) do + local ok, err = pcall(register_role, role_config); + if not ok then + module:log("error", "Error registering custom role %s: %s", role_config.name or tostring(n), err); + end +end + +-- Process custom permissions from config -local admin_role = { ["prosody:admin"] = true }; +local config_add_perms = module:get_option("add_permissions", {}); +local config_remove_perms = module:get_option("remove_permissions", {}); -function get_user_roles(user) - if config_admin_jids:contains(user.."@"..host) then - return admin_role; +for role_name, added_permissions in pairs(config_add_perms) do + if not role_registry[role_name] then + module:log("error", "Cannot add permissions to unknown role '%s'", role_name); + else + for _, permission in ipairs(added_permissions) do + role_registry[role_name]:set_permission(permission, true, true); + end end - return role_store:get(user); end -function set_user_roles(user, roles) - role_store:set(user, roles) - return true; +for role_name, removed_permissions in pairs(config_remove_perms) do + if not role_registry[role_name] then + module:log("error", "Cannot remove permissions from unknown role '%s'", role_name); + else + for _, permission in ipairs(removed_permissions) do + role_registry[role_name]:set_permission(permission, false, true); + end + end end -function get_users_with_role(role) - local storage_role_users = it.to_array(it.keys(role_map_store:get_all(role) or {})); - if role == "prosody:admin" then - local config_admin_users = config_admin_jids / function (admin_jid) +-- Public API + +-- Get the primary role of a user +function get_user_role(user) + local bare_jid = user.."@"..host; + + -- Check config first + if config_global_admin_jids:contains(bare_jid) then + return role_registry["prosody:operator"]; + elseif config_admin_jids:contains(bare_jid) then + return role_registry["prosody:admin"]; + end + + -- Check storage + local stored_roles, err = role_store:get(user); + if not stored_roles then + if err then + -- Unable to fetch role, fail + return nil, err; + end + -- No role set, use default role + return role_registry["prosody:user"]; + end + if stored_roles._default == nil then + -- No primary role explicitly set, return default + return role_registry["prosody:user"]; + end + local primary_stored_role = role_registry[stored_roles._default]; + if not primary_stored_role then + return nil, "unknown-role"; + end + return primary_stored_role; +end + +-- Set the primary role of a user +function set_user_role(user, role_name) + local role = role_registry[role_name]; + if not role then + return error("Cannot assign default user an unknown role: "..tostring(role_name)); + end + local keys_update = { + _default = role_name; + -- Primary role cannot be secondary role + [role_name] = role_map_store.remove; + }; + if role_name == "prosody:user" then + -- Don't store default + keys_update._default = role_map_store.remove; + end + local ok, err = role_map_store:set_keys(user, keys_update); + if not ok then + return nil, err; + end + return role; +end + +function add_user_secondary_role(user, role_name) + if not role_registry[role_name] then + return error("Cannot assign default user an unknown role: "..tostring(role_name)); + end + role_map_store:set(user, role_name, true); +end + +function remove_user_secondary_role(user, role_name) + role_map_store:set(user, role_name, nil); +end + +function get_user_secondary_roles(user) + local stored_roles, err = role_store:get(user); + if not stored_roles then + if err then + -- Unable to fetch role, fail + return nil, err; + end + -- No role set + return {}; + end + stored_roles._default = nil; + for role_name in pairs(stored_roles) do + stored_roles[role_name] = role_registry[role_name]; + end + return stored_roles; +end + +function user_can_assume_role(user, role_name) + local primary_role = get_user_role(user); + if primary_role and primary_role.role_name == role_name then + return true; + end + local secondary_roles = get_user_secondary_roles(user); + if secondary_roles and secondary_roles[role_name] then + return true; + end + return false; +end + +-- This function is *expensive* +function get_users_with_role(role_name) + local function role_filter(username, default_role) --luacheck: ignore 212/username + return default_role == role_name; + end + local primary_role_users = set.new(it.to_array(it.filter(role_filter, pairs(role_map_store:get_all("_default") or {})))); + local secondary_role_users = set.new(it.to_array(it.keys(role_map_store:get_all(role_name) or {}))); + + local config_set; + if role_name == "prosody:admin" then + config_set = config_admin_jids; + elseif role_name == "prosody:operator" then + config_set = config_global_admin_jids; + end + if config_set then + local config_admin_users = config_set / function (admin_jid) local j_node, j_host = jid_split(admin_jid); if j_host == host then return j_node; end end; - return it.to_array(config_admin_users + set.new(storage_role_users)); + return it.to_array(config_admin_users + primary_role_users + secondary_role_users); end - return storage_role_users; + return it.to_array(primary_role_users + secondary_role_users); end -function get_jid_roles(jid) - if config_admin_jids:contains(jid) then - return admin_role; +function get_jid_role(jid) + local bare_jid = jid_bare(jid); + if config_global_admin_jids:contains(bare_jid) then + return role_registry["prosody:operator"]; + elseif config_admin_jids:contains(bare_jid) then + return role_registry["prosody:admin"]; + elseif is_component then + local user_host = jid_host(bare_jid); + if host_user_role and user_host == host_suffix then + return role_registry[host_user_role]; + elseif server_user_role and hosts[user_host] then + return role_registry[server_user_role]; + elseif public_user_role then + return role_registry[public_user_role]; + end end return nil; end -function set_jid_roles(jid) -- luacheck: ignore 212 +function set_jid_role(jid, role_name) -- luacheck: ignore 212 return false; end -function get_jids_with_role(role) +function get_jids_with_role(role_name) -- Fetch role users from storage - local storage_role_jids = array.map(get_users_with_role(role), function (username) + local storage_role_jids = array.map(get_users_with_role(role_name), function (username) return username.."@"..host; end); - if role == "prosody:admin" then + if role_name == "prosody:admin" then return it.to_array(config_admin_jids + set.new(storage_role_jids)); + elseif role_name == "prosody:operator" then + return it.to_array(config_global_admin_jids + set.new(storage_role_jids)); end return storage_role_jids; end + +function add_default_permission(role_name, action, policy) + local role = role_registry[role_name]; + if not role then + module:log("warn", "Attempt to add default permission for unknown role: %s", role_name); + return nil, "no-such-role"; + end + if policy == nil then policy = true; end + module:log("debug", "Adding policy %s for permission %s on role %s", policy, action, role_name); + return role:set_permission(action, policy); +end + +function get_role_by_name(role_name) + return assert(role_registry[role_name], role_name); +end + +-- COMPAT: Migrate from 0.12 role storage +local function do_migration(migrate_host) + local old_role_store = assert(module:context(migrate_host):open_store("roles")); + local new_role_store = assert(module:context(migrate_host):open_store("account_roles")); + + local migrated, failed, skipped = 0, 0, 0; + -- Iterate all users + for username in assert(old_role_store:users()) do + local old_roles = it.to_array(it.filter(function (k) return k:sub(1,1) ~= "_"; end, it.keys(old_role_store:get(username)))); + if #old_roles == 1 then + local ok, err = new_role_store:set(username, { + _default = old_roles[1]; + }); + if ok then + migrated = migrated + 1; + else + failed = failed + 1; + print("EE: Failed to store new role info for '"..username.."': "..err); + end + else + print("WW: User '"..username.."' has multiple roles and cannot be automatically migrated"); + skipped = skipped + 1; + end + end + return migrated, failed, skipped; +end + +function module.command(arg) + if arg[1] == "migrate" then + table.remove(arg, 1); + local migrate_host = arg[1]; + if not migrate_host or not prosody.hosts[migrate_host] then + print("EE: Please supply a valid host to migrate to the new role storage"); + return 1; + end + + -- Initialize storage layer + require "core.storagemanager".initialize_host(migrate_host); + + print("II: Migrating roles..."); + local migrated, failed, skipped = do_migration(migrate_host); + print(("II: %d migrated, %d failed, %d skipped"):format(migrated, failed, skipped)); + return (failed + skipped == 0) and 0 or 1; + else + print("EE: Unknown command: "..(arg[1] or "<none given>")); + print(" Hint: try 'migrate'?"); + end +end diff --git a/plugins/mod_blocklist.lua b/plugins/mod_blocklist.lua index dad06b62..13e98e00 100644 --- a/plugins/mod_blocklist.lua +++ b/plugins/mod_blocklist.lua @@ -54,6 +54,7 @@ local function set_blocklist(username, blocklist) end -- Migrates from the old mod_privacy storage +-- TODO mod_privacy was removed in 0.10.0, this should be phased out local function migrate_privacy_list(username) local legacy_data = module:open_store("privacy"):get(username); if not legacy_data or not legacy_data.lists or not legacy_data.default then return; end @@ -77,6 +78,13 @@ local function migrate_privacy_list(username) return migrated_data; end +if not module:get_option_boolean("migrate_legacy_blocking", true) then + migrate_privacy_list = function (username) + module:log("debug", "Migrating from mod_privacy disabled, user '%s' will start with a fresh blocklist", username); + return nil; + end +end + local function get_blocklist(username) local blocklist = cache2:get(username); if not blocklist then diff --git a/plugins/mod_c2s.lua b/plugins/mod_c2s.lua index c8f54fa7..9af463b6 100644 --- a/plugins/mod_c2s.lua +++ b/plugins/mod_c2s.lua @@ -117,8 +117,7 @@ function stream_callbacks._streamopened(session, attr) session.secure = true; session.encrypted = true; - local sock = session.conn:socket(); - local info = sock.info and sock:info(); + local info = session.conn:ssl_info(); if type(info) == "table" then (session.log or log)("info", "Stream encrypted (%s with %s)", info.protocol, info.cipher); session.compressed = info.compression; @@ -129,8 +128,13 @@ function stream_callbacks._streamopened(session, attr) end local features = st.stanza("stream:features"); - hosts[session.host].events.fire_event("stream-features", { origin = session, features = features }); + hosts[session.host].events.fire_event("stream-features", { origin = session, features = features, stream = attr }); if features.tags[1] or session.full_jid then + if stanza_size_limit then + features:reset(); + features:tag("limits", { xmlns = "urn:xmpp:stream-limits:0" }) + :text_tag("max-size", string.format("%d", stanza_size_limit)):up(); + end send(features); else if session.secure then @@ -260,9 +264,17 @@ local function disconnect_user_sessions(reason, leave_resource) end module:hook_global("user-password-changed", disconnect_user_sessions({ condition = "reset", text = "Password changed" }, true), 200); -module:hook_global("user-roles-changed", disconnect_user_sessions({ condition = "reset", text = "Roles changed" }), 200); +module:hook_global("user-role-changed", disconnect_user_sessions({ condition = "reset", text = "Role changed" }), 200); module:hook_global("user-deleted", disconnect_user_sessions({ condition = "not-authorized", text = "Account deleted" }), 200); +module:hook_global("c2s-session-updated", function (event) + sessions[event.session.conn] = event.session; + local replaced_conn = event.replaced_conn; + if replaced_conn then + sessions[replaced_conn] = nil; + end +end); + function runner_callbacks:ready() if self.data.conn then self.data.conn:resume(); @@ -295,8 +307,7 @@ function listener.onconnect(conn) session.encrypted = true; -- Check if TLS compression is used - local sock = conn:socket(); - local info = sock.info and sock:info(); + local info = conn:ssl_info(); if type(info) == "table" then (session.log or log)("info", "Stream encrypted (%s with %s)", info.protocol, info.cipher); session.compressed = info.compression; diff --git a/plugins/mod_component.lua b/plugins/mod_component.lua index f57c4381..c1c29b5e 100644 --- a/plugins/mod_component.lua +++ b/plugins/mod_component.lua @@ -17,7 +17,7 @@ local logger = require "util.logger"; local sha1 = require "util.hashes".sha1; local st = require "util.stanza"; -local jid_split = require "util.jid".split; +local jid_host = require "util.jid".host; local new_xmpp_stream = require "util.xmppstream".new; local uuid_gen = require "util.uuid".generate; @@ -222,22 +222,19 @@ function stream_callbacks.handlestanza(session, stanza) end if not stanza.attr.xmlns or stanza.attr.xmlns == "jabber:client" then local from = stanza.attr.from; - if from then - if session.component_validate_from then - local _, domain = jid_split(stanza.attr.from); - if domain ~= session.host then - -- Return error - session.log("warn", "Component sent stanza with missing or invalid 'from' address"); - session:close{ - condition = "invalid-from"; - text = "Component tried to send from address <"..tostring(from) - .."> which is not in domain <"..tostring(session.host)..">"; - }; - return; - end + if session.component_validate_from then + if not from or (jid_host(from) ~= session.host) then + -- Return error + session.log("warn", "Component sent stanza with missing or invalid 'from' address"); + session:close{ + condition = "invalid-from"; + text = "Component tried to send from address <"..(from or "< [missing 'from' attribute] >") + .."> which is not in domain <"..tostring(session.host)..">"; + }; + return; end - else - stanza.attr.from = session.host; -- COMPAT: Strictly we shouldn't allow this + elseif not from then + stanza.attr.from = session.host; end if not stanza.attr.to then session.log("warn", "Rejecting stanza with no 'to' address"); diff --git a/plugins/mod_csi_simple.lua b/plugins/mod_csi_simple.lua index 569916b0..b9a470f5 100644 --- a/plugins/mod_csi_simple.lua +++ b/plugins/mod_csi_simple.lua @@ -116,6 +116,9 @@ local flush_reasons = module:metric( { "reason" } ); +local flush_sizes = module:metric("histogram", "flush_stanza_count", "", "Number of stanzas flushed at once", {}, + { buckets = { 0, 1, 2, 4, 8, 16, 32, 64, 128, 256 } }):with_labels(); + local function manage_buffer(stanza, session) local ctr = session.csi_counter or 0; if session.state ~= "inactive" then @@ -129,6 +132,7 @@ local function manage_buffer(stanza, session) session.csi_measure_buffer_hold = nil; end flush_reasons:with_labels(why or "important"):add(1); + flush_sizes:sample(ctr); session.log("debug", "Flushing buffer (%s; queue size is %d)", why or "important", session.csi_counter); session.state = "flushing"; module:fire_event("csi-flushing", { session = session }); @@ -147,6 +151,7 @@ local function flush_buffer(data, session) session.log("debug", "Flushing buffer (%s; queue size is %d)", "client activity", session.csi_counter); session.state = "flushing"; module:fire_event("csi-flushing", { session = session }); + flush_sizes:sample(ctr); flush_reasons:with_labels("client activity"):add(1); if session.csi_measure_buffer_hold then session.csi_measure_buffer_hold(); diff --git a/plugins/mod_debug_stanzas/watcher.lib.lua b/plugins/mod_debug_stanzas/watcher.lib.lua new file mode 100644 index 00000000..e21fc946 --- /dev/null +++ b/plugins/mod_debug_stanzas/watcher.lib.lua @@ -0,0 +1,220 @@ +local filters = require "util.filters"; +local jid = require "util.jid"; +local set = require "util.set"; + +local client_watchers = {}; + +-- active_filters[session] = { +-- filter_func = filter_func; +-- downstream = { cb1, cb2, ... }; +-- } +local active_filters = {}; + +local function subscribe_session_stanzas(session, handler, reason) + if active_filters[session] then + table.insert(active_filters[session].downstream, handler); + if reason then + handler(reason, nil, session); + end + return; + end + local downstream = { handler }; + active_filters[session] = { + filter_in = function (stanza) + module:log("debug", "NOTIFY WATCHER %d", #downstream); + for i = 1, #downstream do + downstream[i]("received", stanza, session); + end + return stanza; + end; + filter_out = function (stanza) + module:log("debug", "NOTIFY WATCHER %d", #downstream); + for i = 1, #downstream do + downstream[i]("sent", stanza, session); + end + return stanza; + end; + downstream = downstream; + }; + filters.add_filter(session, "stanzas/in", active_filters[session].filter_in); + filters.add_filter(session, "stanzas/out", active_filters[session].filter_out); + if reason then + handler(reason, nil, session); + end +end + +local function unsubscribe_session_stanzas(session, handler, reason) + local active_filter = active_filters[session]; + if not active_filter then + return; + end + for i = #active_filter.downstream, 1, -1 do + if active_filter.downstream[i] == handler then + table.remove(active_filter.downstream, i); + if reason then + handler(reason, nil, session); + end + end + end + if #active_filter.downstream == 0 then + filters.remove_filter(session, "stanzas/in", active_filter.filter_in); + filters.remove_filter(session, "stanzas/out", active_filter.filter_out); + end + active_filters[session] = nil; +end + +local function unsubscribe_all_from_session(session, reason) + local active_filter = active_filters[session]; + if not active_filter then + return; + end + for i = #active_filter.downstream, 1, -1 do + local handler = table.remove(active_filter.downstream, i); + if reason then + handler(reason, nil, session); + end + end + filters.remove_filter(session, "stanzas/in", active_filter.filter_in); + filters.remove_filter(session, "stanzas/out", active_filter.filter_out); + active_filters[session] = nil; +end + +local function unsubscribe_handler_from_all(handler, reason) + for session in pairs(active_filters) do + unsubscribe_session_stanzas(session, handler, reason); + end +end + +local s2s_watchers = {}; + +module:hook("s2sin-established", function (event) + for _, watcher in ipairs(s2s_watchers) do + if watcher.target_spec == event.session.from_host then + subscribe_session_stanzas(event.session, watcher.handler, "opened"); + end + end +end); + +module:hook("s2sout-established", function (event) + for _, watcher in ipairs(s2s_watchers) do + if watcher.target_spec == event.session.to_host then + subscribe_session_stanzas(event.session, watcher.handler, "opened"); + end + end +end); + +module:hook("s2s-closed", function (event) + unsubscribe_all_from_session(event.session, "closed"); +end); + +local watched_hosts = set.new(); + +local handler_map = setmetatable({}, { __mode = "kv" }); + +local function add_stanza_watcher(spec, orig_handler) + local function filtering_handler(event_type, stanza, session) + if stanza and spec.filter_spec then + if spec.filter_spec.with_jid then + if event_type == "sent" and (not stanza.attr.from or not jid.compare(stanza.attr.from, spec.filter_spec.with_jid)) then + return; + elseif event_type == "received" and (not stanza.attr.to or not jid.compare(stanza.attr.to, spec.filter_spec.with_jid)) then + return; + end + end + end + return orig_handler(event_type, stanza, session); + end + handler_map[orig_handler] = filtering_handler; + if spec.target_spec.jid then + local target_is_remote_host = not jid.node(spec.target_spec.jid) and not prosody.hosts[spec.target_spec.jid]; + + if target_is_remote_host then + -- Watch s2s sessions + table.insert(s2s_watchers, { + target_spec = spec.target_spec.jid; + handler = filtering_handler; + orig_handler = orig_handler; + }); + + -- Scan existing s2sin for matches + for session in pairs(prosody.incoming_s2s) do + if spec.target_spec.jid == session.from_host then + subscribe_session_stanzas(session, filtering_handler, "attached"); + end + end + -- Scan existing s2sout for matches + for local_host, local_session in pairs(prosody.hosts) do --luacheck: ignore 213/local_host + for remote_host, remote_session in pairs(local_session.s2sout) do + if spec.target_spec.jid == remote_host then + subscribe_session_stanzas(remote_session, filtering_handler, "attached"); + end + end + end + else + table.insert(client_watchers, { + target_spec = spec.target_spec.jid; + handler = filtering_handler; + orig_handler = orig_handler; + }); + local host = jid.host(spec.target_spec.jid); + if not watched_hosts:contains(host) and prosody.hosts[host] then + module:context(host):hook("resource-bind", function (event) + for _, watcher in ipairs(client_watchers) do + module:log("debug", "NEW CLIENT: %s vs %s", event.session.full_jid, watcher.target_spec); + if jid.compare(event.session.full_jid, watcher.target_spec) then + module:log("debug", "MATCH"); + subscribe_session_stanzas(event.session, watcher.handler, "opened"); + else + module:log("debug", "NO MATCH"); + end + end + end); + + module:context(host):hook("resource-unbind", function (event) + unsubscribe_all_from_session(event.session, "closed"); + end); + + watched_hosts:add(host); + end + for full_jid, session in pairs(prosody.full_sessions) do + if jid.compare(full_jid, spec.target_spec.jid) then + subscribe_session_stanzas(session, filtering_handler, "attached"); + end + end + end + else + error("No recognized target selector"); + end +end + +local function remove_stanza_watcher(orig_handler) + local handler = handler_map[orig_handler]; + unsubscribe_handler_from_all(handler, "detached"); + handler_map[orig_handler] = nil; + + for i = #client_watchers, 1, -1 do + if client_watchers[i].orig_handler == orig_handler then + table.remove(client_watchers, i); + end + end + + for i = #s2s_watchers, 1, -1 do + if s2s_watchers[i].orig_handler == orig_handler then + table.remove(s2s_watchers, i); + end + end +end + +local function cleanup(reason) + client_watchers = {}; + s2s_watchers = {}; + for session in pairs(active_filters) do + unsubscribe_all_from_session(session, reason or "cancelled"); + end +end + +return { + add = add_stanza_watcher; + remove = remove_stanza_watcher; + cleanup = cleanup; +}; diff --git a/plugins/mod_disco.lua b/plugins/mod_disco.lua index 79249c52..7b3e5caf 100644 --- a/plugins/mod_disco.lua +++ b/plugins/mod_disco.lua @@ -8,7 +8,6 @@ local get_children = require "core.hostmanager".get_children; local is_contact_subscribed = require "core.rostermanager".is_contact_subscribed; -local um_is_admin = require "core.usermanager".is_admin; local jid_split = require "util.jid".split; local jid_bare = require "util.jid".bare; local st = require "util.stanza" @@ -162,14 +161,16 @@ module:hook("s2s-stream-features", function (event) end end); +module:default_permission("prosody:admin", ":be-discovered-admin"); + -- Handle disco requests to user accounts if module:get_host_type() ~= "local" then return end -- skip for components module:hook("iq-get/bare/http://jabber.org/protocol/disco#info:query", function(event) local origin, stanza = event.origin, event.stanza; local node = stanza.tags[1].attr.node; local username = jid_split(stanza.attr.to) or origin.username; - local is_admin = um_is_admin(stanza.attr.to or origin.full_jid, module.host) - if not stanza.attr.to or (expose_admins and is_admin) or is_contact_subscribed(username, module.host, jid_bare(stanza.attr.from)) then + local target_is_admin = module:may(":be-discovered-admin", stanza.attr.to or origin.full_jid); + if not stanza.attr.to or (expose_admins and target_is_admin) or is_contact_subscribed(username, module.host, jid_bare(stanza.attr.from)) then if node and node ~= "" then local reply = st.reply(stanza):tag('query', {xmlns='http://jabber.org/protocol/disco#info', node=node}); if not reply.attr.from then reply.attr.from = origin.username.."@"..origin.host; end -- COMPAT To satisfy Psi when querying own account @@ -185,7 +186,7 @@ module:hook("iq-get/bare/http://jabber.org/protocol/disco#info:query", function( end local reply = st.reply(stanza):tag('query', {xmlns='http://jabber.org/protocol/disco#info'}); if not reply.attr.from then reply.attr.from = origin.username.."@"..origin.host; end -- COMPAT To satisfy Psi when querying own account - if is_admin then + if target_is_admin then reply:tag('identity', {category='account', type='admin'}):up(); elseif prosody.hosts[module.host].users.name == "anonymous" then reply:tag('identity', {category='account', type='anonymous'}):up(); diff --git a/plugins/mod_external_services.lua b/plugins/mod_external_services.lua index ae418fd8..6a76b922 100644 --- a/plugins/mod_external_services.lua +++ b/plugins/mod_external_services.lua @@ -16,7 +16,7 @@ local configured_services = module:get_option_array("external_services", {}); local access = module:get_option_set("external_service_access", {}); --- https://tools.ietf.org/html/draft-uberti-behave-turn-rest-00 +-- https://datatracker.ietf.org/doc/html/draft-uberti-behave-turn-rest-00 local function behave_turn_rest_credentials(srv, item, secret) local ttl = default_ttl; if type(item.ttl) == "number" then diff --git a/plugins/mod_http_file_share.lua b/plugins/mod_http_file_share.lua index b6200628..c45e7732 100644 --- a/plugins/mod_http_file_share.lua +++ b/plugins/mod_http_file_share.lua @@ -12,7 +12,6 @@ local jid = require "util.jid"; local st = require "util.stanza"; local url = require "socket.url"; local dm = require "core.storagemanager".olddm; -local jwt = require "util.jwt"; local errors = require "util.error"; local dataform = require "util.dataforms".new; local urlencode = require "util.http".urlencode; @@ -44,6 +43,8 @@ local expiry = module:get_option_number(module.name .. "_expires_after", 7 * 864 local daily_quota = module:get_option_number(module.name .. "_daily_quota", file_size_limit*10); -- 100 MB / day local total_storage_limit = module:get_option_number(module.name.."_global_quota", unlimited); +local create_jwt, verify_jwt = require "util.jwt".init("HS256", secret); + local access = module:get_option_set(module.name .. "_access", {}); if not external_base_url then @@ -169,16 +170,13 @@ function may_upload(uploader, filename, filesize, filetype) -- > boolean, error end function get_authz(slot, uploader, filename, filesize, filetype) -local now = os.time(); - return jwt.sign(secret, { + return create_jwt({ -- token properties sub = uploader; - iat = now; - exp = now+300; -- slot properties slot = slot; - expires = expiry >= 0 and (now+expiry) or nil; + expires = expiry >= 0 and (os.time()+expiry) or nil; -- file properties filename = filename; filesize = filesize; @@ -249,32 +247,34 @@ end function handle_upload(event, path) -- PUT /upload/:slot local request = event.request; - local authz = request.headers.authorization; - if authz then - authz = authz:match("^Bearer (.*)") - end - if not authz then - module:log("debug", "Missing or malformed Authorization header"); - event.response.headers.www_authenticate = "Bearer"; - return 401; - end - local authed, upload_info = jwt.verify(secret, authz); - if not (authed and type(upload_info) == "table" and type(upload_info.exp) == "number") then - module:log("debug", "Unauthorized or invalid token: %s, %q", authed, upload_info); - return 401; - end - if not request.body_sink and upload_info.exp < os.time() then - module:log("debug", "Authorization token expired on %s", dt.datetime(upload_info.exp)); - return 410; - end - if not path or upload_info.slot ~= path:match("^[^/]+") then - module:log("debug", "Invalid upload slot: %q, path: %q", upload_info.slot, path); - return 400; - end - if request.headers.content_length and tonumber(request.headers.content_length) ~= upload_info.filesize then - return 413; - -- Note: We don't know the size if the upload is streamed in chunked encoding, - -- so we also check the final file size on completion. + local upload_info = request.http_file_share_upload_info; + + if not upload_info then -- Initial handling of request + local authz = request.headers.authorization; + if authz then + authz = authz:match("^Bearer (.*)") + end + if not authz then + module:log("debug", "Missing or malformed Authorization header"); + event.response.headers.www_authenticate = "Bearer"; + return 401; + end + local authed, authed_upload_info = verify_jwt(authz); + if not authed then + module:log("debug", "Unauthorized or invalid token: %s, %q", authz, authed_upload_info); + return 401; + end + if not path or authed_upload_info.slot ~= path:match("^[^/]+") then + module:log("debug", "Invalid upload slot: %q, path: %q", authed_upload_info.slot, path); + return 400; + end + if request.headers.content_length and tonumber(request.headers.content_length) ~= authed_upload_info.filesize then + return 413; + -- Note: We don't know the size if the upload is streamed in chunked encoding, + -- so we also check the final file size on completion. + end + upload_info = authed_upload_info; + request.http_file_share_upload_info = upload_info; end local filename = get_filename(upload_info.slot, true); diff --git a/plugins/mod_invites_adhoc.lua b/plugins/mod_invites_adhoc.lua index bd6f0c2e..04c74461 100644 --- a/plugins/mod_invites_adhoc.lua +++ b/plugins/mod_invites_adhoc.lua @@ -2,7 +2,6 @@ local dataforms = require "util.dataforms"; local datetime = require "util.datetime"; local split_jid = require "util.jid".split; -local usermanager = require "core.usermanager"; local new_adhoc = module:require("adhoc").new; @@ -13,8 +12,7 @@ local allow_user_invites = module:get_option_boolean("allow_user_invites", false -- on the server, use the option above instead. local allow_contact_invites = module:get_option_boolean("allow_contact_invites", true); -local allow_user_invite_roles = module:get_option_set("allow_user_invites_by_roles"); -local deny_user_invite_roles = module:get_option_set("deny_user_invites_by_roles"); +module:default_permission(allow_user_invites and "prosody:user" or "prosody:admin", ":invite-users"); local invites; if prosody.shutdown then -- COMPAT hack to detect prosodyctl @@ -42,36 +40,8 @@ local invite_result_form = dataforms.new({ -- This is for checking if the specified JID may create invites -- that allow people to register accounts on this host. -local function may_invite_new_users(jid) - if usermanager.get_roles then - local user_roles = usermanager.get_roles(jid, module.host); - if not user_roles then - -- User has no roles we can check, just return default - return allow_user_invites; - end - - if user_roles["prosody:admin"] then - return true; - end - if allow_user_invite_roles then - for allowed_role in allow_user_invite_roles do - if user_roles[allowed_role] then - return true; - end - end - end - if deny_user_invite_roles then - for denied_role in deny_user_invite_roles do - if user_roles[denied_role] then - return false; - end - end - end - elseif usermanager.is_admin(jid, module.host) then -- COMPAT w/0.11 - return true; -- Admins may always create invitations - end - -- No role matches, so whatever the default is - return allow_user_invites; +local function may_invite_new_users(context) + return module:may(":invite-users", context); end module:depends("adhoc"); @@ -91,7 +61,7 @@ module:provides("adhoc", new_adhoc("Create new contact invite", "urn:xmpp:invite }; }; end - local invite = invites.create_contact(username, may_invite_new_users(data.from), { + local invite = invites.create_contact(username, may_invite_new_users(data), { source = data.from }); --TODO: check errors diff --git a/plugins/mod_mam/mod_mam.lua b/plugins/mod_mam/mod_mam.lua index bebee812..a726746c 100644 --- a/plugins/mod_mam/mod_mam.lua +++ b/plugins/mod_mam/mod_mam.lua @@ -34,9 +34,9 @@ local rm_load_roster = require "core.rostermanager".load_roster; local is_stanza = st.is_stanza; local tostring = tostring; -local time_now = os.time; +local time_now = require "util.time".now; local m_min = math.min; -local timestamp, datestamp = import( "util.datetime", "datetime", "date"); +local timestamp, datestamp = import("util.datetime", "datetime", "date"); local default_max_items, max_max_items = 20, module:get_option_number("max_archive_query_results", 50); local strip_tags = module:get_option_set("dont_archive_namespaces", { "http://jabber.org/protocol/chatstates" }); @@ -53,8 +53,12 @@ if not archive.find then end local use_total = module:get_option_boolean("mam_include_total", true); -function schedule_cleanup() - -- replaced later if cleanup is enabled +function schedule_cleanup(_username, _date) -- luacheck: ignore 212 + -- Called to make a note of which users have messages on which days, which in + -- turn is used to optimize the message expiry routine. + -- + -- This noop is conditionally replaced later depending on retention settings + -- and storage backend capabilities. end -- Handle prefs. diff --git a/plugins/mod_muc_mam.lua b/plugins/mod_muc_mam.lua index 1c34b8af..1655bc1d 100644 --- a/plugins/mod_muc_mam.lua +++ b/plugins/mod_muc_mam.lua @@ -29,7 +29,7 @@ local get_room_from_jid = mod_muc.get_room_from_jid; local is_stanza = st.is_stanza; local tostring = tostring; -local time_now = os.time; +local time_now = require "util.time".now; local m_min = math.min; local timestamp, datestamp = import("util.datetime", "datetime", "date"); local default_max_items, max_max_items = 20, module:get_option_number("max_archive_query_results", 50); diff --git a/plugins/mod_pep_simple.lua b/plugins/mod_pep_simple.lua index e686b99b..1314aece 100644 --- a/plugins/mod_pep_simple.lua +++ b/plugins/mod_pep_simple.lua @@ -14,7 +14,7 @@ local is_contact_subscribed = require "core.rostermanager".is_contact_subscribed local pairs = pairs; local next = next; local type = type; -local unpack = table.unpack or unpack; -- luacheck: ignore 113 +local unpack = table.unpack; local calculate_hash = require "util.caps".calculate_hash; local core_post_stanza = prosody.core_post_stanza; local bare_sessions = prosody.bare_sessions; diff --git a/plugins/mod_pubsub/mod_pubsub.lua b/plugins/mod_pubsub/mod_pubsub.lua index ef31f326..f51e8fe4 100644 --- a/plugins/mod_pubsub/mod_pubsub.lua +++ b/plugins/mod_pubsub/mod_pubsub.lua @@ -1,7 +1,6 @@ local pubsub = require "util.pubsub"; local st = require "util.stanza"; local jid_bare = require "util.jid".bare; -local usermanager = require "core.usermanager"; local new_id = require "util.id".medium; local storagemanager = require "core.storagemanager"; local xtemplate = require "util.xtemplate"; @@ -177,9 +176,10 @@ module:hook("host-disco-items", function (event) end); local admin_aff = module:get_option_string("default_admin_affiliation", "owner"); +module:default_permission("prosody:admin", ":service-admin"); local function get_affiliation(jid) local bare_jid = jid_bare(jid); - if bare_jid == module.host or usermanager.is_admin(bare_jid, module.host) then + if bare_jid == module.host or module:may(":service-admin", bare_jid) then return admin_aff; end end diff --git a/plugins/mod_pubsub/pubsub.lib.lua b/plugins/mod_pubsub/pubsub.lib.lua index 83cef808..cd3efb09 100644 --- a/plugins/mod_pubsub/pubsub.lib.lua +++ b/plugins/mod_pubsub/pubsub.lib.lua @@ -1,4 +1,4 @@ -local t_unpack = table.unpack or unpack; -- luacheck: ignore 113 +local t_unpack = table.unpack; local time_now = os.time; local jid_prep = require "util.jid".prep; @@ -678,8 +678,7 @@ end function handlers.set_retract(origin, stanza, retract, service) local node, notify = retract.attr.node, retract.attr.notify; notify = (notify == "1") or (notify == "true"); - local item = retract:get_child("item"); - local id = item and item.attr.id + local id = retract:get_child_attr("item", nil, "id"); if not (node and id) then origin.send(pubsub_error_reply(stanza, node and "item-not-found" or "nodeid-required")); return true; diff --git a/plugins/mod_s2s.lua b/plugins/mod_s2s.lua index 300a747e..6d1a76a0 100644 --- a/plugins/mod_s2s.lua +++ b/plugins/mod_s2s.lua @@ -146,17 +146,17 @@ local function bounce_sendq(session, reason) elseif type(reason) == "string" then reason_text = reason; end - for i, data in ipairs(sendq) do - local reply = data[2]; - if reply and not(reply.attr.xmlns) and bouncy_stanzas[reply.name] then - reply.attr.type = "error"; - reply:tag("error", {type = error_type, by = session.from_host}) - :tag(condition, {xmlns = "urn:ietf:params:xml:ns:xmpp-stanzas"}):up(); - if reason_text then - reply:tag("text", {xmlns = "urn:ietf:params:xml:ns:xmpp-stanzas"}) - :text("Server-to-server connection failed: "..reason_text):up(); - end + for i, stanza in ipairs(sendq) do + if not stanza.attr.xmlns and bouncy_stanzas[stanza.name] and stanza.attr.type ~= "error" and stanza.attr.type ~= "result" then + local reply = st.error_reply( + stanza, + error_type, + condition, + reason_text and ("Server-to-server connection failed: "..reason_text) or nil + ); core_process_stanza(dummy, reply); + else + (session.log or log)("debug", "Not eligible for bouncing, discarding %s", stanza:top_tag()); end sendq[i] = nil; end @@ -182,15 +182,11 @@ function route_to_existing_session(event) (host.log or log)("debug", "trying to send over unauthed s2sout to "..to_host); -- Queue stanza until we are able to send it - local queued_item = { - tostring(stanza), - stanza.attr.type ~= "error" and stanza.attr.type ~= "result" and st.reply(stanza); - }; if host.sendq then - t_insert(host.sendq, queued_item); + t_insert(host.sendq, st.clone(stanza)); else -- luacheck: ignore 122 - host.sendq = { queued_item }; + host.sendq = { st.clone(stanza) }; end host.log("debug", "stanza [%s] queued ", stanza.name); return true; @@ -215,7 +211,7 @@ function route_to_new_session(event) -- Store in buffer host_session.bounce_sendq = bounce_sendq; - host_session.sendq = { {tostring(stanza), stanza.attr.type ~= "error" and stanza.attr.type ~= "result" and st.reply(stanza)} }; + host_session.sendq = { st.clone(stanza) }; log("debug", "stanza [%s] queued until connection complete", stanza.name); -- FIXME Cleaner solution to passing extra data from resolvers to net.server -- This mt-clone allows resolvers to add extra data, currently used for DANE TLSA records @@ -251,9 +247,26 @@ function module.add_host(module) end module:hook("route/remote", route_to_existing_session, -1); module:hook("route/remote", route_to_new_session, -10); + module:hook("s2sout-stream-features", function (event) + if stanza_size_limit then + event.features:tag("limits", { xmlns = "urn:xmpp:stream-limits:0" }) + :text_tag("max-size", string.format("%d", stanza_size_limit)):up(); + end + end); + module:hook_tag("urn:xmpp:bidi", "bidi", function(session, stanza) + -- Advertising features on bidi connections where no <stream:features> is sent in the other direction + local limits = stanza:get_child("limits", "urn:xmpp:stream-limits:0"); + if limits then + session.outgoing_stanza_size_limit = tonumber(limits:get_child_text("max-size")); + end + end, 100); module:hook("s2s-authenticated", make_authenticated, -1); module:hook("s2s-read-timeout", keepalive, -1); module:hook_stanza("http://etherx.jabber.org/streams", "features", function (session, stanza) -- luacheck: ignore 212/stanza + local limits = stanza:get_child("limits", "urn:xmpp:stream-limits:0"); + if limits then + session.outgoing_stanza_size_limit = tonumber(limits:get_child_text("max-size")); + end if session.type == "s2sout" then -- Stream is authenticated and we are seem to be done with feature negotiation, -- so the stream is ready for stanzas. RFC 6120 Section 4.3 @@ -279,7 +292,7 @@ function module.add_host(module) function module.unload() if module.reloading then return end for _, session in pairs(sessions) do - if session.to_host == module.host or session.from_host == module.host then + if session.host == module.host then session:close("host-gone"); end end @@ -324,8 +337,8 @@ function mark_connected(session) if sendq then session.log("debug", "sending %d queued stanzas across new outgoing connection to %s", #sendq, session.to_host); local send = session.sends2s; - for i, data in ipairs(sendq) do - send(data[1]); + for i, stanza in ipairs(sendq) do + send(stanza); sendq[i] = nil; end session.sendq = nil; @@ -389,10 +402,10 @@ end --- Helper to check that a session peer's certificate is valid local function check_cert_status(session) local host = session.direction == "outgoing" and session.to_host or session.from_host - local conn = session.conn:socket() + local conn = session.conn local cert - if conn.getpeercertificate then - cert = conn:getpeercertificate() + if conn.ssl_peercertificate then + cert = conn:ssl_peercertificate() end return module:fire_event("s2s-check-certificate", { host = host, session = session, cert = cert }); @@ -404,8 +417,7 @@ local function session_secure(session) session.secure = true; session.encrypted = true; - local sock = session.conn:socket(); - local info = sock.info and sock:info(); + local info = session.conn:ssl_info(); if type(info) == "table" then (session.log or log)("info", "Stream encrypted (%s with %s)", info.protocol, info.cipher); session.compressed = info.compression; @@ -434,7 +446,8 @@ function stream_callbacks._streamopened(session, attr) session.had_stream = true; -- Had a stream opened at least once -- TODO: Rename session.secure to session.encrypted - if session.secure == false then + if session.secure == false then -- Set by mod_tls during STARTTLS handshake + session.starttls = "completed"; session_secure(session); end @@ -522,6 +535,12 @@ function stream_callbacks._streamopened(session, attr) end if ( session.type == "s2sin" or session.type == "s2sout" ) or features.tags[1] then + if stanza_size_limit then + features:reset(); + features:tag("limits", { xmlns = "urn:xmpp:stream-limits:0" }) + :text_tag("max-size", string.format("%d", stanza_size_limit)):up(); + end + log("debug", "Sending stream features: %s", features); session.sends2s(features); else @@ -756,6 +775,7 @@ local function initialize_session(session) local w = conn.write; if conn:ssl() then + -- Direct TLS was used session_secure(session); end @@ -766,6 +786,11 @@ local function initialize_session(session) end if t then t = filter("bytes/out", tostring(t)); + if session.outgoing_stanza_size_limit and #t > session.outgoing_stanza_size_limit then + log("warn", "Attempt to send a stanza exceeding session limit of %dB (%dB)!", session.outgoing_stanza_size_limit, #t); + -- TODO Pass identifiable error condition back to allow appropriate handling + return false + end if t then return w(conn, t); end @@ -935,6 +960,16 @@ local function friendly_cert_error(session) --> string elseif cert_errors:contains("self signed certificate") then return "is self-signed"; end + + local chain_errors = set.new(session.cert_chain_errors[2]); + for i, e in pairs(session.cert_chain_errors) do + if i > 2 then chain_errors:add_list(e); end + end + if chain_errors:contains("certificate has expired") then + return "has an expired certificate chain"; + elseif chain_errors:contains("No matching DANE TLSA records") then + return "does not match any DANE TLSA records"; + end end return "is not trusted"; -- for some other reason elseif session.cert_identity_status == "invalid" then diff --git a/plugins/mod_s2s_auth_certs.lua b/plugins/mod_s2s_auth_certs.lua index 992ee934..f917b116 100644 --- a/plugins/mod_s2s_auth_certs.lua +++ b/plugins/mod_s2s_auth_certs.lua @@ -9,17 +9,19 @@ local measure_cert_statuses = module:metric("counter", "checked", "", "Certifica module:hook("s2s-check-certificate", function(event) local session, host, cert = event.session, event.host, event.cert; - local conn = session.conn:socket(); + local conn = session.conn; local log = session.log or log; + local secure_hostname = conn.extra and conn.extra.secure_hostname; + if not cert then log("warn", "No certificate provided by %s", host or "unknown host"); return; end local chain_valid, errors; - if conn.getpeerverification then - chain_valid, errors = conn:getpeerverification(); + if conn.ssl_peerverification then + chain_valid, errors = conn:ssl_peerverification(); else chain_valid, errors = false, { { "Chain verification not supported by this version of LuaSec" } }; end @@ -45,6 +47,14 @@ module:hook("s2s-check-certificate", function(event) end log("debug", "certificate identity validation result: %s", session.cert_identity_status); end + + -- Check for DNSSEC-signed SRV hostname + if secure_hostname and session.cert_identity_status ~= "valid" then + if cert_verify_identity(secure_hostname, "xmpp-server", cert) then + module:log("info", "Secure SRV name delegation %q -> %q", secure_hostname, host); + session.cert_identity_status = "valid" + end + end end measure_cert_statuses:with_labels(session.cert_chain_status or "unknown", session.cert_identity_status or "unknown"):add(1); end, 509); diff --git a/plugins/mod_s2s_bidi.lua b/plugins/mod_s2s_bidi.lua index addcd6e2..d5cd7f9a 100644 --- a/plugins/mod_s2s_bidi.lua +++ b/plugins/mod_s2s_bidi.lua @@ -25,7 +25,9 @@ module:hook_tag("http://etherx.jabber.org/streams", "features", function (sessio if bidi then session.incoming = true; session.log("debug", "Requesting bidirectional stream"); - session.sends2s(st.stanza("bidi", { xmlns = xmlns_bidi })); + local request_bidi = st.stanza("bidi", { xmlns = xmlns_bidi }); + module:fire_event("s2sout-stream-features", { origin = session, features = request_bidi }); + session.sends2s(request_bidi); end end end, 200); diff --git a/plugins/mod_saslauth.lua b/plugins/mod_saslauth.lua index ab863aa3..4c8858cb 100644 --- a/plugins/mod_saslauth.lua +++ b/plugins/mod_saslauth.lua @@ -52,7 +52,7 @@ local function handle_status(session, status, ret, err_msg) module:fire_event("authentication-failure", { session = session, condition = ret, text = err_msg }); session.sasl_handler = session.sasl_handler:clean_clone(); elseif status == "success" then - local ok, err = sm_make_authenticated(session, session.sasl_handler.username, session.sasl_handler.scope); + local ok, err = sm_make_authenticated(session, session.sasl_handler.username, session.sasl_handler.role); if ok then module:fire_event("authentication-success", { session = session }); session.sasl_handler = nil; @@ -242,7 +242,16 @@ module:hook("stanza/urn:ietf:params:xml:ns:xmpp-sasl:abort", function(event) end); local function tls_unique(self) - return self.userdata["tls-unique"]:getpeerfinished(); + return self.userdata["tls-unique"]:ssl_peerfinished(); +end + +local function tls_exporter(conn) + if not conn.ssl_exportkeyingmaterial then return end + return conn:ssl_exportkeyingmaterial("EXPORTER-Channel-Binding", 32, ""); +end + +local function sasl_tls_exporter(self) + return tls_exporter(self.userdata["tls-exporter"]); end local mechanisms_attr = { xmlns='urn:ietf:params:xml:ns:xmpp-sasl' }; @@ -258,22 +267,29 @@ module:hook("stream-features", function(event) end local sasl_handler = usermanager_get_sasl_handler(module.host, origin) origin.sasl_handler = sasl_handler; + local channel_bindings = set.new() if origin.encrypted then -- check whether LuaSec has the nifty binding to the function needed for tls-unique -- FIXME: would be nice to have this check only once and not for every socket if sasl_handler.add_cb_handler then - local socket = origin.conn:socket(); - local info = socket.info and socket:info(); - if info.protocol == "TLSv1.3" then + local info = origin.conn:ssl_info(); + if info and info.protocol == "TLSv1.3" then log("debug", "Channel binding 'tls-unique' undefined in context of TLS 1.3"); - elseif socket.getpeerfinished and socket:getpeerfinished() then + if tls_exporter(origin.conn) then + log("debug", "Channel binding 'tls-exporter' supported"); + sasl_handler:add_cb_handler("tls-exporter", sasl_tls_exporter); + channel_bindings:add("tls-exporter"); + end + elseif origin.conn.ssl_peerfinished and origin.conn:ssl_peerfinished() then log("debug", "Channel binding 'tls-unique' supported"); sasl_handler:add_cb_handler("tls-unique", tls_unique); + channel_bindings:add("tls-unique"); else log("debug", "Channel binding 'tls-unique' not supported (by LuaSec?)"); end sasl_handler["userdata"] = { - ["tls-unique"] = socket; + ["tls-unique"] = origin.conn; + ["tls-exporter"] = origin.conn; }; else log("debug", "Channel binding not supported by SASL handler"); @@ -306,6 +322,14 @@ module:hook("stream-features", function(event) mechanisms:tag("mechanism"):text(mechanism):up(); end features:add_child(mechanisms); + if not channel_bindings:empty() then + -- XXX XEP-0440 is Experimental + features:tag("sasl-channel-binding", {xmlns='urn:xmpp:sasl-cb:0'}) + for channel_binding in channel_bindings do + features:tag("channel-binding", {type=channel_binding}):up() + end + features:up(); + end return; end @@ -328,7 +352,7 @@ module:hook("stream-features", function(event) authmod, available_disabled); end - else + elseif not origin.full_jid then features:tag("bind", bind_attr):tag("required"):up():up(); features:tag("session", xmpp_session_attr):tag("optional"):up():up(); end diff --git a/plugins/mod_smacks.lua b/plugins/mod_smacks.lua index a4ff0444..c5ff6644 100644 --- a/plugins/mod_smacks.lua +++ b/plugins/mod_smacks.lua @@ -2,7 +2,7 @@ -- -- Copyright (C) 2010-2015 Matthew Wild -- Copyright (C) 2010 Waqas Hussain --- Copyright (C) 2012-2021 Kim Alvefur +-- Copyright (C) 2012-2022 Kim Alvefur -- Copyright (C) 2012 Thijs Alkemade -- Copyright (C) 2014 Florian Zeitz -- Copyright (C) 2016-2020 Thilo Molitor @@ -10,6 +10,7 @@ -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- +-- TODO unify sendq and smqueue local tonumber = tonumber; local tostring = tostring; @@ -83,6 +84,26 @@ local all_old_sessions = module:open_store("smacks_h"); local old_session_registry = module:open_store("smacks_h", "map"); local session_registry = module:shared "/*/smacks/resumption-tokens"; -- > user@host/resumption-token --> resource +local function registry_key(session, id) + return jid.join(session.username, session.host, id or session.resumption_token); +end + +local function track_session(session, id) + session_registry[registry_key(session, id)] = session; + session.resumption_token = id; +end + +local function save_old_session(session) + session_registry[registry_key(session)] = nil; + return old_session_registry:set(session.username, session.resumption_token, + { h = session.handled_stanza_count; t = os.time() }) +end + +local function clear_old_session(session, id) + session_registry[registry_key(session, id)] = nil; + return old_session_registry:set(session.username, id or session.resumption_token, nil) +end + local ack_errors = require"util.error".init("mod_smacks", xmlns_sm3, { head = { condition = "undefined-condition"; text = "Client acknowledged more stanzas than sent by server" }; tail = { condition = "undefined-condition"; text = "Client acknowledged less stanzas than already acknowledged" }; @@ -90,6 +111,16 @@ local ack_errors = require"util.error".init("mod_smacks", xmlns_sm3, { overflow = { condition = "resource-constraint", text = "Too many unacked stanzas remaining, session can't be resumed" } }); +local enable_errors = require "util.error".init("mod_smacks", xmlns_sm3, { + already_enabled = { condition = "unexpected-request", text = "Stream management is already enabled" }; + bind_required = { condition = "unexpected-request", text = "Client must bind a resource before enabling stream management" }; + unavailable = { condition = "service-unavailable", text = "Stream management is not available for this stream" }; + -- Resumption + expired = { condition = "item-not-found", text = "Session expired, and cannot be resumed" }; + already_bound = { condition = "unexpected-request", text = "Cannot resume another session after a resource is bound" }; + unknown_session = { condition = "item-not-found", text = "Unknown session" }; +}); + -- COMPAT note the use of compatibility wrapper in events (queue:table()) local function ack_delayed(session, stanza) @@ -104,18 +135,18 @@ local function ack_delayed(session, stanza) end local function can_do_smacks(session, advertise_only) - if session.smacks then return false, "unexpected-request", "Stream management is already enabled"; end + if session.smacks then return false, enable_errors.new("already_enabled"); end local session_type = session.type; if session.username then if not(advertise_only) and not(session.resource) then -- Fail unless we're only advertising sm - return false, "unexpected-request", "Client must bind a resource before enabling stream management"; + return false, enable_errors.new("bind_required"); end return true; elseif s2s_smacks and (session_type == "s2sin" or session_type == "s2sout") then return true; end - return false, "service-unavailable", "Stream management is not available for this stream"; + return false, enable_errors.new("unavailable"); end module:hook("stream-features", @@ -155,13 +186,12 @@ end local function request_ack(session, reason) local queue = session.outgoing_stanza_queue; - session.log("debug", "Sending <r> (inside timer, before send) from %s - #queue=%d", reason, queue:count_unacked()); + session.log("debug", "Sending <r> from %s - #queue=%d", reason, queue:count_unacked()); session.awaiting_ack = true; (session.sends2s or session.send)(st.stanza("r", { xmlns = session.smacks })) if session.destroyed then return end -- sending something can trigger destruction -- expected_h could be lower than this expression e.g. more stanzas added to the queue meanwhile) session.last_requested_h = queue:count_acked() + queue:count_unacked(); - session.log("debug", "Sending <r> (inside timer, after send) from %s - #queue=%d", reason, queue:count_unacked()); if not session.delayed_ack_timer then session.delayed_ack_timer = timer.add_task(delayed_ack_timeout, function() ack_delayed(session, nil); -- we don't know if this is the only new stanza in the queue @@ -180,7 +210,6 @@ local function outgoing_stanza_filter(stanza, session) -- supposed to be nil. -- However, when using mod_smacks with mod_websocket, then mod_websocket's -- stanzas/out filter can get called before this one and adds the xmlns. - if session.resending_unacked then return stanza end if not session.smacks then return stanza end local is_stanza = st.is_stanza(stanza) and (not stanza.attr.xmlns or stanza.attr.xmlns == 'jabber:client') @@ -234,8 +263,7 @@ module:hook("pre-session-close", function(event) if session.smacks == nil then return end if session.resumption_token then session.log("debug", "Revoking resumption token"); - session_registry[jid.join(session.username, session.host, session.resumption_token)] = nil; - old_session_registry:set(session.username, session.resumption_token, nil); + clear_old_session(session); session.resumption_token = nil; else session.log("debug", "Session not resumable"); @@ -274,17 +302,16 @@ local function wrap_session(session, resume) return session; end -function handle_enable(session, stanza, xmlns_sm) - local ok, err, err_text = can_do_smacks(session); +function do_enable(session, stanza) + local ok, err = can_do_smacks(session); if not ok then - session.log("warn", "Failed to enable smacks: %s", err_text); -- TODO: XEP doesn't say we can send error text, should it? - (session.sends2s or session.send)(st.stanza("failed", { xmlns = xmlns_sm }):tag(err, { xmlns = xmlns_errors})); - return true; + session.log("warn", "Failed to enable smacks: %s", err.text); -- TODO: XEP doesn't say we can send error text, should it? + return nil, err; end if session.username then local old_sessions, err = all_old_sessions:get(session.username); - module:log("debug", "Old sessions: %q", old_sessions) + session.log("debug", "Old sessions: %q", old_sessions) if old_sessions then local keep, count = {}, 0; for token, info in it.sorted_pairs(old_sessions, function(a, b) @@ -296,54 +323,73 @@ function handle_enable(session, stanza, xmlns_sm) end all_old_sessions:set(session.username, keep); elseif err then - module:log("error", "Unable to retrieve old resumption counters: %s", err); + session.log("error", "Unable to retrieve old resumption counters: %s", err); end end - module:log("debug", "Enabling stream management"); - session.smacks = xmlns_sm; - - wrap_session(session, false); - - local resume_max; local resume_token; local resume = stanza.attr.resume; if (resume == "true" or resume == "1") and session.username then -- resumption on s2s is not currently supported resume_token = new_id(); - session_registry[jid.join(session.username, session.host, resume_token)] = session; - session.resumption_token = resume_token; - resume_max = tostring(resume_timeout); end - (session.sends2s or session.send)(st.stanza("enabled", { xmlns = xmlns_sm, id = resume_token, resume = resume, max = resume_max })); + + return { + type = "enabled"; + id = resume_token; + resume_max = resume_token and tostring(resume_timeout) or nil; + session = session; + finish = function () + session.log("debug", "Enabling stream management"); + + session.smacks = stanza.attr.xmlns; + if resume_token then + track_session(session, resume_token); + end + wrap_session(session, false); + end; + }; +end + +function handle_enable(session, stanza, xmlns_sm) + local enabled, err = do_enable(session, stanza); + if not enabled then + (session.sends2s or session.send)(st.stanza("failed", { xmlns = xmlns_sm }):add_error(err)); + return true; + end + + (session.sends2s or session.send)(st.stanza("enabled", { + xmlns = xmlns_sm; + id = enabled.id; + resume = enabled.id and "true" or nil; -- COMPAT w/ Conversations 2.10.10 requires 'true' not '1' + max = enabled.resume_max; + })); + + session.smacks = xmlns_sm; + enabled.finish(); + return true; end module:hook_tag(xmlns_sm2, "enable", function (session, stanza) return handle_enable(session, stanza, xmlns_sm2); end, 100); module:hook_tag(xmlns_sm3, "enable", function (session, stanza) return handle_enable(session, stanza, xmlns_sm3); end, 100); -module:hook_tag("http://etherx.jabber.org/streams", "features", - function (session, stanza) - -- Needs to be done after flushing sendq since those aren't stored as - -- stanzas and counting them is weird. - -- TODO unify sendq and smqueue - timer.add_task(1e-6, function () - if can_do_smacks(session) then - if stanza:get_child("sm", xmlns_sm3) then - session.sends2s(st.stanza("enable", sm3_attr)); - session.smacks = xmlns_sm3; - elseif stanza:get_child("sm", xmlns_sm2) then - session.sends2s(st.stanza("enable", sm2_attr)); - session.smacks = xmlns_sm2; - else - return; - end - wrap_session_out(session, false); - end - end); - end); +module:hook_tag("http://etherx.jabber.org/streams", "features", function(session, stanza) + if can_do_smacks(session) then + session.smacks_feature = stanza:get_child("sm", xmlns_sm3) or stanza:get_child("sm", xmlns_sm2); + end +end); + +module:hook("s2sout-established", function (event) + local session = event.session; + if not session.smacks_feature then return end + + session.smacks = session.smacks_feature.attr.xmlns; + wrap_session_out(session, false); + session.sends2s(st.stanza("enable", { xmlns = session.smacks })); +end); function handle_enabled(session, stanza, xmlns_sm) -- luacheck: ignore 212/stanza - module:log("debug", "Enabling stream management"); + session.log("debug", "Enabling stream management"); session.smacks = xmlns_sm; wrap_session_in(session, false); @@ -357,10 +403,10 @@ module:hook_tag(xmlns_sm3, "enabled", function (session, stanza) return handle_e function handle_r(origin, stanza, xmlns_sm) -- luacheck: ignore 212/stanza if not origin.smacks then - module:log("debug", "Received ack request from non-smack-enabled session"); + origin.log("debug", "Received ack request from non-smack-enabled session"); return; end - module:log("debug", "Received ack request, acking for %d", origin.handled_stanza_count); + origin.log("debug", "Received ack request, acking for %d", origin.handled_stanza_count); -- Reply with <a> (origin.sends2s or origin.send)(st.stanza("a", { xmlns = xmlns_sm, h = format_h(origin.handled_stanza_count) })); -- piggyback our own ack request if needed (see request_ack_if_needed() for explanation of last_requested_h) @@ -413,13 +459,14 @@ local function handle_unacked_stanzas(session) local queue = session.outgoing_stanza_queue; local unacked = queue:count_unacked() if unacked > 0 then + local error_from = jid.join(session.username, session.host or module.host); tx_dropped_stanzas:sample(unacked); session.smacks = false; -- Disable queueing session.outgoing_stanza_queue = nil; for stanza in queue._queue:consume() do if not module:fire_event("delivery/failure", { session = session, stanza = stanza }) then if stanza.attr.type ~= "error" and stanza.attr.from ~= session.full_jid then - local reply = st.error_reply(stanza, "cancel", "recipient-unavailable"); + local reply = st.error_reply(stanza, "cancel", "recipient-unavailable", nil, error_from); module:send(reply); end end @@ -486,11 +533,8 @@ module:hook("pre-resource-unbind", function (event) end session.log("debug", "Destroying session for hibernating too long"); - session_registry[jid.join(session.username, session.host, session.resumption_token)] = nil; - old_session_registry:set(session.username, session.resumption_token, - { h = session.handled_stanza_count; t = os.time() }); + save_old_session(session); session.resumption_token = nil; - session.resending_unacked = true; -- stop outgoing_stanza_filter from re-queueing anything anymore sessionmanager.destroy_session(session, "Hibernating too long"); sessions_expired(1); end); @@ -524,131 +568,110 @@ end module:hook("s2sout-destroyed", handle_s2s_destroyed); module:hook("s2sin-destroyed", handle_s2s_destroyed); -local function get_session_id(session) - return session.id or (tostring(session):match("[a-f0-9]+$")); -end - -function handle_resume(session, stanza, xmlns_sm) +function do_resume(session, stanza) if session.full_jid then session.log("warn", "Tried to resume after resource binding"); - session.send(st.stanza("failed", { xmlns = xmlns_sm }) - :tag("unexpected-request", { xmlns = xmlns_errors }) - ); - return true; + return nil, enable_errors.new("already_bound"); end local id = stanza.attr.previd; - local original_session = session_registry[jid.join(session.username, session.host, id)]; + local original_session = session_registry[registry_key(session, id)]; if not original_session then local old_session = old_session_registry:get(session.username, id); if old_session then session.log("debug", "Tried to resume old expired session with id %s", id); - session.send(st.stanza("failed", { xmlns = xmlns_sm, h = format_h(old_session.h) }) - :tag("item-not-found", { xmlns = xmlns_errors }) - ); - old_session_registry:set(session.username, id, nil); + clear_old_session(session, id); resumption_expired(1); - else - session.log("debug", "Tried to resume non-existent session with id %s", id); - session.send(st.stanza("failed", { xmlns = xmlns_sm }) - :tag("item-not-found", { xmlns = xmlns_errors }) - ); - end; - else - if original_session.hibernating_watchdog then - original_session.log("debug", "Letting the watchdog go"); - original_session.hibernating_watchdog:cancel(); - original_session.hibernating_watchdog = nil; - elseif session.hibernating then - original_session.log("error", "Hibernating session has no watchdog!") - end - -- zero age = was not hibernating yet - local age = 0; - if original_session.hibernating then - local now = os_time(); - age = now - original_session.hibernating; - end - session.log("debug", "mod_smacks resuming existing session %s...", get_session_id(original_session)); - original_session.log("debug", "mod_smacks session resumed from %s...", get_session_id(session)); - -- TODO: All this should move to sessionmanager (e.g. session:replace(new_session)) - if original_session.conn then - original_session.log("debug", "mod_smacks closing an old connection for this session"); - local conn = original_session.conn; - c2s_sessions[conn] = nil; - conn:close(); + return nil, enable_errors.new("expired", { h = old_session.h }); end + session.log("debug", "Tried to resume non-existent session with id %s", id); + return nil, enable_errors.new("unknown_session"); + end - local migrated_session_log = session.log; - original_session.ip = session.ip; - original_session.conn = session.conn; - original_session.rawsend = session.rawsend; - original_session.rawsend.session = original_session; - original_session.rawsend.conn = original_session.conn; - original_session.send = session.send; - original_session.send.session = original_session; - original_session.close = session.close; - original_session.filter = session.filter; - original_session.filter.session = original_session; - original_session.filters = session.filters; - original_session.send.filter = original_session.filter; - original_session.stream = session.stream; - original_session.secure = session.secure; - original_session.hibernating = nil; - original_session.resumption_counter = (original_session.resumption_counter or 0) + 1; - session.log = original_session.log; - session.type = original_session.type; - wrap_session(original_session, true); - -- Inform xmppstream of the new session (passed to its callbacks) - original_session.stream:set_session(original_session); - -- Similar for connlisteners - c2s_sessions[session.conn] = original_session; - - local queue = original_session.outgoing_stanza_queue; - local h = tonumber(stanza.attr.h); - - original_session.log("debug", "Pre-resumption #queue = %d", queue:count_unacked()) - local acked, err = ack_errors.coerce(queue:ack(h)); -- luacheck: ignore 211/acked - - if not err and not queue:resumable() then - err = ack_errors.new("overflow"); - end + if original_session.hibernating_watchdog then + original_session.log("debug", "Letting the watchdog go"); + original_session.hibernating_watchdog:cancel(); + original_session.hibernating_watchdog = nil; + elseif session.hibernating then + original_session.log("error", "Hibernating session has no watchdog!") + end + -- zero age = was not hibernating yet + local age = 0; + if original_session.hibernating then + local now = os_time(); + age = now - original_session.hibernating; + end - if err or not queue:resumable() then - original_session.send(st.stanza("failed", - { xmlns = xmlns_sm; h = format_h(original_session.handled_stanza_count); previd = id })); - original_session:close(err); - return false; - end + session.log("debug", "mod_smacks resuming existing session %s...", original_session.id); - original_session.send(st.stanza("resumed", { xmlns = xmlns_sm, - h = format_h(original_session.handled_stanza_count), previd = id })); + local queue = original_session.outgoing_stanza_queue; + local h = tonumber(stanza.attr.h); - -- Ok, we need to re-send any stanzas that the client didn't see - -- ...they are what is now left in the outgoing stanza queue - -- We have to use the send of "session" because we don't want to add our resent stanzas - -- to the outgoing queue again + original_session.log("debug", "Pre-resumption #queue = %d", queue:count_unacked()) + local acked, err = ack_errors.coerce(queue:ack(h)); -- luacheck: ignore 211/acked - session.log("debug", "resending all unacked stanzas that are still queued after resume, #queue = %d", queue:count_unacked()); - -- FIXME Which session is it that the queue filter sees? - session.resending_unacked = true; - original_session.resending_unacked = true; - for _, queued_stanza in queue:resume() do - session.send(queued_stanza); - end - session.resending_unacked = nil; - original_session.resending_unacked = nil; - session.log("debug", "all stanzas resent, now disabling send() in this migrated session, #queue = %d", queue:count_unacked()); - function session.send(stanza) -- luacheck: ignore 432 - migrated_session_log("error", "Tried to send stanza on old session migrated by smacks resume (maybe there is a bug?): %s", tostring(stanza)); - return false; - end - module:fire_event("smacks-hibernation-end", {origin = session, resumed = original_session, queue = queue:table()}); - original_session.awaiting_ack = nil; -- Don't wait for acks from before the resumption - request_ack_now_if_needed(original_session, true, "handle_resume", nil); - resumption_age:sample(age); + if not err and not queue:resumable() then + err = ack_errors.new("overflow"); end + + if err then + session.log("debug", "Resumption failed: %s", err); + return nil, err; + end + + -- Update original_session with the parameters (connection, etc.) from the new session + sessionmanager.update_session(original_session, session); + + return { + type = "resumed"; + session = original_session; + id = id; + -- Return function to complete the resumption and resync unacked stanzas + -- This is two steps so we can support SASL2/ISR + finish = function () + -- Ok, we need to re-send any stanzas that the client didn't see + -- ...they are what is now left in the outgoing stanza queue + -- We have to use the send of "session" because we don't want to add our resent stanzas + -- to the outgoing queue again + + original_session.log("debug", "resending all unacked stanzas that are still queued after resume, #queue = %d", queue:count_unacked()); + for _, queued_stanza in queue:resume() do + original_session.send(queued_stanza); + end + original_session.log("debug", "all stanzas resent, enabling stream management on resumed stream, #queue = %d", queue:count_unacked()); + + -- Add our own handlers to the resumed session (filters have been reset in the update) + wrap_session(original_session, true); + + -- Let everyone know that we are no longer hibernating + module:fire_event("smacks-hibernation-end", {origin = session, resumed = original_session, queue = queue:table()}); + original_session.awaiting_ack = nil; -- Don't wait for acks from before the resumption + request_ack_now_if_needed(original_session, true, "handle_resume", nil); + resumption_age:sample(age); + end; + }; +end + +function handle_resume(session, stanza, xmlns_sm) + local resumed, err = do_resume(session, stanza); + if not resumed then + session.send(st.stanza("failed", { xmlns = xmlns_sm, h = format_h(err.context.h) }) + :tag(err.condition, { xmlns = xmlns_errors })); + return true; + end + + session = resumed.session; + + -- Inform client of successful resumption + session.send(st.stanza("resumed", { xmlns = xmlns_sm, + h = format_h(session.handled_stanza_count), previd = resumed.id })); + + -- Complete resume (sync stanzas, etc.) + resumed.finish(); + return true; end + module:hook_tag(xmlns_sm2, "resume", function (session, stanza) return handle_resume(session, stanza, xmlns_sm2); end); module:hook_tag(xmlns_sm3, "resume", function (session, stanza) return handle_resume(session, stanza, xmlns_sm3); end); @@ -703,8 +726,7 @@ module:hook_global("server-stopping", function(event) for _, user in pairs(local_sessions) do for _, session in pairs(user.sessions) do if session.resumption_token then - if old_session_registry:set(session.username, session.resumption_token, - { h = session.handled_stanza_count; t = os.time() }) then + if save_old_session(session) then session.resumption_token = nil; -- Deal with unacked stanzas diff --git a/plugins/mod_storage_sql.lua b/plugins/mod_storage_sql.lua index b3ed7638..86a64ff6 100644 --- a/plugins/mod_storage_sql.lua +++ b/plugins/mod_storage_sql.lua @@ -1,9 +1,11 @@ -- luacheck: ignore 212/self +local deps = require "util.dependencies"; local cache = require "util.cache"; local json = require "util.json"; -local sql = require "util.sql"; +local sqlite = deps.softreq "util.sqlite3"; +local dbisql = (sqlite and deps.softreq or require) "util.sql"; local xml_parse = require "util.xml".parse; local uuid = require "util.uuid"; local resolve_relative_path = require "util.paths".resolve_relative_path; @@ -13,7 +15,7 @@ local is_stanza = require"util.stanza".is_stanza; local t_concat = table.concat; local noop = function() end -local unpack = table.unpack or unpack; -- luacheck: ignore 113 +local unpack = table.unpack; local function iterator(result) return function(result_) local row = result_(); @@ -321,7 +323,8 @@ function archive_store:append(username, key, value, when, with) end end - when = when or os.time(); + -- FIXME update the schema to allow precision timestamps + when = when and math.floor(when) or os.time(); with = with or ""; local ok, ret = engine:transaction(function() local delete_sql = [[ @@ -354,12 +357,12 @@ end local function archive_where(query, args, where) -- Time range, inclusive if query.start then - args[#args+1] = query.start + args[#args+1] = math.floor(query.start); where[#where+1] = "\"when\" >= ?" end if query["end"] then - args[#args+1] = query["end"]; + args[#args+1] = math.floor(query["end"]); if query.start then where[#where] = "\"when\" BETWEEN ? AND ?" -- is this inclusive? else @@ -382,8 +385,7 @@ local function archive_where(query, args, where) -- Set of ids if query.ids then local nids, nargs = #query.ids, #args; - -- COMPAT Lua 5.1: No separator argument to string.rep - where[#where + 1] = "\"key\" IN (" .. string.rep("?,", nids):sub(1,-2) .. ")"; + where[#where + 1] = "\"key\" IN (" .. string.rep("?", nids, ",") .. ")"; for i, id in ipairs(query.ids) do args[nargs+i] = id; end @@ -692,6 +694,7 @@ end local function create_table(engine) -- luacheck: ignore 431/engine + local sql = engine.params.driver == "SQLite3" and sqlite or dbisql; local Table, Column, Index = sql.Table, sql.Column, sql.Index; local ProsodyTable = Table { @@ -732,6 +735,7 @@ end local function upgrade_table(engine, params, apply_changes) -- luacheck: ignore 431/engine local changes = false; if params.driver == "MySQL" then + local sql = dbisql; local success,err = engine:transaction(function() do local result = assert(engine:execute("SHOW COLUMNS FROM \"prosody\" WHERE \"Field\"='value' and \"Type\"='text'")); @@ -831,6 +835,7 @@ end function module.load() local engines = module:shared("/*/sql/connections"); local params = normalize_params(module:get_option("sql", default_params)); + local sql = params.driver == "SQLite3" and sqlite or dbisql; local db_uri = sql.db2uri(params); engine = engines[db_uri]; if not engine then @@ -853,8 +858,13 @@ function module.load() end end end + module:set_status("info", "Connected to " .. engine.params.driver); + end, function (engine) -- luacheck: ignore 431/engine + module:set_status("error", "Disconnected from " .. engine.params.driver); end); engines[sql.db2uri(params)] = engine; + else + module:set_status("info", "Using existing engine"); end module:provides("storage", driver); @@ -869,6 +879,7 @@ function module.command(arg) local uris = {}; for host in pairs(prosody.hosts) do -- luacheck: ignore 431/host local params = normalize_params(config.get(host, "sql") or default_params); + local sql = engine.params.driver == "SQLite3" and sqlite or dbisql; uris[sql.db2uri(params)] = params; end print("We will check and upgrade the following databases:\n"); @@ -884,6 +895,7 @@ function module.command(arg) -- Upgrade each one for _, params in pairs(uris) do print("Checking "..params.database.."..."); + local sql = params.driver == "SQLite3" and sqlite or dbisql; engine = sql:create_engine(params); upgrade_table(engine, params, true); end diff --git a/plugins/mod_storage_xep0227.lua b/plugins/mod_storage_xep0227.lua index 5c3cf7f6..079e49d8 100644 --- a/plugins/mod_storage_xep0227.lua +++ b/plugins/mod_storage_xep0227.lua @@ -2,7 +2,7 @@ local ipairs, pairs = ipairs, pairs; local setmetatable = setmetatable; local tostring = tostring; -local next, unpack = next, table.unpack or unpack; --luacheck: ignore 113/unpack +local next, unpack = next, table.unpack; local os_remove = os.remove; local io_open = io.open; local jid_bare = require "util.jid".bare; diff --git a/plugins/mod_time.lua b/plugins/mod_time.lua index 0cd5a4ea..c9197799 100644 --- a/plugins/mod_time.lua +++ b/plugins/mod_time.lua @@ -8,7 +8,7 @@ local st = require "util.stanza"; local datetime = require "util.datetime".datetime; -local legacy = require "util.datetime".legacy; +local now = require "util.time".now; -- XEP-0202: Entity Time @@ -18,23 +18,10 @@ local function time_handler(event) local origin, stanza = event.origin, event.stanza; origin.send(st.reply(stanza):tag("time", {xmlns="urn:xmpp:time"}) :tag("tzo"):text("+00:00"):up() -- TODO get the timezone in a platform independent fashion - :tag("utc"):text(datetime())); + :tag("utc"):text(datetime(now()))); return true; end module:hook("iq-get/bare/urn:xmpp:time:time", time_handler); module:hook("iq-get/host/urn:xmpp:time:time", time_handler); --- XEP-0090: Entity Time (deprecated) - -module:add_feature("jabber:iq:time"); - -local function legacy_time_handler(event) - local origin, stanza = event.origin, event.stanza; - origin.send(st.reply(stanza):tag("query", {xmlns="jabber:iq:time"}) - :tag("utc"):text(legacy())); - return true; -end - -module:hook("iq-get/bare/jabber:iq:time:query", legacy_time_handler); -module:hook("iq-get/host/jabber:iq:time:query", legacy_time_handler); diff --git a/plugins/mod_tls.lua b/plugins/mod_tls.lua index afc1653a..380effe3 100644 --- a/plugins/mod_tls.lua +++ b/plugins/mod_tls.lua @@ -80,6 +80,9 @@ end module:hook_global("config-reloaded", module.load); local function can_do_tls(session) + if session.secure then + return false; + end if session.conn and not session.conn.starttls then if not session.secure then session.log("debug", "Underlying connection does not support STARTTLS"); @@ -125,7 +128,15 @@ end); -- Hook <starttls/> module:hook("stanza/urn:ietf:params:xml:ns:xmpp-tls:starttls", function(event) local origin = event.origin; + origin.starttls = "requested"; if can_do_tls(origin) then + if origin.conn.block_reads then + -- we need to ensure that no data is read anymore, otherwise we could end up in a situation where + -- <proceed/> is sent and the socket receives the TLS handshake (and passes the data to lua) before + -- it is asked to initiate TLS + -- (not with the classical single-threaded server backends) + origin.conn:block_reads() + end (origin.sends2s or origin.send)(starttls_proceed); if origin.destroyed then return end origin:reset_stream(); @@ -166,6 +177,7 @@ module:hook_tag("http://etherx.jabber.org/streams", "features", function (sessio module:log("debug", "%s is not offering TLS", session.to_host); return; end + session.starttls = "initiated"; session.sends2s(starttls_initiate); return true; end @@ -183,7 +195,8 @@ module:hook_tag(xmlns_starttls, "proceed", function (session, stanza) -- luachec if session.type == "s2sout_unauthed" and can_do_tls(session) then module:log("debug", "Proceeding with TLS on s2sout..."); session:reset_stream(); - session.conn:starttls(session.ssl_ctx); + session.starttls = "proceeding" + session.conn:starttls(session.ssl_ctx, session.to_host); session.secure = false; return true; end diff --git a/plugins/mod_tokenauth.lua b/plugins/mod_tokenauth.lua index c04a1aa4..9cd73570 100644 --- a/plugins/mod_tokenauth.lua +++ b/plugins/mod_tokenauth.lua @@ -1,10 +1,19 @@ local id = require "util.id"; local jid = require "util.jid"; local base64 = require "util.encodings".base64; +local usermanager = require "core.usermanager"; +local generate_identifier = require "util.id".short; local token_store = module:open_store("auth_tokens", "map"); -function create_jid_token(actor_jid, token_jid, token_scope, token_ttl) +local function select_role(username, host, role) + if role then + return prosody.hosts[host].authz.get_role_by_name(role); + end + return usermanager.get_user_role(username, host); +end + +function create_jid_token(actor_jid, token_jid, token_role, token_ttl, token_data) token_jid = jid.prep(token_jid); if not actor_jid or token_jid ~= actor_jid and not jid.compare(token_jid, actor_jid) then return nil, "not-authorized"; @@ -21,13 +30,10 @@ function create_jid_token(actor_jid, token_jid, token_scope, token_ttl) created = os.time(); expires = token_ttl and (os.time() + token_ttl) or nil; jid = token_jid; - session = { - username = token_username; - host = token_host; - resource = token_resource; - auth_scope = token_scope; - }; + resource = token_resource; + role = token_role; + data = token_data; }; local token_id = id.long(); @@ -46,11 +52,7 @@ local function parse_token(encoded_token) return token_id, token_user, token_host; end -function get_token_info(token) - local token_id, token_user, token_host = parse_token(token); - if not token_id then - return nil, "invalid-token-format"; - end +local function _get_parsed_token_info(token_id, token_user, token_host) if token_host ~= module.host then return nil, "invalid-host"; end @@ -64,12 +66,47 @@ function get_token_info(token) end if token_info.expires and token_info.expires < os.time() then + token_store:set(token_user, token_id, nil); + return nil, "not-authorized"; + end + + local account_info = usermanager.get_account_info(token_user, module.host); + local password_updated_at = account_info and account_info.password_updated; + if password_updated_at and password_updated_at > token_info.created then + token_store:set(token_user, token_id, nil); return nil, "not-authorized"; end return token_info end +function get_token_info(token) + local token_id, token_user, token_host = parse_token(token); + if not token_id then + return nil, "invalid-token-format"; + end + return _get_parsed_token_info(token_id, token_user, token_host); +end + +function get_token_session(token, resource) + local token_id, token_user, token_host = parse_token(token); + if not token_id then + return nil, "invalid-token-format"; + end + + local token_info, err = _get_parsed_token_info(token_id, token_user, token_host); + if not token_info then return nil, err; end + + return { + username = token_user; + host = token_host; + resource = token_info.resource or resource or generate_identifier(); + + role = select_role(token_user, token_host, token_info.role); + }; +end + + function revoke_token(token) local token_id, token_user, token_host = parse_token(token); if not token_id then diff --git a/plugins/muc/hidden.lib.lua b/plugins/muc/hidden.lib.lua index 153df21a..087fa102 100644 --- a/plugins/muc/hidden.lib.lua +++ b/plugins/muc/hidden.lib.lua @@ -8,7 +8,7 @@ -- local restrict_public = not module:get_option_boolean("muc_room_allow_public", true); -local um_is_admin = require "core.usermanager".is_admin; +module:default_permission(restrict_public and "prosody:admin" or "prosody:user", ":create-public-room"); local function get_hidden(room) return room._data.hidden; @@ -22,8 +22,8 @@ local function set_hidden(room, hidden) end module:hook("muc-config-form", function(event) - if restrict_public and not um_is_admin(event.actor, module.host) then - -- Don't show option if public rooms are restricted and user is not admin of this host + if not module:may(":create-public-room", event.actor) then + -- Hide config option if this user is not allowed to create public rooms return; end table.insert(event.form, { @@ -36,7 +36,7 @@ module:hook("muc-config-form", function(event) end, 100-9); module:hook("muc-config-submitted/muc#roomconfig_publicroom", function(event) - if restrict_public and not um_is_admin(event.actor, module.host) then + if not module:may(":create-public-room", event.actor) then return; -- Not allowed end if set_hidden(event.room, not event.value) then diff --git a/plugins/muc/mod_muc.lua b/plugins/muc/mod_muc.lua index 5873b1a2..259423cf 100644 --- a/plugins/muc/mod_muc.lua +++ b/plugins/muc/mod_muc.lua @@ -89,7 +89,7 @@ room_mt.handle_register_iq = register.handle_register_iq; local presence_broadcast = module:require "muc/presence_broadcast"; room_mt.get_presence_broadcast = presence_broadcast.get; room_mt.set_presence_broadcast = presence_broadcast.set; -room_mt.get_valid_broadcast_roles = presence_broadcast.get_valid_broadcast_roles; +room_mt.get_valid_broadcast_roles = presence_broadcast.get_valid_broadcast_roles; -- FIXME doesn't exist in the library local occupant_id = module:require "muc/occupant_id"; room_mt.get_salt = occupant_id.get_room_salt; @@ -100,7 +100,6 @@ local jid_prep = require "util.jid".prep; local jid_bare = require "util.jid".bare; local st = require "util.stanza"; local cache = require "util.cache"; -local um_is_admin = require "core.usermanager".is_admin; module:require "muc/config_form_sections"; @@ -111,21 +110,23 @@ module:depends "muc_unique" module:require "muc/hats"; module:require "muc/lock"; -local function is_admin(jid) - return um_is_admin(jid, module.host); -end +module:default_permissions("prosody:admin", { + ":automatic-ownership"; + ":create-room"; + ":recreate-destroyed-room"; +}); if module:get_option_boolean("component_admins_as_room_owners", true) then -- Monkey patch to make server admins room owners local _get_affiliation = room_mt.get_affiliation; function room_mt:get_affiliation(jid) - if is_admin(jid) then return "owner"; end + if module:may(":automatic-ownership", jid) then return "owner"; end return _get_affiliation(self, jid); end local _set_affiliation = room_mt.set_affiliation; function room_mt:set_affiliation(actor, jid, affiliation, reason, data) - if affiliation ~= "owner" and is_admin(jid) then return nil, "modify", "not-acceptable"; end + if affiliation ~= "owner" and module:may(":automatic-ownership", jid) then return nil, "modify", "not-acceptable"; end return _set_affiliation(self, actor, jid, affiliation, reason, data); end end @@ -412,26 +413,15 @@ if module:get_option_boolean("muc_tombstones", true) then end, -10); end -do - local restrict_room_creation = module:get_option("restrict_room_creation"); - if restrict_room_creation == true then - restrict_room_creation = "admin"; - end - if restrict_room_creation then - local host_suffix = module.host:gsub("^[^%.]+%.", ""); - module:hook("muc-room-pre-create", function(event) - local origin, stanza = event.origin, event.stanza; - local user_jid = stanza.attr.from; - if not is_admin(user_jid) and not ( - restrict_room_creation == "local" and - select(2, jid_split(user_jid)) == host_suffix - ) then - origin.send(st.error_reply(stanza, "cancel", "not-allowed", "Room creation is restricted", module.host)); - return true; - end - end); +local restrict_room_creation = module:get_option("restrict_room_creation"); +module:default_permission(restrict_room_creation == true and "prosody:admin" or "prosody:user", ":create-room"); +module:hook("muc-room-pre-create", function(event) + local origin, stanza = event.origin, event.stanza; + if restrict_room_creation ~= false and not module:may(":create-room", event) then + origin.send(st.error_reply(stanza, "cancel", "not-allowed", "Room creation is restricted", module.host)); + return true; end -end +end); for event_name, method in pairs { -- Normal room interactions @@ -465,7 +455,7 @@ for event_name, method in pairs { if room and room._data.destroyed then if room._data.locked < os.time() - or (is_admin(stanza.attr.from) and stanza.name == "presence" and stanza.attr.type == nil) then + or (module:may(":recreate-destroyed-room", event) and stanza.name == "presence" and stanza.attr.type == nil) then -- Allow the room to be recreated by admin or after time has passed delete_room(room); room = nil; diff --git a/plugins/muc/persistent.lib.lua b/plugins/muc/persistent.lib.lua index c3b16ea4..4c753921 100644 --- a/plugins/muc/persistent.lib.lua +++ b/plugins/muc/persistent.lib.lua @@ -8,7 +8,10 @@ -- local restrict_persistent = not module:get_option_boolean("muc_room_allow_persistent", true); -local um_is_admin = require "core.usermanager".is_admin; +module:default_permission( + restrict_persistent and "prosody:admin" or "prosody:user", + ":create-persistent-room" +); local function get_persistent(room) return room._data.persistent; @@ -22,8 +25,8 @@ local function set_persistent(room, persistent) end module:hook("muc-config-form", function(event) - if restrict_persistent and not um_is_admin(event.actor, module.host) then - -- Don't show option if hidden rooms are restricted and user is not admin of this host + if not module:may(":create-persistent-room", event.actor) then + -- Hide config option if this user is not allowed to create persistent rooms return; end table.insert(event.form, { @@ -36,7 +39,7 @@ module:hook("muc-config-form", function(event) end, 100-5); module:hook("muc-config-submitted/muc#roomconfig_persistentroom", function(event) - if restrict_persistent and not um_is_admin(event.actor, module.host) then + if not module:may(":create-persistent-room", event.actor) then return; -- Not allowed end if set_persistent(event.room, event.value) then @@ -44,6 +44,12 @@ if CFG_DATADIR then end +-- Check before first require, to preempt the probable failure +if _VERSION < "Lua 5.2" then + io.stderr:write("Prosody is no longer compatible with Lua 5.1\n") + io.stderr:write("See https://prosody.im/doc/depends#lua for more information\n") + return os.exit(1); +end local startup = require "util.startup"; local async = require "util.async"; @@ -44,6 +44,13 @@ end ----------- +-- Check before first require, to preempt the probable failure +if _VERSION < "Lua 5.2" then + io.stderr:write("Prosody is no longer compatible with Lua 5.1\n") + io.stderr:write("See https://prosody.im/doc/depends#lua for more information\n") + return os.exit(1); +end + local startup = require "util.startup"; startup.prosodyctl(); @@ -573,7 +580,7 @@ function commands.reload(arg) end -- ejabberdctl compatibility -local unpack = table.unpack or unpack; -- luacheck: ignore 113 +local unpack = table.unpack; function commands.register(arg) local user, host, password = unpack(arg); diff --git a/spec/core_storagemanager_spec.lua b/spec/core_storagemanager_spec.lua index ae4f44c8..04acb1dd 100644 --- a/spec/core_storagemanager_spec.lua +++ b/spec/core_storagemanager_spec.lua @@ -1,4 +1,4 @@ -local unpack = table.unpack or unpack; -- luacheck: ignore 113 +local unpack = table.unpack; local server = require "net.server_select"; package.loaded["net.server"] = server; diff --git a/spec/inputs/test_keys.lua b/spec/inputs/test_keys.lua new file mode 100644 index 00000000..e0e9ff8c --- /dev/null +++ b/spec/inputs/test_keys.lua @@ -0,0 +1,179 @@ +local test_keys = { + -- ECDSA keypair from jwt.io + ecdsa_private_pem = [[ +-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgevZzL1gdAFr88hb2 +OF/2NxApJCzGCEDdfSp6VQO30hyhRANCAAQRWz+jn65BtOMvdyHKcvjBeBSDZH2r +1RTwjmYSi9R/zpBnuQ4EiMnCqfMPWiZqB4QdbAd0E7oH50VpuZ1P087G +-----END PRIVATE KEY----- +]]; + + ecdsa_public_pem = [[ +-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEEVs/o5+uQbTjL3chynL4wXgUg2R9 +q9UU8I5mEovUf86QZ7kOBIjJwqnzD1omageEHWwHdBO6B+dFabmdT9POxg== +-----END PUBLIC KEY----- +]]; + + -- Self-generated ECDSA keypair + alt_ecdsa_private_pem = [[ +-----BEGIN PRIVATE KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgQnn4AHz2Zy+JMAgp +AZfKAm9F3s6791PstPf5XjHtETKhRANCAAScv9jI3+BOXXlCOXwmQYosIbl9mf4V +uOwfIoCYSLylAghyxO0n2of8Kji+D+4C1zxNKmZIQa4s8neaIIzXnMY1 +-----END PRIVATE KEY----- +]]; + + alt_ecdsa_public_pem = [[ +-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEnL/YyN/gTl15Qjl8JkGKLCG5fZn+ +FbjsHyKAmEi8pQIIcsTtJ9qH/Co4vg/uAtc8TSpmSEGuLPJ3miCM15zGNQ== +-----END PUBLIC KEY----- +]]; + + -- JWT reference keys for ES512 + + ecdsa_521_public_pem = [[ +-----BEGIN PUBLIC KEY----- +MIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQBgc4HZz+/fBbC7lmEww0AO3NK9wVZ +PDZ0VEnsaUFLEYpTzb90nITtJUcPUbvOsdZIZ1Q8fnbquAYgxXL5UgHMoywAib47 +6MkyyYgPk0BXZq3mq4zImTRNuaU9slj9TVJ3ScT3L1bXwVuPJDzpr5GOFpaj+WwM +Al8G7CqwoJOsW7Kddns= +-----END PUBLIC KEY----- +]]; + + ecdsa_521_private_pem = [[ +-----BEGIN PRIVATE KEY----- +MIHuAgEAMBAGByqGSM49AgEGBSuBBAAjBIHWMIHTAgEBBEIBiyAa7aRHFDCh2qga +9sTUGINE5jHAFnmM8xWeT/uni5I4tNqhV5Xx0pDrmCV9mbroFtfEa0XVfKuMAxxf +Z6LM/yKhgYkDgYYABAGBzgdnP798FsLuWYTDDQA7c0r3BVk8NnRUSexpQUsRilPN +v3SchO0lRw9Ru86x1khnVDx+duq4BiDFcvlSAcyjLACJvjvoyTLJiA+TQFdmrear +jMiZNE25pT2yWP1NUndJxPcvVtfBW48kPOmvkY4WlqP5bAwCXwbsKrCgk6xbsp12 +ew== +-----END PRIVATE KEY----- +]]; + + -- Self-generated keys for ES512 + + alt_ecdsa_521_public_pem = [[ +-----BEGIN PUBLIC KEY----- +MIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQBIxV0ecG/+qFc/kVPKs8Z6tjJEuRe +dzrEaqABY6THu7BhCjEoxPr6iRYdiFPzNruFORsCAKf/NFLSoCqyrw9S0YMA1xc+ +uW01145oxT7Sp8BOH1MyOh7xNh+LFLi6X4lV6j5GQrM1sKSa3O5m0+VJmLy5b7cy +oxNCzXrnEByz+EO2nYI= +-----END PUBLIC KEY----- +]]; + + alt_ecdsa_521_private_pem = [[ +-----BEGIN EC PRIVATE KEY----- +MIHcAgEBBEIAV2XJQ4/5Pa5m43/AJdL4XzrRV/l7eQ1JObqmI95YDs3zxM5Mfygz +DivhvuPdZCZUR+TdZQEdYN4LpllCzrDwmTCgBwYFK4EEACOhgYkDgYYABAEjFXR5 +wb/6oVz+RU8qzxnq2MkS5F53OsRqoAFjpMe7sGEKMSjE+vqJFh2IU/M2u4U5GwIA +p/80UtKgKrKvD1LRgwDXFz65bTXXjmjFPtKnwE4fUzI6HvE2H4sUuLpfiVXqPkZC +szWwpJrc7mbT5UmYvLlvtzKjE0LNeucQHLP4Q7adgg== +-----END EC PRIVATE KEY----- +]]; + + -- Self-generated EdDSA (Ed25519) keypair + eddsa_private_pem = [[ +-----BEGIN PRIVATE KEY----- +MC4CAQAwBQYDK2VwBCIEIOmrajEfnqdzdJzkJ4irQMCGbYRqrl0RlwPHIw+a5b7M +-----END PRIVATE KEY----- +]]; + + eddsa_public_pem = [[ +-----BEGIN PUBLIC KEY----- +MCowBQYDK2VwAyEAFipbSXeGvPVK7eA4+hIOdutZTUUyXswVSbMGi0j1QKE= +-----END PUBLIC KEY----- +]]; + + -- RSA keypair from jwt.io + rsa_private_pem = [[ +-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQC7VJTUt9Us8cKj +MzEfYyjiWA4R4/M2bS1GB4t7NXp98C3SC6dVMvDuictGeurT8jNbvJZHtCSuYEvu +NMoSfm76oqFvAp8Gy0iz5sxjZmSnXyCdPEovGhLa0VzMaQ8s+CLOyS56YyCFGeJZ +qgtzJ6GR3eqoYSW9b9UMvkBpZODSctWSNGj3P7jRFDO5VoTwCQAWbFnOjDfH5Ulg +p2PKSQnSJP3AJLQNFNe7br1XbrhV//eO+t51mIpGSDCUv3E0DDFcWDTH9cXDTTlR +ZVEiR2BwpZOOkE/Z0/BVnhZYL71oZV34bKfWjQIt6V/isSMahdsAASACp4ZTGtwi +VuNd9tybAgMBAAECggEBAKTmjaS6tkK8BlPXClTQ2vpz/N6uxDeS35mXpqasqskV +laAidgg/sWqpjXDbXr93otIMLlWsM+X0CqMDgSXKejLS2jx4GDjI1ZTXg++0AMJ8 +sJ74pWzVDOfmCEQ/7wXs3+cbnXhKriO8Z036q92Qc1+N87SI38nkGa0ABH9CN83H +mQqt4fB7UdHzuIRe/me2PGhIq5ZBzj6h3BpoPGzEP+x3l9YmK8t/1cN0pqI+dQwY +dgfGjackLu/2qH80MCF7IyQaseZUOJyKrCLtSD/Iixv/hzDEUPfOCjFDgTpzf3cw +ta8+oE4wHCo1iI1/4TlPkwmXx4qSXtmw4aQPz7IDQvECgYEA8KNThCO2gsC2I9PQ +DM/8Cw0O983WCDY+oi+7JPiNAJwv5DYBqEZB1QYdj06YD16XlC/HAZMsMku1na2T +N0driwenQQWzoev3g2S7gRDoS/FCJSI3jJ+kjgtaA7Qmzlgk1TxODN+G1H91HW7t +0l7VnL27IWyYo2qRRK3jzxqUiPUCgYEAx0oQs2reBQGMVZnApD1jeq7n4MvNLcPv +t8b/eU9iUv6Y4Mj0Suo/AU8lYZXm8ubbqAlwz2VSVunD2tOplHyMUrtCtObAfVDU +AhCndKaA9gApgfb3xw1IKbuQ1u4IF1FJl3VtumfQn//LiH1B3rXhcdyo3/vIttEk +48RakUKClU8CgYEAzV7W3COOlDDcQd935DdtKBFRAPRPAlspQUnzMi5eSHMD/ISL +DY5IiQHbIH83D4bvXq0X7qQoSBSNP7Dvv3HYuqMhf0DaegrlBuJllFVVq9qPVRnK +xt1Il2HgxOBvbhOT+9in1BzA+YJ99UzC85O0Qz06A+CmtHEy4aZ2kj5hHjECgYEA +mNS4+A8Fkss8Js1RieK2LniBxMgmYml3pfVLKGnzmng7H2+cwPLhPIzIuwytXywh +2bzbsYEfYx3EoEVgMEpPhoarQnYPukrJO4gwE2o5Te6T5mJSZGlQJQj9q4ZB2Dfz +et6INsK0oG8XVGXSpQvQh3RUYekCZQkBBFcpqWpbIEsCgYAnM3DQf3FJoSnXaMhr +VBIovic5l0xFkEHskAjFTevO86Fsz1C2aSeRKSqGFoOQ0tmJzBEs1R6KqnHInicD +TQrKhArgLXX4v3CddjfTRJkFWDbE/CkvKZNOrcf1nhaGCPspRJj2KUkj1Fhl9Cnc +dn/RsYEONbwQSjIfMPkvxF+8HQ== +-----END PRIVATE KEY----- +]]; + + rsa_public_pem = [[ +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAu1SU1LfVLPHCozMxH2Mo +4lgOEePzNm0tRgeLezV6ffAt0gunVTLw7onLRnrq0/IzW7yWR7QkrmBL7jTKEn5u ++qKhbwKfBstIs+bMY2Zkp18gnTxKLxoS2tFczGkPLPgizskuemMghRniWaoLcyeh +kd3qqGElvW/VDL5AaWTg0nLVkjRo9z+40RQzuVaE8AkAFmxZzow3x+VJYKdjykkJ +0iT9wCS0DRTXu269V264Vf/3jvredZiKRkgwlL9xNAwxXFg0x/XFw005UWVRIkdg +cKWTjpBP2dPwVZ4WWC+9aGVd+Gyn1o0CLelf4rEjGoXbAAEgAqeGUxrcIlbjXfbc +mwIDAQAB +-----END PUBLIC KEY----- +]]; + + + -- Self-generated RSA keypair + alt_rsa_private_pem = [[ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEA4bt6kor2TomqRXfjCFe6T42ibatloyHntZCUdlDDAkUh4oJ/ +4jDCXAUMYqmEsZKCPXxUGQgrmSmNnJPEDMTq3XLDsjhyN4stxEi0UVAiqqBkcEnk +qbQIJSc9v5gpQF8IuJFWRvSNic0uClFL5W9R2s5AHcOhdFYKeDuitqHT5r+dC7cy +WZs5YleKaESxmK6i6wMVhL9adAilTuETyMH0yLSh+aXsPYhjns4AbjGmiKOjqd5w +sPwllEg6rGcIUi/o79z9HN8yLMXq3XNFCCA8RI4Zh3cADI1I5fe6wk1ETN+30cDw +dGQ+uQbaQrzqmKVRNjZcorMwBjsOX5AMQBFx7wIDAQABAoIBAGxj5pZpTZ4msnEL +ASQnY9oBS4ZXr8UmaamgU/mADDOR2JR4T0ngWeNvtSPG/GV70TgO9B7U8oJoFoyh +05jCEXjmO5vfSNDs7rv6oUMONKczvybABKGMRgD5F8hhGyXCvGBLwV7u3OvXbw0b +PlNcIbTsJpNkNam0CvDyyc3iZOq+HjIqituREV7lDw0rFeAR2YfEWn4VjZsQRZUZ +XkpQJ5silrXgGemIEGqVA4YyM7i2HmTiLozfVYaVckMc02VFgOaoK9Z/wGlBxtS5 +evc/IGErSA4dc7uXBEeVjhtZoBkof2JV9BNt4hl4KN9wX3tkEX5Aq1K2lirSmg2r +k+UEtwkCgYEA/5uYg25OR+jCFY/7uNS8e32Re1lgDeO+TeT1m+hcF1gCb2GBLifL +yprnuytaz1/mPqawfwbilaxntLBoa5cmNKB3zDsgv4sM451yGZ0oxU0dXpDVHblu +3nhxcaOXtb8jiSsr2MqgMbFlu7m8OupIliS+s8Pq72s6HUQQRKbJ+9MCgYEA4hQl +1W/7nDI2SR4Q3UapQnaUjmDVxX5OD+E4RpKuRF6xF7Ao2CLZusMVo8WN8YiSQP2c +RnzQNKgAVy/1zlhaaQDTs2TmSy9iStbuNZ8P+Gh6kmQXuHxwPyURSmwdpgZdL3+D +8tt6pQNQ0vsLjA9VwHmzIT+rsxPmTxKNvBdNK/UCgYByP6zqyioJMDtYAfRkiAn7 +NIQLW0Z4ztvn2zgAyNoowPjNqgpgg/8t/xEm8tjzKg0y4bSwAnbSqa3s8JCrznKQ +QU1qpt8bXl6TenNeiYWIstA2zYvEbnbkz3b9cT7FSLrse7RsgR0bOQyc3QcKWl+5 +ZJEsrpxbCVV/cUXIObi8awKBgQDOI8rfk+0bXhlrkBOWf/CjnpYUQK2LF4C8MALt +Lp/hzWmyjLihYx2eknUv0Fl966ZXxidxiisaaDlvRlbeIGfHqK5fu9fUpE7+qH2p +vPCF81YYF1YdrLF4kiby8iQSl2juf1nj3kY1IhHXXnsH6Y+qIg24emLntXRhkyxT +XffK5QKBgGbzEvVgDkerw1SiefAaZnLumJJXBlKjJ00Sq8YLeViyFC/sr4EfG/cV +7VYRhBw3e7RcYSBAA7uv8i3iIeCFjFooIZUARqXk4+yW753tY5nSJTWfkR7Bp5Pa +9jKloxckbZKMjH23a+ABOxomY3l93KOBvjLvMYqccuREOwaT12cn +-----END RSA PRIVATE KEY----- +]]; + + alt_rsa_public_pem = [[ +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4bt6kor2TomqRXfjCFe6 +T42ibatloyHntZCUdlDDAkUh4oJ/4jDCXAUMYqmEsZKCPXxUGQgrmSmNnJPEDMTq +3XLDsjhyN4stxEi0UVAiqqBkcEnkqbQIJSc9v5gpQF8IuJFWRvSNic0uClFL5W9R +2s5AHcOhdFYKeDuitqHT5r+dC7cyWZs5YleKaESxmK6i6wMVhL9adAilTuETyMH0 +yLSh+aXsPYhjns4AbjGmiKOjqd5wsPwllEg6rGcIUi/o79z9HN8yLMXq3XNFCCA8 +RI4Zh3cADI1I5fe6wk1ETN+30cDwdGQ+uQbaQrzqmKVRNjZcorMwBjsOX5AMQBFx +7wIDAQAB +-----END PUBLIC KEY----- +]]; +}; + +return test_keys; diff --git a/spec/net_resolvers_service_spec.lua b/spec/net_resolvers_service_spec.lua new file mode 100644 index 00000000..53ce4754 --- /dev/null +++ b/spec/net_resolvers_service_spec.lua @@ -0,0 +1,241 @@ +local set = require "util.set"; + +insulate("net.resolvers.service", function () + local adns = { + resolver = function () + return { + lookup = function (_, cb, qname, qtype, qclass) + if qname == "_xmpp-server._tcp.example.com" + and (qtype or "SRV") == "SRV" + and (qclass or "IN") == "IN" then + cb({ + { -- 60+35+60 + srv = { target = "xmpp0-a.example.com", port = 5228, priority = 0, weight = 60 }; + }; + { + srv = { target = "xmpp0-b.example.com", port = 5216, priority = 0, weight = 35 }; + }; + { + srv = { target = "xmpp0-c.example.com", port = 5200, priority = 0, weight = 0 }; + }; + { + srv = { target = "xmpp0-d.example.com", port = 5256, priority = 0, weight = 120 }; + }; + + { + srv = { target = "xmpp1-a.example.com", port = 5273, priority = 1, weight = 30 }; + }; + { + srv = { target = "xmpp1-b.example.com", port = 5274, priority = 1, weight = 30 }; + }; + + { + srv = { target = "xmpp2.example.com", port = 5275, priority = 2, weight = 0 }; + }; + }); + elseif qname == "_xmpp-server._tcp.single.example.com" + and (qtype or "SRV") == "SRV" + and (qclass or "IN") == "IN" then + cb({ + { + srv = { target = "xmpp0-a.example.com", port = 5269, priority = 0, weight = 0 }; + }; + }); + elseif qname == "_xmpp-server._tcp.half.example.com" + and (qtype or "SRV") == "SRV" + and (qclass or "IN") == "IN" then + cb({ + { + srv = { target = "xmpp0-a.example.com", port = 5269, priority = 0, weight = 0 }; + }; + { + srv = { target = "xmpp0-b.example.com", port = 5270, priority = 0, weight = 1 }; + }; + }); + elseif qtype == "A" then + local l = qname:match("%-(%a)%.example.com$") or "1"; + local d = ("%d"):format(l:byte()) + cb({ + { + a = "127.0.0."..d; + }; + }); + elseif qtype == "AAAA" then + local l = qname:match("%-(%a)%.example.com$") or "1"; + local d = ("%04d"):format(l:byte()) + cb({ + { + aaaa = "fdeb:9619:649e:c7d9::"..d; + }; + }); + else + cb(nil); + end + end; + }; + end; + }; + package.loaded["net.adns"] = mock(adns); + local resolver = require "net.resolvers.service"; + math.randomseed(os.time()); + it("works for 99% of deployments", function () + -- Most deployments only have a single SRV record, let's make + -- sure that works okay + + local expected_targets = set.new({ + -- xmpp0-a + "tcp4 127.0.0.97 5269"; + "tcp6 fdeb:9619:649e:c7d9::0097 5269"; + }); + local received_targets = set.new({}); + + local r = resolver.new("single.example.com", "xmpp-server"); + local done = false; + local function handle_target(...) + if ... == nil then + done = true; + -- No more targets + return; + end + received_targets:add(table.concat({ ... }, " ", 1, 3)); + end + r:next(handle_target); + while not done do + r:next(handle_target); + end + + -- We should have received all expected targets, and no unexpected + -- ones: + assert.truthy(set.xor(received_targets, expected_targets):empty()); + end); + + it("supports A/AAAA fallback", function () + -- Many deployments don't have any SRV records, so we should + -- fall back to A/AAAA records instead when that is the case + + local expected_targets = set.new({ + -- xmpp0-a + "tcp4 127.0.0.97 5269"; + "tcp6 fdeb:9619:649e:c7d9::0097 5269"; + }); + local received_targets = set.new({}); + + local r = resolver.new("xmpp0-a.example.com", "xmpp-server", "tcp", { default_port = 5269 }); + local done = false; + local function handle_target(...) + if ... == nil then + done = true; + -- No more targets + return; + end + received_targets:add(table.concat({ ... }, " ", 1, 3)); + end + r:next(handle_target); + while not done do + r:next(handle_target); + end + + -- We should have received all expected targets, and no unexpected + -- ones: + assert.truthy(set.xor(received_targets, expected_targets):empty()); + end); + + + it("works", function () + local expected_targets = set.new({ + -- xmpp0-a + "tcp4 127.0.0.97 5228"; + "tcp6 fdeb:9619:649e:c7d9::0097 5228"; + "tcp4 127.0.0.97 5273"; + "tcp6 fdeb:9619:649e:c7d9::0097 5273"; + + -- xmpp0-b + "tcp4 127.0.0.98 5274"; + "tcp6 fdeb:9619:649e:c7d9::0098 5274"; + "tcp4 127.0.0.98 5216"; + "tcp6 fdeb:9619:649e:c7d9::0098 5216"; + + -- xmpp0-c + "tcp4 127.0.0.99 5200"; + "tcp6 fdeb:9619:649e:c7d9::0099 5200"; + + -- xmpp0-d + "tcp4 127.0.0.100 5256"; + "tcp6 fdeb:9619:649e:c7d9::0100 5256"; + + -- xmpp2 + "tcp4 127.0.0.49 5275"; + "tcp6 fdeb:9619:649e:c7d9::0049 5275"; + + }); + local received_targets = set.new({}); + + local r = resolver.new("example.com", "xmpp-server"); + local done = false; + local function handle_target(...) + if ... == nil then + done = true; + -- No more targets + return; + end + received_targets:add(table.concat({ ... }, " ", 1, 3)); + end + r:next(handle_target); + while not done do + r:next(handle_target); + end + + -- We should have received all expected targets, and no unexpected + -- ones: + assert.truthy(set.xor(received_targets, expected_targets):empty()); + end); + + it("balances across weights correctly #slow", function () + -- This mimics many repeated connections to 'example.com' (mock + -- records defined above), and records the port number of the + -- first target. Therefore it (should) only return priority + -- 0 records, and the input data is constructed such that the + -- last two digits of the port number represent the percentage + -- of times that record should (on average) be picked first. + + -- To prevent random test failures, we test across a handful + -- of fixed (randomly selected) seeds. + for _, seed in ipairs({ 8401877, 3943829, 7830992 }) do + math.randomseed(seed); + + local results = {}; + local function run() + local run_results = {}; + local r = resolver.new("example.com", "xmpp-server"); + local function record_target(...) + if ... == nil then + -- No more targets + return; + end + run_results = { ... }; + end + r:next(record_target); + return run_results[3]; + end + + for _ = 1, 1000 do + local port = run(); + results[port] = (results[port] or 0) + 1; + end + + local ports = {}; + for port in pairs(results) do + table.insert(ports, port); + end + table.sort(ports); + for _, port in ipairs(ports) do + --print("PORT", port, tostring((results[port]/1000) * 100).."% hits (expected "..tostring(port-5200).."%)"); + local hit_pct = (results[port]/1000) * 100; + local expected_pct = port - 5200; + --print(hit_pct, expected_pct, math.abs(hit_pct - expected_pct)); + assert.is_true(math.abs(hit_pct - expected_pct) < 5); + end + --print("---"); + end + end); +end); diff --git a/spec/scansion/mam_extended.scs b/spec/scansion/mam_extended.scs index 2c6840df..70897737 100644 --- a/spec/scansion/mam_extended.scs +++ b/spec/scansion/mam_extended.scs @@ -45,8 +45,8 @@ Romeo sends: Romeo receives: <iq type="result" id="mamextmeta"> <metadata xmlns="urn:xmpp:mam:2"> - <start timestamp="2008-08-22T21:09:04Z" xmlns="urn:xmpp:mam:2" id="{scansion:any}"/> - <end timestamp="2008-08-22T21:09:04Z" xmlns="urn:xmpp:mam:2" id="{scansion:any}"/> + <start timestamp="2008-08-22T21:09:04.500000Z" xmlns="urn:xmpp:mam:2" id="{scansion:any}"/> + <end timestamp="2008-08-22T21:09:04.500000Z" xmlns="urn:xmpp:mam:2" id="{scansion:any}"/> </metadata> </iq> @@ -59,7 +59,7 @@ Romeo receives: <message to="${Romeo's full JID}"> <result xmlns="urn:xmpp:mam:2" queryid="q1" id="{scansion:any}"> <forwarded xmlns="urn:xmpp:forward:0"> - <delay stamp="2008-08-22T21:09:04Z" xmlns="urn:xmpp:delay"/> + <delay stamp="2008-08-22T21:09:04.500000Z" xmlns="urn:xmpp:delay"/> <message to="someone@localhost" xmlns="jabber:client" type="chat" xml:lang="en" id="chat01" from="${Romeo's full JID}"> <body>Hello</body> </message> @@ -71,7 +71,7 @@ Romeo receives: <message to="${Romeo's full JID}"> <result xmlns="urn:xmpp:mam:2" queryid="q1" id="{scansion:any}"> <forwarded xmlns="urn:xmpp:forward:0"> - <delay stamp="2008-08-22T21:09:04Z" xmlns="urn:xmpp:delay"/> + <delay stamp="2008-08-22T21:09:04.500000Z" xmlns="urn:xmpp:delay"/> <message to="someone@localhost" xmlns="jabber:client" type="chat" xml:lang="en" id="chat02" from="${Romeo's full JID}"> <body>U there?</body> </message> @@ -98,7 +98,7 @@ Romeo receives: <message to="${Romeo's full JID}"> <result xmlns="urn:xmpp:mam:2" queryid="q1" id="{scansion:any}"> <forwarded xmlns="urn:xmpp:forward:0"> - <delay stamp="2008-08-22T21:09:04Z" xmlns="urn:xmpp:delay"/> + <delay stamp="2008-08-22T21:09:04.500000Z" xmlns="urn:xmpp:delay"/> <message to="someone@localhost" xmlns="jabber:client" type="chat" xml:lang="en" id="chat02" from="${Romeo's full JID}"> <body>U there?</body> </message> @@ -110,7 +110,7 @@ Romeo receives: <message to="${Romeo's full JID}"> <result xmlns="urn:xmpp:mam:2" queryid="q1" id="{scansion:any}"> <forwarded xmlns="urn:xmpp:forward:0"> - <delay stamp="2008-08-22T21:09:04Z" xmlns="urn:xmpp:delay"/> + <delay stamp="2008-08-22T21:09:04.500000Z" xmlns="urn:xmpp:delay"/> <message to="someone@localhost" xmlns="jabber:client" type="chat" xml:lang="en" id="chat01" from="${Romeo's full JID}"> <body>Hello</body> </message> diff --git a/spec/scansion/prosody.cfg.lua b/spec/scansion/prosody.cfg.lua index 6901cc11..0779f883 100644 --- a/spec/scansion/prosody.cfg.lua +++ b/spec/scansion/prosody.cfg.lua @@ -6,8 +6,8 @@ function _G.os.time() end package.preload["util.time"] = function () return { - now = function () return 1219439344.1; end; - monotonic = function () return 0.1; end; + now = function () return 1219439344.5; end; + monotonic = function () return 0.5; end; } end diff --git a/spec/util_cache_spec.lua b/spec/util_cache_spec.lua index d57e25ac..2cb7b7dd 100644 --- a/spec/util_cache_spec.lua +++ b/spec/util_cache_spec.lua @@ -4,6 +4,20 @@ local cache = require "util.cache"; describe("util.cache", function() describe("#new()", function() it("should work", function() + do + local c = cache.new(1); + assert.is_not_nil(c); + + assert.has_error(function () + cache.new(0); + end); + assert.has_error(function () + cache.new(-1); + end); + assert.has_error(function () + cache.new("foo"); + end); + end local c = cache.new(5); @@ -314,7 +328,7 @@ describe("util.cache", function() end); - (_VERSION=="Lua 5.1" and pending or it)(":table works", function () + it(":table works", function () local t = cache.new(3):table(); assert.is.table(t); t["a"] = "1"; @@ -336,5 +350,43 @@ describe("util.cache", function() assert.spy(i).was_called_with("c", "3"); assert.spy(i).was_called_with("d", "4"); end); + + local function vs(t) + local vs_ = {}; + for v in t:values() do + vs_[#vs_+1] = v; + end + return vs_; + end + + it(":values works", function () + local t = cache.new(3); + t:set("k1", "v1"); + t:set("k2", "v2"); + assert.same({"v2", "v1"}, vs(t)); + t:set("k3", "v3"); + assert.same({"v3", "v2", "v1"}, vs(t)); + t:set("k4", "v4"); + assert.same({"v4", "v3", "v2"}, vs(t)); + end); + + it(":resize works", function () + local c = cache.new(5); + for i = 1, 5 do + c:set(("k%d"):format(i), ("v%d"):format(i)); + end + assert.same({"v5", "v4", "v3", "v2", "v1"}, vs(c)); + assert.has_error(function () + c:resize(-1); + end); + assert.has_error(function () + c:resize(0); + end); + assert.has_error(function () + c:resize("foo"); + end); + c:resize(3); + assert.same({"v5", "v4", "v3"}, vs(c)); + end); end); end); diff --git a/spec/util_crypto_spec.lua b/spec/util_crypto_spec.lua new file mode 100644 index 00000000..77d046ac --- /dev/null +++ b/spec/util_crypto_spec.lua @@ -0,0 +1,184 @@ +local test_keys = require "spec.inputs.test_keys"; + +describe("util.crypto", function () + local crypto = require "util.crypto"; + local random = require "util.random"; + + describe("generate_ed25519_keypair", function () + local keypair = crypto.generate_ed25519_keypair(); + assert.is_not_nil(keypair); + assert.equal("ED25519", keypair:get_type()); + end) + + describe("import_private_pem", function () + it("can import ECDSA keys", function () + local ecdsa_key = crypto.import_private_pem(test_keys.ecdsa_private_pem); + assert.equal("id-ecPublicKey", ecdsa_key:get_type()); + end); + + it("can import EdDSA (Ed25519) keys", function () + local ed25519_key = crypto.import_private_pem(crypto.generate_ed25519_keypair():private_pem()); + assert.equal("ED25519", ed25519_key:get_type()); + end); + + it("can import RSA keys", function () + -- TODO + end); + + it("rejects invalid keys", function () + assert.is_nil(crypto.import_private_pem(test_keys.eddsa_public_pem)); + assert.is_nil(crypto.import_private_pem(test_keys.ecdsa_public_pem)); + assert.is_nil(crypto.import_private_pem("foo")); + assert.is_nil(crypto.import_private_pem("")); + end); + end); + + describe("import_public_pem", function () + it("can import ECDSA public keys", function () + local ecdsa_key = crypto.import_public_pem(test_keys.ecdsa_public_pem); + assert.equal("id-ecPublicKey", ecdsa_key:get_type()); + end); + + it("can import EdDSA (Ed25519) public keys", function () + local ed25519_key = crypto.import_public_pem(test_keys.eddsa_public_pem); + assert.equal("ED25519", ed25519_key:get_type()); + end); + + it("can import RSA public keys", function () + -- TODO + end); + end); + + describe("PEM export", function () + it("works", function () + local ecdsa_key = crypto.import_public_pem(test_keys.ecdsa_public_pem); + assert.equal("id-ecPublicKey", ecdsa_key:get_type()); + assert.equal(test_keys.ecdsa_public_pem, ecdsa_key:public_pem()); + + assert.has_error(function () + -- Fails because private key is not available + ecdsa_key:private_pem(); + end); + + local ecdsa_private_key = crypto.import_private_pem(test_keys.ecdsa_private_pem); + assert.equal(test_keys.ecdsa_private_pem, ecdsa_private_key:private_pem()); + end); + end); + + describe("sign/verify with", function () + local test_cases = { + ed25519 = { + crypto.ed25519_sign, crypto.ed25519_verify; + key = crypto.import_private_pem(test_keys.eddsa_private_pem); + sig_length = 64; + }; + ecdsa = { + crypto.ecdsa_sha256_sign, crypto.ecdsa_sha256_verify; + key = crypto.import_private_pem(test_keys.ecdsa_private_pem); + }; + }; + for test_name, test in pairs(test_cases) do + local key = test.key; + describe(test_name, function () + it("works", function () + local sign, verify = test[1], test[2]; + local sig = assert(sign(key, "Hello world")); + assert.is_string(sig); + if test.sig_length then + assert.equal(test.sig_length, #sig); + end + + do + local ok = verify(key, "Hello world", sig); + assert.is_truthy(ok); + end + do -- Incorrect signature + local ok = verify(key, "Hello world", sig:sub(1, -2)..string.char((sig:byte(-1)+1)%255)); + assert.is_falsy(ok); + end + do -- Incorrect message + local ok = verify(key, "Hello earth", sig); + assert.is_falsy(ok); + end + do -- Incorrect message (embedded NUL) + local ok = verify(key, "Hello world\0foo", sig); + assert.is_falsy(ok); + end + end); + end); + end + end); + + describe("ECDSA signatures", function () + local hex = require "util.hex"; + local sig = hex.decode((([[ + 304402203e936e7b0bc62887e0e9d675afd08531a930384cfcf301 + f25d13053a2ebf141d02205a5a7c7b7ac5878d004cb79b17b39346 + 6b0cd1043718ffc31c153b971d213a8e + ]]):gsub("%s+", ""))); + it("can be parsed", function () + local r, s = crypto.parse_ecdsa_signature(sig, 32); + assert.is_string(r); + assert.is_string(s); + assert.equal(32, #r); + assert.equal(32, #s); + end); + it("fails to parse invalid signatures", function () + local invalid_sigs = { + ""; + "\000"; + string.rep("\000", 64); + string.rep("\000", 72); + string.rep("\000", 256); + string.rep("\255", 72); + string.rep("\255", 3); + }; + for _, invalid_sig in ipairs(invalid_sigs) do + local r, s = crypto.parse_ecdsa_signature(invalid_sig, 32); + assert.is_nil(r); + assert.is_nil(s); + end + end); + it("can be built", function () + local r, s = crypto.parse_ecdsa_signature(sig, 32); + local rebuilt_sig = crypto.build_ecdsa_signature(r, s); + assert.equal(sig, rebuilt_sig); + end); + end); + + describe("AES-GCM encryption", function () + it("works", function () + local message = "foo\0bar"; + local key_128_bit = random.bytes(16); + local key_256_bit = random.bytes(32); + local test_cases = { + { crypto.aes_128_gcm_encrypt, crypto.aes_128_gcm_decrypt, key = key_128_bit }; + { crypto.aes_256_gcm_encrypt, crypto.aes_256_gcm_decrypt, key = key_256_bit }; + }; + for _, params in pairs(test_cases) do + local iv = params.iv or random.bytes(12); + local encrypted = params[1](params.key, iv, message); + assert.not_equal(message, encrypted); + local decrypted = params[2](params.key, iv, encrypted); + assert.equal(message, decrypted); + end + end); + end); + + describe("AES-CTR encryption", function () + it("works", function () + local message = "foo\0bar hello world"; + local key_256_bit = random.bytes(32); + local test_cases = { + { crypto.aes_256_ctr_decrypt, crypto.aes_256_ctr_decrypt, key = key_256_bit }; + }; + for _, params in pairs(test_cases) do + local iv = params.iv or random.bytes(16); + local encrypted = params[1](params.key, iv, message); + assert.not_equal(message, encrypted); + local decrypted = params[2](params.key, iv, encrypted); + assert.equal(message, decrypted); + end + end); + end); +end); diff --git a/spec/util_dataforms_spec.lua b/spec/util_dataforms_spec.lua index 5293238a..ab402fdb 100644 --- a/spec/util_dataforms_spec.lua +++ b/spec/util_dataforms_spec.lua @@ -130,7 +130,7 @@ describe("util.dataforms", function () assert.truthy(st.is_stanza(xform)); assert.equal("x", xform.name); assert.equal("jabber:x:data", xform.attr.xmlns); - assert.equal("FORM_TYPE", xform:find("field@var")); + assert.equal("FORM_TYPE", xform:get_child_attr("field", nil, "var")); assert.equal("xmpp:prosody.im/spec/util.dataforms#1", xform:find("field/value#")); local allowed_direct_children = { title = true, diff --git a/spec/util_datamapper_spec.lua b/spec/util_datamapper_spec.lua index 51ccf127..cc0e80e1 100644 --- a/spec/util_datamapper_spec.lua +++ b/spec/util_datamapper_spec.lua @@ -15,22 +15,22 @@ describe("util.datamapper", function() setup(function() -- a convenience function for simple attributes, there's a few of them - local function attr() return {["$ref"]="#/$defs/attr"} end + local attr = {["$ref"]="#/$defs/attr"}; s = { ["$defs"] = { attr = { type = "string"; xml = { attribute = true } } }; type = "object"; xml = {name = "message"; namespace = "jabber:client"}; properties = { - to = attr(); - from = attr(); - type = attr(); - id = attr(); + to = attr; + from = attr; + type = attr; + id = attr; body = true; -- should be assumed to be a string lang = {type = "string"; xml = {attribute = true; prefix = "xml"}}; delay = { type = "object"; xml = {namespace = "urn:xmpp:delay"; name = "delay"}; - properties = {stamp = attr(); from = attr(); reason = {type = "string"; xml = {text = true}}}; + properties = {stamp = attr; from = attr; reason = {type = "string"; xml = {text = true}}}; }; state = { type = "string"; @@ -66,8 +66,8 @@ describe("util.datamapper", function() xml = {name = "stanza-id"; namespace = "urn:xmpp:sid:0"}; type = "object"; properties = { - id = attr(); - by = attr(); + id = attr; + by = attr; }; }; }; @@ -120,10 +120,10 @@ describe("util.datamapper", function() namespace = "jabber:client" }; properties = { - to = attr(); - from = attr(); - type = attr(); - id = attr(); + to = attr; + from = attr; + type = attr; + id = attr; disco = { type = "object"; xml = { diff --git a/spec/util_datetime_spec.lua b/spec/util_datetime_spec.lua index 497ab7d3..960e8aef 100644 --- a/spec/util_datetime_spec.lua +++ b/spec/util_datetime_spec.lua @@ -16,7 +16,10 @@ describe("util.datetime", function () assert.truthy(string.find(date(), "^%d%d%d%d%-%d%d%-%d%d$")); end); it("should work", function () - assert.equals(date(1136239445), "2006-01-02"); + assert.equals("2006-01-02", date(1136239445)); + end); + it("should ignore fractional parts", function () + assert.equals("2006-01-02", date(1136239445.5)); end); end); describe("#time", function () @@ -32,8 +35,14 @@ describe("util.datetime", function () assert.truthy(string.find(time(), "^%d%d:%d%d:%d%d")); end); it("should work", function () - assert.equals(time(1136239445), "22:04:05"); - end); + assert.equals("22:04:05", time(1136239445)); + end); + it("should handle precision", function () + assert.equal("14:46:31.158200", time(1660488391.1582)) + assert.equal("14:46:32.158200", time(1660488392.1582)) + assert.equal("14:46:33.158200", time(1660488393.1582)) + assert.equal("14:46:33.999900", time(1660488393.9999)) + end) end); describe("#datetime", function () local datetime = util_datetime.datetime; @@ -48,14 +57,23 @@ describe("util.datetime", function () assert.truthy(string.find(datetime(), "^%d%d%d%d%-%d%d%-%d%dT%d%d:%d%d:%d%d")); end); it("should work", function () - assert.equals(datetime(1136239445), "2006-01-02T22:04:05Z"); - end); + assert.equals("2006-01-02T22:04:05Z", datetime(1136239445)); + end); + it("should handle precision", function () + assert.equal("2022-08-14T14:46:31.158200Z", datetime(1660488391.1582)) + assert.equal("2022-08-14T14:46:32.158200Z", datetime(1660488392.1582)) + assert.equal("2022-08-14T14:46:33.158200Z", datetime(1660488393.1582)) + assert.equal("2022-08-14T14:46:33.999900Z", datetime(1660488393.9999)) + end) end); describe("#legacy", function () local legacy = util_datetime.legacy; it("should exist", function () assert.is_function(legacy); end); + it("should not add precision", function () + assert.equal("20220814T14:46:31", legacy(1660488391.1582)); + end); end); describe("#parse", function () local parse = util_datetime.parse; @@ -64,13 +82,23 @@ describe("util.datetime", function () end); it("should work", function () -- Timestamp used by Go - assert.equals(parse("2017-11-19T17:58:13Z"), 1511114293); - assert.equals(parse("2017-11-19T18:58:50+0100"), 1511114330); - assert.equals(parse("2006-01-02T15:04:05-0700"), 1136239445); + assert.equals(1511114293, parse("2017-11-19T17:58:13Z")); + assert.equals(1511114330, parse("2017-11-19T18:58:50+0100")); + assert.equals(1136239445, parse("2006-01-02T15:04:05-0700")); + assert.equals(1136239445, parse("2006-01-02T15:04:05-07")); end); it("should handle timezones", function () -- https://xmpp.org/extensions/xep-0082.html#example-2 and 3 assert.equals(parse("1969-07-21T02:56:15Z"), parse("1969-07-20T21:56:15-05:00")); end); + it("should handle precision", function () + -- floating point comparison is not an exact science + assert.truthy(math.abs(1660488392.1582 - parse("2022-08-14T14:46:32.158200Z")) < 0.001) + end) + it("should return nil when given invalid inputs", function () + assert.is_nil(parse(nil)); + assert.is_nil(parse("hello world")); + assert.is_nil(parse("2017-11-19T18:58:50$0100")); + end); end); end); diff --git a/spec/util_dbuffer_spec.lua b/spec/util_dbuffer_spec.lua index 83ca1823..0072a59d 100644 --- a/spec/util_dbuffer_spec.lua +++ b/spec/util_dbuffer_spec.lua @@ -6,6 +6,8 @@ describe("util.dbuffer", function () end); it("can be created", function () assert.truthy(dbuffer.new()); + assert.truthy(dbuffer.new(1)); + assert.truthy(dbuffer.new(1024)); end); it("won't create an empty buffer", function () assert.falsy(dbuffer.new(0)); @@ -15,10 +17,21 @@ describe("util.dbuffer", function () end); end); describe(":write", function () - local b = dbuffer.new(); + local b = dbuffer.new(10, 3); it("works", function () assert.truthy(b:write("hi")); end); + it("fails when the buffer is full", function () + local ret = b:write(" there world, this is a long piece of data"); + assert.is_falsy(ret); + end); + it("works when max_chunks is reached", function () + -- Chunks are an optimization, dbuffer should collapse chunks when needed + for _ = 1, 8 do + assert.truthy(b:write("!")); + end + assert.falsy(b:write("!")); -- Length reached + end); end); describe(":read", function () @@ -34,6 +47,14 @@ describe("util.dbuffer", function () assert.equal(" ", b:read()); assert.equal("world", b:read()); end); + it("fails when there is not enough data in the buffer", function () + local b = dbuffer.new(12); + b:write("hello"); + b:write(" "); + b:write("world"); + assert.is_falsy(b:read(12)); + assert.is_falsy(b:read(13)); + end); end); describe(":read_until", function () @@ -68,9 +89,46 @@ describe("util.dbuffer", function () assert.equal(5, b:len()); assert.equal("world", b:read(5)); end); + it("works across chunks", function () + assert.truthy(b:write("hello")); + assert.truthy(b:write(" ")); + assert.truthy(b:write("world")); + assert.truthy(b:discard(3)); + assert.equal(8, b:length()); + assert.truthy(b:discard(3)); + assert.equal(5, b:length()); + assert.equal("world", b:read(5)); + end); + it("can discard the entire buffer", function () + assert.equal(b:len(), 0); + assert.truthy(b:write("hello world")); + assert.truthy(b:discard(11)); + assert.equal(0, b:len()); + assert.truthy(b:write("hello world")); + assert.truthy(b:discard(12)); + assert.equal(0, b:len()); + assert.truthy(b:write("hello world")); + assert.truthy(b:discard(128)); + assert.equal(0, b:len()); + end); + it("works on an empty buffer", function () + assert.truthy(dbuffer.new():discard()); + assert.truthy(dbuffer.new():discard(0)); + assert.truthy(dbuffer.new():discard(1)); + end); end); describe(":collapse()", function () + it("works", function () + local b = dbuffer.new(); + b:write("hello"); + b:write(" "); + b:write("world"); + b:collapse(6); + local ret, bytes = b:read_chunk(); + assert.equal("hello ", ret); + assert.equal(6, bytes); + end); it("works on an empty buffer", function () local b = dbuffer.new(); b:collapse(); @@ -115,6 +173,11 @@ describe("util.dbuffer", function () end end end); + + it("works on an empty buffer", function () + local b = dbuffer.new(); + assert.equal("", b:sub(1, 12)); + end); end); describe(":byte", function () @@ -122,7 +185,11 @@ describe("util.dbuffer", function () local s = "hello world" local function test_byte(b, x, y) local string_result, buffer_result = {s:byte(x, y)}, {b:byte(x, y)}; - assert.same(string_result, buffer_result, ("buffer:byte(%d, %s) does not match string:byte()"):format(x, y and ("%d"):format(y) or "nil")); + assert.same( + string_result, + buffer_result, + ("buffer:byte(%s, %s) does not match string:byte()"):format(x and ("%d"):format(x) or "nil", y and ("%d"):format(y) or "nil") + ); end it("is equivalent to string:byte", function () @@ -132,6 +199,7 @@ describe("util.dbuffer", function () test_byte(b, 3); test_byte(b, -1); test_byte(b, -3); + test_byte(b, nil, 5); for i = -13, 13 do for j = -13, 13 do test_byte(b, i, j); diff --git a/spec/util_format_spec.lua b/spec/util_format_spec.lua index cb473b47..1016a215 100644 --- a/spec/util_format_spec.lua +++ b/spec/util_format_spec.lua @@ -333,29 +333,27 @@ describe("util.format", function() end); end); - if _VERSION > "Lua 5.1" then -- COMPAT no %a or %A in Lua 5.1 - describe("to %a", function () - it("works", function () - assert.equal("0x1.84p+6", format("%a", 97)) - assert.equal("-0x1.81c8p+13", format("%a", -12345)) - assert.equal("0x1.8p+0", format("%a", 1.5)) - assert.equal("0x1p+66", format("%a", 73786976294838206464)) - assert.equal("inf", format("%a", math.huge)) - assert.equal("0x1.fffffffcp+30", format("%a", 2147483647)) - end); - end); - - describe("to %A", function () - it("works", function () - assert.equal("0X1.84P+6", format("%A", 97)) - assert.equal("-0X1.81C8P+13", format("%A", -12345)) - assert.equal("0X1.8P+0", format("%A", 1.5)) - assert.equal("0X1P+66", format("%A", 73786976294838206464)) - assert.equal("INF", format("%A", math.huge)) - assert.equal("0X1.FFFFFFFCP+30", format("%A", 2147483647)) - end); - end); - end + describe("to %a", function () + it("works", function () + assert.equal("0x1.84p+6", format("%a", 97)) + assert.equal("-0x1.81c8p+13", format("%a", -12345)) + assert.equal("0x1.8p+0", format("%a", 1.5)) + assert.equal("0x1p+66", format("%a", 73786976294838206464)) + assert.equal("inf", format("%a", math.huge)) + assert.equal("0x1.fffffffcp+30", format("%a", 2147483647)) + end); + end); + + describe("to %A", function () + it("works", function () + assert.equal("0X1.84P+6", format("%A", 97)) + assert.equal("-0X1.81C8P+13", format("%A", -12345)) + assert.equal("0X1.8P+0", format("%A", 1.5)) + assert.equal("0X1P+66", format("%A", 73786976294838206464)) + assert.equal("INF", format("%A", math.huge)) + assert.equal("0X1.FFFFFFFCP+30", format("%A", 2147483647)) + end); + end); describe("to %e", function () it("works", function () diff --git a/spec/util_hashes_spec.lua b/spec/util_hashes_spec.lua index 51a4a79c..c4b1841e 100644 --- a/spec/util_hashes_spec.lua +++ b/spec/util_hashes_spec.lua @@ -4,6 +4,8 @@ local hex = require "util.hex"; -- Also see spec for util.hmac where HMAC test cases reside +--luacheck: ignore 631 + describe("PBKDF2-HMAC-SHA1", function () it("test vector 1", function () local P = "password" @@ -53,3 +55,56 @@ describe("PBKDF2-HMAC-SHA256", function () end); +describe("SHA-3", function () + describe("256", function () + it("works", function () + local expected = "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a" + assert.equal(expected, hashes.sha3_256("", true)); + end); + end); + describe("512", function () + it("works", function () + local expected = "a69f73cca23a9ac5c8b567dc185a756e97c982164fe25859e0d1dcc1475c80a615b2123af1f5f94c11e3e9402c3ac558f500199d95b6d3e301758586281dcd26" + assert.equal(expected, hashes.sha3_512("", true)); + end); + end); +end); + +describe("HKDF", function () + describe("HMAC-SHA256", function () + describe("RFC 5869", function () + it("test vector A.1", function () + local ikm = hex.decode("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b"); + local salt = hex.decode("000102030405060708090a0b0c"); + local info = hex.decode("f0f1f2f3f4f5f6f7f8f9"); + + local expected = "3cb25f25faacd57a90434f64d0362f2a2d2d0a90cf1a5a4c5db02d56ecc4c5bf34007208d5b887185865"; + + local ret = hashes.hkdf_hmac_sha256(42, ikm, salt, info); + assert.equal(expected, hex.encode(ret)); + end); + + it("test vector A.2", function () + local ikm = hex.decode("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f"); + local salt = hex.decode("606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeaf"); + local info = hex.decode("b0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff"); + + local expected = "b11e398dc80327a1c8e7f78c596a49344f012eda2d4efad8a050cc4c19afa97c59045a99cac7827271cb41c65e590e09da3275600c2f09b8367793a9aca3db71cc30c58179ec3e87c14c01d5c1f3434f1d87"; + + local ret = hashes.hkdf_hmac_sha256(82, ikm, salt, info); + assert.equal(expected, hex.encode(ret)); + end); + + it("test vector A.3", function () + local ikm = hex.decode("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b"); + local salt = ""; + local info = ""; + + local expected = "8da4e775a563c18f715f802a063c5a31b8a11f5c5ee1879ec3454e5f3c738d2d9d201395faa4b61a96c8"; + + local ret = hashes.hkdf_hmac_sha256(42, ikm, salt, info); + assert.equal(expected, hex.encode(ret)); + end); + end); + end); +end); diff --git a/spec/util_hashring_spec.lua b/spec/util_hashring_spec.lua index d8801774..4f6ec3a3 100644 --- a/spec/util_hashring_spec.lua +++ b/spec/util_hashring_spec.lua @@ -1,6 +1,7 @@ local hashring = require "util.hashring"; describe("util.hashring", function () + randomize(false); local sha256 = require "util.hashes".sha256; @@ -82,4 +83,11 @@ describe("util.hashring", function () end end); + it("should support values associated with nodes", function () + local r = hashring.new(128, sha256); + r:add_node("node1", { a = 1 }); + local node, value = r:get_node("foo"); + assert.is_equal("node1", node); + assert.same({ a = 1 }, value); + end); end); diff --git a/spec/util_iterators_spec.lua b/spec/util_iterators_spec.lua index 4cf6f19d..a4b85884 100644 --- a/spec/util_iterators_spec.lua +++ b/spec/util_iterators_spec.lua @@ -10,6 +10,14 @@ describe("util.iterators", function () end assert.same(output, expect); end); + it("should work with only a single iterator", function () + local expect = { "a", "b", "c" }; + local output = {}; + for x in iter.join(iter.values({"a", "b", "c"})) do + table.insert(output, x); + end + assert.same(output, expect); + end); end); describe("sorted_pairs", function () diff --git a/spec/util_jid_spec.lua b/spec/util_jid_spec.lua index b92ca06c..4f5fe356 100644 --- a/spec/util_jid_spec.lua +++ b/spec/util_jid_spec.lua @@ -48,6 +48,47 @@ describe("util.jid", function() end) end); + describe("#prepped_split()", function() + local function test(input_jid, expected_node, expected_server, expected_resource) + local rnode, rserver, rresource = jid.prepped_split(input_jid); + assert.are.equal(expected_node, rnode, "split("..tostring(input_jid)..") failed"); + assert.are.equal(expected_server, rserver, "split("..tostring(input_jid)..") failed"); + assert.are.equal(expected_resource, rresource, "split("..tostring(input_jid)..") failed"); + end + + it("should work", function() + -- Valid JIDs + test("node@server", "node", "server", nil ); + test("node@server/resource", "node", "server", "resource" ); + test("server", nil, "server", nil ); + test("server/resource", nil, "server", "resource" ); + test("server/resource@foo", nil, "server", "resource@foo" ); + test("server/resource@foo/bar", nil, "server", "resource@foo/bar"); + + -- Always invalid JIDs + test(nil, nil, nil, nil); + test("node@/server", nil, nil, nil); + test("@server", nil, nil, nil); + test("@server/resource", nil, nil, nil); + test("@/resource", nil, nil, nil); + test("@server/", nil, nil, nil); + test("server/", nil, nil, nil); + test("/resource", nil, nil, nil); + end); + it("should reject invalid arguments", function () + assert.has_error(function () jid.prepped_split(false) end) + end) + it("should strip empty root label", function () + test("node@server.", "node", "server", nil); + end); + it("should fail for JIDs that fail stringprep", function () + test("node@invalid-\128-server", nil, nil, nil); + test("node@invalid-\194\128-server", nil, nil, nil); + test("<invalid node>@server", nil, nil, nil); + test("node@server/invalid-\000-resource", nil, nil, nil); + end); + end); + describe("#bare()", function() it("should work", function() diff --git a/spec/util_jsonpointer_spec.lua b/spec/util_jsonpointer_spec.lua index ce07c7a1..75122296 100644 --- a/spec/util_jsonpointer_spec.lua +++ b/spec/util_jsonpointer_spec.lua @@ -21,9 +21,11 @@ describe("util.jsonpointer", function() }]]) end) it("works", function() + assert.is_nil(jp.resolve("string", "/string")) assert.same(example, jp.resolve(example, "")); assert.same({ "bar", "baz" }, jp.resolve(example, "/foo")); assert.same("bar", jp.resolve(example, "/foo/0")); + assert.same(nil, jp.resolve(example, "/foo/-")); assert.same(0, jp.resolve(example, "/")); assert.same(1, jp.resolve(example, "/a~1b")); assert.same(2, jp.resolve(example, "/c%d")); diff --git a/spec/util_jwt_spec.lua b/spec/util_jwt_spec.lua index b391a870..6946bdd3 100644 --- a/spec/util_jwt_spec.lua +++ b/spec/util_jwt_spec.lua @@ -1,4 +1,12 @@ local jwt = require "util.jwt"; +local test_keys = require "spec.inputs.test_keys"; + +local array = require "util.array"; +local iter = require "util.iterators"; +local set = require "util.set"; + +-- Ignore long lines. We have some long tokens embedded here. +--luacheck: ignore 631 describe("util.jwt", function () it("validates", function () @@ -8,6 +16,9 @@ describe("util.jwt", function () local ok, parsed = jwt.verify(key, token); assert.truthy(ok) assert.same({ payload = "this" }, parsed); + + + end); it("rejects invalid", function () local key = "secret"; @@ -16,5 +27,233 @@ describe("util.jwt", function () local ok = jwt.verify(key, token); assert.falsy(ok) end); + + local function jwt_reference_token(token) + return { + name = "jwt.io reference"; + token; + { -- payload + sub = "1234567890"; + name = "John Doe"; + admin = true; + iat = 1516239022; + }; + }; + end + + local untested_algorithms = set.new(array.collect(iter.keys(jwt._algorithms))); + + local test_cases = { + { + algorithm = "HS256"; + keys = { + { "your-256-bit-secret", "your-256-bit-secret" }; + { "another-secret", "another-secret" }; + }; + + jwt_reference_token [[eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJhZG1pbiI6dHJ1ZX0.F-cvL2RcfQhUtCavIM7q7zYE8drmj2LJk0JRkrS6He4]]; + }; + { + algorithm = "HS384"; + keys = { + { "your-384-bit-secret", "your-384-bit-secret" }; + { "another-secret", "another-secret" }; + }; + + jwt_reference_token [[eyJhbGciOiJIUzM4NCIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.bQTnz6AuMJvmXXQsVPrxeQNvzDkimo7VNXxHeSBfClLufmCVZRUuyTwJF311JHuh]]; + }; + { + algorithm = "HS512"; + keys = { + { "your-512-bit-secret", "your-512-bit-secret" }; + { "another-secret", "another-secret" }; + }; + + jwt_reference_token [[eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.VFb0qJ1LRg_4ujbZoRMXnVkUgiuKq5KxWqNdbKq_G9Vvz-S1zZa9LPxtHWKa64zDl2ofkT8F6jBt_K4riU-fPg]]; + }; + { + algorithm = "ES256"; + keys = { + { test_keys.ecdsa_private_pem, test_keys.ecdsa_public_pem }; + { test_keys.alt_ecdsa_private_pem, test_keys.alt_ecdsa_public_pem }; + }; + { + name = "jwt.io reference"; + [[eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.tyh-VfuzIxCyGYDlkBA7DfyjrqmSHu6pQ2hoZuFqUSLPNY2N0mpHb3nk5K17HWP_3cYHBw7AhHale5wky6-sVA]]; + { -- payload + sub = "1234567890"; + name = "John Doe"; + admin = true; + iat = 1516239022; + }; + }; + }; + { + algorithm = "ES512"; + keys = { + { test_keys.ecdsa_521_private_pem, test_keys.ecdsa_521_public_pem }; + { test_keys.alt_ecdsa_521_private_pem, test_keys.alt_ecdsa_521_public_pem }; + }; + { + name = "jwt.io reference"; + [[eyJhbGciOiJFUzUxMiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.AbVUinMiT3J_03je8WTOIl-VdggzvoFgnOsdouAs-DLOtQzau9valrq-S6pETyi9Q18HH-EuwX49Q7m3KC0GuNBJAc9Tksulgsdq8GqwIqZqDKmG7hNmDzaQG1Dpdezn2qzv-otf3ZZe-qNOXUMRImGekfQFIuH_MjD2e8RZyww6lbZk]]; + { -- payload + sub = "1234567890"; + name = "John Doe"; + admin = true; + iat = 1516239022; + }; + }; + }; + { + algorithm = "RS256"; + keys = { + { test_keys.rsa_private_pem, test_keys.rsa_public_pem }; + { test_keys.alt_rsa_private_pem, test_keys.alt_rsa_public_pem }; + }; + { + name = "jwt.io reference"; + [[eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.NHVaYe26MbtOYhSKkoKYdFVomg4i8ZJd8_-RU8VNbftc4TSMb4bXP3l3YlNWACwyXPGffz5aXHc6lty1Y2t4SWRqGteragsVdZufDn5BlnJl9pdR_kdVFUsra2rWKEofkZeIC4yWytE58sMIihvo9H1ScmmVwBcQP6XETqYd0aSHp1gOa9RdUPDvoXQ5oqygTqVtxaDr6wUFKrKItgBMzWIdNZ6y7O9E0DhEPTbE9rfBo6KTFsHAZnMg4k68CDp2woYIaXbmYTWcvbzIuHO7_37GT79XdIwkm95QJ7hYC9RiwrV7mesbY4PAahERJawntho0my942XheVLmGwLMBkQ]]; + { -- payload + sub = "1234567890"; + name = "John Doe"; + admin = true; + iat = 1516239022; + }; + }; + }; + { + algorithm = "RS384"; + keys = { + { test_keys.rsa_private_pem, test_keys.rsa_public_pem }; + { test_keys.alt_rsa_private_pem, test_keys.alt_rsa_public_pem }; + }; + + jwt_reference_token [[eyJhbGciOiJSUzM4NCIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.o1hC1xYbJolSyh0-bOY230w22zEQSk5TiBfc-OCvtpI2JtYlW-23-8B48NpATozzMHn0j3rE0xVUldxShzy0xeJ7vYAccVXu2Gs9rnTVqouc-UZu_wJHkZiKBL67j8_61L6SXswzPAQu4kVDwAefGf5hyYBUM-80vYZwWPEpLI8K4yCBsF6I9N1yQaZAJmkMp_Iw371Menae4Mp4JusvBJS-s6LrmG2QbiZaFaxVJiW8KlUkWyUCns8-qFl5OMeYlgGFsyvvSHvXCzQrsEXqyCdS4tQJd73ayYA4SPtCb9clz76N1zE5WsV4Z0BYrxeb77oA7jJhh994RAPzCG0hmQ]]; + }; + { + algorithm = "RS512"; + keys = { + { test_keys.rsa_private_pem, test_keys.rsa_public_pem }; + { test_keys.alt_rsa_private_pem, test_keys.alt_rsa_public_pem }; + }; + + jwt_reference_token [[eyJhbGciOiJSUzUxMiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.jYW04zLDHfR1v7xdrW3lCGZrMIsVe0vWCfVkN2DRns2c3MN-mcp_-RE6TN9umSBYoNV-mnb31wFf8iun3fB6aDS6m_OXAiURVEKrPFNGlR38JSHUtsFzqTOj-wFrJZN4RwvZnNGSMvK3wzzUriZqmiNLsG8lktlEn6KA4kYVaM61_NpmPHWAjGExWv7cjHYupcjMSmR8uMTwN5UuAwgW6FRstCJEfoxwb0WKiyoaSlDuIiHZJ0cyGhhEmmAPiCwtPAwGeaL1yZMcp0p82cpTQ5Qb-7CtRov3N4DcOHgWYk6LomPR5j5cCkePAz87duqyzSMpCB0mCOuE3CU2VMtGeQ]]; + }; + { + algorithm = "PS256"; + keys = { + { test_keys.rsa_private_pem, test_keys.rsa_public_pem }; + { test_keys.alt_rsa_private_pem, test_keys.alt_rsa_public_pem }; + }; + + jwt_reference_token [[eyJhbGciOiJQUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.iOeNU4dAFFeBwNj6qdhdvm-IvDQrTa6R22lQVJVuWJxorJfeQww5Nwsra0PjaOYhAMj9jNMO5YLmud8U7iQ5gJK2zYyepeSuXhfSi8yjFZfRiSkelqSkU19I-Ja8aQBDbqXf2SAWA8mHF8VS3F08rgEaLCyv98fLLH4vSvsJGf6ueZSLKDVXz24rZRXGWtYYk_OYYTVgR1cg0BLCsuCvqZvHleImJKiWmtS0-CymMO4MMjCy_FIl6I56NqLE9C87tUVpo1mT-kbg5cHDD8I7MjCW5Iii5dethB4Vid3mZ6emKjVYgXrtkOQ-JyGMh6fnQxEFN1ft33GX2eRHluK9eg]]; + }; + { + algorithm = "PS384"; + keys = { + { test_keys.rsa_private_pem, test_keys.rsa_public_pem }; + { test_keys.alt_rsa_private_pem, test_keys.alt_rsa_public_pem }; + }; + + jwt_reference_token [[eyJhbGciOiJQUzM4NCIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.Lfe_aCQme_gQpUk9-6l9qesu0QYZtfdzfy08w8uqqPH_gnw-IVyQwyGLBHPFBJHMbifdSMxPjJjkCD0laIclhnBhowILu6k66_5Y2z78GHg8YjKocAvB-wSUiBhuV6hXVxE5emSjhfVz2OwiCk2bfk2hziRpkdMvfcITkCx9dmxHU6qcEIsTTHuH020UcGayB1-IoimnjTdCsV1y4CMr_ECDjBrqMdnontkqKRIM1dtmgYFsJM6xm7ewi_ksG_qZHhaoBkxQ9wq9OVQRGiSZYowCp73d2BF3jYMhdmv2JiaUz5jRvv6lVU7Quq6ylVAlSPxeov9voYHO1mgZFCY1kQ]]; + }; + { + algorithm = "PS512"; + keys = { + { test_keys.rsa_private_pem, test_keys.rsa_public_pem }; + { test_keys.alt_rsa_private_pem, test_keys.alt_rsa_public_pem }; + }; + + jwt_reference_token [[eyJhbGciOiJQUzUxMiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.J5W09-rNx0pt5_HBiydR-vOluS6oD-RpYNa8PVWwMcBDQSXiw6-EPW8iSsalXPspGj3ouQjAnOP_4-zrlUUlvUIt2T79XyNeiKuooyIFvka3Y5NnGiOUBHWvWcWp4RcQFMBrZkHtJM23sB5D7Wxjx0-HFeNk-Y3UJgeJVhg5NaWXypLkC4y0ADrUBfGAxhvGdRdULZivfvzuVtv6AzW6NRuEE6DM9xpoWX_4here-yvLS2YPiBTZ8xbB3axdM99LhES-n52lVkiX5AWg2JJkEROZzLMpaacA_xlbUz_zbIaOaoqk8gB5oO7kI6sZej3QAdGigQy-hXiRnW_L98d4GQ]]; + }; + }; + + local function do_verify_test(algorithm, verifying_key, token, expect_payload) + local verify = jwt.new_verifier(algorithm, verifying_key); + + assert.is_string(token); + local result = {verify(token)}; + if expect_payload then + assert.same({ + true; -- success + expect_payload; -- payload + }, result); + else + assert.same({ + false; + "signature-mismatch"; + }, result); + end + end + + local function do_sign_verify_test(algorithm, signing_key, verifying_key, expect_success, expect_token) + local sign = jwt.new_signer(algorithm, signing_key); + + local test_payload = { + sub = "1234567890"; + name = "John Doe"; + admin = true; + iat = 1516239022; + }; + + local token = sign(test_payload); + + if expect_token then + assert.equal(expect_token, token); + end + + do_verify_test(algorithm, verifying_key, token, expect_success and test_payload or false); + end + + + for _, algorithm_tests in ipairs(test_cases) do + local algorithm = algorithm_tests.algorithm; + local keypairs = algorithm_tests.keys; + + untested_algorithms:remove(algorithm); + + describe(algorithm, function () + describe("can do basic sign and verify", function () + for keypair_n, keypair in ipairs(keypairs) do + local signing_key, verifying_key = keypair[1], keypair[2]; + it(("(test key pair %d)"):format(keypair_n), function () + do_sign_verify_test(algorithm, signing_key, verifying_key, true); + end); + end + end); + + if #keypairs >= 2 then + it("rejects invalid tokens", function () + do_sign_verify_test(algorithm, keypairs[1][1], keypairs[2][2], false); + end); + else + pending("rejects invalid tokens", function () + error("Needs at least 2 key pairs"); + end); + end + + if #algorithm_tests > 0 then + for test_n, test_case in ipairs(algorithm_tests) do + it("can verify "..(test_case.name or (("test case %d"):format(test_n))), function () + do_verify_test( + algorithm, + test_case.verifying_key or keypairs[1][2], + test_case[1], + test_case[2] + ); + end); + end + else + pending("can verify reference tokens", function () + error("No test tokens provided"); + end); + end + end); + end + + for algorithm in untested_algorithms do + pending(algorithm.." tests", function () end); + end end); diff --git a/spec/util_paseto_spec.lua b/spec/util_paseto_spec.lua new file mode 100644 index 00000000..417278b1 --- /dev/null +++ b/spec/util_paseto_spec.lua @@ -0,0 +1,292 @@ +-- Ignore long lines in this file +--luacheck: ignore 631 + +describe("util.paseto", function () + local paseto = require "util.paseto"; + local json = require "util.json"; + local hex = require "util.hex"; + + describe("v3.local", function () + local function parse_test_cases(json_test_cases) + local input_cases = json.decode(json_test_cases); + local output_cases = {}; + for _, case in ipairs(input_cases) do + assert.is_string(case.name, "Bad test case: expected name"); + assert.is_nil(output_cases[case.name], "Bad test case: duplicate name"); + output_cases[case.name] = function () + local key = hex.decode(case.key); + local payload, err = paseto.v3_local.decrypt(case.token, key, case.footer, case["implicit-assertion"]); + if case["expect-fail"] then + assert.is_nil(payload); + else + assert.is_nil(err); + assert.same(json.decode(case.payload), payload); + end + end; + end + return output_cases; + end + + local test_cases = parse_test_cases [=[[ + { + "name": "3-E-1", + "expect-fail": false, + "key": "707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f", + "nonce": "0000000000000000000000000000000000000000000000000000000000000000", + "token": "v3.local.AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADbfcIURX_0pVZVU1mAESUzrKZAsRm2EsD6yBoZYn6cpVZNzSJOhSDN-sRaWjfLU-yn9OJH1J_B8GKtOQ9gSQlb8yk9Iza7teRdkiR89ZFyvPPsVjjFiepFUVcMa-LP18zV77f_crJrVXWa5PDNRkCSeHfBBeg", + "payload": "{\"data\":\"this is a secret message\",\"exp\":\"2022-01-01T00:00:00+00:00\"}", + "footer": "", + "implicit-assertion": "" + }, + { + "name": "3-E-2", + "expect-fail": false, + "key": "707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f", + "nonce": "0000000000000000000000000000000000000000000000000000000000000000", + "token": "v3.local.AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADbfcIURX_0pVZVU1mAESUzrKZAqhWxBMDgyBoZYn6cpVZNzSJOhSDN-sRaWjfLU-yn9OJH1J_B8GKtOQ9gSQlb8yk9IzZfaZpReVpHlDSwfuygx1riVXYVs-UjcrG_apl9oz3jCVmmJbRuKn5ZfD8mHz2db0A", + "payload": "{\"data\":\"this is a hidden message\",\"exp\":\"2022-01-01T00:00:00+00:00\"}", + "footer": "", + "implicit-assertion": "" + }, + { + "name": "3-E-3", + "expect-fail": false, + "nonce": "26f7553354482a1d91d4784627854b8da6b8042a7966523c2b404e8dbbe7f7f2", + "key": "707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f", + "token": "v3.local.JvdVM1RIKh2R1HhGJ4VLjaa4BCp5ZlI8K0BOjbvn9_LwY78vQnDait-Q-sjhF88dG2B0ROIIykcrGHn8wzPbTrqObHhyoKpjy3cwZQzLdiwRsdEK5SDvl02_HjWKJW2oqGMOQJlxnt5xyhQjFJomwnt7WW_7r2VT0G704ifult011-TgLCyQ2X8imQhniG_hAQ4BydM", + "payload": "{\"data\":\"this is a secret message\",\"exp\":\"2022-01-01T00:00:00+00:00\"}", + "footer": "", + "implicit-assertion": "" + }, + { + "name": "3-E-4", + "expect-fail": false, + "nonce": "26f7553354482a1d91d4784627854b8da6b8042a7966523c2b404e8dbbe7f7f2", + "key": "707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f", + "token": "v3.local.JvdVM1RIKh2R1HhGJ4VLjaa4BCp5ZlI8K0BOjbvn9_LwY78vQnDait-Q-sjhF88dG2B0X-4P3EcxGHn8wzPbTrqObHhyoKpjy3cwZQzLdiwRsdEK5SDvl02_HjWKJW2oqGMOQJlBZa_gOpVj4gv0M9lV6Pwjp8JS_MmaZaTA1LLTULXybOBZ2S4xMbYqYmDRhh3IgEk", + "payload": "{\"data\":\"this is a hidden message\",\"exp\":\"2022-01-01T00:00:00+00:00\"}", + "footer": "", + "implicit-assertion": "" + }, + { + "name": "3-E-5", + "expect-fail": false, + "nonce": "26f7553354482a1d91d4784627854b8da6b8042a7966523c2b404e8dbbe7f7f2", + "key": "707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f", + "token": "v3.local.JvdVM1RIKh2R1HhGJ4VLjaa4BCp5ZlI8K0BOjbvn9_LwY78vQnDait-Q-sjhF88dG2B0ROIIykcrGHn8wzPbTrqObHhyoKpjy3cwZQzLdiwRsdEK5SDvl02_HjWKJW2oqGMOQJlkYSIbXOgVuIQL65UMdW9WcjOpmqvjqD40NNzed-XPqn1T3w-bJvitYpUJL_rmihc.eyJraWQiOiJVYmtLOFk2aXY0R1poRnA2VHgzSVdMV0xmTlhTRXZKY2RUM3pkUjY1WVp4byJ9", + "payload": "{\"data\":\"this is a secret message\",\"exp\":\"2022-01-01T00:00:00+00:00\"}", + "footer": "{\"kid\":\"UbkK8Y6iv4GZhFp6Tx3IWLWLfNXSEvJcdT3zdR65YZxo\"}", + "implicit-assertion": "" + }, + { + "name": "3-E-6", + "expect-fail": false, + "nonce": "26f7553354482a1d91d4784627854b8da6b8042a7966523c2b404e8dbbe7f7f2", + "key": "707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f", + "token": "v3.local.JvdVM1RIKh2R1HhGJ4VLjaa4BCp5ZlI8K0BOjbvn9_LwY78vQnDait-Q-sjhF88dG2B0X-4P3EcxGHn8wzPbTrqObHhyoKpjy3cwZQzLdiwRsdEK5SDvl02_HjWKJW2oqGMOQJmSeEMphEWHiwtDKJftg41O1F8Hat-8kQ82ZIAMFqkx9q5VkWlxZke9ZzMBbb3Znfo.eyJraWQiOiJVYmtLOFk2aXY0R1poRnA2VHgzSVdMV0xmTlhTRXZKY2RUM3pkUjY1WVp4byJ9", + "payload": "{\"data\":\"this is a hidden message\",\"exp\":\"2022-01-01T00:00:00+00:00\"}", + "footer": "{\"kid\":\"UbkK8Y6iv4GZhFp6Tx3IWLWLfNXSEvJcdT3zdR65YZxo\"}", + "implicit-assertion": "" + }, + { + "name": "3-E-7", + "expect-fail": false, + "nonce": "26f7553354482a1d91d4784627854b8da6b8042a7966523c2b404e8dbbe7f7f2", + "key": "707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f", + "token": "v3.local.JvdVM1RIKh2R1HhGJ4VLjaa4BCp5ZlI8K0BOjbvn9_LwY78vQnDait-Q-sjhF88dG2B0ROIIykcrGHn8wzPbTrqObHhyoKpjy3cwZQzLdiwRsdEK5SDvl02_HjWKJW2oqGMOQJkzWACWAIoVa0bz7EWSBoTEnS8MvGBYHHo6t6mJunPrFR9JKXFCc0obwz5N-pxFLOc.eyJraWQiOiJVYmtLOFk2aXY0R1poRnA2VHgzSVdMV0xmTlhTRXZKY2RUM3pkUjY1WVp4byJ9", + "payload": "{\"data\":\"this is a secret message\",\"exp\":\"2022-01-01T00:00:00+00:00\"}", + "footer": "{\"kid\":\"UbkK8Y6iv4GZhFp6Tx3IWLWLfNXSEvJcdT3zdR65YZxo\"}", + "implicit-assertion": "{\"test-vector\":\"3-E-7\"}" + }, + { + "name": "3-E-8", + "expect-fail": false, + "nonce": "26f7553354482a1d91d4784627854b8da6b8042a7966523c2b404e8dbbe7f7f2", + "key": "707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f", + "token": "v3.local.JvdVM1RIKh2R1HhGJ4VLjaa4BCp5ZlI8K0BOjbvn9_LwY78vQnDait-Q-sjhF88dG2B0X-4P3EcxGHn8wzPbTrqObHhyoKpjy3cwZQzLdiwRsdEK5SDvl02_HjWKJW2oqGMOQJmZHSSKYR6AnPYJV6gpHtx6dLakIG_AOPhu8vKexNyrv5_1qoom6_NaPGecoiz6fR8.eyJraWQiOiJVYmtLOFk2aXY0R1poRnA2VHgzSVdMV0xmTlhTRXZKY2RUM3pkUjY1WVp4byJ9", + "payload": "{\"data\":\"this is a hidden message\",\"exp\":\"2022-01-01T00:00:00+00:00\"}", + "footer": "{\"kid\":\"UbkK8Y6iv4GZhFp6Tx3IWLWLfNXSEvJcdT3zdR65YZxo\"}", + "implicit-assertion": "{\"test-vector\":\"3-E-8\"}" + }, + { + "name": "3-E-9", + "expect-fail": false, + "nonce": "26f7553354482a1d91d4784627854b8da6b8042a7966523c2b404e8dbbe7f7f2", + "key": "707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f", + "token": "v3.local.JvdVM1RIKh2R1HhGJ4VLjaa4BCp5ZlI8K0BOjbvn9_LwY78vQnDait-Q-sjhF88dG2B0X-4P3EcxGHn8wzPbTrqObHhyoKpjy3cwZQzLdiwRsdEK5SDvl02_HjWKJW2oqGMOQJlk1nli0_wijTH_vCuRwckEDc82QWK8-lG2fT9wQF271sgbVRVPjm0LwMQZkvvamqU.YXJiaXRyYXJ5LXN0cmluZy10aGF0LWlzbid0LWpzb24", + "payload": "{\"data\":\"this is a hidden message\",\"exp\":\"2022-01-01T00:00:00+00:00\"}", + "footer": "arbitrary-string-that-isn't-json", + "implicit-assertion": "{\"test-vector\":\"3-E-9\"}" + }, + { + "name": "3-F-3", + "expect-fail": true, + "nonce": "26f7553354482a1d91d4784627854b8da6b8042a7966523c2b404e8dbbe7f7f2", + "key": "707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f", + "token": "v4.local.1JgN1UG8TFAYS49qsx8rxlwh-9E4ONUm3slJXYi5EibmzxpF0Q-du6gakjuyKCBX8TvnSLOKqCPu8Yh3WSa5yJWigPy33z9XZTJF2HQ9wlLDPtVn_Mu1pPxkTU50ZaBKblJBufRA.YXJiaXRyYXJ5LXN0cmluZy10aGF0LWlzbid0LWpzb24", + "payload": null, + "footer": "arbitrary-string-that-isn't-json", + "implicit-assertion": "{\"test-vector\":\"3-F-3\"}" + }, + { + "name": "3-F-4", + "expect-fail": true, + "key": "707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f", + "nonce": "0000000000000000000000000000000000000000000000000000000000000000", + "token": "v3.local.AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADbfcIURX_0pVZVU1mAESUzrKZAsRm2EsD6yBoZYn6cpVZNzSJOhSDN-sRaWjfLU-yn9OJH1J_B8GKtOQ9gSQlb8yk9Iza7teRdkiR89ZFyvPPsVjjFiepFUVcMa-LP18zV77f_crJrVXWa5PDNRkCSeHfBBeh", + "payload": null, + "footer": "", + "implicit-assertion": "" + }, + { + "name": "3-F-5", + "expect-fail": true, + "nonce": "26f7553354482a1d91d4784627854b8da6b8042a7966523c2b404e8dbbe7f7f2", + "key": "707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f", + "token": "v3.local.JvdVM1RIKh2R1HhGJ4VLjaa4BCp5ZlI8K0BOjbvn9_LwY78vQnDait-Q-sjhF88dG2B0ROIIykcrGHn8wzPbTrqObHhyoKpjy3cwZQzLdiwRsdEK5SDvl02_HjWKJW2oqGMOQJlkYSIbXOgVuIQL65UMdW9WcjOpmqvjqD40NNzed-XPqn1T3w-bJvitYpUJL_rmihc=.eyJraWQiOiJVYmtLOFk2aXY0R1poRnA2VHgzSVdMV0xmTlhTRXZKY2RUM3pkUjY1WVp4byJ9", + "payload": null, + "footer": "{\"kid\":\"UbkK8Y6iv4GZhFp6Tx3IWLWLfNXSEvJcdT3zdR65YZxo\"}", + "implicit-assertion": "" + } + ]]=]; + for name, test in pairs(test_cases) do + it("test case "..name, test); + end + + describe("basic sign/verify", function () + local key = paseto.v3_local.new_key(); + local sign, verify = paseto.v3_local.init(key); + + --luacheck: ignore 211/sign2 + local key2 = paseto.v3_local.new_key(); + local sign2, verify2 = paseto.v3_local.init(key2); + + it("works", function () + local payload = { foo = "hello world", b = { 1, 2, 3 } }; + + local tok = sign(payload); + assert.same(payload, verify(tok)); + assert.is_nil(verify2(tok)); + end); + + it("rejects tokens if implicit assertion fails", function () + local payload = { foo = "hello world", b = { 1, 2, 3 } }; + local tok = sign(payload, nil, "my-custom-assertion"); + assert.is_nil(verify(tok, nil, "my-incorrect-assertion")); + assert.is_nil(verify(tok, nil, nil)); + assert.same(payload, verify(tok, nil, "my-custom-assertion")); + end); + end); + end); + + describe("v4.public", function () + local function parse_test_cases(json_test_cases) + local input_cases = json.decode(json_test_cases); + local output_cases = {}; + for _, case in ipairs(input_cases) do + assert.is_string(case.name, "Bad test case: expected name"); + assert.is_nil(output_cases[case.name], "Bad test case: duplicate name"); + output_cases[case.name] = function () + local verify_key = paseto.v4_public.import_public_key(case["public-key-pem"]); + local payload, err = paseto.v4_public.verify(case.token, verify_key, case.footer, case["implicit-assertion"]); + if case["expect-fail"] then + assert.is_nil(payload); + else + assert.is_nil(err); + assert.same(json.decode(case.payload), payload); + end + end; + end + return output_cases; + end + + local test_cases = parse_test_cases [=[[ + { + "name": "4-S-1", + "expect-fail": false, + "public-key": "1eb9dbbbbc047c03fd70604e0071f0987e16b28b757225c11f00415d0e20b1a2", + "secret-key": "b4cbfb43df4ce210727d953e4a713307fa19bb7d9f85041438d9e11b942a37741eb9dbbbbc047c03fd70604e0071f0987e16b28b757225c11f00415d0e20b1a2", + "secret-key-seed": "b4cbfb43df4ce210727d953e4a713307fa19bb7d9f85041438d9e11b942a3774", + "secret-key-pem": "-----BEGIN PRIVATE KEY-----\nMC4CAQAwBQYDK2VwBCIEILTL+0PfTOIQcn2VPkpxMwf6Gbt9n4UEFDjZ4RuUKjd0\n-----END PRIVATE KEY-----", + "public-key-pem": "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAHrnbu7wEfAP9cGBOAHHwmH4Wsot1ciXBHwBBXQ4gsaI=\n-----END PUBLIC KEY-----", + "token": "v4.public.eyJkYXRhIjoidGhpcyBpcyBhIHNpZ25lZCBtZXNzYWdlIiwiZXhwIjoiMjAyMi0wMS0wMVQwMDowMDowMCswMDowMCJ9bg_XBBzds8lTZShVlwwKSgeKpLT3yukTw6JUz3W4h_ExsQV-P0V54zemZDcAxFaSeef1QlXEFtkqxT1ciiQEDA", + "payload": "{\"data\":\"this is a signed message\",\"exp\":\"2022-01-01T00:00:00+00:00\"}", + "footer": "", + "implicit-assertion": "" + }, + { + "name": "4-S-2", + "expect-fail": false, + "public-key": "1eb9dbbbbc047c03fd70604e0071f0987e16b28b757225c11f00415d0e20b1a2", + "secret-key": "b4cbfb43df4ce210727d953e4a713307fa19bb7d9f85041438d9e11b942a37741eb9dbbbbc047c03fd70604e0071f0987e16b28b757225c11f00415d0e20b1a2", + "secret-key-seed": "b4cbfb43df4ce210727d953e4a713307fa19bb7d9f85041438d9e11b942a3774", + "secret-key-pem": "-----BEGIN PRIVATE KEY-----\nMC4CAQAwBQYDK2VwBCIEILTL+0PfTOIQcn2VPkpxMwf6Gbt9n4UEFDjZ4RuUKjd0\n-----END PRIVATE KEY-----", + "public-key-pem": "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAHrnbu7wEfAP9cGBOAHHwmH4Wsot1ciXBHwBBXQ4gsaI=\n-----END PUBLIC KEY-----", + "token": "v4.public.eyJkYXRhIjoidGhpcyBpcyBhIHNpZ25lZCBtZXNzYWdlIiwiZXhwIjoiMjAyMi0wMS0wMVQwMDowMDowMCswMDowMCJ9v3Jt8mx_TdM2ceTGoqwrh4yDFn0XsHvvV_D0DtwQxVrJEBMl0F2caAdgnpKlt4p7xBnx1HcO-SPo8FPp214HDw.eyJraWQiOiJ6VmhNaVBCUDlmUmYyc25FY1Q3Z0ZUaW9lQTlDT2NOeTlEZmdMMVc2MGhhTiJ9", + "payload": "{\"data\":\"this is a signed message\",\"exp\":\"2022-01-01T00:00:00+00:00\"}", + "footer": "{\"kid\":\"zVhMiPBP9fRf2snEcT7gFTioeA9COcNy9DfgL1W60haN\"}", + "implicit-assertion": "" + }, + { + "name": "4-S-3", + "expect-fail": false, + "public-key": "1eb9dbbbbc047c03fd70604e0071f0987e16b28b757225c11f00415d0e20b1a2", + "secret-key": "b4cbfb43df4ce210727d953e4a713307fa19bb7d9f85041438d9e11b942a37741eb9dbbbbc047c03fd70604e0071f0987e16b28b757225c11f00415d0e20b1a2", + "secret-key-seed": "b4cbfb43df4ce210727d953e4a713307fa19bb7d9f85041438d9e11b942a3774", + "secret-key-pem": "-----BEGIN PRIVATE KEY-----\nMC4CAQAwBQYDK2VwBCIEILTL+0PfTOIQcn2VPkpxMwf6Gbt9n4UEFDjZ4RuUKjd0\n-----END PRIVATE KEY-----", + "public-key-pem": "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAHrnbu7wEfAP9cGBOAHHwmH4Wsot1ciXBHwBBXQ4gsaI=\n-----END PUBLIC KEY-----", + "token": "v4.public.eyJkYXRhIjoidGhpcyBpcyBhIHNpZ25lZCBtZXNzYWdlIiwiZXhwIjoiMjAyMi0wMS0wMVQwMDowMDowMCswMDowMCJ9NPWciuD3d0o5eXJXG5pJy-DiVEoyPYWs1YSTwWHNJq6DZD3je5gf-0M4JR9ipdUSJbIovzmBECeaWmaqcaP0DQ.eyJraWQiOiJ6VmhNaVBCUDlmUmYyc25FY1Q3Z0ZUaW9lQTlDT2NOeTlEZmdMMVc2MGhhTiJ9", + "payload": "{\"data\":\"this is a signed message\",\"exp\":\"2022-01-01T00:00:00+00:00\"}", + "footer": "{\"kid\":\"zVhMiPBP9fRf2snEcT7gFTioeA9COcNy9DfgL1W60haN\"}", + "implicit-assertion": "{\"test-vector\":\"4-S-3\"}" + }]]=]; + for name, test in pairs(test_cases) do + it("test case "..name, test); + end + + describe("basic sign/verify", function () + local function new_keypair() + local kp = paseto.v4_public.new_keypair(); + return kp:private_pem(), kp:public_pem(); + end + + local privkey1, pubkey1 = new_keypair(); + local privkey2, pubkey2 = new_keypair(); + local sign1, verify1 = paseto.v4_public.init(privkey1, pubkey1); + local sign2, verify2 = paseto.v4_public.init(privkey2, pubkey2); + + it("works", function () + local payload = { foo = "hello world", b = { 1, 2, 3 } }; + + local tok1 = sign1(payload); + assert.same(payload, verify1(tok1)); + assert.is_nil(verify2(tok1)); + + local tok2 = sign2(payload); + assert.same(payload, verify2(tok2)); + assert.is_nil(verify1(tok2)); + end); + + it("rejects tokens if implicit assertion fails", function () + local payload = { foo = "hello world", b = { 1, 2, 3 } }; + local tok = sign1(payload, nil, "my-custom-assertion"); + assert.is_nil(verify1(tok, nil, "my-incorrect-assertion")); + assert.is_nil(verify1(tok, nil, nil)); + assert.same(payload, verify1(tok, nil, "my-custom-assertion")); + end); + end); + end); + + describe("pae", function () + it("encodes correctly", function () + -- These test cases are taken from the PASETO docs + -- https://github.com/paseto-standard/paseto-spec/blob/master/docs/01-Protocol-Versions/Common.md + assert.equal("\x00\x00\x00\x00\x00\x00\x00\x00", paseto.pae{}); + assert.equal("\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", paseto.pae{''}); + assert.equal("\x01\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00test", paseto.pae{'test'}); + assert.has_errors(function () + paseto.pae("test"); + end); + end); + end); +end); diff --git a/spec/util_poll_spec.lua b/spec/util_poll_spec.lua index a763be90..05318453 100644 --- a/spec/util_poll_spec.lua +++ b/spec/util_poll_spec.lua @@ -1,6 +1,35 @@ -describe("util.poll", function () - it("loads", function () - require "util.poll" +describe("util.poll", function() + local poll; + setup(function() + poll = require "util.poll"; end); + it("loads", function() + assert.is_table(poll); + assert.is_function(poll.new); + assert.is_string(poll.api); + end); + describe("new", function() + local p; + setup(function() + p = poll.new(); + end) + it("times out", function () + local fd, err = p:wait(0); + assert.falsy(fd); + assert.equal("timeout", err); + end); + it("works", function() + -- stdout should be writable, right? + assert.truthy(p:add(1, false, true)); + local fd, r, w = p:wait(1); + assert.is_number(fd); + assert.is_boolean(r); + assert.is_boolean(w); + assert.equal(1, fd); + assert.falsy(r); + assert.truthy(w); + assert.truthy(p:del(1)); + end); + end) end); diff --git a/spec/util_promise_spec.lua b/spec/util_promise_spec.lua index 597b56f8..91b3d2d1 100644 --- a/spec/util_promise_spec.lua +++ b/spec/util_promise_spec.lua @@ -7,6 +7,11 @@ describe("util.promise", function () assert(promise.new()); end); end); + it("supplies a sensible tostring()", function () + local s = tostring(promise.new()); + assert.truthy(s:find("promise", 1, true)); + assert.truthy(s:find("pending", 1, true)); + end); it("notifies immediately for fulfilled promises", function () local p = promise.new(function (resolve) resolve("foo"); @@ -30,6 +35,27 @@ describe("util.promise", function () r("foo"); assert.spy(cb).was_called(1); end); + it("ignores resolve/reject of settled promises", function () + local res, rej; + local p = promise.new(function (resolve, reject) + res, rej = resolve, reject; + end); + local cb = spy.new(function (v) + assert.equal("foo", v); + end); + p:next(cb, cb); + assert.spy(cb).was_called(0); + res("foo"); + assert.spy(cb).was_called(1); + rej("bar"); + assert.spy(cb).was_called(1); + rej(promise.resolve("bar")); + assert.spy(cb).was_called(1); + res(promise.reject("bar")); + assert.spy(cb).was_called(1); + res(promise.resolve("bar")); + assert.spy(cb).was_called(1); + end); it("allows chaining :next() calls", function () local r; local result; @@ -438,6 +464,26 @@ describe("util.promise", function () { status = "rejected", reason = "this fails" }; }, result); end); + it("works when all promises reject", function () + local r1, r2; + local p1, p2 = promise.new(function (_, reject) r1 = reject end), promise.new(function (_, reject) r2 = reject end); + local p = promise.all_settled({ p1, p2 }); + + local result; + local cb = spy.new(function (v) + result = v; + end); + p:next(cb); + assert.spy(cb).was_called(0); + r2("this fails"); + assert.spy(cb).was_called(0); + r1("this fails too"); + assert.spy(cb).was_called(1); + assert.same({ + { status = "rejected", reason = "this fails too" }; + { status = "rejected", reason = "this fails" }; + }, result); + end); it("works with non-numeric keys", function () local r1, r2; local p1, p2 = promise.new(function (resolve) r1 = resolve end), promise.new(function (resolve) r2 = resolve end); diff --git a/spec/util_roles_spec.lua b/spec/util_roles_spec.lua new file mode 100644 index 00000000..44d2a977 --- /dev/null +++ b/spec/util_roles_spec.lua @@ -0,0 +1,134 @@ +describe("util.roles", function () + randomize(false); + local roles; + it("can be loaded", function () + roles = require "util.roles"; + end); + local test_role; + it("can create a new role", function () + test_role = roles.new(); + assert.is_not_nil(test_role); + assert.is_truthy(roles.is_role(test_role)); + end); + describe("role object", function () + it("can be initialized with permissions", function () + local test_role_2 = roles.new({ + permissions = { + perm1 = true; + perm2 = false; + }; + }); + assert.truthy(test_role_2:may("perm1")); + assert.falsy(test_role_2:may("perm2")); + end); + it("has a sensible tostring", function () + local test_role_2 = roles.new({ + id = "test-role-2"; + name = "Test Role 2"; + }); + assert.truthy(tostring(test_role_2):find(test_role_2.id, 1, true)); + assert.truthy(tostring(test_role_2):find("Test Role 2", 1, true)); + end); + it("is restrictive by default", function () + assert.falsy(test_role:may("my-permission")); + end); + it("allows you to set permissions", function () + test_role:set_permission("my-permission", true); + assert.truthy(test_role:may("my-permission")); + end); + it("allows you to set negative permissions", function () + test_role:set_permission("my-other-permission", false); + assert.falsy(test_role:may("my-other-permission")); + end); + it("does not allows you to override previously set permissions by default", function () + local ok, err = test_role:set_permission("my-permission", false); + assert.falsy(ok); + assert.is_equal("policy-already-exists", err); + -- Confirm old permission still in place + assert.truthy(test_role:may("my-permission")); + end); + it("allows you to explicitly override previously set permissions", function () + assert.truthy(test_role:set_permission("my-permission", false, true)); + assert.falsy(test_role:may("my-permission")); + end); + describe("inheritance", function () + local child_role; + it("works", function () + test_role:set_permission("inherited-permission", true); + child_role = roles.new({ + inherits = { test_role }; + }); + assert.truthy(child_role:may("inherited-permission")); + assert.falsy(child_role:may("my-permission")); + end); + it("allows listing policies", function () + local expected = { + ["my-permission"] = false; + ["my-other-permission"] = false; + ["inherited-permission"] = true; + }; + local received = {}; + for permission_name, permission_policy in child_role:policies() do + received[permission_name] = permission_policy; + end + assert.same(expected, received); + end); + it("supports multiple depths of inheritance", function () + local grandchild_role = roles.new({ + inherits = { child_role }; + }); + assert.truthy(grandchild_role:may("inherited-permission")); + end); + describe("supports ordered inheritance from multiple roles", function () + local parent_role = roles.new(); + local final_role = roles.new({ + -- Yes, the names are getting confusing. + -- btw, test_role is inherited through child_role. + inherits = { parent_role, child_role }; + }); + + local test_cases = { + -- { <final_role policy>, <parent_role policy>, <test_role policy> } + { true, nil, false, result = true }; + { nil, false, true, result = false }; + { nil, true, false, result = true }; + { nil, nil, false, result = false }; + { nil, nil, true, result = true }; + }; + + for n, test_case in ipairs(test_cases) do + it("(case "..n..")", function () + local perm_name = ("multi-inheritance-perm-%d"):format(n); + assert.truthy(final_role:set_permission(perm_name, test_case[1])); + assert.truthy(parent_role:set_permission(perm_name, test_case[2])); + assert.truthy(test_role:set_permission(perm_name, test_case[3])); + assert.equal(test_case.result, final_role:may(perm_name)); + end); + end + end); + it("updates child roles when parent roles change", function () + assert.truthy(child_role:may("inherited-permission")); + assert.truthy(test_role:set_permission("inherited-permission", false, true)); + assert.falsy(child_role:may("inherited-permission")); + end); + end); + describe("cloning", function () + local cloned_role; + it("works", function () + assert.truthy(test_role:set_permission("perm-1", true)); + cloned_role = test_role:clone(); + assert.truthy(cloned_role:may("perm-1")); + end); + it("isolates changes", function () + -- After cloning, changes in either the original or the clone + -- should not appear in the other. + assert.truthy(test_role:set_permission("perm-1", false, true)); + assert.truthy(test_role:set_permission("perm-2", true)); + assert.truthy(cloned_role:set_permission("perm-3", true)); + assert.truthy(cloned_role:may("perm-1")); + assert.falsy(cloned_role:may("perm-2")); + assert.falsy(test_role:may("perm-3")); + end); + end); + end); +end); diff --git a/spec/util_smqueue_spec.lua b/spec/util_smqueue_spec.lua index 0a02a60b..858b99c0 100644 --- a/spec/util_smqueue_spec.lua +++ b/spec/util_smqueue_spec.lua @@ -5,6 +5,9 @@ describe("util.smqueue", function() describe("#new()", function() it("should work", function() + assert.has_error(function () smqueue.new(-1) end); + assert.has_error(function () smqueue.new(0) end); + assert.not_has_error(function () smqueue.new(1) end); local q = smqueue.new(10); assert.truthy(q); end) diff --git a/spec/util_stanza_spec.lua b/spec/util_stanza_spec.lua index 8732a111..e11c70dd 100644 --- a/spec/util_stanza_spec.lua +++ b/spec/util_stanza_spec.lua @@ -314,6 +314,20 @@ describe("util.stanza", function() end) end) + describe("#add_error()", function () + describe("basics", function () + local s = st.stanza("custom", { xmlns = "urn:example:foo" }); + local e = s:add_error("cancel", "not-acceptable", "UNACCEPTABLE!!!! ONE MILLION YEARS DUNGEON!") + :tag("dungeon", { xmlns = "urn:uuid:c9026187-5b05-4e70-b265-c3b6338a7d0f", period="1000000years"}); + assert.equal(s, e); + local typ, cond, text, extra = e:get_error(); + assert.equal("cancel", typ); + assert.equal("not-acceptable", cond); + assert.equal("UNACCEPTABLE!!!! ONE MILLION YEARS DUNGEON!", text); + assert.is_nil(extra); + end) + end) + describe("should reject #invalid", function () local invalid_names = { ["empty string"] = "", ["characters"] = "<>"; diff --git a/spec/util_table_spec.lua b/spec/util_table_spec.lua index 76f54b69..a0535c08 100644 --- a/spec/util_table_spec.lua +++ b/spec/util_table_spec.lua @@ -12,6 +12,17 @@ describe("util.table", function () assert.same({ "lorem", "ipsum", "dolor", "sit", "amet", n = 5 }, u_table.pack("lorem", "ipsum", "dolor", "sit", "amet")); end); end); + + describe("move()", function () + it("works", function () + local t1 = { "apple", "banana", "carrot" }; + local t2 = { "cat", "donkey", "elephant" }; + local t3 = {}; + u_table.move(t1, 1, 3, 1, t3); + u_table.move(t2, 1, 3, 3, t3); + assert.same({ "apple", "banana", "cat", "donkey", "elephant" }, t3); + end); + end); end); diff --git a/spec/util_uuid_spec.lua b/spec/util_uuid_spec.lua index 95ae0a20..46400d00 100644 --- a/spec/util_uuid_spec.lua +++ b/spec/util_uuid_spec.lua @@ -5,7 +5,7 @@ local uuid = require "util.uuid"; describe("util.uuid", function() describe("#generate()", function() it("should work follow the UUID pattern", function() - -- https://tools.ietf.org/html/rfc4122#section-4.4 + -- https://www.rfc-editor.org/rfc/rfc4122.html#section-4.4 local pattern = "^" .. table.concat({ string.rep("%x", 8), diff --git a/teal-src/core/storagemanager.d.tl b/teal-src/core/storagemanager.d.tl new file mode 100644 index 00000000..3c8253b1 --- /dev/null +++ b/teal-src/core/storagemanager.d.tl @@ -0,0 +1,74 @@ +-- Storage local record API Description +-- +-- This is written as a TypedLua description + +-- Key-Value stores (the default) + +local stanza = require"util.stanza".stanza_t + +local record keyval_store + get : function ( keyval_store, string ) : any , string + set : function ( keyval_store, string, any ) : boolean, string +end + +-- Map stores (key-key-value stores) + +local record map_store + get : function ( map_store, string, any ) : any, string + set : function ( map_store, string, any, any ) : boolean, string + set_keys : function ( map_store, string, { any : any }) : boolean, string + remove : table +end + +-- Archive stores + +local record archive_query + start : number -- timestamp + ["end"]: number -- timestamp + with : string + after : string -- archive id + before : string -- archive id + total : boolean +end + +local record archive_store + -- Optional set of capabilities + caps : { + -- Optional total count of matching items returned as second return value from :find() + string : any + } + + -- Add to the archive + append : function ( archive_store, string, string, any, number, string ) : string, string + + -- Iterate over archive + type iterator = function () : string, any, number, string + find : function ( archive_store, string, archive_query ) : iterator, integer + + -- Removal of items. API like find. Optional + delete : function ( archive_store, string, archive_query ) : boolean | number, string + + -- Array of dates which do have messages (Optional) + dates : function ( archive_store, string ) : { string }, string + + -- Map of counts per "with" field + summary : function ( archive_store, string, archive_query ) : { string : integer }, string + + -- Map-store API + get : function ( archive_store, string, string ) : stanza, number, string + get : function ( archive_store, string, string ) : nil, string + set : function ( archive_store, string, string, stanza, number, string ) : boolean, string +end + +-- This represents moduleapi +local record coremodule + -- If the first string is omitted then the name of the module is used + -- The second string is one of "keyval" (default), "map" or "archive" + open_store : function (archive_store, string, string) : keyval_store, string + open_store : function (archive_store, string, string) : map_store, string + open_store : function (archive_store, string, string) : archive_store, string + + -- Other module methods omitted +end + +return coremodule diff --git a/teal-src/module.d.tl b/teal-src/module.d.tl index 67b2437c..24eb9558 100644 --- a/teal-src/module.d.tl +++ b/teal-src/module.d.tl @@ -62,7 +62,12 @@ global record moduleapi send_iq : function (moduleapi, st.stanza_t, util_session, number) broadcast : function (moduleapi, { string }, st.stanza_t, function) type timer_callback = function (number, ... : any) : number - add_timer : function (moduleapi, number, timer_callback, ... : any) + record timer_wrapper + stop : function (timer_wrapper) + disarm : function (timer_wrapper) + reschedule : function (timer_wrapper, number) + end + add_timer : function (moduleapi, number, timer_callback, ... : any) : timer_wrapper get_directory : function (moduleapi) : string enum file_mode "r" "w" "a" "r+" "w+" "a+" @@ -121,6 +126,11 @@ global record moduleapi path : string resource_path : string + -- access control + may : function (moduleapi, string, table|string) + default_permission : function (string, string) + default_permissions : function (string, { string }) + -- methods the module can add load : function () add_host : function (moduleapi) diff --git a/teal-src/net/http.d.tl b/teal-src/net/http.d.tl new file mode 100644 index 00000000..9135ec12 --- /dev/null +++ b/teal-src/net/http.d.tl @@ -0,0 +1,86 @@ +local Promise = require "util.promise".Promise; + +local record sslctx -- from LuaSec +end + +local record lib + + enum http_method + "GET" + "HEAD" + "POST" + "PUT" + "OPTIONS" + "DELETE" + -- etc? + end + + record http_client_options + sslctx : sslctx + end + + record http_options + id : string + onlystatus : boolean + body : string + method : http_method + headers : { string : string } + insecure : boolean + suppress_errors : boolean + streaming_handler : function + suppress_url : boolean + sslctx : sslctx + end + + record http_request + host : string + port : string + enum scheme + "http" + "https" + end + scheme : scheme + url : string + userinfo : string + path : string + + method : http_method + headers : { string : string } + + insecure : boolean + suppress_errors : boolean + streaming_handler : function + http : http_client + time : integer + id : string + callback : http_callback + end + + record http_response + end + + type http_callback = function (string, number, http_response, http_request) + + record http_client + options : http_client_options + request : function (http_client, string, http_options, http_callback) + end + + request : function (string, http_options, http_callback) : Promise, string + default : http_client + new : function (http_client_options) : http_client + events : table + -- COMPAT + urlencode : function (string) : string + urldecode : function (string) : string + formencode : function ({ string : string }) : string + formdecode : function (string) : { string : string } + destroy_request : function (http_request) + + enum available_features + "sni" + end + features : { available_features : boolean } +end + +return lib diff --git a/teal-src/net/http/codes.d.tl b/teal-src/net/http/codes.d.tl new file mode 100644 index 00000000..65d004fc --- /dev/null +++ b/teal-src/net/http/codes.d.tl @@ -0,0 +1,2 @@ +local type response_codes = { integer : string } +return response_codes diff --git a/teal-src/net/http/errors.d.tl b/teal-src/net/http/errors.d.tl new file mode 100644 index 00000000..a9b6ea6c --- /dev/null +++ b/teal-src/net/http/errors.d.tl @@ -0,0 +1,22 @@ +local record http_errors + enum known_conditions + "cancelled" + "connection-closed" + "certificate-chain-invalid" + "certificate-verify-failed" + "connection failed" + "invalid-url" + "unable to resolve service" + end + type registry_keys = known_conditions | integer + record error + type : string + condition : string + code : integer + text : string + end + registry : { registry_keys : error } + new : function (integer, known_conditions, table) + new : function (integer, string, table) +end +return http_errors diff --git a/teal-src/net/http/files.d.tl b/teal-src/net/http/files.d.tl new file mode 100644 index 00000000..d0ba5c1c --- /dev/null +++ b/teal-src/net/http/files.d.tl @@ -0,0 +1,14 @@ +local record serve_options + path : string + mime_map : { string : string } + cache_size : integer + cache_max_file_size : integer + index_files : { string } + directory_index : boolean +end + +local record http_files + serve : function(serve_options|string) : function +end + +return http_files diff --git a/teal-src/net/http/parser.d.tl b/teal-src/net/http/parser.d.tl new file mode 100644 index 00000000..1cd6ccf4 --- /dev/null +++ b/teal-src/net/http/parser.d.tl @@ -0,0 +1,58 @@ +local record httpstream + feed : function(httpstream, string) +end + +local type sink_cb = function () + +local record httppacket + enum http_method + "HEAD" + "GET" + "POST" + "PUT" + "DELETE" + "OPTIONS" + -- etc + end + method : http_method + record url_details + path : string + query : string + end + url : url_details + path : string + enum http_version + "1.0" + "1.1" + end + httpversion : http_version + headers : { string : string } + body : string | boolean + body_sink : sink_cb + chunked : boolean + partial : boolean +end + +local enum error_conditions + "cancelled" + "connection-closed" + "certificate-chain-invalid" + "certificate-verify-failed" + "connection failed" + "invalid-url" + "unable to resolve service" +end + +local type success_cb = function (httppacket) +local type error_cb = function (error_conditions) + +local enum stream_mode + "client" + "server" +end + +local record lib + new : function (success_cb, error_cb, stream_mode) : httpstream +end + +return lib diff --git a/teal-src/net/http/server.d.tl b/teal-src/net/http/server.d.tl new file mode 100644 index 00000000..5a83af7e --- /dev/null +++ b/teal-src/net/http/server.d.tl @@ -0,0 +1,6 @@ + +local record http_server + -- TODO +end + +return http_server diff --git a/teal-src/net/server.d.tl b/teal-src/net/server.d.tl new file mode 100644 index 00000000..bb61f677 --- /dev/null +++ b/teal-src/net/server.d.tl @@ -0,0 +1,65 @@ +local record server + record LuaSocketTCP + end + record LuaSecCTX + end + + record extra_settings + end + + record interface + end + enum socket_type + "tcp" + "tcp6" + "tcp4" + end + + record listeners + onconnect : function (interface) + ondetach : function (interface) + onattach : function (interface, string) + onincoming : function (interface, string, string) + ondrain : function (interface) + onreadtimeout : function (interface) + onstarttls : function (interface) + onstatus : function (interface, string) + ondisconnect : function (interface, string) + end + + get_backend : function () : string + + type port = string | integer + enum read_mode + "*a" + "*l" + end + type read_size = read_mode | integer + addserver : function (string, port, listeners, read_size, LuaSecCTX) : interface + addclient : function (string, port, listeners, read_size, LuaSecCTX, socket_type, extra_settings) : interface + record listen_config + read_size : read_size + tls_ctx : LuaSecCTX + tls_direct : boolean + sni_hosts : { string : LuaSecCTX } + end + listen : function (string, port, listeners, listen_config) : interface + enum quitting + "quitting" + end + loop : function () : quitting + closeall : function () + setquitting : function (boolean | quitting) + + wrapclient : function (LuaSocketTCP, string, port, listeners, read_size, LuaSecCTX, extra_settings) : interface + wrapserver : function (LuaSocketTCP, string, port, listeners, listen_config) : interface + watchfd : function (integer | LuaSocketTCP, function (interface), function (interface)) : interface + link : function () + + record config + end + set_config : function (config) + +end + +return server diff --git a/teal-src/plugins/mod_cron.tl b/teal-src/plugins/mod_cron.tl index f3b8f62f..7fa2a36b 100644 --- a/teal-src/plugins/mod_cron.tl +++ b/teal-src/plugins/mod_cron.tl @@ -88,8 +88,8 @@ local function run_task(task : task_spec) task:save(started_at); end -local task_runner = async.runner(run_task); -module:add_timer(1, function() : integer +local task_runner : async.runner_t<task_spec> = async.runner(run_task); +scheduled = module:add_timer(1, function() : integer module:log("info", "Running periodic tasks"); local delay = 3600; for host in pairs(active_hosts) do diff --git a/teal-src/util/array.d.tl b/teal-src/util/array.d.tl new file mode 100644 index 00000000..70bf2624 --- /dev/null +++ b/teal-src/util/array.d.tl @@ -0,0 +1,9 @@ +local record array_t<T> + { T } +end + +local record lib + metamethod __call : function () : array_t +end + +return lib diff --git a/teal-src/util/async.d.tl b/teal-src/util/async.d.tl new file mode 100644 index 00000000..a2e41cd6 --- /dev/null +++ b/teal-src/util/async.d.tl @@ -0,0 +1,42 @@ +local record lib + ready : function () : boolean + waiter : function (num : integer, allow_many : boolean) : function (), function () + guarder : function () : function (id : function ()) : function () | nil + record runner_t<T> + func : function (T) + thread : thread + enum state_e + -- from Lua manual + "running" + "suspended" + "normal" + "dead" + + -- from util.async + "ready" + "error" + end + state : state_e + notified_state : state_e + queue : { T } + type watcher_t = function (runner_t<T>, ... : any) + type watchers_t = { state_e : watcher_t } + data : any + id : string + + run : function (runner_t<T>, T) : boolean, state_e, integer + enqueue : function (runner_t<T>, T) : runner_t<T> + log : function (runner_t<T>, string, string, ... : any) + onready : function (runner_t<T>, function) : runner_t<T> + onready : function (runner_t<T>, function) : runner_t<T> + onwaiting : function (runner_t<T>, function) : runner_t<T> + onerror : function (runner_t<T>, function) : runner_t<T> + end + runner : function <T>(function (T), runner_t.watchers_t, any) : runner_t<T> + wait_for : function (any) : any, any + sleep : function (t:number) + + -- set_nexttick = function(new_next_tick) next_tick = new_next_tick; end; + -- set_schedule_function = function (new_schedule_function) schedule_task = new_schedule_function; end; +end +return lib diff --git a/teal-src/util/bitcompat.d.tl b/teal-src/util/bitcompat.d.tl new file mode 100644 index 00000000..18adf725 --- /dev/null +++ b/teal-src/util/bitcompat.d.tl @@ -0,0 +1,8 @@ +local record lib + band : function (integer, integer, ... : integer) : integer + bor : function (integer, integer, ... : integer) : integer + bxor : function (integer, integer, ... : integer) : integer + lshift : function (integer, integer) : integer + rshift : function (integer, integer) : integer +end +return lib diff --git a/teal-src/util/crypto.d.tl b/teal-src/util/crypto.d.tl new file mode 100644 index 00000000..cf0b0d1b --- /dev/null +++ b/teal-src/util/crypto.d.tl @@ -0,0 +1,29 @@ +local record lib + record key + private_pem : function (key) : string + public_pem : function (key) : string + get_type : function (key) : string + end + + generate_ed25519_keypair : function () : key + ed25519_sign : function (key, string) : string + ed25519_verify : function (key, string, string) : boolean + + ecdsa_sha256_sign : function (key, string) : string + ecdsa_sha256_verify : function (key, string, string) : boolean + parse_ecdsa_signature : function (string) : string, string + build_ecdsa_signature : function (string, string) : string + + import_private_pem : function (string) : key + import_public_pem : function (string) : key + + aes_128_gcm_encrypt : function (key, string, string) : string + aes_128_gcm_decrypt : function (key, string, string) : string + aes_256_gcm_encrypt : function (key, string, string) : string + aes_256_gcm_decrypt : function (key, string, string) : string + + + version : string + _LIBCRYPTO_VERSION : string +end +return lib diff --git a/teal-src/util/dataforms.d.tl b/teal-src/util/dataforms.d.tl index 9e4170fa..0eddf98e 100644 --- a/teal-src/util/dataforms.d.tl +++ b/teal-src/util/dataforms.d.tl @@ -1,51 +1,53 @@ local stanza_t = require "util.stanza".stanza_t -local enum form_type - "form" - "submit" - "cancel" - "result" -end - -local enum field_type - "boolean" - "fixed" - "hidden" - "jid-multi" - "jid-single" - "list-multi" - "list-single" - "text-multi" - "text-private" - "text-single" -end - -local record form_field - - type : field_type - var : string -- protocol name - name : string -- internal name - - label : string - desc : string - - datatype : string - range_min : number - range_max : number - - value : any -- depends on field_type - options : table -end - -local record dataform - title : string - instructions : string - { form_field } -- XXX https://github.com/teal-language/tl/pull/415 - - form : function ( dataform, table, form_type ) : stanza_t -end - local record lib + record dataform + title : string + instructions : string + + record form_field + + enum field_type + "boolean" + "fixed" + "hidden" + "jid-multi" + "jid-single" + "list-multi" + "list-single" + "text-multi" + "text-private" + "text-single" + end + + type : field_type + var : string -- protocol name + name : string -- internal name + + label : string + desc : string + + datatype : string + range_min : number + range_max : number + + value : any -- depends on field_type + options : table + end + + { form_field } + + enum form_type + "form" + "submit" + "cancel" + "result" + end + + form : function ( dataform, { string : any }, form_type ) : stanza_t + data : function ( dataform, stanza_t ) : { string : any } + end + new : function ( dataform ) : dataform end diff --git a/teal-src/util/datamapper.tl b/teal-src/util/datamapper.tl index 73b1dfc0..4ff3a02c 100644 --- a/teal-src/util/datamapper.tl +++ b/teal-src/util/datamapper.tl @@ -19,6 +19,8 @@ -- TODO s/number/integer/ once we have appropriate math.type() compat -- +if not math.type then require "util.mathcompat" end + local st = require "util.stanza"; local json = require"util.json" local pointer = require"util.jsonpointer"; @@ -133,10 +135,6 @@ local function unpack_propschema( propschema : schema_t, propname : string, curr end end - if current_ns == "urn:xmpp:reactions:0" and name == "reactions" then - assert(proptype=="array") - end - return proptype, value_where, name, namespace, prefix, single_attribute, enums end diff --git a/teal-src/util/datetime.d.tl b/teal-src/util/datetime.d.tl index 971e8f9c..9f770a73 100644 --- a/teal-src/util/datetime.d.tl +++ b/teal-src/util/datetime.d.tl @@ -1,11 +1,9 @@ --- TODO s/number/integer/ once Teal gets support for that - local record lib - date : function (t : integer) : string - datetime : function (t : integer) : string - time : function (t : integer) : string - legacy : function (t : integer) : string - parse : function (t : string) : integer + date : function (t : number) : string + datetime : function (t : number) : string + time : function (t : number) : string + legacy : function (t : number) : string + parse : function (t : string) : number end return lib diff --git a/teal-src/util/error.d.tl b/teal-src/util/error.d.tl index 05f52405..4c3a7196 100644 --- a/teal-src/util/error.d.tl +++ b/teal-src/util/error.d.tl @@ -38,7 +38,7 @@ local record protoerror code : integer end -local record error +local record Error type : error_type condition : error_condition text : string @@ -55,10 +55,10 @@ local type context = { string : any } local record error_registry_wrapper source : string registry : registry - new : function (string, context) : error - coerce : function (any, string) : any, error - wrap : function (error) : error - wrap : function (string, context) : error + new : function (string, context) : Error + coerce : function (any, string) : any, Error + wrap : function (Error) : Error + wrap : function (string, context) : Error is_error : function (any) : boolean end @@ -66,12 +66,12 @@ local record lib record configure_opt auto_inject_traceback : boolean end - new : function (protoerror, context, { string : protoerror }, string) : error + new : function (protoerror, context, { string : protoerror }, string) : Error init : function (string, string, registry | compact_registry) : error_registry_wrapper init : function (string, registry | compact_registry) : error_registry_wrapper is_error : function (any) : boolean - coerce : function (any, string) : any, error - from_stanza : function (table, context, string) : error + coerce : function (any, string) : any, Error + from_stanza : function (table, context, string) : Error configure : function end diff --git a/teal-src/util/hashes.d.tl b/teal-src/util/hashes.d.tl index cbb06f8e..5c249627 100644 --- a/teal-src/util/hashes.d.tl +++ b/teal-src/util/hashes.d.tl @@ -9,10 +9,18 @@ local record lib sha384 : hash sha512 : hash md5 : hash + sha3_256 : hash + sha3_512 : hash + blake2s256 : hash + blake2b512 : hash hmac_sha1 : hmac hmac_sha256 : hmac + hmac_sha224 : hmac + hmac_sha384 :hmac hmac_sha512 : hmac hmac_md5 : hmac + hmac_sha3_256 : hmac + hmac_sha3_512 : hmac scram_Hi_sha1 : kdf pbkdf2_hmac_sha1 : kdf pbkdf2_hmac_sha256 : kdf diff --git a/teal-src/util/hex.d.tl b/teal-src/util/hex.d.tl index 3b216a88..9d84540b 100644 --- a/teal-src/util/hex.d.tl +++ b/teal-src/util/hex.d.tl @@ -2,5 +2,7 @@ local type s2s = function (s : string) : string local record lib to : s2s from : s2s + encode : s2s + decode : s2s end return lib diff --git a/teal-src/util/human/io.d.tl b/teal-src/util/human/io.d.tl new file mode 100644 index 00000000..e4f64cd1 --- /dev/null +++ b/teal-src/util/human/io.d.tl @@ -0,0 +1,28 @@ +local record lib + getchar : function (n : integer) : string + getline : function () : string + getpass : function () : string + show_yesno : function (prompt : string) : boolean + read_password : function () : string + show_prompt : function (prompt : string) : boolean + printf : function (fmt : string, ... : any) + padleft : function (s : string, width : integer) : string + padright : function (s : string, width : integer) : string + + -- {K:V} vs T ? + record tablerow<K,V> + width : integer | string -- generate an 1..100 % enum? + title : string + mapper : function (V, {K:V}) : string + key : K + enum alignments + "left" + "right" + end + align : alignments + end + type getrow = function<K,V> ({ K : V }) : string + table : function<K,V> ({ tablerow<K,V> }, width : integer) : getrow<K,V> +end + +return lib diff --git a/teal-src/util/human/units.d.tl b/teal-src/util/human/units.d.tl index f6568d90..3db17c3a 100644 --- a/teal-src/util/human/units.d.tl +++ b/teal-src/util/human/units.d.tl @@ -1,5 +1,8 @@ local lib = record + enum logbase + "b" -- 1024 + end adjust : function (number, string) : number, string - format : function (number, string, string) : string + format : function (number, string, logbase) : string end return lib diff --git a/teal-src/util/jsonschema.tl b/teal-src/util/jsonschema.tl index 160c164c..14b04370 100644 --- a/teal-src/util/jsonschema.tl +++ b/teal-src/util/jsonschema.tl @@ -8,6 +8,8 @@ -- https://json-schema.org/draft/2020-12/json-schema-validation.html -- +if not math.type then require "util.mathcompat" end + local json = require"util.json" local null = json.null; diff --git a/teal-src/util/logger.d.tl b/teal-src/util/logger.d.tl new file mode 100644 index 00000000..db29adfd --- /dev/null +++ b/teal-src/util/logger.d.tl @@ -0,0 +1,18 @@ +local record util + enum loglevel + "debug" + "info" + "warn" + "error" + end + type logger = function ( loglevel, string, ...:any ) + type sink = function ( string, loglevel, string, ...:any ) + type simple_sink = function ( string, loglevel, string ) + init : function ( string ) : logger + make_logger : function ( string, loglevel ) : function ( string, ...:any ) + reset : function () + add_level_sink : function ( loglevel, sink ) + add_simple_sink : function ( simple_sink, { loglevel } ) +end + +return util diff --git a/teal-src/util/mathcompat.tl b/teal-src/util/mathcompat.tl new file mode 100644 index 00000000..1e3f9bab --- /dev/null +++ b/teal-src/util/mathcompat.tl @@ -0,0 +1,15 @@ +if not math.type then + local enum number_subtype + "float" "integer" + end + local function math_type(t:any) : number_subtype + if t is number then + if t % 1 == 0 and t ~= t+1 and t ~= t-1 then + return "integer" + else + return "float" + end + end + end + _G.math.type = math_type +end diff --git a/teal-src/util/promise.d.tl b/teal-src/util/promise.d.tl new file mode 100644 index 00000000..d895a828 --- /dev/null +++ b/teal-src/util/promise.d.tl @@ -0,0 +1,22 @@ + +local record lib + type resolve_func = function (any) + type promise_body = function (resolve_func, resolve_func) + + record Promise<A, B> + type on_resolved = function (A) : any + type on_rejected = function (B) : any + next : function (Promise, on_resolved, on_rejected) : Promise<any, any> + end + + new : function (promise_body) : Promise + resolve : function (any) : Promise + reject : function (any) : Promise + all : function ({ Promise }) : Promise + all_settled : function ({ Promise }) : Promise + race : function ({ Promise }) : Promise + try : function + is_promise : function(any) : boolean +end + +return lib diff --git a/teal-src/util/queue.d.tl b/teal-src/util/queue.d.tl new file mode 100644 index 00000000..cb8458e7 --- /dev/null +++ b/teal-src/util/queue.d.tl @@ -0,0 +1,21 @@ +local record lib + record queue<T> + size : integer + count : function (queue<T>) : integer + enum push_errors + "queue full" + end + + push : function (queue<T>, T) : boolean, push_errors + pop : function (queue<T>) : T + peek : function (queue<T>) : T + replace : function (queue<T>, T) : boolean, push_errors + type iterator = function (T, integer) : integer, T + items : function (queue<T>) : iterator, T, integer + type consume_iter = function (queue<T>) : T + consume : function (queue<T>) : consume_iter + end + + new : function<T> (size:integer, allow_wrapping:boolean) : queue<T> +end +return lib; diff --git a/teal-src/util/roles.d.tl b/teal-src/util/roles.d.tl new file mode 100644 index 00000000..fef4f88a --- /dev/null +++ b/teal-src/util/roles.d.tl @@ -0,0 +1,32 @@ +local record util_roles + + type context = any + + record Role + id : string + name : string + description : string + default : boolean + priority : number -- or integer? + permissions : { string : boolean } + + may : function (Role, string, context) + clone : function (Role, role_config) + set_permission : function (Role, string, boolean, boolean) + end + + is_role : function (any) : boolean + + record role_config + name : string + description : string + default : boolean + priority : number -- or integer? + inherits : { Role } + permissions : { string : boolean } + end + + new : function (role_config, Role) : Role +end + +return util_roles diff --git a/teal-src/util/serialization.d.tl b/teal-src/util/serialization.d.tl new file mode 100644 index 00000000..27f100c0 --- /dev/null +++ b/teal-src/util/serialization.d.tl @@ -0,0 +1,33 @@ +local record _M + enum preset + "debug" + "oneline" + "compact" + end + type fallback = function (any, string) : string + record config + preset : preset + fallback : fallback + fatal : boolean + keywords : { string : boolean } + indentwith : string + itemstart : string + itemsep : string + itemlast : string + tstart : string + tend : string + kstart : string + kend : string + equals : string + unquoted : boolean | string + hex : string + freeze : boolean + maxdepth : integer + multirefs : boolean + table_pairs : function + end + type serializer = function (any) : string + new : function (config|preset) : serializer + serialize : function (any, config|preset) : string +end +return _M diff --git a/teal-src/util/set.d.tl b/teal-src/util/set.d.tl new file mode 100644 index 00000000..1631eec4 --- /dev/null +++ b/teal-src/util/set.d.tl @@ -0,0 +1,21 @@ +local record lib + record Set<T> + add : function<T> (Set<T>, T) + contains : function<T> (Set<T>, T) : boolean + contains_set : function<T> (Set<T>, Set<T>) : boolean + items : function<T> (Set<T>) : function<T> (Set<T>, T) : T + add_list : function<T> (Set<T>, { T }) + include : function<T> (Set<T>, Set<T>) + exclude : function<T> (Set<T>, Set<T>) + empty : function<T> (Set<T>) : boolean + end + + new : function<T> ({ T }) : Set<T> + is_set : function (any) : boolean + union : function<T> (Set<T>, Set<T>) : Set <T> + difference : function<T> (Set<T>, Set<T>) : Set <T> + intersection : function<T> (Set<T>, Set<T>) : Set <T> + xor : function<T> (Set<T>, Set<T>) : Set <T> +end + +return lib diff --git a/teal-src/util/signal.d.tl b/teal-src/util/signal.d.tl index 8610aa7f..290cf08f 100644 --- a/teal-src/util/signal.d.tl +++ b/teal-src/util/signal.d.tl @@ -1,5 +1,5 @@ local record lib - enum signal + enum Signal "SIGABRT" "SIGALRM" "SIGBUS" @@ -33,9 +33,9 @@ local record lib "SIGXCPU" "SIGXFSZ" end - signal : function (integer | signal, function, boolean) : boolean - raise : function (integer | signal) - kill : function (integer, integer | signal) + signal : function (integer | Signal, function, boolean) : boolean + raise : function (integer | Signal) + kill : function (integer, integer | Signal) -- enum : integer end return lib diff --git a/teal-src/util/stanza.d.tl b/teal-src/util/stanza.d.tl index a358248a..e1ab2105 100644 --- a/teal-src/util/stanza.d.tl +++ b/teal-src/util/stanza.d.tl @@ -4,6 +4,39 @@ local record lib type childtags_iter = function () : stanza_t type maptags_cb = function ( stanza_t ) : stanza_t + + enum stanza_error_type + "auth" + "cancel" + "continue" + "modify" + "wait" + end + enum stanza_error_condition + "bad-request" + "conflict" + "feature-not-implemented" + "forbidden" + "gone" + "internal-server-error" + "item-not-found" + "jid-malformed" + "not-acceptable" + "not-allowed" + "not-authorized" + "policy-violation" + "recipient-unavailable" + "redirect" + "registration-required" + "remote-server-not-found" + "remote-server-timeout" + "resource-constraint" + "service-unavailable" + "subscription-required" + "undefined-condition" + "unexpected-request" + end + record stanza_t name : string attr : { string : string } @@ -16,6 +49,7 @@ local record lib tag : function ( stanza_t, string, { string : string } ) : stanza_t text : function ( stanza_t, string ) : stanza_t up : function ( stanza_t ) : stanza_t + at_top : function ( stanza_t ) : boolean reset : function ( stanza_t ) : stanza_t add_direct_child : function ( stanza_t, stanza_t ) add_child : function ( stanza_t, stanza_t ) @@ -24,6 +58,8 @@ local record lib get_child : function ( stanza_t, string, string ) : stanza_t get_text : function ( stanza_t ) : string get_child_text : function ( stanza_t, string, string ) : string + get_child_attr : function ( stanza_t, string, string ) : string + get_child_with_attr : function ( stanza_t, string, string, string, function (string) : boolean ) : string child_with_name : function ( stanza_t, string, string ) : stanza_t child_with_ns : function ( stanza_t, string, string ) : stanza_t children : function ( stanza_t ) : children_iter, stanza_t, integer @@ -35,7 +71,9 @@ local record lib pretty_print : function ( stanza_t ) : string pretty_top_tag : function ( stanza_t ) : string - get_error : function ( stanza_t ) : string, string, string, stanza_t + -- FIXME Represent util.error support + get_error : function ( stanza_t ) : stanza_error_type, stanza_error_condition, string, stanza_t + add_error : function ( stanza_t, stanza_error_type, stanza_error_condition, string, string ) indent : function ( stanza_t, integer, string ) : stanza_t end @@ -45,16 +83,61 @@ local record lib { serialized_stanza_t | string } end + record message_attr + ["xml:lang"] : string + from : string + id : string + to : string + type : message_type + enum message_type + "chat" + "error" + "groupchat" + "headline" + "normal" + end + end + + record presence_attr + ["xml:lang"] : string + from : string + id : string + to : string + type : presence_type + enum presence_type + "error" + "probe" + "subscribe" + "subscribed" + "unsubscribe" + "unsubscribed" + end + end + + record iq_attr + ["xml:lang"] : string + from : string + id : string + to : string + type : iq_type + enum iq_type + "error" + "get" + "result" + "set" + end + end + stanza : function ( string, { string : string } ) : stanza_t is_stanza : function ( any ) : boolean preserialize : function ( stanza_t ) : serialized_stanza_t deserialize : function ( serialized_stanza_t ) : stanza_t clone : function ( stanza_t, boolean ) : stanza_t - message : function ( { string : string }, string ) : stanza_t - iq : function ( { string : string } ) : stanza_t + message : function ( message_attr, string ) : stanza_t + iq : function ( iq_attr ) : stanza_t reply : function ( stanza_t ) : stanza_t - error_reply : function ( stanza_t, string, string, string, string ) - presence : function ( { string : string } ) : stanza_t + error_reply : function ( stanza_t, stanza_error_type, stanza_error_condition, string, string ) : stanza_t + presence : function ( presence_attr ) : stanza_t xml_escape : function ( string ) : string pretty_print : function ( string ) : string end diff --git a/teal-src/util/struct.d.tl b/teal-src/util/struct.d.tl new file mode 100644 index 00000000..201aaa23 --- /dev/null +++ b/teal-src/util/struct.d.tl @@ -0,0 +1,6 @@ +local record lib + pack : function (string, ...:any) : string + unpack : function(string, string, integer) : any... + size : function(string) : integer +end +return lib diff --git a/teal-src/util/table.d.tl b/teal-src/util/table.d.tl index 0ff5ed95..67e5d0f0 100644 --- a/teal-src/util/table.d.tl +++ b/teal-src/util/table.d.tl @@ -1,6 +1,7 @@ local record lib create : function (narr:integer, nrec:integer):table pack : function (...:any):{any} + move : function (table, integer, integer, integer, table) : table end return lib diff --git a/teal-src/util/termcolours.d.tl b/teal-src/util/termcolours.d.tl new file mode 100644 index 00000000..226259aa --- /dev/null +++ b/teal-src/util/termcolours.d.tl @@ -0,0 +1,7 @@ +local record lib + getstring : function (string, string) : string + getstyle : function (...:string) : string + setstyle : function (string) : string + tohtml : function (string) : string +end +return lib diff --git a/teal-src/util/timer.d.tl b/teal-src/util/timer.d.tl new file mode 100644 index 00000000..a6394cf3 --- /dev/null +++ b/teal-src/util/timer.d.tl @@ -0,0 +1,8 @@ +local record util_timer + record task end + type timer_callback = function (number) : number + add_task : function ( number, timer_callback, any ) : task + stop : function ( task ) + reschedule : function ( task, number ) : task +end +return util_timer diff --git a/teal-src/util/uuid.d.tl b/teal-src/util/uuid.d.tl index 45fd4312..284a4e4c 100644 --- a/teal-src/util/uuid.d.tl +++ b/teal-src/util/uuid.d.tl @@ -1,5 +1,5 @@ local record lib - get_nibbles : (number) : string + get_nibbles : function (number) : string generate : function () : string seed : function (string) diff --git a/tools/dnsregistry.lua b/tools/dnsregistry.lua index 3fd26628..41fb751b 100644 --- a/tools/dnsregistry.lua +++ b/tools/dnsregistry.lua @@ -22,7 +22,7 @@ for registry in registries:childtags("registry") do local record_desc = record:get_child_text("description"); local record_code = tonumber(record:get_child_text("value")); - if tostring(record):lower():match("reserved") or tostring(record):lower():match("reserved") then + if tostring(record):lower():match("reserved") or tostring(record):lower():match("unassigned") then record_code = nil; end diff --git a/tools/modtrace.lua b/tools/modtrace.lua index 45fa9f6a..f1927077 100644 --- a/tools/modtrace.lua +++ b/tools/modtrace.lua @@ -8,9 +8,9 @@ -- local dbuffer = require "tools.modtrace".trace("util.dbuffer"); -- -local t_pack = require "util.table".pack; +local t_pack = table.pack; local serialize = require "util.serialization".serialize; -local unpack = table.unpack or unpack; --luacheck: ignore 113 +local unpack = table.unpack; local set = require "util.set"; local serialize_cfg = { diff --git a/tools/test_mutants.sh.lua b/tools/test_mutants.sh.lua new file mode 100755 index 00000000..a0a55a8e --- /dev/null +++ b/tools/test_mutants.sh.lua @@ -0,0 +1,217 @@ +#!/bin/bash + +POLYGLOT=1--[===[ + +set -o pipefail + +if [[ "$#" == "0" ]]; then + echo "Lua mutation testing tool" + echo + echo "Usage:" + echo " $BASH_SOURCE MODULE_NAME SPEC_FILE" + echo + echo "Requires 'lua', 'ltokenp' and 'busted' in PATH" + exit 1; +fi + +MOD_NAME="$1" +MOD_FILE="$(lua "$BASH_SOURCE" resolve "$MOD_NAME")" + +if [[ "$MOD_FILE" == "" || ! -f "$MOD_FILE" ]]; then + echo "EE: Failed to locate module '$MOD_NAME' ($MOD_FILE)"; + exit 1; +fi + +SPEC_FILE="$2" + +if [[ "$SPEC_FILE" == "" ]]; then + SPEC_FILE="spec/${MOD_NAME/./_}_spec.lua" +fi + +if [[ "$SPEC_FILE" == "" || ! -f "$SPEC_FILE" ]]; then + echo "EE: Failed to find test spec file ($SPEC_FILE)" + exit 1; +fi + +if ! busted "$SPEC_FILE"; then + echo "EE: Tests fail on original source. Fix it"\!; + exit 1; +fi + +export MUTANT_N=0 +LIVING_MUTANTS=0 + +FILE_PREFIX="${MOD_FILE%.*}.mutant-" +FILE_SUFFIX=".${MOD_FILE##*.}" + +gen_mutant () { + echo "Generating mutant $2 to $3..." + ltokenp -s "$BASH_SOURCE" "$1" > "$3" + return "$?" +} + +# $1 = MOD_NAME, $2 = MUTANT_N, $3 = SPEC_FILE +test_mutant () { + ( + ulimit -m 131072 # 128MB + ulimit -t 16 # 16s + ulimit -f 32768 # 128MB (?) + exec busted --helper="$BASH_SOURCE" -Xhelper mutate="$1":"$2" "$3" + ) >/dev/null + return "$?"; +} + +MUTANT_FILE="${FILE_PREFIX}${MUTANT_N}${FILE_SUFFIX}" + +gen_mutant "$MOD_FILE" "$MUTANT_N" "$MUTANT_FILE" +while [[ "$?" == "0" ]]; do + if ! test_mutant "$MOD_NAME" "$MUTANT_N" "$SPEC_FILE"; then + echo "Tests successfully killed mutant $MUTANT_N"; + rm "$MUTANT_FILE"; + else + echo "Mutant $MUTANT_N lives on"\! + LIVING_MUTANTS=$((LIVING_MUTANTS+1)) + fi + MUTANT_N=$((MUTANT_N+1)) + MUTANT_FILE="${FILE_PREFIX}${MUTANT_N}${FILE_SUFFIX}" + gen_mutant "$MOD_FILE" "$MUTANT_N" "$MUTANT_FILE" +done + +if [[ "$?" != "2" ]]; then + echo "Failed: $?" + exit "$?"; +fi + +MUTANT_SCORE="$(lua -e "print(('%0.2f'):format((1-($LIVING_MUTANTS/$MUTANT_N))*100))")" +if test -f mutant-scores.txt; then + echo "$MOD_NAME $MUTANT_SCORE" >> mutant-scores.txt +fi +echo "$MOD_NAME: All $MUTANT_N mutants generated, $LIVING_MUTANTS survived (score: $MUTANT_SCORE%)" +rm "$MUTANT_FILE"; # Last file is always unmodified +exit 0; +]===] + +-- busted helper that runs mutations +if arg then + if arg[1] == "resolve" then + local filename = package.searchpath(assert(arg[2], "no module name given"), package.path); + if filename then + print(filename); + end + os.exit(filename and 0 or 1); + end + local mutants = {}; + + for i = 1, #arg do + local opt = arg[i]; + print("LOAD", i, opt) + local module_name, mutant_n = opt:match("^mutate=([^:]+):(%d+)"); + if module_name then + mutants[module_name] = tonumber(mutant_n); + end + end + + local orig_lua_searcher = package.searchers[2]; + + local function mutant_searcher(module_name) + local mutant_n = mutants[module_name]; + if not mutant_n then + return orig_lua_searcher(module_name); + end + local base_file, err = package.searchpath(module_name, package.path); + if not base_file then + return base_file, err; + end + local mutant_file = base_file:gsub("%.lua$", (".mutant-%d.lua"):format(mutant_n)); + return loadfile(mutant_file), mutant_file; + end + + if next(mutants) then + table.insert(package.searchers, 1, mutant_searcher); + end +end + +-- filter for ltokenp to mutate scripts +do + local last_output = {}; + local function emit(...) + last_output = {...}; + io.write(...) + io.write(" ") + return true; + end + + local did_mutate = false; + local count = -1; + local threshold = tonumber(os.getenv("MUTANT_N")) or 0; + local function should_mutate() + count = count + 1; + return count == threshold; + end + + local function mutate(name, value) + if name == "if" then + -- Bypass conditionals + if should_mutate() then + return emit("if true or"); + elseif should_mutate() then + return emit("if false and"); + end + elseif name == "<integer>" then + -- Introduce off-by-one errors + if should_mutate() then + return emit(("%d"):format(tonumber(value)+1)); + elseif should_mutate() then + return emit(("%d"):format(tonumber(value)-1)); + end + elseif name == "and" then + if should_mutate() then + return emit("or"); + end + elseif name == "or" then + if should_mutate() then + return emit("and"); + end + end + end + + local current_line_n, current_line_input, current_line_output = 0, {}, {}; + function FILTER(line_n,token,name,value) + if current_line_n ~= line_n then -- Finished a line, moving to the next? + if did_mutate and did_mutate.line == current_line_n then + -- The line we finished was mutated. Store the original and modified outputs. + did_mutate.line_original_src = table.concat(current_line_input, " "); + did_mutate.line_modified_src = table.concat(current_line_output, " "); + end + current_line_input = {}; + current_line_output = {}; + end + current_line_n = line_n; + if name == "<file>" then return; end + if name == "<eof>" then + if not did_mutate then + return os.exit(2); + else + emit(("\n-- Mutated line %d (changed '%s' to '%s'):\n"):format(did_mutate.line, did_mutate.original, did_mutate.modified)) + emit( ("-- Original: %s\n"):format(did_mutate.line_original_src)) + emit( ("-- Modified: %s\n"):format(did_mutate.line_modified_src)); + return; + end + end + if name == "<string>" then + value = string.format("%q",value); + end + if mutate(name, value) then + did_mutate = { + original = value; + modified = table.concat(last_output); + line = line_n; + }; + else + emit(value); + end + table.insert(current_line_input, value); + table.insert(current_line_output, table.concat(last_output)); + end +end + diff --git a/util-src/GNUmakefile b/util-src/GNUmakefile index 810f39f7..3f539387 100644 --- a/util-src/GNUmakefile +++ b/util-src/GNUmakefile @@ -8,7 +8,7 @@ TARGET?=../util/ ALL=encodings.so hashes.so net.so pposix.so signal.so table.so \ ringbuffer.so time.so poll.so compat.so strbitop.so \ - struct.so + struct.so crypto.so ifdef RANDOM ALL+=crand.so @@ -28,7 +28,7 @@ clean: encodings.o: CFLAGS+=$(IDNA_FLAGS) encodings.so: LDLIBS+=$(IDNA_LIBS) -hashes.so: LDLIBS+=$(OPENSSL_LIBS) +crypto.so hashes.so: LDLIBS+=$(OPENSSL_LIBS) crand.o: CFLAGS+=-DWITH_$(RANDOM) crand.so: LDLIBS+=$(RANDOM_LIBS) diff --git a/util-src/crand.c b/util-src/crand.c index 160ac1f6..c6f0a3ba 100644 --- a/util-src/crand.c +++ b/util-src/crand.c @@ -45,7 +45,7 @@ #endif /* This wasn't present before glibc 2.25 */ -int getrandom(void *buf, size_t buflen, unsigned int flags) { +static int getrandom(void *buf, size_t buflen, unsigned int flags) { return syscall(SYS_getrandom, buf, buflen, flags); } #else @@ -66,7 +66,7 @@ int getrandom(void *buf, size_t buflen, unsigned int flags) { #define SMALLBUFSIZ 32 #endif -int Lrandom(lua_State *L) { +static int Lrandom(lua_State *L) { char smallbuf[SMALLBUFSIZ]; char *buf = &smallbuf[0]; const lua_Integer l = luaL_checkinteger(L, 1); @@ -124,9 +124,7 @@ int Lrandom(lua_State *L) { } int luaopen_util_crand(lua_State *L) { -#if (LUA_VERSION_NUM > 501) luaL_checkversion(L); -#endif lua_createtable(L, 0, 2); lua_pushcfunction(L, Lrandom); diff --git a/util-src/crypto.c b/util-src/crypto.c new file mode 100644 index 00000000..d2427503 --- /dev/null +++ b/util-src/crypto.c @@ -0,0 +1,618 @@ +/* Prosody IM +-- Copyright (C) 2022 Matthew Wild +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- +*/ + +/* +* crypto.c +* Lua library for cryptographic operations using OpenSSL +*/ + +#include <string.h> +#include <stdlib.h> + +#ifdef _MSC_VER +typedef unsigned __int32 uint32_t; +#else +#include <inttypes.h> +#endif + +#include "lua.h" +#include "lauxlib.h" +#include <openssl/crypto.h> +#include <openssl/ecdsa.h> +#include <openssl/err.h> +#include <openssl/evp.h> +#include <openssl/obj_mac.h> +#include <openssl/pem.h> + +#if (LUA_VERSION_NUM == 501) +#define luaL_setfuncs(L, R, N) luaL_register(L, NULL, R) +#endif + +/* The max size of an encoded 'R' or 'S' value. P-521 = 521 bits = 66 bytes */ +#define MAX_ECDSA_SIG_INT_BYTES 66 + +#include "managed_pointer.h" + +#define PKEY_MT_TAG "util.crypto key" + +static BIO* new_memory_BIO(void) { + return BIO_new(BIO_s_mem()); +} + +MANAGED_POINTER_ALLOCATOR(new_managed_EVP_MD_CTX, EVP_MD_CTX*, EVP_MD_CTX_new, EVP_MD_CTX_free) +MANAGED_POINTER_ALLOCATOR(new_managed_BIO_s_mem, BIO*, new_memory_BIO, BIO_free) +MANAGED_POINTER_ALLOCATOR(new_managed_EVP_CIPHER_CTX, EVP_CIPHER_CTX*, EVP_CIPHER_CTX_new, EVP_CIPHER_CTX_free) + +#define CRYPTO_KEY_TYPE_ERR "unexpected key type: got '%s', expected '%s'" + +static EVP_PKEY* pkey_from_arg(lua_State *L, int idx, const int type, const int require_private) { + EVP_PKEY *pkey = *(EVP_PKEY**)luaL_checkudata(L, idx, PKEY_MT_TAG); + int got_type; + if(type || require_private) { + lua_getuservalue(L, idx); + if(type != 0) { + lua_getfield(L, -1, "type"); + got_type = lua_tointeger(L, -1); + if(got_type != type) { + const char *got_key_type_name = OBJ_nid2sn(got_type); + const char *want_key_type_name = OBJ_nid2sn(type); + lua_pushfstring(L, CRYPTO_KEY_TYPE_ERR, got_key_type_name, want_key_type_name); + luaL_argerror(L, idx, lua_tostring(L, -1)); + } + lua_pop(L, 1); + } + if(require_private != 0) { + lua_getfield(L, -1, "private"); + if(lua_toboolean(L, -1) != 1) { + luaL_argerror(L, idx, "private key expected, got public key only"); + } + lua_pop(L, 1); + } + lua_pop(L, 1); + } + return pkey; +} + +static int Lpkey_finalizer(lua_State *L) { + EVP_PKEY *pkey = pkey_from_arg(L, 1, 0, 0); + EVP_PKEY_free(pkey); + return 0; +} + +static int Lpkey_meth_get_type(lua_State *L) { + EVP_PKEY *pkey = pkey_from_arg(L, 1, 0, 0); + + int key_type = EVP_PKEY_id(pkey); + lua_pushstring(L, OBJ_nid2sn(key_type)); + return 1; +} + +static int base_evp_sign(lua_State *L, const int key_type, const EVP_MD *digest_type) { + EVP_PKEY *pkey = pkey_from_arg(L, 1, (key_type!=NID_rsassaPss)?key_type:NID_rsaEncryption, 1); + luaL_Buffer sigbuf; + + size_t msg_len; + const unsigned char* msg = (unsigned char*)lua_tolstring(L, 2, &msg_len); + + size_t sig_len; + unsigned char *sig = NULL; + EVP_MD_CTX *md_ctx = new_managed_EVP_MD_CTX(L); + + if(EVP_DigestSignInit(md_ctx, NULL, digest_type, NULL, pkey) != 1) { + lua_pushnil(L); + return 1; + } + if(key_type == NID_rsassaPss) { + EVP_PKEY_CTX_set_rsa_padding(EVP_MD_CTX_pkey_ctx(md_ctx), RSA_PKCS1_PSS_PADDING); + } + if(EVP_DigestSign(md_ctx, NULL, &sig_len, msg, msg_len) != 1) { + lua_pushnil(L); + return 1; + } + + // COMPAT w/ Lua 5.1 + luaL_buffinit(L, &sigbuf); + sig = memset(luaL_prepbuffer(&sigbuf), 0, sig_len); + + if(EVP_DigestSign(md_ctx, sig, &sig_len, msg, msg_len) != 1) { + lua_pushnil(L); + } + else { + luaL_addsize(&sigbuf, sig_len); + luaL_pushresult(&sigbuf); + return 1; + } + + return 1; +} + +static int base_evp_verify(lua_State *L, const int key_type, const EVP_MD *digest_type) { + EVP_PKEY *pkey = pkey_from_arg(L, 1, (key_type!=NID_rsassaPss)?key_type:NID_rsaEncryption, 0); + + size_t msg_len; + const unsigned char *msg = (unsigned char*)luaL_checklstring(L, 2, &msg_len); + + size_t sig_len; + const unsigned char *sig = (unsigned char*)luaL_checklstring(L, 3, &sig_len); + + EVP_MD_CTX *md_ctx = EVP_MD_CTX_new(); + + if(EVP_DigestVerifyInit(md_ctx, NULL, digest_type, NULL, pkey) != 1) { + lua_pushnil(L); + goto cleanup; + } + if(key_type == NID_rsassaPss) { + EVP_PKEY_CTX_set_rsa_padding(EVP_MD_CTX_pkey_ctx(md_ctx), RSA_PKCS1_PSS_PADDING); + } + int result = EVP_DigestVerify(md_ctx, sig, sig_len, msg, msg_len); + if(result == 0) { + lua_pushboolean(L, 0); + } else if(result != 1) { + lua_pushnil(L); + } + else { + lua_pushboolean(L, 1); + } +cleanup: + EVP_MD_CTX_free(md_ctx); + return 1; +} + +static int Lpkey_meth_public_pem(lua_State *L) { + char *data; + size_t bytes; + EVP_PKEY *pkey = pkey_from_arg(L, 1, 0, 0); + BIO *bio = new_managed_BIO_s_mem(L); + if(PEM_write_bio_PUBKEY(bio, pkey)) { + bytes = BIO_get_mem_data(bio, &data); + if (bytes > 0) { + lua_pushlstring(L, data, bytes); + } + else { + lua_pushnil(L); + } + } + else { + lua_pushnil(L); + } + return 1; +} + +static int Lpkey_meth_private_pem(lua_State *L) { + char *data; + size_t bytes; + EVP_PKEY *pkey = pkey_from_arg(L, 1, 0, 1); + BIO *bio = new_managed_BIO_s_mem(L); + + if(PEM_write_bio_PrivateKey(bio, pkey, NULL, NULL, 0, NULL, NULL)) { + bytes = BIO_get_mem_data(bio, &data); + if (bytes > 0) { + lua_pushlstring(L, data, bytes); + } + else { + lua_pushnil(L); + } + } + else { + lua_pushnil(L); + } + return 1; +} + +static int push_pkey(lua_State *L, EVP_PKEY *pkey, const int type, const int privkey) { + EVP_PKEY **ud = lua_newuserdata(L, sizeof(EVP_PKEY*)); + *ud = pkey; + luaL_newmetatable(L, PKEY_MT_TAG); + lua_setmetatable(L, -2); + + /* Set some info about the key and attach it as a user value */ + lua_newtable(L); + if(type != 0) { + lua_pushinteger(L, type); + lua_setfield(L, -2, "type"); + } + if(privkey != 0) { + lua_pushboolean(L, 1); + lua_setfield(L, -2, "private"); + } + lua_setuservalue(L, -2); + return 1; +} + +static int Lgenerate_ed25519_keypair(lua_State *L) { + EVP_PKEY *pkey = NULL; + EVP_PKEY_CTX *pctx = EVP_PKEY_CTX_new_id(EVP_PKEY_ED25519, NULL); + + /* Generate key */ + EVP_PKEY_keygen_init(pctx); + EVP_PKEY_keygen(pctx, &pkey); + EVP_PKEY_CTX_free(pctx); + + push_pkey(L, pkey, NID_ED25519, 1); + return 1; +} + +static int Limport_private_pem(lua_State *L) { + EVP_PKEY *pkey = NULL; + + size_t privkey_bytes; + const char* privkey_data; + BIO *bio = new_managed_BIO_s_mem(L); + + privkey_data = luaL_checklstring(L, 1, &privkey_bytes); + BIO_write(bio, privkey_data, privkey_bytes); + pkey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL); + if (pkey) { + push_pkey(L, pkey, EVP_PKEY_id(pkey), 1); + } + else { + lua_pushnil(L); + } + + return 1; +} + +static int Limport_public_pem(lua_State *L) { + EVP_PKEY *pkey = NULL; + + size_t pubkey_bytes; + const char* pubkey_data; + BIO *bio = new_managed_BIO_s_mem(L); + + pubkey_data = luaL_checklstring(L, 1, &pubkey_bytes); + BIO_write(bio, pubkey_data, pubkey_bytes); + pkey = PEM_read_bio_PUBKEY(bio, NULL, NULL, NULL); + if (pkey) { + push_pkey(L, pkey, EVP_PKEY_id(pkey), 0); + } + else { + lua_pushnil(L); + } + + return 1; +} + +static int Led25519_sign(lua_State *L) { + return base_evp_sign(L, NID_ED25519, NULL); +} + +static int Led25519_verify(lua_State *L) { + return base_evp_verify(L, NID_ED25519, NULL); +} + +/* encrypt(key, iv, plaintext) */ +static int Levp_encrypt(lua_State *L, const EVP_CIPHER *cipher, const unsigned char expected_key_len, const unsigned char expected_iv_len, const size_t tag_len) { + EVP_CIPHER_CTX *ctx; + luaL_Buffer ciphertext_buffer; + + size_t key_len, iv_len, plaintext_len; + int ciphertext_len, final_len; + + const unsigned char *key = (unsigned char*)luaL_checklstring(L, 1, &key_len); + const unsigned char *iv = (unsigned char*)luaL_checklstring(L, 2, &iv_len); + const unsigned char *plaintext = (unsigned char*)luaL_checklstring(L, 3, &plaintext_len); + + if(key_len != expected_key_len) { + return luaL_error(L, "key must be %d bytes", expected_key_len); + } + if(iv_len != expected_iv_len) { + return luaL_error(L, "iv must be %d bytes", expected_iv_len); + } + if(lua_gettop(L) > 3) { + return luaL_error(L, "Expected 3 arguments, got %d", lua_gettop(L)); + } + + // Create and initialise the context + ctx = new_managed_EVP_CIPHER_CTX(L); + + // Initialise the encryption operation + if(1 != EVP_EncryptInit_ex(ctx, cipher, NULL, NULL, NULL)) { + return luaL_error(L, "Error while initializing encryption engine"); + } + + // Initialise key and IV + if(1 != EVP_EncryptInit_ex(ctx, NULL, NULL, key, iv)) { + return luaL_error(L, "Error while initializing key/iv"); + } + + luaL_buffinit(L, &ciphertext_buffer); + unsigned char *ciphertext = (unsigned char*)luaL_prepbuffsize(&ciphertext_buffer, plaintext_len+tag_len); + + if(1 != EVP_EncryptUpdate(ctx, ciphertext, &ciphertext_len, plaintext, plaintext_len)) { + return luaL_error(L, "Error while encrypting data"); + } + + /* + * Finalise the encryption. Normally ciphertext bytes may be written at + * this stage, but this does not occur in GCM mode + */ + if(1 != EVP_EncryptFinal_ex(ctx, ciphertext + ciphertext_len, &final_len)) { + return luaL_error(L, "Error while encrypting final data"); + } + if(final_len != 0) { + return luaL_error(L, "Non-zero final data"); + } + + if(tag_len > 0) { + /* Get the tag */ + if(1 != EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_GET_TAG, tag_len, ciphertext + ciphertext_len)) { + return luaL_error(L, "Unable to read AEAD tag of encrypted data"); + } + /* Append tag */ + luaL_addsize(&ciphertext_buffer, ciphertext_len + tag_len); + } else { + luaL_addsize(&ciphertext_buffer, ciphertext_len); + } + luaL_pushresult(&ciphertext_buffer); + + return 1; +} + +static int Laes_128_gcm_encrypt(lua_State *L) { + return Levp_encrypt(L, EVP_aes_128_gcm(), 16, 12, 16); +} + +static int Laes_256_gcm_encrypt(lua_State *L) { + return Levp_encrypt(L, EVP_aes_256_gcm(), 32, 12, 16); +} + +static int Laes_256_ctr_encrypt(lua_State *L) { + return Levp_encrypt(L, EVP_aes_256_ctr(), 32, 16, 0); +} + +/* decrypt(key, iv, ciphertext) */ +static int Levp_decrypt(lua_State *L, const EVP_CIPHER *cipher, const unsigned char expected_key_len, const unsigned char expected_iv_len, const size_t tag_len) { + EVP_CIPHER_CTX *ctx; + luaL_Buffer plaintext_buffer; + + size_t key_len, iv_len, ciphertext_len; + int plaintext_len, final_len; + + const unsigned char *key = (unsigned char*)luaL_checklstring(L, 1, &key_len); + const unsigned char *iv = (unsigned char*)luaL_checklstring(L, 2, &iv_len); + const unsigned char *ciphertext = (unsigned char*)luaL_checklstring(L, 3, &ciphertext_len); + + if(key_len != expected_key_len) { + return luaL_error(L, "key must be %d bytes", expected_key_len); + } + if(iv_len != expected_iv_len) { + return luaL_error(L, "iv must be %d bytes", expected_iv_len); + } + if(ciphertext_len <= tag_len) { + return luaL_error(L, "ciphertext must be at least %d bytes (including tag)", tag_len); + } + if(lua_gettop(L) > 3) { + return luaL_error(L, "Expected 3 arguments, got %d", lua_gettop(L)); + } + + /* Create and initialise the context */ + ctx = new_managed_EVP_CIPHER_CTX(L); + + /* Initialise the decryption operation. */ + if(!EVP_DecryptInit_ex(ctx, cipher, NULL, NULL, NULL)) { + return luaL_error(L, "Error while initializing decryption engine"); + } + + /* Initialise key and IV */ + if(!EVP_DecryptInit_ex(ctx, NULL, NULL, key, iv)) { + return luaL_error(L, "Error while initializing key/iv"); + } + + luaL_buffinit(L, &plaintext_buffer); + unsigned char *plaintext = (unsigned char*)luaL_prepbuffsize(&plaintext_buffer, ciphertext_len); + + /* + * Provide the message to be decrypted, and obtain the plaintext output. + * EVP_DecryptUpdate can be called multiple times if necessary + */ + if(!EVP_DecryptUpdate(ctx, plaintext, &plaintext_len, ciphertext, ciphertext_len-tag_len)) { + return luaL_error(L, "Error while decrypting data"); + } + + if(tag_len > 0) { + /* Set expected tag value. Works in OpenSSL 1.0.1d and later */ + if(!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_TAG, tag_len, (unsigned char*)ciphertext + (ciphertext_len-tag_len))) { + return luaL_error(L, "Error while processing authentication tag"); + } + } + + /* + * Finalise the decryption. A positive return value indicates success, + * anything else is a failure - the plaintext is not trustworthy. + */ + int ret = EVP_DecryptFinal_ex(ctx, plaintext + plaintext_len, &final_len); + + if(ret <= 0) { + /* Verify failed */ + lua_pushnil(L); + lua_pushliteral(L, "verify-failed"); + return 2; + } + + luaL_addsize(&plaintext_buffer, plaintext_len + final_len); + luaL_pushresult(&plaintext_buffer); + return 1; +} + +static int Laes_128_gcm_decrypt(lua_State *L) { + return Levp_decrypt(L, EVP_aes_128_gcm(), 16, 12, 16); +} + +static int Laes_256_gcm_decrypt(lua_State *L) { + return Levp_decrypt(L, EVP_aes_256_gcm(), 32, 12, 16); +} + +static int Laes_256_ctr_decrypt(lua_State *L) { + return Levp_decrypt(L, EVP_aes_256_ctr(), 32, 16, 0); +} + +/* r, s = parse_ecdsa_sig(sig_der) */ +static int Lparse_ecdsa_signature(lua_State *L) { + ECDSA_SIG *sig; + size_t sig_der_len; + const unsigned char *sig_der = (unsigned char*)luaL_checklstring(L, 1, &sig_der_len); + const size_t sig_int_bytes = luaL_checkinteger(L, 2); + const BIGNUM *r, *s; + int rlen, slen; + unsigned char rb[MAX_ECDSA_SIG_INT_BYTES]; + unsigned char sb[MAX_ECDSA_SIG_INT_BYTES]; + + if(sig_int_bytes > MAX_ECDSA_SIG_INT_BYTES) { + luaL_error(L, "requested signature size exceeds supported limit"); + } + + sig = d2i_ECDSA_SIG(NULL, &sig_der, sig_der_len); + + if(sig == NULL) { + lua_pushnil(L); + return 1; + } + + ECDSA_SIG_get0(sig, &r, &s); + + rlen = BN_bn2binpad(r, rb, sig_int_bytes); + slen = BN_bn2binpad(s, sb, sig_int_bytes); + + if (rlen == -1 || slen == -1) { + ECDSA_SIG_free(sig); + luaL_error(L, "encoded integers exceed requested size"); + } + + ECDSA_SIG_free(sig); + + lua_pushlstring(L, (const char*)rb, rlen); + lua_pushlstring(L, (const char*)sb, slen); + + return 2; +} + +/* sig_der = build_ecdsa_signature(r, s) */ +static int Lbuild_ecdsa_signature(lua_State *L) { + ECDSA_SIG *sig = ECDSA_SIG_new(); + BIGNUM *r, *s; + luaL_Buffer sigbuf; + + size_t rlen, slen; + const unsigned char *rbin, *sbin; + + rbin = (unsigned char*)luaL_checklstring(L, 1, &rlen); + sbin = (unsigned char*)luaL_checklstring(L, 2, &slen); + + r = BN_bin2bn(rbin, (int)rlen, NULL); + s = BN_bin2bn(sbin, (int)slen, NULL); + + ECDSA_SIG_set0(sig, r, s); + + luaL_buffinit(L, &sigbuf); + + /* DER structure of an ECDSA signature has 7 bytes plus the integers themselves, + which may gain an extra byte once encoded */ + unsigned char *buffer = (unsigned char*)luaL_prepbuffsize(&sigbuf, (rlen+1)+(slen+1)+7); + int len = i2d_ECDSA_SIG(sig, &buffer); + luaL_addsize(&sigbuf, len); + luaL_pushresult(&sigbuf); + + ECDSA_SIG_free(sig); + + return 1; +} + +#define REG_SIGN_VERIFY(algorithm, digest) \ + { #algorithm "_" #digest "_sign", L ## algorithm ## _ ## digest ## _sign },\ + { #algorithm "_" #digest "_verify", L ## algorithm ## _ ## digest ## _verify }, + +#define IMPL_SIGN_VERIFY(algorithm, key_type, digest) \ + static int L ## algorithm ## _ ## digest ## _sign(lua_State *L) { \ + return base_evp_sign(L, key_type, EVP_ ## digest()); \ + } \ + static int L ## algorithm ## _ ## digest ## _verify(lua_State *L) { \ + return base_evp_verify(L, key_type, EVP_ ## digest()); \ + } + +IMPL_SIGN_VERIFY(ecdsa, NID_X9_62_id_ecPublicKey, sha256) +IMPL_SIGN_VERIFY(ecdsa, NID_X9_62_id_ecPublicKey, sha384) +IMPL_SIGN_VERIFY(ecdsa, NID_X9_62_id_ecPublicKey, sha512) + +IMPL_SIGN_VERIFY(rsassa_pkcs1, NID_rsaEncryption, sha256) +IMPL_SIGN_VERIFY(rsassa_pkcs1, NID_rsaEncryption, sha384) +IMPL_SIGN_VERIFY(rsassa_pkcs1, NID_rsaEncryption, sha512) + +IMPL_SIGN_VERIFY(rsassa_pss, NID_rsassaPss, sha256) +IMPL_SIGN_VERIFY(rsassa_pss, NID_rsassaPss, sha384) +IMPL_SIGN_VERIFY(rsassa_pss, NID_rsassaPss, sha512) + +static const luaL_Reg Reg[] = { + { "ed25519_sign", Led25519_sign }, + { "ed25519_verify", Led25519_verify }, + + REG_SIGN_VERIFY(ecdsa, sha256) + REG_SIGN_VERIFY(ecdsa, sha384) + REG_SIGN_VERIFY(ecdsa, sha512) + + REG_SIGN_VERIFY(rsassa_pkcs1, sha256) + REG_SIGN_VERIFY(rsassa_pkcs1, sha384) + REG_SIGN_VERIFY(rsassa_pkcs1, sha512) + + REG_SIGN_VERIFY(rsassa_pss, sha256) + REG_SIGN_VERIFY(rsassa_pss, sha384) + REG_SIGN_VERIFY(rsassa_pss, sha512) + + { "aes_128_gcm_encrypt", Laes_128_gcm_encrypt }, + { "aes_128_gcm_decrypt", Laes_128_gcm_decrypt }, + { "aes_256_gcm_encrypt", Laes_256_gcm_encrypt }, + { "aes_256_gcm_decrypt", Laes_256_gcm_decrypt }, + + { "aes_256_ctr_encrypt", Laes_256_ctr_encrypt }, + { "aes_256_ctr_decrypt", Laes_256_ctr_decrypt }, + + { "generate_ed25519_keypair", Lgenerate_ed25519_keypair }, + + { "import_private_pem", Limport_private_pem }, + { "import_public_pem", Limport_public_pem }, + + { "parse_ecdsa_signature", Lparse_ecdsa_signature }, + { "build_ecdsa_signature", Lbuild_ecdsa_signature }, + { NULL, NULL } +}; + +static const luaL_Reg KeyMethods[] = { + { "private_pem", Lpkey_meth_private_pem }, + { "public_pem", Lpkey_meth_public_pem }, + { "get_type", Lpkey_meth_get_type }, + { NULL, NULL } +}; + +static const luaL_Reg KeyMetatable[] = { + { "__gc", Lpkey_finalizer }, + { NULL, NULL } +}; + +LUALIB_API int luaopen_util_crypto(lua_State *L) { +#if (LUA_VERSION_NUM > 501) + luaL_checkversion(L); +#endif + + /* Initialize pkey metatable */ + luaL_newmetatable(L, PKEY_MT_TAG); + luaL_setfuncs(L, KeyMetatable, 0); + lua_newtable(L); + luaL_setfuncs(L, KeyMethods, 0); + lua_setfield(L, -2, "__index"); + lua_pop(L, 1); + + /* Initialize lib table */ + lua_newtable(L); + luaL_setfuncs(L, Reg, 0); + lua_pushliteral(L, "-3.14"); + lua_setfield(L, -2, "version"); +#ifdef OPENSSL_VERSION + lua_pushstring(L, OpenSSL_version(OPENSSL_VERSION)); + lua_setfield(L, -2, "_LIBCRYPTO_VERSION"); +#endif + return 1; +} diff --git a/util-src/encodings.c b/util-src/encodings.c index 72264da8..157d8526 100644 --- a/util-src/encodings.c +++ b/util-src/encodings.c @@ -21,9 +21,6 @@ #include "lua.h" #include "lauxlib.h" -#if (LUA_VERSION_NUM == 501) -#define luaL_setfuncs(L, R, N) luaL_register(L, NULL, R) -#endif #if (LUA_VERSION_NUM < 504) #define luaL_pushfail lua_pushnil #endif @@ -616,9 +613,7 @@ static const luaL_Reg Reg_idna[] = { /***************** end *****************/ LUALIB_API int luaopen_util_encodings(lua_State *L) { -#if (LUA_VERSION_NUM > 501) luaL_checkversion(L); -#endif #ifdef USE_STRINGPREP_ICU init_icu(); #endif diff --git a/util-src/hashes.c b/util-src/hashes.c index 8eefcd6b..94f8bcb2 100644 --- a/util-src/hashes.c +++ b/util-src/hashes.c @@ -28,13 +28,16 @@ typedef unsigned __int32 uint32_t; #include <openssl/md5.h> #include <openssl/hmac.h> #include <openssl/evp.h> +#include <openssl/kdf.h> +#include <openssl/err.h> -#if (LUA_VERSION_NUM == 501) -#define luaL_setfuncs(L, R, N) luaL_register(L, NULL, R) -#endif -#define HMAC_IPAD 0x36363636 -#define HMAC_OPAD 0x5c5c5c5c +/* Semi-arbitrary limit here. The actual theoretical limit +* is (255*(hash output octets)), but allocating 16KB on the +* stack when in practice we only ever request a few dozen +* bytes seems excessive. +*/ +#define MAX_HKDF_OUTPUT 256 static const char *hex_tab = "0123456789abcdef"; static void toHex(const unsigned char *in, int length, unsigned char *out) { @@ -46,94 +49,228 @@ static void toHex(const unsigned char *in, int length, unsigned char *out) { } } -#define MAKE_HASH_FUNCTION(myFunc, func, size) \ -static int myFunc(lua_State *L) { \ - size_t len; \ - const char *s = luaL_checklstring(L, 1, &len); \ - int hex_out = lua_toboolean(L, 2); \ - unsigned char hash[size], result[size*2]; \ - func((const unsigned char*)s, len, hash); \ - if (hex_out) { \ - toHex(hash, size, result); \ - lua_pushlstring(L, (char*)result, size*2); \ - } else { \ - lua_pushlstring(L, (char*)hash, size);\ - } \ - return 1; \ -} - -MAKE_HASH_FUNCTION(Lsha1, SHA1, SHA_DIGEST_LENGTH) -MAKE_HASH_FUNCTION(Lsha224, SHA224, SHA224_DIGEST_LENGTH) -MAKE_HASH_FUNCTION(Lsha256, SHA256, SHA256_DIGEST_LENGTH) -MAKE_HASH_FUNCTION(Lsha384, SHA384, SHA384_DIGEST_LENGTH) -MAKE_HASH_FUNCTION(Lsha512, SHA512, SHA512_DIGEST_LENGTH) -MAKE_HASH_FUNCTION(Lmd5, MD5, MD5_DIGEST_LENGTH) - -struct hash_desc { - int (*Init)(void *); - int (*Update)(void *, const void *, size_t); - int (*Final)(unsigned char *, void *); - size_t digestLength; - void *ctx, *ctxo; -}; +static int Levp_hash(lua_State *L, const EVP_MD *evp) { + size_t len; + unsigned int size = EVP_MAX_MD_SIZE; + const char *s = luaL_checklstring(L, 1, &len); + int hex_out = lua_toboolean(L, 2); -#define MAKE_HMAC_FUNCTION(myFunc, evp, size, type) \ -static int myFunc(lua_State *L) { \ - unsigned char hash[size], result[2*size]; \ - size_t key_len, msg_len; \ - unsigned int out_len; \ - const char *key = luaL_checklstring(L, 1, &key_len); \ - const char *msg = luaL_checklstring(L, 2, &msg_len); \ - const int hex_out = lua_toboolean(L, 3); \ - HMAC(evp(), key, key_len, (const unsigned char*)msg, msg_len, (unsigned char*)hash, &out_len); \ - if (hex_out) { \ - toHex(hash, out_len, result); \ - lua_pushlstring(L, (char*)result, out_len*2); \ - } else { \ - lua_pushlstring(L, (char*)hash, out_len); \ - } \ - return 1; \ -} - -MAKE_HMAC_FUNCTION(Lhmac_sha1, EVP_sha1, SHA_DIGEST_LENGTH, SHA_CTX) -MAKE_HMAC_FUNCTION(Lhmac_sha256, EVP_sha256, SHA256_DIGEST_LENGTH, SHA256_CTX) -MAKE_HMAC_FUNCTION(Lhmac_sha512, EVP_sha512, SHA512_DIGEST_LENGTH, SHA512_CTX) -MAKE_HMAC_FUNCTION(Lhmac_md5, EVP_md5, MD5_DIGEST_LENGTH, MD5_CTX) + unsigned char hash[EVP_MAX_MD_SIZE], result[EVP_MAX_MD_SIZE * 2]; + + EVP_MD_CTX *ctx = EVP_MD_CTX_new(); + + if(ctx == NULL) { + goto fail; + } + + if(!EVP_DigestInit_ex(ctx, evp, NULL)) { + goto fail; + } + + if(!EVP_DigestUpdate(ctx, s, len)) { + goto fail; + } + + if(!EVP_DigestFinal_ex(ctx, hash, &size)) { + goto fail; + } + + EVP_MD_CTX_free(ctx); + + if(hex_out) { + toHex(hash, size, result); + lua_pushlstring(L, (char *)result, size * 2); + } else { + lua_pushlstring(L, (char *)hash, size); + } + + return 1; + +fail: + EVP_MD_CTX_free(ctx); + return luaL_error(L, ERR_error_string(ERR_get_error(), NULL)); +} + +static int Lsha1(lua_State *L) { + return Levp_hash(L, EVP_sha1()); +} + +static int Lsha224(lua_State *L) { + return Levp_hash(L, EVP_sha224()); +} + +static int Lsha256(lua_State *L) { + return Levp_hash(L, EVP_sha256()); +} + +static int Lsha384(lua_State *L) { + return Levp_hash(L, EVP_sha384()); +} + +static int Lsha512(lua_State *L) { + return Levp_hash(L, EVP_sha512()); +} + +static int Lmd5(lua_State *L) { + return Levp_hash(L, EVP_md5()); +} + +static int Lblake2s256(lua_State *L) { + return Levp_hash(L, EVP_blake2s256()); +} + +static int Lblake2b512(lua_State *L) { + return Levp_hash(L, EVP_blake2b512()); +} + +static int Lsha3_256(lua_State *L) { + return Levp_hash(L, EVP_sha3_256()); +} + +static int Lsha3_512(lua_State *L) { + return Levp_hash(L, EVP_sha3_512()); +} + +static int Levp_hmac(lua_State *L, const EVP_MD *evp) { + unsigned char hash[EVP_MAX_MD_SIZE], result[EVP_MAX_MD_SIZE * 2]; + size_t key_len, msg_len; + unsigned int out_len = EVP_MAX_MD_SIZE; + const char *key = luaL_checklstring(L, 1, &key_len); + const char *msg = luaL_checklstring(L, 2, &msg_len); + const int hex_out = lua_toboolean(L, 3); + + if(HMAC(evp, key, key_len, (const unsigned char*)msg, msg_len, (unsigned char*)hash, &out_len) == NULL) { + goto fail; + } + + if(hex_out) { + toHex(hash, out_len, result); + lua_pushlstring(L, (char *)result, out_len * 2); + } else { + lua_pushlstring(L, (char *)hash, out_len); + } + + return 1; + +fail: + return luaL_error(L, ERR_error_string(ERR_get_error(), NULL)); +} + +static int Lhmac_sha1(lua_State *L) { + return Levp_hmac(L, EVP_sha1()); +} + +static int Lhmac_sha224(lua_State *L) { + return Levp_hmac(L, EVP_sha224()); +} + +static int Lhmac_sha256(lua_State *L) { + return Levp_hmac(L, EVP_sha256()); +} + +static int Lhmac_sha384(lua_State *L) { + return Levp_hmac(L, EVP_sha384()); +} + +static int Lhmac_sha512(lua_State *L) { + return Levp_hmac(L, EVP_sha512()); +} + +static int Lhmac_md5(lua_State *L) { + return Levp_hmac(L, EVP_md5()); +} + +static int Lhmac_sha3_256(lua_State *L) { + return Levp_hmac(L, EVP_sha3_256()); +} + +static int Lhmac_sha3_512(lua_State *L) { + return Levp_hmac(L, EVP_sha3_512()); +} + +static int Lhmac_blake2s256(lua_State *L) { + return Levp_hmac(L, EVP_blake2s256()); +} + +static int Lhmac_blake2b512(lua_State *L) { + return Levp_hmac(L, EVP_blake2b512()); +} -static int Lpbkdf2_sha1(lua_State *L) { - unsigned char out[SHA_DIGEST_LENGTH]; + +static int Levp_pbkdf2(lua_State *L, const EVP_MD *evp, size_t out_len) { + unsigned char out[EVP_MAX_MD_SIZE]; size_t pass_len, salt_len; const char *pass = luaL_checklstring(L, 1, &pass_len); const unsigned char *salt = (unsigned char *)luaL_checklstring(L, 2, &salt_len); const int iter = luaL_checkinteger(L, 3); - if(PKCS5_PBKDF2_HMAC(pass, pass_len, salt, salt_len, iter, EVP_sha1(), SHA_DIGEST_LENGTH, out) == 0) { - return luaL_error(L, "PKCS5_PBKDF2_HMAC() failed"); + if(PKCS5_PBKDF2_HMAC(pass, pass_len, salt, salt_len, iter, evp, out_len, out) == 0) { + return luaL_error(L, ERR_error_string(ERR_get_error(), NULL)); } - lua_pushlstring(L, (char *)out, SHA_DIGEST_LENGTH); + lua_pushlstring(L, (char *)out, out_len); return 1; } +static int Lpbkdf2_sha1(lua_State *L) { + return Levp_pbkdf2(L, EVP_sha1(), SHA_DIGEST_LENGTH); +} static int Lpbkdf2_sha256(lua_State *L) { - unsigned char out[SHA256_DIGEST_LENGTH]; + return Levp_pbkdf2(L, EVP_sha256(), SHA256_DIGEST_LENGTH); +} - size_t pass_len, salt_len; - const char *pass = luaL_checklstring(L, 1, &pass_len); - const unsigned char *salt = (unsigned char *)luaL_checklstring(L, 2, &salt_len); - const int iter = luaL_checkinteger(L, 3); - if(PKCS5_PBKDF2_HMAC(pass, pass_len, salt, salt_len, iter, EVP_sha256(), SHA256_DIGEST_LENGTH, out) == 0) { - return luaL_error(L, "PKCS5_PBKDF2_HMAC() failed"); +/* HKDF(length, input, salt, info) */ +static int Levp_hkdf(lua_State *L, const EVP_MD *evp) { + unsigned char out[MAX_HKDF_OUTPUT]; + + size_t input_len, salt_len, info_len; + size_t actual_out_len = luaL_checkinteger(L, 1); + const unsigned char *input = (unsigned char *)luaL_checklstring(L, 2, &input_len); + const unsigned char *salt = (unsigned char *)luaL_optlstring(L, 3, NULL, &salt_len); + const unsigned char *info = (unsigned char *)luaL_checklstring(L, 4, &info_len); + + if(actual_out_len > MAX_HKDF_OUTPUT) + return luaL_error(L, "desired output length %ul exceeds internal limit %ul", actual_out_len, MAX_HKDF_OUTPUT); + + EVP_PKEY_CTX *pctx = EVP_PKEY_CTX_new_id(EVP_PKEY_HKDF, NULL); + + if (EVP_PKEY_derive_init(pctx) <= 0) + return luaL_error(L, ERR_error_string(ERR_get_error(), NULL)); + + if (EVP_PKEY_CTX_set_hkdf_md(pctx, evp) <= 0) + return luaL_error(L, ERR_error_string(ERR_get_error(), NULL)); + + if(salt != NULL) { + if (EVP_PKEY_CTX_set1_hkdf_salt(pctx, salt, salt_len) <= 0) + return luaL_error(L, ERR_error_string(ERR_get_error(), NULL)); } - lua_pushlstring(L, (char *)out, SHA256_DIGEST_LENGTH); + if (EVP_PKEY_CTX_set1_hkdf_key(pctx, input, input_len) <= 0) + return luaL_error(L, ERR_error_string(ERR_get_error(), NULL)); + + if (EVP_PKEY_CTX_add1_hkdf_info(pctx, info, info_len) <= 0) + return luaL_error(L, ERR_error_string(ERR_get_error(), NULL)); + + if (EVP_PKEY_derive(pctx, out, &actual_out_len) <= 0) + return luaL_error(L, ERR_error_string(ERR_get_error(), NULL)); + + lua_pushlstring(L, (char *)out, actual_out_len); + return 1; } +static int Lhkdf_sha256(lua_State *L) { + return Levp_hkdf(L, EVP_sha256()); +} + +static int Lhkdf_sha384(lua_State *L) { + return Levp_hkdf(L, EVP_sha384()); +} + static int Lhash_equals(lua_State *L) { size_t len1, len2; const char *s1 = luaL_checklstring(L, 1, &len1); @@ -153,21 +290,31 @@ static const luaL_Reg Reg[] = { { "sha384", Lsha384 }, { "sha512", Lsha512 }, { "md5", Lmd5 }, + { "sha3_256", Lsha3_256 }, + { "sha3_512", Lsha3_512 }, + { "blake2s256", Lblake2s256 }, + { "blake2b512", Lblake2b512 }, { "hmac_sha1", Lhmac_sha1 }, + { "hmac_sha224", Lhmac_sha224 }, { "hmac_sha256", Lhmac_sha256 }, + { "hmac_sha384", Lhmac_sha384 }, { "hmac_sha512", Lhmac_sha512 }, { "hmac_md5", Lhmac_md5 }, + { "hmac_sha3_256", Lhmac_sha3_256 }, + { "hmac_sha3_512", Lhmac_sha3_512 }, + { "hmac_blake2s256", Lhmac_blake2s256 }, + { "hmac_blake2b512", Lhmac_blake2b512 }, { "scram_Hi_sha1", Lpbkdf2_sha1 }, /* COMPAT */ { "pbkdf2_hmac_sha1", Lpbkdf2_sha1 }, { "pbkdf2_hmac_sha256", Lpbkdf2_sha256 }, + { "hkdf_hmac_sha256", Lhkdf_sha256 }, + { "hkdf_hmac_sha384", Lhkdf_sha384 }, { "equals", Lhash_equals }, { NULL, NULL } }; LUALIB_API int luaopen_util_hashes(lua_State *L) { -#if (LUA_VERSION_NUM > 501) luaL_checkversion(L); -#endif lua_newtable(L); luaL_setfuncs(L, Reg, 0); lua_pushliteral(L, "-3.14"); diff --git a/util-src/managed_pointer.h b/util-src/managed_pointer.h new file mode 100644 index 00000000..213b5fd7 --- /dev/null +++ b/util-src/managed_pointer.h @@ -0,0 +1,61 @@ +/* managed_pointer.h + +These macros allow wrapping an allocator/deallocator into an object that is +owned and managed by the Lua garbage collector. + +Why? It is too easy to leak objects that need to be manually released, especially +when dealing with the Lua API which can throw errors from many operations. + +USAGE +----- + +For example, given an object that can be created or released with the following +functions: + + fancy_buffer* new_buffer(); + void free_buffer(fancy_buffer* p_buffer) + +You could declare a managed version like so: + + MANAGED_POINTER_ALLOCATOR(new_managed_buffer, fancy_buffer*, new_buffer, free_buffer) + +And then, when you need to create a new fancy_buffer in your code: + + fancy_buffer *my_buffer = new_managed_buffer(L); + +NOTES +----- + +Managed objects MUST NOT be freed manually. They will automatically be +freed during the next GC sweep after your function exits (even if via an error). + +The managed object is pushed onto the stack, but should generally be ignored, +but you'll need to bear this in mind when creating managed pointers in the +middle of a sequence of stack operations. +*/ + +#define MANAGED_POINTER_MT(wrapped_type) #wrapped_type "_managedptr_mt" + +#define MANAGED_POINTER_ALLOCATOR(name, wrapped_type, wrapped_alloc, wrapped_free) \ + static int _release_ ## name(lua_State *L) { \ + wrapped_type *p = (wrapped_type*)lua_topointer(L, 1); \ + if(*p != NULL) { \ + wrapped_free(*p); \ + } \ + return 0; \ + } \ + static wrapped_type name(lua_State *L) { \ + wrapped_type *p = (wrapped_type*)lua_newuserdata(L, sizeof(wrapped_type)); \ + if(luaL_newmetatable(L, MANAGED_POINTER_MT(wrapped_type)) != 0) { \ + lua_pushcfunction(L, _release_ ## name); \ + lua_setfield(L, -2, "__gc"); \ + } \ + lua_setmetatable(L, -2); \ + *p = wrapped_alloc(); \ + if(*p == NULL) { \ + lua_pushliteral(L, "not enough memory"); \ + lua_error(L); \ + } \ + return *p; \ + } + diff --git a/util-src/net.c b/util-src/net.c index d786e885..96b50e7b 100644 --- a/util-src/net.c +++ b/util-src/net.c @@ -30,9 +30,6 @@ #include <lua.h> #include <lauxlib.h> -#if (LUA_VERSION_NUM == 501) -#define luaL_setfuncs(L, R, N) luaL_register(L, NULL, R) -#endif #if (LUA_VERSION_NUM < 504) #define luaL_pushfail lua_pushnil #endif @@ -193,9 +190,7 @@ static int lc_ntop(lua_State *L) { } int luaopen_util_net(lua_State *L) { -#if (LUA_VERSION_NUM > 501) luaL_checkversion(L); -#endif luaL_Reg exports[] = { { "local_addresses", lc_local_addresses }, { "pton", lc_pton }, diff --git a/util-src/poll.c b/util-src/poll.c index 81caa953..7ea2e140 100644 --- a/util-src/poll.c +++ b/util-src/poll.c @@ -8,7 +8,6 @@ * */ -#include <unistd.h> #include <string.h> #include <errno.h> @@ -24,6 +23,7 @@ #endif #ifdef USE_EPOLL +#include <unistd.h> #include <sys/epoll.h> #ifndef MAX_EVENTS #define MAX_EVENTS 64 @@ -44,9 +44,6 @@ #define STATE_MT "util.poll<" POLL_BACKEND ">" -#if (LUA_VERSION_NUM == 501) -#define luaL_setmetatable(L, tname) luaL_getmetatable(L, tname); lua_setmetatable(L, -2) -#endif #if (LUA_VERSION_NUM < 504) #define luaL_pushfail lua_pushnil #endif @@ -564,9 +561,7 @@ static int Lnew(lua_State *L) { * Open library */ int luaopen_util_poll(lua_State *L) { -#if (LUA_VERSION_NUM > 501) luaL_checkversion(L); -#endif luaL_newmetatable(L, STATE_MT); { diff --git a/util-src/pposix.c b/util-src/pposix.c index a8e0720f..aac27d35 100644 --- a/util-src/pposix.c +++ b/util-src/pposix.c @@ -58,9 +58,6 @@ #include "lualib.h" #include "lauxlib.h" -#if (LUA_VERSION_NUM == 501) -#define luaL_setfuncs(L, R, N) luaL_register(L, NULL, R) -#endif #if (LUA_VERSION_NUM < 503) #define lua_isinteger(L, n) lua_isnumber(L, n) #endif @@ -829,9 +826,7 @@ static int lc_isatty(lua_State *L) { /* Register functions */ int luaopen_util_pposix(lua_State *L) { -#if (LUA_VERSION_NUM > 501) luaL_checkversion(L); -#endif luaL_Reg exports[] = { { "abort", lc_abort }, diff --git a/util-src/ringbuffer.c b/util-src/ringbuffer.c index 0f250c12..95c62de9 100644 --- a/util-src/ringbuffer.c +++ b/util-src/ringbuffer.c @@ -314,9 +314,7 @@ static int rb_new(lua_State *L) { } int luaopen_util_ringbuffer(lua_State *L) { -#if (LUA_VERSION_NUM > 501) luaL_checkversion(L); -#endif if(luaL_newmetatable(L, "ringbuffer_mt")) { lua_pushcfunction(L, rb_tostring); diff --git a/util-src/signal.c b/util-src/signal.c index 1a398fa0..b5ba16a9 100644 --- a/util-src/signal.c +++ b/util-src/signal.c @@ -36,9 +36,6 @@ #include "lua.h" #include "lauxlib.h" -#if (LUA_VERSION_NUM == 501) -#define luaL_setfuncs(L, R, N) luaL_register(L, NULL, R) -#endif #if (LUA_VERSION_NUM < 503) #define lua_isinteger(L, n) lua_isnumber(L, n) #endif @@ -381,9 +378,7 @@ static const struct luaL_Reg lsignal_lib[] = { }; int luaopen_util_signal(lua_State *L) { -#if (LUA_VERSION_NUM > 501) luaL_checkversion(L); -#endif int i = 0; /* add the library */ diff --git a/util-src/strbitop.c b/util-src/strbitop.c index 89fce661..722f5a2d 100644 --- a/util-src/strbitop.c +++ b/util-src/strbitop.c @@ -8,13 +8,10 @@ #include <lua.h> #include <lauxlib.h> -#if (LUA_VERSION_NUM == 501) -#define luaL_setfuncs(L, R, N) luaL_register(L, NULL, R) -#endif /* TODO Deduplicate code somehow */ -int strop_and(lua_State *L) { +static int strop_and(lua_State *L) { luaL_Buffer buf; size_t a, b, i; const char *str_a = luaL_checklstring(L, 1, &a); @@ -35,7 +32,7 @@ int strop_and(lua_State *L) { return 1; } -int strop_or(lua_State *L) { +static int strop_or(lua_State *L) { luaL_Buffer buf; size_t a, b, i; const char *str_a = luaL_checklstring(L, 1, &a); @@ -56,7 +53,7 @@ int strop_or(lua_State *L) { return 1; } -int strop_xor(lua_State *L) { +static int strop_xor(lua_State *L) { luaL_Buffer buf; size_t a, b, i; const char *str_a = luaL_checklstring(L, 1, &a); diff --git a/util-src/struct.c b/util-src/struct.c index e80df4e6..b236db5b 100644 --- a/util-src/struct.c +++ b/util-src/struct.c @@ -36,12 +36,6 @@ #include "lauxlib.h" -#if (LUA_VERSION_NUM >= 502) - -#define luaL_register(L,n,f) luaL_newlib(L,f) - -#endif - /* basic integer type */ #if !defined(STRUCT_INT) @@ -140,7 +134,7 @@ static int gettoalign (size_t len, Header *h, int opt, size_t size) { /* -** options to control endianess and alignment +** options to control endianness and alignment */ static void controloptions (lua_State *L, int opt, const char **fmt, Header *h) { @@ -392,7 +386,7 @@ static const struct luaL_Reg thislib[] = { LUALIB_API int luaopen_util_struct (lua_State *L); LUALIB_API int luaopen_util_struct (lua_State *L) { - luaL_register(L, "struct", thislib); + luaL_newlib(L, thislib); return 1; } diff --git a/util-src/table.c b/util-src/table.c index 9a9553fc..1cbb276d 100644 --- a/util-src/table.c +++ b/util-src/table.c @@ -1,11 +1,17 @@ #include <lua.h> #include <lauxlib.h> +#ifndef LUA_MAXINTEGER +#include <stdint.h> +#define LUA_MAXINTEGER PTRDIFF_MAX +#endif + static int Lcreate_table(lua_State *L) { lua_createtable(L, luaL_checkinteger(L, 1), luaL_checkinteger(L, 2)); return 1; } +/* COMPAT: w/ Lua pre-5.2 */ static int Lpack(lua_State *L) { unsigned int n_args = lua_gettop(L); lua_createtable(L, n_args, 1); @@ -20,14 +26,48 @@ static int Lpack(lua_State *L) { return 1; } +/* COMPAT: w/ Lua pre-5.4 */ +static int Lmove (lua_State *L) { + lua_Integer f = luaL_checkinteger(L, 2); + lua_Integer e = luaL_checkinteger(L, 3); + lua_Integer t = luaL_checkinteger(L, 4); + + int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */ + luaL_checktype(L, 1, LUA_TTABLE); + luaL_checktype(L, tt, LUA_TTABLE); + + if (e >= f) { /* otherwise, nothing to move */ + lua_Integer n, i; + luaL_argcheck(L, f > 0 || e < LUA_MAXINTEGER + f, 3, + "too many elements to move"); + n = e - f + 1; /* number of elements to move */ + luaL_argcheck(L, t <= LUA_MAXINTEGER - n + 1, 4, + "destination wrap around"); + if (t > e || t <= f || (tt != 1 && !lua_compare(L, 1, tt, LUA_OPEQ))) { + for (i = 0; i < n; i++) { + lua_rawgeti(L, 1, f + i); + lua_rawseti(L, tt, t + i); + } + } else { + for (i = n - 1; i >= 0; i--) { + lua_rawgeti(L, 1, f + i); + lua_rawseti(L, tt, t + i); + } + } + } + + lua_pushvalue(L, tt); /* return destination table */ + return 1; +} + int luaopen_util_table(lua_State *L) { -#if (LUA_VERSION_NUM > 501) luaL_checkversion(L); -#endif lua_createtable(L, 0, 2); lua_pushcfunction(L, Lcreate_table); lua_setfield(L, -2, "create"); lua_pushcfunction(L, Lpack); lua_setfield(L, -2, "pack"); + lua_pushcfunction(L, Lmove); + lua_setfield(L, -2, "move"); return 1; } diff --git a/util-src/windows.c b/util-src/windows.c index 57af79d5..2adb85f5 100644 --- a/util-src/windows.c +++ b/util-src/windows.c @@ -19,9 +19,6 @@ #include "lua.h" #include "lauxlib.h" -#if (LUA_VERSION_NUM == 501) -#define luaL_setfuncs(L, R, N) luaL_register(L, NULL, R) -#endif #if (LUA_VERSION_NUM < 504) #define luaL_pushfail lua_pushnil #endif @@ -106,9 +103,7 @@ static const luaL_Reg Reg[] = { }; LUALIB_API int luaopen_util_windows(lua_State *L) { -#if (LUA_VERSION_NUM > 501) luaL_checkversion(L); -#endif lua_newtable(L); luaL_setfuncs(L, Reg, 0); lua_pushliteral(L, "-3.14"); diff --git a/util/adminstream.lua b/util/adminstream.lua index 4075aa05..ba8ce0a0 100644 --- a/util/adminstream.lua +++ b/util/adminstream.lua @@ -145,7 +145,7 @@ local function new_connection(socket_path, listeners) -- constructor was exported instead of a module table. Due to the lack of a -- proper release of LuaSocket, distros have settled on shipping either the -- last RC tag or some commit since then. - -- Here we accomodate both variants. + -- Here we accommodate both variants. unix = { stream = unix }; end if type(unix) ~= "table" then diff --git a/util/array.lua b/util/array.lua index c33a5ef1..9d438940 100644 --- a/util/array.lua +++ b/util/array.lua @@ -8,6 +8,7 @@ local t_insert, t_sort, t_remove, t_concat = table.insert, table.sort, table.remove, table.concat; +local t_move = require "util.table".move; local setmetatable = setmetatable; local getmetatable = getmetatable; @@ -137,13 +138,11 @@ function array_base.slice(outa, ina, i, j) return outa; end - for idx = 1, 1+j-i do - outa[idx] = ina[i+(idx-1)]; - end + + t_move(ina, i, j, 1, outa); if ina == outa then - for idx = 2+j-i, #outa do - outa[idx] = nil; - end + -- Clear (nil) remainder of range + t_move(ina, #outa+1, #outa*2, 2+j-i, ina); end return outa; end @@ -209,10 +208,7 @@ function array_methods:shuffle() end function array_methods:append(ina) - local len, len2 = #self, #ina; - for i = 1, len2 do - self[len+i] = ina[i]; - end + t_move(ina, 1, #ina, #self+1, self); return self; end diff --git a/util/bitcompat.lua b/util/bitcompat.lua index 454181af..8f227354 100644 --- a/util/bitcompat.lua +++ b/util/bitcompat.lua @@ -5,12 +5,6 @@ -- Lua 5.2 has it by default if _G.bit32 then return _G.bit32; -else - -- Lua 5.1 may have it as a standalone module that can be installed - local ok, bitop = pcall(require, "bit32") - if ok then - return bitop; - end end do @@ -21,12 +15,4 @@ do end end -do - -- Lastly, try the LuaJIT bitop library - local ok, bitop = pcall(require, "bit") - if ok then - return bitop; - end -end - error "No bit module found. See https://prosody.im/doc/depends#bitop"; diff --git a/util/datamapper.lua b/util/datamapper.lua index 2378314c..e1484525 100644 --- a/util/datamapper.lua +++ b/util/datamapper.lua @@ -1,5 +1,9 @@ -- This file is generated from teal-src/util/datamapper.lua +if not math.type then + require("util.mathcompat") +end + local st = require("util.stanza"); local pointer = require("util.jsonpointer"); diff --git a/util/datetime.lua b/util/datetime.lua index 2d27ece4..6df146f4 100644 --- a/util/datetime.lua +++ b/util/datetime.lua @@ -12,31 +12,41 @@ local os_date = os.date; local os_time = os.time; local os_difftime = os.difftime; +local floor = math.floor; local tonumber = tonumber; local _ENV = nil; -- luacheck: std none local function date(t) - return os_date("!%Y-%m-%d", t); + return os_date("!%Y-%m-%d", t and floor(t) or nil); end local function datetime(t) - return os_date("!%Y-%m-%dT%H:%M:%SZ", t); + if t == nil or t % 1 == 0 then + return os_date("!%Y-%m-%dT%H:%M:%SZ", t); + end + local m = t % 1; + local s = floor(t); + return os_date("!%Y-%m-%dT%H:%M:%S.%%06dZ", s):format(floor(m * 1000000)); end local function time(t) - return os_date("!%H:%M:%S", t); + if t == nil or t % 1 == 0 then + return os_date("!%H:%M:%S", t); + end + local m = t % 1; + local s = floor(t); + return os_date("!%H:%M:%S.%%06d", s):format(floor(m * 1000000)); end local function legacy(t) - return os_date("!%Y%m%dT%H:%M:%S", t); + return os_date("!%Y%m%dT%H:%M:%S", t and floor(t) or nil); end local function parse(s) if s then - local year, month, day, hour, min, sec, tzd; - year, month, day, hour, min, sec, tzd = s:match("^(%d%d%d%d)%-?(%d%d)%-?(%d%d)T(%d%d):(%d%d):(%d%d)%.?%d*([Z+%-]?.*)$"); + local year, month, day, hour, min, sec, tzd = s:match("^(%d%d%d%d)%-?(%d%d)%-?(%d%d)T(%d%d):(%d%d):(%d%d%.?%d*)([Z+%-]?.*)$"); if year then local now = os_time(); local time_offset = os_difftime(os_time(os_date("*t", now)), os_time(os_date("!*t", now))); -- to deal with local timezone @@ -49,8 +59,9 @@ local function parse(s) tzd_offset = h * 60 * 60 + m * 60; if sign == "-" then tzd_offset = -tzd_offset; end end - sec = (sec + time_offset) - tzd_offset; - return os_time({year=year, month=month, day=day, hour=hour, min=min, sec=sec, isdst=false}); + local prec = sec%1; + sec = floor(sec + time_offset) - tzd_offset; + return os_time({year=year, month=month, day=day, hour=hour, min=min, sec=sec, isdst=false})+prec; end end end diff --git a/util/dbuffer.lua b/util/dbuffer.lua index 3ad5fdfe..0a36288d 100644 --- a/util/dbuffer.lua +++ b/util/dbuffer.lua @@ -91,18 +91,18 @@ function dbuffer_methods:read_until(char) end function dbuffer_methods:discard(requested_bytes) - if requested_bytes > self._length then - return nil; + if self._length == 0 then return true; end + if not requested_bytes or requested_bytes >= self._length then + self.front_consumed = 0; + self._length = 0; + for _ in self.items:consume() do end + return true; end local chunk, read_bytes = self:read_chunk(requested_bytes); - if chunk then - requested_bytes = requested_bytes - read_bytes; - if requested_bytes == 0 then -- Already read everything we need - return true; - end - else - return nil; + requested_bytes = requested_bytes - read_bytes; + if requested_bytes == 0 then -- Already read everything we need + return true; end while chunk do diff --git a/util/dependencies.lua b/util/dependencies.lua index d7836404..165468c5 100644 --- a/util/dependencies.lua +++ b/util/dependencies.lua @@ -32,10 +32,10 @@ local function missingdep(name, sources, msg, err) -- luacheck: ignore err end local function check_dependencies() - if _VERSION < "Lua 5.1" then + if _VERSION < "Lua 5.2" then print "***********************************" print("Unsupported Lua version: ".._VERSION); - print("At least Lua 5.1 is required."); + print("At least Lua 5.2 is required."); print "***********************************" return false; end @@ -155,7 +155,7 @@ local function log_warnings() if _VERSION > "Lua 5.4" then prosody.log("warn", "Support for %s is experimental, please report any issues", _VERSION); elseif _VERSION < "Lua 5.2" then - prosody.log("warn", "%s has several issues and support is being phased out, consider upgrading", _VERSION); + prosody.log("warn", "%s support is deprecated, upgrade as soon as possible", _VERSION); end local ssl = softreq"ssl"; if ssl then diff --git a/util/dnsregistry.lua b/util/dnsregistry.lua index 635b7e3a..c52abee9 100644 --- a/util/dnsregistry.lua +++ b/util/dnsregistry.lua @@ -1,5 +1,5 @@ -- Source: https://www.iana.org/assignments/dns-parameters/dns-parameters.xml --- Generated on 2022-02-02 +-- Generated on 2023-01-20 return { classes = { ["IN"] = 1; [1] = "IN"; @@ -61,7 +61,6 @@ return { ["NSEC3PARAM"] = 51; [51] = "NSEC3PARAM"; ["TLSA"] = 52; [52] = "TLSA"; ["SMIMEA"] = 53; [53] = "SMIMEA"; - ["Unassigned"] = 54; [54] = "Unassigned"; ["HIP"] = 55; [55] = "HIP"; ["NINFO"] = 56; [56] = "NINFO"; ["RKEY"] = 57; [57] = "RKEY"; diff --git a/util/envload.lua b/util/envload.lua index 6182a1f9..cf45b702 100644 --- a/util/envload.lua +++ b/util/envload.lua @@ -6,38 +6,19 @@ -- -- luacheck: ignore 113/setfenv 113/loadstring -local load, loadstring, setfenv = load, loadstring, setfenv; +local load = load; local io_open = io.open; -local envload; -local envloadfile; -if setfenv then - function envload(code, source, env) - local f, err = loadstring(code, source); - if f and env then setfenv(f, env); end - return f, err; - end - - function envloadfile(file, env) - local fh, err, errno = io_open(file); - if not fh then return fh, err, errno; end - local f, err = load(function () return fh:read(2048); end, "@"..file); - fh:close(); - if f and env then setfenv(f, env); end - return f, err; - end -else - function envload(code, source, env) - return load(code, source, nil, env); - end +local function envload(code, source, env) + return load(code, source, nil, env); +end - function envloadfile(file, env) - local fh, err, errno = io_open(file); - if not fh then return fh, err, errno; end - local f, err = load(fh:lines(2048), "@"..file, nil, env); - fh:close(); - return f, err; - end +local function envloadfile(file, env) + local fh, err, errno = io_open(file); + if not fh then return fh, err, errno; end + local f, err = load(fh:lines(2048), "@" .. file, nil, env); + fh:close(); + return f, err; end return { envload = envload, envloadfile = envloadfile }; diff --git a/util/format.lua b/util/format.lua index d709aada..0631f423 100644 --- a/util/format.lua +++ b/util/format.lua @@ -6,14 +6,12 @@ -- Provides some protection from e.g. CAPEC-135, CWE-117, CWE-134, CWE-93 local tostring = tostring; -local unpack = table.unpack or unpack; -- luacheck: ignore 113/unpack -local pack = require "util.table".pack; -- TODO table.pack in 5.2+ +local unpack = table.unpack; +local pack = table.pack; local valid_utf8 = require "util.encodings".utf8.valid; local type = type; local dump = require "util.serialization".new("debug"); -local num_type = math.type or function (n) - return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float"; -end +local num_type = math.type; -- In Lua 5.3+ these formats throw an error if given a float local expects_integer = { c = true, d = true, i = true, o = true, u = true, X = true, x = true, }; @@ -35,7 +33,6 @@ local control_symbols = { ["\030"] = "\226\144\158", ["\031"] = "\226\144\159", ["\127"] = "\226\144\161", }; local supports_p = pcall(string.format, "%p", ""); -- >= Lua 5.4 -local supports_a = pcall(string.format, "%a", 0.0); -- > Lua 5.1 local function format(formatstring, ...) local args = pack(...); @@ -93,8 +90,6 @@ local function format(formatstring, ...) elseif expects_positive[option] and arg < 0 then args[i] = tostring(arg); return "[%s]"; - elseif (option == "a" or option == "A") and not supports_a then - return "%x"; else return -- acceptable number end diff --git a/util/hashring.lua b/util/hashring.lua index d4555669..5e71654b 100644 --- a/util/hashring.lua +++ b/util/hashring.lua @@ -1,3 +1,5 @@ +local it = require "util.iterators"; + local function generate_ring(nodes, num_replicas, hash) local new_ring = {}; for _, node_name in ipairs(nodes) do @@ -28,18 +30,22 @@ local function new(num_replicas, hash_function) return setmetatable({ nodes = {}, num_replicas = num_replicas, hash = hash_function }, hashring_mt); end; -function hashring_methods:add_node(name) +function hashring_methods:add_node(name, value) self.ring = nil; - self.nodes[name] = true; + self.nodes[name] = value == nil and true or value; table.insert(self.nodes, name); return true; end function hashring_methods:add_nodes(nodes) self.ring = nil; - for _, node_name in ipairs(nodes) do - if not self.nodes[node_name] then - self.nodes[node_name] = true; + local iter = pairs; + if nodes[1] then -- simple array? + iter = it.values; + end + for node_name, node_value in iter(nodes) do + if self.nodes[node_name] == nil then + self.nodes[node_name] = node_value == nil and true or node_value; table.insert(self.nodes, node_name); end end @@ -48,7 +54,7 @@ end function hashring_methods:remove_node(node_name) self.ring = nil; - if self.nodes[node_name] then + if self.nodes[node_name] ~= nil then for i, stored_node_name in ipairs(self.nodes) do if node_name == stored_node_name then self.nodes[node_name] = nil; @@ -69,18 +75,26 @@ end function hashring_methods:clone() local clone_hashring = new(self.num_replicas, self.hash); - clone_hashring:add_nodes(self.nodes); + for node_name, node_value in pairs(self.nodes) do + clone_hashring.nodes[node_name] = node_value; + end + clone_hashring.ring = nil; return clone_hashring; end function hashring_methods:get_node(key) + local node; local key_hash = self.hash(key); for _, replica_hash in ipairs(self.ring) do if key_hash < replica_hash then - return self.ring[replica_hash]; + node = self.ring[replica_hash]; + break; end end - return self.ring[self.ring[1]]; + if not node then + node = self.ring[self.ring[1]]; + end + return node, self.nodes[node]; end return { diff --git a/util/hmac.lua b/util/hmac.lua index 4cad17cc..ca030259 100644 --- a/util/hmac.lua +++ b/util/hmac.lua @@ -13,6 +13,10 @@ local hashes = require "util.hashes" return { md5 = hashes.hmac_md5, sha1 = hashes.hmac_sha1, + sha224 = hashes.hmac_sha224, sha256 = hashes.hmac_sha256, + sha384 = hashes.hmac_sha384, sha512 = hashes.hmac_sha512, + blake2s256 = hashes.hmac_blake2s256, + blake2b512 = hashes.hmac_blake2b512, }; diff --git a/util/human/io.lua b/util/human/io.lua index 7d7dea97..b272af71 100644 --- a/util/human/io.lua +++ b/util/human/io.lua @@ -8,7 +8,7 @@ end; local function getchar(n) local stty_ret = os.execute("stty raw -echo 2>/dev/null"); local ok, char; - if stty_ret == true or stty_ret == 0 then + if stty_ret then ok, char = pcall(io.read, n or 1); os.execute("stty sane"); else @@ -30,15 +30,12 @@ local function getline() end local function getpass() - local stty_ret, _, status_code = os.execute("stty -echo 2>/dev/null"); - if status_code then -- COMPAT w/ Lua 5.1 - stty_ret = status_code; - end - if stty_ret ~= 0 then + local stty_ret = os.execute("stty -echo 2>/dev/null"); + if not stty_ret then io.write("\027[08m"); -- ANSI 'hidden' text attribute end local ok, pass = pcall(io.read, "*l"); - if stty_ret == 0 then + if stty_ret then os.execute("stty sane"); else io.write("\027[00m"); diff --git a/util/human/units.lua b/util/human/units.lua index af233e98..329c8518 100644 --- a/util/human/units.lua +++ b/util/human/units.lua @@ -4,15 +4,7 @@ local math_floor = math.floor; local math_log = math.log; local math_max = math.max; local math_min = math.min; -local unpack = table.unpack or unpack; --luacheck: ignore 113 - -if math_log(10, 10) ~= 1 then - -- Lua 5.1 COMPAT - local log10 = math.log10; - function math_log(n, base) - return log10(n) / log10(base); - end -end +local unpack = table.unpack; local large = { "k", 1000, diff --git a/util/import.lua b/util/import.lua index 1007bc0a..0892e9b1 100644 --- a/util/import.lua +++ b/util/import.lua @@ -8,7 +8,7 @@ -local unpack = table.unpack or unpack; --luacheck: ignore 113 +local unpack = table.unpack; local t_insert = table.insert; function _G.import(module, ...) local m = package.loaded[module] or require(module); diff --git a/util/iterators.lua b/util/iterators.lua index c03c2fd6..eb4c54af 100644 --- a/util/iterators.lua +++ b/util/iterators.lua @@ -12,8 +12,8 @@ local it = {}; local t_insert = table.insert; local next = next; -local unpack = table.unpack or unpack; --luacheck: ignore 113 -local pack = table.pack or require "util.table".pack; +local unpack = table.unpack; +local pack = table.pack; local type = type; local table, setmetatable = table, setmetatable; @@ -240,7 +240,8 @@ function join_methods:prepend(f, s, var) end function it.join(f, s, var) - return setmetatable({ {f, s, var} }, join_mt); + local t = setmetatable({ {f, s, var} }, join_mt); + return t, { t, 1 }; end return it; diff --git a/util/jid.lua b/util/jid.lua index 694a6b1f..55567ea2 100644 --- a/util/jid.lua +++ b/util/jid.lua @@ -35,8 +35,7 @@ local function split(jid) if jid == nil then return; end local node, nodepos = match(jid, "^([^@/]+)@()"); local host, hostpos = match(jid, "^([^@/]+)()", nodepos); - if node ~= nil and host == nil then return nil, nil, nil; end - local resource = match(jid, "^/(.+)$", hostpos); + local resource = host and match(jid, "^/(.+)$", hostpos); if (host == nil) or ((resource == nil) and #jid >= hostpos) then return nil, nil, nil; end return node, host, resource; end @@ -91,9 +90,9 @@ local function compare(jid, acl) -- TODO compare to table of rules? local jid_node, jid_host, jid_resource = split(jid); local acl_node, acl_host, acl_resource = split(acl); - if ((acl_node ~= nil and acl_node == jid_node) or acl_node == nil) and - ((acl_host ~= nil and acl_host == jid_host) or acl_host == nil) and - ((acl_resource ~= nil and acl_resource == jid_resource) or acl_resource == nil) then + if (acl_node == nil or acl_node == jid_node) and + (acl_host == nil or acl_host == jid_host) and + (acl_resource == nil or acl_resource == jid_resource) then return true end return false @@ -111,6 +110,7 @@ local function resource(jid) return (select(3, split(jid))); end +-- TODO Forbid \20 at start and end of escaped output per XEP-0106 v1.1 local function escape(s) return s and (s:gsub("\\%x%x", backslash_escapes):gsub("[\"&'/:<>@ ]", escapes)); end local function unescape(s) return s and (s:gsub("\\%x%x", unescapes)); end diff --git a/util/jsonpointer.lua b/util/jsonpointer.lua index 9b871ae7..f1c354a4 100644 --- a/util/jsonpointer.lua +++ b/util/jsonpointer.lua @@ -1,6 +1,4 @@ -local m_type = math.type or function (n) - return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float"; -end; +local m_type = math.type; local function unescape_token(escaped_token) local unescaped = escaped_token:gsub("~1", "/"):gsub("~0", "~") diff --git a/util/jwt.lua b/util/jwt.lua index bf106dfa..42a9f7f2 100644 --- a/util/jwt.lua +++ b/util/jwt.lua @@ -1,4 +1,5 @@ local s_gsub = string.gsub; +local crypto = require "util.crypto"; local json = require "util.json"; local hashes = require "util.hashes"; local base64_encode = require "util.encodings".base64.encode; @@ -13,17 +14,8 @@ local function unb64url(data) return base64_decode(s_gsub(data, "[-_]", b64url_rep).."=="); end -local static_header = b64url('{"alg":"HS256","typ":"JWT"}') .. '.'; - -local function sign(key, payload) - local encoded_payload = json.encode(payload); - local signed = static_header .. b64url(encoded_payload); - local signature = hashes.hmac_sha256(key, signed); - return signed .. "." .. b64url(signature); -end - local jwt_pattern = "^(([A-Za-z0-9-_]+)%.([A-Za-z0-9-_]+))%.([A-Za-z0-9-_]+)$" -local function verify(key, blob) +local function decode_jwt(blob, expected_alg) local signed, bheader, bpayload, signature = string.match(blob, jwt_pattern); if not signed then return nil, "invalid-encoding"; @@ -31,21 +23,197 @@ local function verify(key, blob) local header = json.decode(unb64url(bheader)); if not header or type(header) ~= "table" then return nil, "invalid-header"; - elseif header.alg ~= "HS256" then + elseif header.alg ~= expected_alg then return nil, "unsupported-algorithm"; end - if not secure_equals(b64url(hashes.hmac_sha256(key, signed)), signature) then - return false, "signature-mismatch"; - end - local payload, err = json.decode(unb64url(bpayload)); + return signed, signature, bpayload; +end + +local function new_static_header(algorithm_name) + return b64url('{"alg":"'..algorithm_name..'","typ":"JWT"}') .. '.'; +end + +local function decode_raw_payload(raw_payload) + local payload, err = json.decode(unb64url(raw_payload)); if err ~= nil then return nil, "json-decode-error"; + elseif type(payload) ~= "table" then + return nil, "invalid-payload-type"; end return true, payload; end +-- HS*** family +local function new_hmac_algorithm(name) + local static_header = new_static_header(name); + + local hmac = hashes["hmac_sha"..name:sub(-3)]; + + local function sign(key, payload) + local encoded_payload = json.encode(payload); + local signed = static_header .. b64url(encoded_payload); + local signature = hmac(key, signed); + return signed .. "." .. b64url(signature); + end + + local function verify(key, blob) + local signed, signature, raw_payload = decode_jwt(blob, name); + if not signed then return nil, signature; end -- nil, err + + if not secure_equals(b64url(hmac(key, signed)), signature) then + return false, "signature-mismatch"; + end + + return decode_raw_payload(raw_payload); + end + + local function load_key(key) + assert(type(key) == "string", "key must be string (long, random, secure)"); + return key; + end + + return { sign = sign, verify = verify, load_key = load_key }; +end + +local function new_crypto_algorithm(name, key_type, c_sign, c_verify, sig_encode, sig_decode) + local static_header = new_static_header(name); + + return { + sign = function (private_key, payload) + local encoded_payload = json.encode(payload); + local signed = static_header .. b64url(encoded_payload); + + local signature = c_sign(private_key, signed); + if sig_encode then + signature = sig_encode(signature); + end + + return signed.."."..b64url(signature); + end; + + verify = function (public_key, blob) + local signed, signature, raw_payload = decode_jwt(blob, name); + if not signed then return nil, signature; end -- nil, err + + signature = unb64url(signature); + if sig_decode and signature then + signature = sig_decode(signature); + end + if not signature then + return false, "signature-mismatch"; + end + + local verify_ok = c_verify(public_key, signed, signature); + if not verify_ok then + return false, "signature-mismatch"; + end + + return decode_raw_payload(raw_payload); + end; + + load_public_key = function (public_key_pem) + local key = assert(crypto.import_public_pem(public_key_pem)); + assert(key:get_type() == key_type, "incorrect key type"); + return key; + end; + + load_private_key = function (private_key_pem) + local key = assert(crypto.import_private_pem(private_key_pem)); + assert(key:get_type() == key_type, "incorrect key type"); + return key; + end; + }; +end + +-- RS***, PS*** +local rsa_sign_algos = { RS = "rsassa_pkcs1", PS = "rsassa_pss" }; +local function new_rsa_algorithm(name) + local family, digest_bits = name:match("^(..)(...)$"); + local c_sign = crypto[rsa_sign_algos[family].."_sha"..digest_bits.."_sign"]; + local c_verify = crypto[rsa_sign_algos[family].."_sha"..digest_bits.."_verify"]; + return new_crypto_algorithm(name, "rsaEncryption", c_sign, c_verify); +end + +-- ES*** +local function new_ecdsa_algorithm(name, c_sign, c_verify, sig_bytes) + local function encode_ecdsa_sig(der_sig) + local r, s = crypto.parse_ecdsa_signature(der_sig, sig_bytes); + return r..s; + end + + local expected_sig_length = sig_bytes*2; + local function decode_ecdsa_sig(jwk_sig) + if #jwk_sig ~= expected_sig_length then + return nil; + end + return crypto.build_ecdsa_signature(jwk_sig:sub(1, sig_bytes), jwk_sig:sub(sig_bytes+1)); + end + return new_crypto_algorithm(name, "id-ecPublicKey", c_sign, c_verify, encode_ecdsa_sig, decode_ecdsa_sig); +end + +local algorithms = { + HS256 = new_hmac_algorithm("HS256"), HS384 = new_hmac_algorithm("HS384"), HS512 = new_hmac_algorithm("HS512"); + ES256 = new_ecdsa_algorithm("ES256", crypto.ecdsa_sha256_sign, crypto.ecdsa_sha256_verify, 32); + ES512 = new_ecdsa_algorithm("ES512", crypto.ecdsa_sha512_sign, crypto.ecdsa_sha512_verify, 66); + RS256 = new_rsa_algorithm("RS256"), RS384 = new_rsa_algorithm("RS384"), RS512 = new_rsa_algorithm("RS512"); + PS256 = new_rsa_algorithm("PS256"), PS384 = new_rsa_algorithm("PS384"), PS512 = new_rsa_algorithm("PS512"); +}; + +local function new_signer(algorithm, key_input, options) + local impl = assert(algorithms[algorithm], "Unknown JWT algorithm: "..algorithm); + local key = (impl.load_private_key or impl.load_key)(key_input); + local sign = impl.sign; + local default_ttl = (options and options.default_ttl) or 3600; + return function (payload) + local issued_at; + if not payload.iat then + issued_at = os.time(); + payload.iat = issued_at; + end + if not payload.exp then + payload.exp = (issued_at or os.time()) + default_ttl; + end + return sign(key, payload); + end +end + +local function new_verifier(algorithm, key_input, options) + local impl = assert(algorithms[algorithm], "Unknown JWT algorithm: "..algorithm); + local key = (impl.load_public_key or impl.load_key)(key_input); + local verify = impl.verify; + local check_expiry = not (options and options.accept_expired); + local claim_verifier = options and options.claim_verifier; + return function (token) + local ok, payload = verify(key, token); + if ok then + local expires_at = check_expiry and payload.exp; + if expires_at then + if type(expires_at) ~= "number" then + return nil, "invalid-expiry"; + elseif expires_at < os.time() then + return nil, "token-expired"; + end + end + if claim_verifier and not claim_verifier(payload) then + return nil, "incorrect-claims"; + end + end + return ok, payload; + end +end + +local function init(algorithm, private_key, public_key, options) + return new_signer(algorithm, private_key, options), new_verifier(algorithm, public_key or private_key, options); +end + return { - sign = sign; - verify = verify; + init = init; + new_signer = new_signer; + new_verifier = new_verifier; + -- Exported mainly for tests + _algorithms = algorithms; + -- Deprecated + sign = algorithms.HS256.sign; + verify = algorithms.HS256.verify; }; diff --git a/util/logger.lua b/util/logger.lua index 20a5cef2..148b98dc 100644 --- a/util/logger.lua +++ b/util/logger.lua @@ -10,6 +10,7 @@ local pairs = pairs; local ipairs = ipairs; local require = require; +local t_remove = table.remove; local _ENV = nil; -- luacheck: std none @@ -78,6 +79,20 @@ local function add_simple_sink(simple_sink_function, levels) for _, level in ipairs(levels or {"debug", "info", "warn", "error"}) do add_level_sink(level, sink_function); end + return sink_function; +end + +local function remove_sink(sink_function) + local removed; + for level, sinks in pairs(level_sinks) do + for i = #sinks, 1, -1 do + if sinks[i] == sink_function then + t_remove(sinks, i); + removed = true; + end + end + end + return removed; end return { @@ -87,4 +102,5 @@ return { add_level_sink = add_level_sink; add_simple_sink = add_simple_sink; new = make_logger; + remove_sink = remove_sink; }; diff --git a/util/mathcompat.lua b/util/mathcompat.lua new file mode 100644 index 00000000..e8acb261 --- /dev/null +++ b/util/mathcompat.lua @@ -0,0 +1,13 @@ +if not math.type then + + local function math_type(t) + if type(t) == "number" then + if t % 1 == 0 and t ~= t + 1 and t ~= t - 1 then + return "integer" + else + return "float" + end + end + end + _G.math.type = math_type +end diff --git a/util/multitable.lua b/util/multitable.lua index 4f2cd972..0c292b45 100644 --- a/util/multitable.lua +++ b/util/multitable.lua @@ -9,7 +9,7 @@ local select = select; local t_insert = table.insert; local pairs, next, type = pairs, next, type; -local unpack = table.unpack or unpack; --luacheck: ignore 113 +local unpack = table.unpack; local _ENV = nil; -- luacheck: std none diff --git a/util/openmetrics.lua b/util/openmetrics.lua index c18e63e9..7bdbde9e 100644 --- a/util/openmetrics.lua +++ b/util/openmetrics.lua @@ -1,7 +1,7 @@ --[[ This module implements a subset of the OpenMetrics Internet Draft version 00. -URL: https://tools.ietf.org/html/draft-richih-opsawg-openmetrics-00 +URL: https://datatracker.ietf.org/doc/html/draft-richih-opsawg-openmetrics-00 The following metric types are supported: @@ -26,7 +26,7 @@ local log = require "util.logger".init("util.openmetrics"); local new_multitable = require "util.multitable".new; local iter_multitable = require "util.multitable".iter; local t_concat, t_insert = table.concat, table.insert; -local t_pack, t_unpack = require "util.table".pack, table.unpack or unpack; --luacheck: ignore 113/unpack +local t_pack, t_unpack = table.pack, table.unpack; -- BEGIN of Utility: "metric proxy" -- This allows to wrap a MetricFamily in a proxy which only provides the @@ -35,6 +35,7 @@ local t_pack, t_unpack = require "util.table".pack, table.unpack or unpack; --lu -- `with_partial_label` by the moduleapi in order to pre-set the `host` label -- on metrics created in non-global modules. local metric_proxy_mt = {} +metric_proxy_mt.__name = "metric_proxy" metric_proxy_mt.__index = metric_proxy_mt local function new_metric_proxy(metric_family, with_labels_proxy_fun) @@ -128,6 +129,7 @@ end -- BEGIN of generic MetricFamily implementation local metric_family_mt = {} +metric_family_mt.__name = "metric_family" metric_family_mt.__index = metric_family_mt local function histogram_metric_ctor(orig_ctor, buckets) @@ -278,6 +280,7 @@ local function compose_name(name, unit) end local metric_registry_mt = {} +metric_registry_mt.__name = "metric_registry" metric_registry_mt.__index = metric_registry_mt local function new_metric_registry(backend) diff --git a/util/openssl.lua b/util/openssl.lua index 32b5aea7..3acb4f04 100644 --- a/util/openssl.lua +++ b/util/openssl.lua @@ -166,8 +166,7 @@ do -- Lua to shell calls. setmetatable(_M, { __index = function(_, command) return function(opts) - local ret = os_execute(serialize(command, type(opts) == "table" and opts or {})); - return ret == true or ret == 0; + return os_execute(serialize(command, type(opts) == "table" and opts or {})); end; end; }); diff --git a/util/paseto.lua b/util/paseto.lua new file mode 100644 index 00000000..6cd29f68 --- /dev/null +++ b/util/paseto.lua @@ -0,0 +1,218 @@ +local crypto = require "util.crypto"; +local json = require "util.json"; +local hashes = require "util.hashes"; +local base64_encode = require "util.encodings".base64.encode; +local base64_decode = require "util.encodings".base64.decode; +local secure_equals = require "util.hashes".equals; +local bit = require "util.bitcompat"; +local hex = require "util.hex"; +local rand = require "util.random"; +local s_pack = require "util.struct".pack; + +local s_gsub = string.gsub; + +local v4_public = {}; + +local b64url_rep = { ["+"] = "-", ["/"] = "_", ["="] = "", ["-"] = "+", ["_"] = "/" }; +local function b64url(data) + return (s_gsub(base64_encode(data), "[+/=]", b64url_rep)); +end + +local valid_tails = { + nil; -- Always invalid + "^.[AQgw]$"; -- b??????00 + "^..[AQgwEUk0IYo4Mcs8]$"; -- b????0000 +} + +local function unb64url(data) + local rem = #data%4; + if data:sub(-1,-1) == "=" or rem == 1 or (rem > 1 and not data:sub(-rem):match(valid_tails[rem])) then + return nil; + end + return base64_decode(s_gsub(data, "[-_]", b64url_rep).."=="); +end + +local function le64(n) + return s_pack("<I8", bit.band(n, 0x7F)); +end + +local function pae(parts) + if type(parts) ~= "table" then + error("bad argument #1 to 'pae' (table expected, got "..type(parts)..")"); + end + local o = { le64(#parts) }; + for _, part in ipairs(parts) do + table.insert(o, le64(#part)..part); + end + return table.concat(o); +end + +function v4_public.sign(m, sk, f, i) + if type(m) ~= "table" then + return nil, "PASETO payloads must be a table"; + end + m = json.encode(m); + local h = "v4.public."; + local m2 = pae({ h, m, f or "", i or "" }); + local sig = crypto.ed25519_sign(sk, m2); + if not f or f == "" then + return h..b64url(m..sig); + else + return h..b64url(m..sig).."."..b64url(f); + end +end + +function v4_public.verify(tok, pk, expected_f, i) + local h, sm, f = tok:match("^(v4%.public%.)([^%.]+)%.?(.*)$"); + if not h then + return nil, "invalid-token-format"; + end + f = f and unb64url(f) or nil; + if expected_f then + if not f or not secure_equals(expected_f, f) then + return nil, "invalid-footer"; + end + end + local raw_sm = unb64url(sm); + if not raw_sm or #raw_sm <= 64 then + return nil, "invalid-token-format"; + end + local s, m = raw_sm:sub(-64), raw_sm:sub(1, -65); + local m2 = pae({ h, m, f or "", i or "" }); + local ok = crypto.ed25519_verify(pk, m2, s); + if not ok then + return nil, "invalid-token"; + end + local payload, err = json.decode(m); + if err ~= nil or type(payload) ~= "table" then + return nil, "json-decode-error"; + end + return payload; +end + +v4_public.import_private_key = crypto.import_private_pem; +v4_public.import_public_key = crypto.import_public_pem; +function v4_public.new_keypair() + return crypto.generate_ed25519_keypair(); +end + +function v4_public.init(private_key_pem, public_key_pem, options) + local sign, verify = v4_public.sign, v4_public.verify; + local public_key = public_key_pem and v4_public.import_public_key(public_key_pem); + local private_key = private_key_pem and v4_public.import_private_key(private_key_pem); + local default_footer = options and options.default_footer; + local default_assertion = options and options.default_implicit_assertion; + return private_key and function (token, token_footer, token_assertion) + return sign(token, private_key, token_footer or default_footer, token_assertion or default_assertion); + end, public_key and function (token, expected_footer, token_assertion) + return verify(token, public_key, expected_footer or default_footer, token_assertion or default_assertion); + end; +end + +function v4_public.new_signer(private_key_pem, options) + return (v4_public.init(private_key_pem, nil, options)); +end + +function v4_public.new_verifier(public_key_pem, options) + return (select(2, v4_public.init(nil, public_key_pem, options))); +end + +local v3_local = { _key_mt = {} }; + +local function v3_local_derive_keys(k, n) + local tmp = hashes.hkdf_hmac_sha384(48, k, nil, "paseto-encryption-key"..n); + local Ek = tmp:sub(1, 32); + local n2 = tmp:sub(33); + local Ak = hashes.hkdf_hmac_sha384(48, k, nil, "paseto-auth-key-for-aead"..n); + return Ek, Ak, n2; +end + +function v3_local.encrypt(m, k, f, i) + assert(#k == 32) + if type(m) ~= "table" then + return nil, "PASETO payloads must be a table"; + end + m = json.encode(m); + local h = "v3.local."; + local n = rand.bytes(32); + local Ek, Ak, n2 = v3_local_derive_keys(k, n); + + local c = crypto.aes_256_ctr_encrypt(Ek, n2, m); + local m2 = pae({ h, n, c, f or "", i or "" }); + local t = hashes.hmac_sha384(Ak, m2); + + if not f or f == "" then + return h..b64url(n..c..t); + else + return h..b64url(n..c..t).."."..b64url(f); + end +end + +function v3_local.decrypt(tok, k, expected_f, i) + assert(#k == 32) + + local h, sm, f = tok:match("^(v3%.local%.)([^%.]+)%.?(.*)$"); + if not h then + return nil, "invalid-token-format"; + end + f = f and unb64url(f) or nil; + if expected_f then + if not f or not secure_equals(expected_f, f) then + return nil, "invalid-footer"; + end + end + local m = unb64url(sm); + if not m or #m <= 80 then + return nil, "invalid-token-format"; + end + local n, c, t = m:sub(1, 32), m:sub(33, -49), m:sub(-48); + local Ek, Ak, n2 = v3_local_derive_keys(k, n); + local preAuth = pae({ h, n, c, f or "", i or "" }); + local t2 = hashes.hmac_sha384(Ak, preAuth); + if not secure_equals(t, t2) then + return nil, "invalid-token"; + end + local m2 = crypto.aes_256_ctr_decrypt(Ek, n2, c); + if not m2 then + return nil, "invalid-token"; + end + + local payload, err = json.decode(m2); + if err ~= nil or type(payload) ~= "table" then + return nil, "json-decode-error"; + end + return payload; +end + +function v3_local.new_key() + return "secret-token:paseto.v3.local:"..hex.encode(rand.bytes(32)); +end + +function v3_local.init(key, options) + local encoded_key = key:match("^secret%-token:paseto%.v3%.local:(%x+)$"); + if not encoded_key or #encoded_key ~= 64 then + return error("invalid key for v3.local"); + end + local raw_key = hex.decode(encoded_key); + local default_footer = options and options.default_footer; + local default_assertion = options and options.default_implicit_assertion; + return function (token, token_footer, token_assertion) + return v3_local.encrypt(token, raw_key, token_footer or default_footer, token_assertion or default_assertion); + end, function (token, token_footer, token_assertion) + return v3_local.decrypt(token, raw_key, token_footer or default_footer, token_assertion or default_assertion); + end; +end + +function v3_local.new_signer(key, options) + return (v3_local.init(key, options)); +end + +function v3_local.new_verifier(key, options) + return (select(2, v3_local.init(key, options))); +end + +return { + pae = pae; + v3_local = v3_local; + v4_public = v4_public; +}; diff --git a/util/promise.lua b/util/promise.lua index c4e166ed..f56502d2 100644 --- a/util/promise.lua +++ b/util/promise.lua @@ -2,7 +2,7 @@ local promise_methods = {}; local promise_mt = { __name = "promise", __index = promise_methods }; local xpcall = require "util.xpcall".xpcall; -local unpack = table.unpack or unpack; --luacheck: ignore 113 +local unpack = table.unpack; function promise_mt:__tostring() return "promise (" .. (self._state or "invalid") .. ")"; @@ -57,10 +57,7 @@ local function promise_settle(promise, new_state, new_next, cbs, value) end local function new_resolve_functions(p) - local resolved = false; local function _resolve(v) - if resolved then return; end - resolved = true; if is_promise(v) then v:next(new_resolve_functions(p)); elseif promise_settle(p, "fulfilled", next_fulfilled, p._pending_on_fulfilled, v) then @@ -69,8 +66,6 @@ local function new_resolve_functions(p) end local function _reject(e) - if resolved then return; end - resolved = true; if promise_settle(p, "rejected", next_rejected, p._pending_on_rejected, e) then p.reason = e; end diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua index 4d49cd16..b3163799 100644 --- a/util/prosodyctl.lua +++ b/util/prosodyctl.lua @@ -224,8 +224,7 @@ local function call_luarocks(operation, mod, server) local ok, _, code = os.execute(render_cli("luarocks --lua-version={luav} {op} --tree={dir} {server&--server={server}} {mod?}", { dir = dir; op = operation; mod = mod; server = server; luav = _VERSION:match("5%.%d"); })); - if type(ok) == "number" then code = ok; end - return code; + return ok and code; end return { diff --git a/util/prosodyctl/cert.lua b/util/prosodyctl/cert.lua index 02c81585..ebc14a4e 100644 --- a/util/prosodyctl/cert.lua +++ b/util/prosodyctl/cert.lua @@ -179,7 +179,7 @@ local function copy(from, to, umask, owner, group) os.execute(("chown -c --reference=%s %s"):format(sh_esc(cert_basedir), sh_esc(to))); elseif owner and group then local ok = os.execute(("chown %s:%s %s"):format(sh_esc(owner), sh_esc(group), sh_esc(to))); - assert(ok == true or ok == 0, "Failed to change ownership of "..to); + assert(ok, "Failed to change ownership of "..to); end if old_umask then pposix.umask(old_umask); end return true; diff --git a/util/prosodyctl/check.lua b/util/prosodyctl/check.lua index e5566ff7..c1cf5ad1 100644 --- a/util/prosodyctl/check.lua +++ b/util/prosodyctl/check.lua @@ -155,7 +155,7 @@ local function check_turn_service(turn_service, ping_service) result.error = "TURN server did not response to allocation request: "..err; return result; elseif alloc_response:is_err_resp() then - result.error = ("TURN allocation failed: %d (%s)"):format(alloc_response:get_error()); + result.error = ("TURN server failed to create allocation: %d (%s)"):format(alloc_response:get_error()); return result; elseif not alloc_response:is_success_resp() then result.error = ("Unexpected TURN response: %d (%s)"):format(alloc_response:get_type()); diff --git a/util/prosodyctl/shell.lua b/util/prosodyctl/shell.lua index 8cf7df69..61050e4d 100644 --- a/util/prosodyctl/shell.lua +++ b/util/prosodyctl/shell.lua @@ -4,6 +4,8 @@ local st = require "util.stanza"; local path = require "util.paths"; local parse_args = require "util.argparse".parse; local unpack = table.unpack or _G.unpack; +local tc = require "util.termcolours"; +local isatty = require "util.pposix".isatty; local have_readline, readline = pcall(require, "readline"); @@ -27,7 +29,7 @@ local function read_line(prompt_string) end local function send_line(client, line) - client.send(st.stanza("repl-input"):text(line)); + client.send(st.stanza("repl-input", { width = os.getenv "COLUMNS" }):text(line)); end local function repl(client) @@ -64,6 +66,7 @@ end local function start(arg) --luacheck: ignore 212/arg local client = adminstream.client(); local opts, err, where = parse_args(arg); + local ttyout = isatty(io.stdout); if not opts then if err == "param-not-found" then @@ -77,8 +80,7 @@ local function start(arg) --luacheck: ignore 212/arg if arg[1] then if arg[2] then -- prosodyctl shell module reload foo bar.com --> module:reload("foo", "bar.com") - -- COMPAT Lua 5.1 doesn't have the separator argument to string.rep - arg[1] = string.format("%s:%s("..string.rep("%q, ", #arg-2):sub(1, -3)..")", unpack(arg)); + arg[1] = string.format("%s:%s("..string.rep("%q", #arg-2,", ")..")", unpack(arg)); end client.events.add_handler("connected", function() @@ -89,11 +91,15 @@ local function start(arg) --luacheck: ignore 212/arg local errors = 0; -- TODO This is weird, but works for now. client.events.add_handler("received", function(stanza) if stanza.name == "repl-output" or stanza.name == "repl-result" then + local dest = io.stdout; if stanza.attr.type == "error" then errors = errors + 1; - io.stderr:write(stanza:get_text(), "\n"); + dest = io.stderr; + end + if stanza.attr.eol == "0" then + dest:write(stanza:get_text()); else - print(stanza:get_text()); + dest:write(stanza:get_text(), "\n"); end end if stanza.name == "repl-result" then @@ -118,7 +124,11 @@ local function start(arg) --luacheck: ignore 212/arg client.events.add_handler("received", function (stanza) if stanza.name == "repl-output" or stanza.name == "repl-result" then local result_prefix = stanza.attr.type == "error" and "!" or "|"; - print(result_prefix.." "..stanza:get_text()); + local out = result_prefix.." "..stanza:get_text(); + if ttyout and stanza.attr.type == "error" then + out = tc.getstring(tc.getstyle("red"), out); + end + print(out); end if stanza.name == "repl-result" then repl(client); diff --git a/util/roles.lua b/util/roles.lua new file mode 100644 index 00000000..2c3a5026 --- /dev/null +++ b/util/roles.lua @@ -0,0 +1,110 @@ +local array = require "util.array"; +local it = require "util.iterators"; +local new_short_id = require "util.id".short; + +local role_methods = {}; +local role_mt = { + __index = role_methods; + __name = "role"; + __add = nil; +}; + +local function is_role(o) + local mt = getmetatable(o); + return mt == role_mt; +end + +local function _new_may(permissions, inherited_mays) + local n_inherited = inherited_mays and #inherited_mays; + return function (role, action, context) + -- Note: 'role' may be a descendent role, not only the one we're attached to + local policy = permissions[action]; + if policy ~= nil then + return policy; + end + if n_inherited then + for i = 1, n_inherited do + policy = inherited_mays[i](role, action, context); + if policy ~= nil then + return policy; + end + end + end + return nil; + end +end + +local permissions_key = {}; + +-- { +-- Required: +-- name = "My fancy role"; +-- +-- Optional: +-- inherits = { role_obj... } +-- default = true +-- priority = 100 +-- permissions = { +-- ["foo"] = true; -- allow +-- ["bar"] = false; -- deny +-- } +-- } +local function new(base_config, overrides) + local config = setmetatable(overrides or {}, { __index = base_config }); + local permissions = {}; + local inherited_mays; + if config.inherits then + inherited_mays = array.pluck(config.inherits, "may"); + end + local new_role = { + id = new_short_id(); + name = config.name; + description = config.description; + default = config.default; + priority = config.priority; + may = _new_may(permissions, inherited_mays); + inherits = config.inherits; + [permissions_key] = permissions; + }; + local desired_permissions = config.permissions or config[permissions_key]; + for k, v in pairs(desired_permissions or {}) do + permissions[k] = v; + end + return setmetatable(new_role, role_mt); +end + +function role_methods:clone(overrides) + return new(self, overrides); +end + +function role_methods:set_permission(permission_name, policy, overwrite) + local permissions = self[permissions_key]; + if overwrite ~= true and permissions[permission_name] ~= nil and permissions[permission_name] ~= policy then + return false, "policy-already-exists"; + end + permissions[permission_name] = policy; + return true; +end + +function role_methods:policies() + local policy_iterator, s, v = it.join(pairs(self[permissions_key])); + if self.inherits then + for _, inherited_role in ipairs(self.inherits) do + policy_iterator:append(inherited_role:policies()); + end + end + return policy_iterator, s, v; +end + +function role_mt.__tostring(self) + return ("role<[%s] %s>"):format(self.id or "nil", self.name or "[no name]"); +end + +function role_mt.__pairs(self) + return it.filter(permissions_key, next, self); +end + +return { + is_role = is_role; + new = new; +}; diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua index 37abf4a4..4606d1fd 100644 --- a/util/sasl/scram.lua +++ b/util/sasl/scram.lua @@ -240,7 +240,7 @@ local function init(registerMechanism) -- register channel binding equivalent registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, - scram_gen(hash_name:lower(), hash, hmac_hash, get_auth_db, true), {"tls-unique"}); + scram_gen(hash_name:lower(), hash, hmac_hash, get_auth_db, true), {"tls-unique", "tls-exporter"}); end registerSCRAMMechanism("SHA-1", hashes.sha1, hashes.hmac_sha1, hashes.pbkdf2_hmac_sha1); diff --git a/util/serialization.lua b/util/serialization.lua index d310a3e8..e2e104f1 100644 --- a/util/serialization.lua +++ b/util/serialization.lua @@ -21,10 +21,12 @@ local to_hex = require "util.hex".to; local pcall = pcall; local envload = require"util.envload".envload; +if not math.type then + require "util.mathcompat" +end + local pos_inf, neg_inf = math.huge, -math.huge; -local m_type = math.type or function (n) - return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float"; -end; +local m_type = math.type; local function rawpairs(t) return next, t, nil; diff --git a/util/session.lua b/util/session.lua index 25b22faf..d908476a 100644 --- a/util/session.lua +++ b/util/session.lua @@ -57,10 +57,16 @@ local function set_send(session) return session; end +local function set_role(session, role) + session.role = role; +end + return { new = new_session; + set_id = set_id; set_logger = set_logger; set_conn = set_conn; set_send = set_send; + set_role = set_role; } diff --git a/util/sql.lua b/util/sql.lua index 9d1c86ca..623d4ed4 100644 --- a/util/sql.lua +++ b/util/sql.lua @@ -99,6 +99,9 @@ end function engine:onconnect() -- luacheck: ignore 212/self -- Override from create_engine() end +function engine:ondisconnect() -- luacheck: ignore 212/self + -- Override from create_engine() +end function engine:prepquery(sql) if self.params.driver == "MySQL" then @@ -224,6 +227,7 @@ function engine:transaction(...) if not conn or not conn:ping() then log("debug", "Database connection was closed. Will reconnect and retry."); self.conn = nil; + self:ondisconnect(); log("debug", "Retrying SQL transaction [%s]", (...)); ok, ret, b, c = self:_transaction(...); log("debug", "SQL transaction retry %s", ok and "succeeded" or "failed"); @@ -365,8 +369,8 @@ local function db2uri(params) }; end -local function create_engine(_, params, onconnect) - return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt); +local function create_engine(_, params, onconnect, ondisconnect) + return setmetatable({ url = db2uri(params); params = params; onconnect = onconnect; ondisconnect = ondisconnect }, engine_mt); end return { diff --git a/util/sqlite3.lua b/util/sqlite3.lua new file mode 100644 index 00000000..dce45e13 --- /dev/null +++ b/util/sqlite3.lua @@ -0,0 +1,413 @@ + +-- luacheck: ignore 113/unpack 211 212 411 213 +local setmetatable, getmetatable = setmetatable, getmetatable; +local ipairs, unpack, select = ipairs, table.unpack or unpack, select; +local tonumber, tostring = tonumber, tostring; +local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback; +local error = error +local type = type +local t_concat = table.concat; +local t_insert = table.insert; +local s_char = string.char; +local log = require "util.logger".init("sql"); + +local lsqlite3 = require "lsqlite3"; +local build_url = require "socket.url".build; +local ROW, DONE = lsqlite3.ROW, lsqlite3.DONE; + +-- from sqlite3.h, no copyright claimed +local sqlite_errors = require"util.error".init("util.sqlite3", { + -- FIXME xmpp error conditions? + [1] = { code = 1; type = "modify"; condition = "ERROR"; text = "Generic error" }; + [2] = { code = 2; type = "cancel"; condition = "INTERNAL"; text = "Internal logic error in SQLite" }; + [3] = { code = 3; type = "auth"; condition = "PERM"; text = "Access permission denied" }; + [4] = { code = 4; type = "cancel"; condition = "ABORT"; text = "Callback routine requested an abort" }; + [5] = { code = 5; type = "wait"; condition = "BUSY"; text = "The database file is locked" }; + [6] = { code = 6; type = "wait"; condition = "LOCKED"; text = "A table in the database is locked" }; + [7] = { code = 7; type = "wait"; condition = "NOMEM"; text = "A malloc() failed" }; + [8] = { code = 8; type = "cancel"; condition = "READONLY"; text = "Attempt to write a readonly database" }; + [9] = { code = 9; type = "cancel"; condition = "INTERRUPT"; text = "Operation terminated by sqlite3_interrupt()" }; + [10] = { code = 10; type = "wait"; condition = "IOERR"; text = "Some kind of disk I/O error occurred" }; + [11] = { code = 11; type = "cancel"; condition = "CORRUPT"; text = "The database disk image is malformed" }; + [12] = { code = 12; type = "modify"; condition = "NOTFOUND"; text = "Unknown opcode in sqlite3_file_control()" }; + [13] = { code = 13; type = "wait"; condition = "FULL"; text = "Insertion failed because database is full" }; + [14] = { code = 14; type = "auth"; condition = "CANTOPEN"; text = "Unable to open the database file" }; + [15] = { code = 15; type = "cancel"; condition = "PROTOCOL"; text = "Database lock protocol error" }; + [16] = { code = 16; type = "continue"; condition = "EMPTY"; text = "Internal use only" }; + [17] = { code = 17; type = "modify"; condition = "SCHEMA"; text = "The database schema changed" }; + [18] = { code = 18; type = "modify"; condition = "TOOBIG"; text = "String or BLOB exceeds size limit" }; + [19] = { code = 19; type = "modify"; condition = "CONSTRAINT"; text = "Abort due to constraint violation" }; + [20] = { code = 20; type = "modify"; condition = "MISMATCH"; text = "Data type mismatch" }; + [21] = { code = 21; type = "modify"; condition = "MISUSE"; text = "Library used incorrectly" }; + [22] = { code = 22; type = "cancel"; condition = "NOLFS"; text = "Uses OS features not supported on host" }; + [23] = { code = 23; type = "auth"; condition = "AUTH"; text = "Authorization denied" }; + [24] = { code = 24; type = "modify"; condition = "FORMAT"; text = "Not used" }; + [25] = { code = 25; type = "modify"; condition = "RANGE"; text = "2nd parameter to sqlite3_bind out of range" }; + [26] = { code = 26; type = "cancel"; condition = "NOTADB"; text = "File opened that is not a database file" }; + [27] = { code = 27; type = "continue"; condition = "NOTICE"; text = "Notifications from sqlite3_log()" }; + [28] = { code = 28; type = "continue"; condition = "WARNING"; text = "Warnings from sqlite3_log()" }; + [100] = { code = 100; type = "continue"; condition = "ROW"; text = "sqlite3_step() has another row ready" }; + [101] = { code = 101; type = "continue"; condition = "DONE"; text = "sqlite3_step() has finished executing" }; +}); + +local assert = function(cond, errno, err) + return assert(sqlite_errors.coerce(cond, err or errno)); +end +local _ENV = nil; +-- luacheck: std none + +local column_mt = {}; +local table_mt = {}; +local query_mt = {}; +--local op_mt = {}; +local index_mt = {}; + +local function is_column(x) return getmetatable(x)==column_mt; end +local function is_index(x) return getmetatable(x)==index_mt; end +local function is_table(x) return getmetatable(x)==table_mt; end +local function is_query(x) return getmetatable(x)==query_mt; end +local function Integer(n) return "Integer()" end +local function String(n) return "String()" end + +local function Column(definition) + return setmetatable(definition, column_mt); +end +local function Table(definition) + local c = {} + for i,col in ipairs(definition) do + if is_column(col) then + c[i], c[col.name] = col, col; + elseif is_index(col) then + col.table = definition.name; + end + end + return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt); +end +local function Index(definition) + return setmetatable(definition, index_mt); +end + +function table_mt:__tostring() + local s = { 'name="'..self.__table__.name..'"' } + for i,col in ipairs(self.__table__) do + s[#s+1] = tostring(col); + end + return 'Table{ '..t_concat(s, ", ")..' }' +end +table_mt.__index = {}; +function table_mt.__index:create(engine) + return engine:_create_table(self); +end +function table_mt:__call(...) + -- TODO +end +function column_mt:__tostring() + return 'Column{ name="'..self.name..'", type="'..self.type..'" }' +end +function index_mt:__tostring() + local s = 'Index{ name="'..self.name..'"'; + for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end + return s..' }'; +-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }' +end + +local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end +local function parse_url(url) + local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)"); + assert(scheme, "Invalid URL format"); + local username, password, host, port; + local authpart, hostpart = secondpart:match("([^@]+)@([^@+])"); + if not authpart then hostpart = secondpart; end + if authpart then + username, password = authpart:match("([^:]*):(.*)"); + username = username or authpart; + password = password and urldecode(password); + end + if hostpart then + host, port = hostpart:match("([^:]*):(.*)"); + host = host or hostpart; + port = port and assert(tonumber(port), "Invalid URL format"); + end + return { + scheme = scheme:lower(); + username = username; password = password; + host = host; port = port; + database = #database > 0 and database or nil; + }; +end + +local engine = {}; +function engine:connect() + if self.conn then return true; end + + local params = self.params; + assert(params.driver == "SQLite3", "Only sqlite3 is supported"); + local dbh, err = sqlite_errors.coerce(lsqlite3.open(params.database)); + if not dbh then return nil, err; end + self.conn = dbh; + self.prepared = {}; + local ok, err = self:set_encoding(); + if not ok then + return ok, err; + end + local ok, err = self:onconnect(); + if ok == false then + return ok, err; + end + return true; +end +function engine:onconnect() + -- Override from create_engine() +end +function engine:ondisconnect() -- luacheck: ignore 212/self + -- Override from create_engine() +end +function engine:execute(sql, ...) + local success, err = self:connect(); + if not success then return success, err; end + local prepared = self.prepared; + + if select('#', ...) == 0 then + local ret = self.conn:exec(sql); + if ret ~= lsqlite3.OK then + local err = sqlite_errors.new(err); + err.text = self.conn:errmsg(); + return err; + end + return true; + end + + local stmt = prepared[sql]; + if not stmt then + local err; + stmt, err = self.conn:prepare(sql); + if not stmt then + err = sqlite_errors.new(err); + err.text = self.conn:errmsg(); + return stmt, err; + end + prepared[sql] = stmt; + end + + local ret = stmt:bind_values(...); + if ret ~= lsqlite3.OK then return nil, sqlite_errors.new(ret, { message = self.conn:errmsg() }); end + return stmt; +end + +local result_mt = { + __index = { + affected = function(self) return self.__affected; end; + rowcount = function(self) return self.__rowcount; end; + }, +}; + +local function iterator(table) + local i=0; + return function() + i=i+1; + local item=table[i]; + if item ~= nil then + return item; + end + end +end + +local function debugquery(where, sql, ...) + local i = 0; local a = {...} + sql = sql:gsub("\n?\t+", " "); + log("debug", "[%s] %s", where, (sql:gsub("%?", function () + i = i + 1; + local v = a[i]; + if type(v) == "string" then + v = ("'%s'"):format(v:gsub("'", "''")); + end + return tostring(v); + end))); +end + +function engine:execute_query(sql, ...) + local prepared = self.prepared; + local stmt = prepared[sql]; + if stmt and stmt:isopen() then + prepared[sql] = nil; -- Can't be used concurrently + else + stmt = assert(self.conn:prepare(sql)); + end + local ret = stmt:bind_values(...); + if ret ~= lsqlite3.OK then error(self.conn:errmsg()); end + local data, ret = {} + while stmt:step() == ROW do + t_insert(data, stmt:get_values()); + end + -- FIXME Error handling, BUSY, ERROR, MISUSE + if stmt:reset() == lsqlite3.OK then + prepared[sql] = stmt; + end + return setmetatable({ __data = data }, { __index = result_mt.__index, __call = iterator(data) }); +end +function engine:execute_update(sql, ...) + local prepared = self.prepared; + local stmt = prepared[sql]; + if not stmt or not stmt:isopen() then + stmt = assert(self.conn:prepare(sql)); + else + prepared[sql] = nil; + end + local ret = stmt:bind_values(...); + if ret ~= lsqlite3.OK then error(self.conn:errmsg()); end + local rowcount = 0; + repeat + ret = stmt:step(); + if ret == lsqlite3.ROW then + rowcount = rowcount + 1; + end + until ret ~= lsqlite3.ROW; + local affected = self.conn:changes(); + if stmt:reset() == lsqlite3.OK then + prepared[sql] = stmt; + end + return setmetatable({ __affected = affected, __rowcount = rowcount }, result_mt); +end +engine.insert = engine.execute_update; +engine.select = engine.execute_query; +engine.delete = engine.execute_update; +engine.update = engine.execute_update; +local function debugwrap(name, f) + return function (self, sql, ...) + debugquery(name, sql, ...) + return f(self, sql, ...) + end +end +function engine:debug(enable) + self._debug = enable; + if enable then + engine.insert = debugwrap("insert", engine.execute_update); + engine.select = debugwrap("select", engine.execute_query); + engine.delete = debugwrap("delete", engine.execute_update); + engine.update = debugwrap("update", engine.execute_update); + else + engine.insert = engine.execute_update; + engine.select = engine.execute_query; + engine.delete = engine.execute_update; + engine.update = engine.execute_update; + end +end +function engine:_(word) + local ret = self.conn:exec(word); + if ret ~= lsqlite3.OK then return nil, self.conn:errmsg(); end + return true; +end +function engine:_transaction(func, ...) + if not self.conn then + local a,b = self:connect(); + if not a then return a,b; end + end + --assert(not self.__transaction, "Recursive transactions not allowed"); + local ok, err = self:_"BEGIN"; + if not ok then return ok, err; end + self.__transaction = true; + local success, a, b, c = xpcall(func, debug_traceback, ...); + self.__transaction = nil; + if success then + log("debug", "SQL transaction success [%s]", tostring(func)); + local ok, err = self:_"COMMIT"; + if not ok then return ok, err; end -- commit failed + return success, a, b, c; + else + log("debug", "SQL transaction failure [%s]: %s", tostring(func), a); + if self.conn then self:_"ROLLBACK"; end + return success, a; + end +end +function engine:transaction(...) + local ok, ret = self:_transaction(...); + if not ok then + local conn = self.conn; + if not conn or not conn:isopen() then + self.conn = nil; + self:ondisconnect(); + ok, ret = self:_transaction(...); + end + end + return ok, ret; +end +function engine:_create_index(index) + local sql = "CREATE INDEX IF NOT EXISTS \""..index.name.."\" ON \""..index.table.."\" ("; + for i=1,#index do + sql = sql.."\""..index[i].."\""; + if i ~= #index then sql = sql..", "; end + end + sql = sql..");" + if index.unique then + sql = sql:gsub("^CREATE", "CREATE UNIQUE"); + end + if self._debug then + debugquery("create", sql); + end + return self:execute(sql); +end +function engine:_create_table(table) + local sql = "CREATE TABLE IF NOT EXISTS \""..table.name.."\" ("; + for i,col in ipairs(table.c) do + local col_type = col.type; + sql = sql.."\""..col.name.."\" "..col_type; + if col.nullable == false then sql = sql.." NOT NULL"; end + if col.primary_key == true then sql = sql.." PRIMARY KEY"; end + if col.auto_increment == true then + sql = sql.." AUTOINCREMENT"; + end + if i ~= #table.c then sql = sql..", "; end + end + sql = sql.. ");" + if self._debug then + debugquery("create", sql); + end + local success,err = self:execute(sql); + if not success then return success,err; end + for i,v in ipairs(table.__table__) do + if is_index(v) then + self:_create_index(v); + end + end + return success; +end +function engine:set_encoding() -- to UTF-8 + return self:transaction(function() + for encoding in self:select"PRAGMA encoding;" do + if encoding[1] == "UTF-8" then + self.charset = "utf8"; + end + end + end); +end +local engine_mt = { __index = engine }; + +local function db2uri(params) + return build_url{ + scheme = params.driver, + user = params.username, + password = params.password, + host = params.host, + port = params.port, + path = params.database, + }; +end + +local function create_engine(_, params, onconnect, ondisconnect) + assert(params.driver == "SQLite3", "Only SQLite3 is supported without LuaDBI"); + return setmetatable({ url = db2uri(params); params = params; onconnect = onconnect; ondisconnect = ondisconnect }, engine_mt); +end + +return { + is_column = is_column; + is_index = is_index; + is_table = is_table; + is_query = is_query; + Integer = Integer; + String = String; + Column = Column; + Table = Table; + Index = Index; + create_engine = create_engine; + db2uri = db2uri; +}; diff --git a/util/sslconfig.lua b/util/sslconfig.lua index 6074a1fb..0078365b 100644 --- a/util/sslconfig.lua +++ b/util/sslconfig.lua @@ -3,9 +3,12 @@ local type = type; local pairs = pairs; local rawset = rawset; +local rawget = rawget; +local error = error; local t_concat = table.concat; local t_insert = table.insert; local setmetatable = setmetatable; +local resolve_path = require"util.paths".resolve_relative_path; local _ENV = nil; -- luacheck: std none @@ -34,7 +37,7 @@ function handlers.options(config, field, new) options[value] = true; end end - config[field] = options; + rawset(config, field, options) end handlers.verifyext = handlers.options; @@ -70,6 +73,20 @@ finalisers.curveslist = finalisers.ciphers; -- TLS 1.3 ciphers finalisers.ciphersuites = finalisers.ciphers; +-- Path expansion +function finalisers.key(path, config) + if type(path) == "string" then + return resolve_path(config._basedir, path); + else + return nil + end +end +finalisers.certificate = finalisers.key; +finalisers.cafile = finalisers.key; +finalisers.capath = finalisers.key; +-- XXX: copied from core/certmanager.lua, but this seems odd, because it would remove a dhparam function from the config +finalisers.dhparam = finalisers.key; + -- protocol = "x" should enable only that protocol -- protocol = "x+" should enable x and later versions @@ -89,37 +106,81 @@ end -- Merge options from 'new' config into 'config' local function apply(config, new) + rawset(config, "_cache", nil); if type(new) == "table" then for field, value in pairs(new) do - (handlers[field] or rawset)(config, field, value); + -- exclude keys which are internal to the config builder + if field:sub(1, 1) ~= "_" then + (handlers[field] or rawset)(config, field, value); + end end end + return config end -- Finalize the config into the form LuaSec expects local function final(config) local output = { }; for field, value in pairs(config) do - output[field] = (finalisers[field] or id)(value); + -- exclude keys which are internal to the config builder + if field:sub(1, 1) ~= "_" then + output[field] = (finalisers[field] or id)(value, config); + end end -- Need to handle protocols last because it adds to the options list protocol(output); return output; end +local function build(config) + local cached = rawget(config, "_cache"); + if cached then + return cached, nil + end + + local ctx, err = rawget(config, "_context_factory")(config:final(), config); + if ctx then + rawset(config, "_cache", ctx); + end + return ctx, err +end + local sslopts_mt = { __index = { apply = apply; final = final; + build = build; }; + __newindex = function() + error("SSL config objects cannot be modified directly. Use :apply()") + end; }; -local function new() - return setmetatable({options={}}, sslopts_mt); + +-- passing basedir through everything is required to avoid sslconfig depending +-- on prosody.paths.config +local function new(context_factory, basedir) + return setmetatable({ + _context_factory = context_factory, + _basedir = basedir, + options={}, + }, sslopts_mt); end +local function clone(config) + local result = new(); + for k, v in pairs(config) do + -- note that we *do* copy the internal keys on clone -- we have to carry + -- both the factory and the cache with us + rawset(result, k, v); + end + return result +end + +sslopts_mt.__index.clone = clone; + return { apply = apply; final = final; - new = new; + _new = new; }; diff --git a/util/stanza.lua b/util/stanza.lua index 86b88169..0f8827d5 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -21,12 +21,15 @@ local type = type; local s_gsub = string.gsub; local s_sub = string.sub; local s_find = string.find; +local t_move = table.move or require "util.table".move; +local t_create = require"util.table".create; local valid_utf8 = require "util.encodings".utf8.valid; local do_pretty_printing, termcolours = pcall(require, "util.termcolours"); local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas"; +local xmpp_stanzas_attr = { xmlns = xmlns_stanzas }; local _ENV = nil; -- luacheck: std none @@ -179,6 +182,14 @@ function stanza_mt:get_child_text(name, xmlns) return nil; end +function stanza_mt:get_child_attr(name, xmlns, attr) + local tag = self:get_child(name, xmlns); + if tag then + return tag.attr[attr]; + end + return nil; +end + function stanza_mt:child_with_name(name) for _, child in ipairs(self.tags) do if child.name == name then return child; end @@ -283,25 +294,33 @@ function stanza_mt:find(path) end local function _clone(stanza, only_top) - local attr, tags = {}, {}; + local attr = {}; for k,v in pairs(stanza.attr) do attr[k] = v; end local old_namespaces, namespaces = stanza.namespaces; if old_namespaces then namespaces = {}; for k,v in pairs(old_namespaces) do namespaces[k] = v; end end - local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags }; + local tags, new; + if only_top then + tags = {}; + new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags }; + else + tags = t_create(#stanza.tags, 0); + new = t_create(#stanza, 4); + new.name = stanza.name; + new.attr = attr; + new.namespaces = namespaces; + new.tags = tags; + end + + setmetatable(new, stanza_mt); if not only_top then - for i=1,#stanza do - local child = stanza[i]; - if child.name then - child = _clone(child); - t_insert(tags, child); - end - t_insert(new, child); - end + t_move(stanza, 1, #stanza, 1, new); + t_move(stanza.tags, 1, #stanza.tags, 1, tags); + new:maptags(_clone); end - return setmetatable(new, stanza_mt); + return new; end local function clone(stanza, only_top) @@ -387,6 +406,33 @@ function stanza_mt.get_error(stanza) return error_type, condition or "undefined-condition", text, extra_tag; end +function stanza_mt.add_error(stanza, error_type, condition, error_message, error_by) + local extra; + if type(error_type) == "table" then -- an util.error or similar object + if type(error_type.extra) == "table" then + extra = error_type.extra; + end + if type(error_type.context) == "table" and type(error_type.context.by) == "string" then error_by = error_type.context.by; end + error_type, condition, error_message = error_type.type, error_type.condition, error_type.text; + end + if stanza.attr.from == error_by then + error_by = nil; + end + stanza:tag("error", {type = error_type, by = error_by}) --COMPAT: Some day xmlns:stanzas goes here + :tag(condition, xmpp_stanzas_attr); + if extra and condition == "gone" and type(extra.uri) == "string" then + stanza:text(extra.uri); + end + stanza:up(); + if error_message then stanza:text_tag("text", error_message, xmpp_stanzas_attr); end + if extra and is_stanza(extra.tag) then + stanza:add_child(extra.tag); + elseif extra and extra.namespace and extra.condition then + stanza:tag(extra.condition, { xmlns = extra.namespace }):up(); + end + return stanza:up(); +end + local function preserialize(stanza) local s = { name = stanza.name, attr = stanza.attr }; for _, child in ipairs(stanza) do @@ -461,7 +507,6 @@ local function reply(orig) }); end -local xmpp_stanzas_attr = { xmlns = xmlns_stanzas }; local function error_reply(orig, error_type, condition, error_message, error_by) if not is_stanza(orig) then error("bad argument to error_reply: expected stanza, got "..type(orig)); @@ -470,30 +515,9 @@ local function error_reply(orig, error_type, condition, error_message, error_by) end local t = reply(orig); t.attr.type = "error"; - local extra; - if type(error_type) == "table" then -- an util.error or similar object - if type(error_type.extra) == "table" then - extra = error_type.extra; - end - if type(error_type.context) == "table" and type(error_type.context.by) == "string" then error_by = error_type.context.by; end - error_type, condition, error_message = error_type.type, error_type.condition, error_type.text; - end - if t.attr.from == error_by then - error_by = nil; - end - t:tag("error", {type = error_type, by = error_by}) --COMPAT: Some day xmlns:stanzas goes here - :tag(condition, xmpp_stanzas_attr); - if extra and condition == "gone" and type(extra.uri) == "string" then - t:text(extra.uri); - end - t:up(); - if error_message then t:text_tag("text", error_message, xmpp_stanzas_attr); end - if extra and is_stanza(extra.tag) then - t:add_child(extra.tag); - elseif extra and extra.namespace and extra.condition then - t:tag(extra.condition, { xmlns = extra.namespace }):up(); - end - return t; -- stanza ready for adding app-specific errors + t:add_error(error_type, condition, error_message, error_by); + t.last_add = { t[1] }; -- ready to add application-specific errors + return t; end local function presence(attr) diff --git a/util/startup.lua b/util/startup.lua index 545b6ae7..8be54884 100644 --- a/util/startup.lua +++ b/util/startup.lua @@ -277,6 +277,11 @@ function startup.init_global_state() startup.detect_platform(); startup.detect_installed(); _G.prosody = prosody; + + -- COMPAT Lua < 5.3 + if not math.type then + require "util.mathcompat" + end end function startup.setup_datadir() diff --git a/util/vcard.lua b/util/vcard.lua deleted file mode 100644 index e311f73f..00000000 --- a/util/vcard.lua +++ /dev/null @@ -1,574 +0,0 @@ --- Copyright (C) 2011-2014 Kim Alvefur --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - --- TODO --- Fix folding. - -local st = require "util.stanza"; -local t_insert, t_concat = table.insert, table.concat; -local type = type; -local pairs, ipairs = pairs, ipairs; - -local from_text, to_text, from_xep54, to_xep54; - -local line_sep = "\n"; - -local vCard_dtd; -- See end of file -local vCard4_dtd; - -local function vCard_esc(s) - return s:gsub("[,:;\\]", "\\%1"):gsub("\n","\\n"); -end - -local function vCard_unesc(s) - return s:gsub("\\?[\\nt:;,]", { - ["\\\\"] = "\\", - ["\\n"] = "\n", - ["\\r"] = "\r", - ["\\t"] = "\t", - ["\\:"] = ":", -- FIXME Shouldn't need to escape : in values, just params - ["\\;"] = ";", - ["\\,"] = ",", - [":"] = "\29", - [";"] = "\30", - [","] = "\31", - }); -end - -local function item_to_xep54(item) - local t = st.stanza(item.name, { xmlns = "vcard-temp" }); - - local prop_def = vCard_dtd[item.name]; - if prop_def == "text" then - t:text(item[1]); - elseif type(prop_def) == "table" then - if prop_def.types and item.TYPE then - if type(item.TYPE) == "table" then - for _,v in pairs(prop_def.types) do - for _,typ in pairs(item.TYPE) do - if typ:upper() == v then - t:tag(v):up(); - break; - end - end - end - else - t:tag(item.TYPE:upper()):up(); - end - end - - if prop_def.props then - for _,prop in pairs(prop_def.props) do - if item[prop] then - for _, v in ipairs(item[prop]) do - t:text_tag(prop, v); - end - end - end - end - - if prop_def.value then - t:text_tag(prop_def.value, item[1]); - elseif prop_def.values then - local prop_def_values = prop_def.values; - local repeat_last = prop_def_values.behaviour == "repeat-last" and prop_def_values[#prop_def_values]; - for i=1,#item do - t:text_tag(prop_def.values[i] or repeat_last, item[i]); - end - end - end - - return t; -end - -local function vcard_to_xep54(vCard) - local t = st.stanza("vCard", { xmlns = "vcard-temp" }); - for i=1,#vCard do - t:add_child(item_to_xep54(vCard[i])); - end - return t; -end - -function to_xep54(vCards) - if not vCards[1] or vCards[1].name then - return vcard_to_xep54(vCards) - else - local t = st.stanza("xCard", { xmlns = "vcard-temp" }); - for i=1,#vCards do - t:add_child(vcard_to_xep54(vCards[i])); - end - return t; - end -end - -function from_text(data) - data = data -- unfold and remove empty lines - :gsub("\r\n","\n") - :gsub("\n ", "") - :gsub("\n\n+","\n"); - local vCards = {}; - local current; - for line in data:gmatch("[^\n]+") do - line = vCard_unesc(line); - local name, params, value = line:match("^([-%a]+)(\30?[^\29]*)\29(.*)$"); - value = value:gsub("\29",":"); - if #params > 0 then - local _params = {}; - for k,isval,v in params:gmatch("\30([^=]+)(=?)([^\30]*)") do - k = k:upper(); - local _vt = {}; - for _p in v:gmatch("[^\31]+") do - _vt[#_vt+1]=_p - _vt[_p]=true; - end - if isval == "=" then - _params[k]=_vt; - else - _params[k]=true; - end - end - params = _params; - end - if name == "BEGIN" and value == "VCARD" then - current = {}; - vCards[#vCards+1] = current; - elseif name == "END" and value == "VCARD" then - current = nil; - elseif current and vCard_dtd[name] then - local dtd = vCard_dtd[name]; - local item = { name = name }; - t_insert(current, item); - local up = current; - current = item; - if dtd.types then - for _, t in ipairs(dtd.types) do - t = t:lower(); - if ( params.TYPE and params.TYPE[t] == true) - or params[t] == true then - current.TYPE=t; - end - end - end - if dtd.props then - for _, p in ipairs(dtd.props) do - if params[p] then - if params[p] == true then - current[p]=true; - else - for _, prop in ipairs(params[p]) do - current[p]=prop; - end - end - end - end - end - if dtd == "text" or dtd.value then - t_insert(current, value); - elseif dtd.values then - for p in ("\30"..value):gmatch("\30([^\30]*)") do - t_insert(current, p); - end - end - current = up; - end - end - return vCards; -end - -local function item_to_text(item) - local value = {}; - for i=1,#item do - value[i] = vCard_esc(item[i]); - end - value = t_concat(value, ";"); - - local params = ""; - for k,v in pairs(item) do - if type(k) == "string" and k ~= "name" then - params = params .. (";%s=%s"):format(k, type(v) == "table" and t_concat(v,",") or v); - end - end - - return ("%s%s:%s"):format(item.name, params, value) -end - -local function vcard_to_text(vcard) - local t={}; - t_insert(t, "BEGIN:VCARD") - for i=1,#vcard do - t_insert(t, item_to_text(vcard[i])); - end - t_insert(t, "END:VCARD") - return t_concat(t, line_sep); -end - -function to_text(vCards) - if vCards[1] and vCards[1].name then - return vcard_to_text(vCards) - else - local t = {}; - for i=1,#vCards do - t[i]=vcard_to_text(vCards[i]); - end - return t_concat(t, line_sep); - end -end - -local function from_xep54_item(item) - local prop_name = item.name; - local prop_def = vCard_dtd[prop_name]; - - local prop = { name = prop_name }; - - if prop_def == "text" then - prop[1] = item:get_text(); - elseif type(prop_def) == "table" then - if prop_def.value then --single item - prop[1] = item:get_child_text(prop_def.value) or ""; - elseif prop_def.values then --array - local value_names = prop_def.values; - if value_names.behaviour == "repeat-last" then - for i=1,#item.tags do - t_insert(prop, item.tags[i]:get_text() or ""); - end - else - for i=1,#value_names do - t_insert(prop, item:get_child_text(value_names[i]) or ""); - end - end - elseif prop_def.names then - local names = prop_def.names; - for i=1,#names do - if item:get_child(names[i]) then - prop[1] = names[i]; - break; - end - end - end - - if prop_def.props_verbatim then - for k,v in pairs(prop_def.props_verbatim) do - prop[k] = v; - end - end - - if prop_def.types then - local types = prop_def.types; - prop.TYPE = {}; - for i=1,#types do - if item:get_child(types[i]) then - t_insert(prop.TYPE, types[i]:lower()); - end - end - if #prop.TYPE == 0 then - prop.TYPE = nil; - end - end - - -- A key-value pair, within a key-value pair? - if prop_def.props then - local params = prop_def.props; - for i=1,#params do - local name = params[i] - local data = item:get_child_text(name); - if data then - prop[name] = prop[name] or {}; - t_insert(prop[name], data); - end - end - end - else - return nil - end - - return prop; -end - -local function from_xep54_vCard(vCard) - local tags = vCard.tags; - local t = {}; - for i=1,#tags do - t_insert(t, from_xep54_item(tags[i])); - end - return t -end - -function from_xep54(vCard) - if vCard.attr.xmlns ~= "vcard-temp" then - return nil, "wrong-xmlns"; - end - if vCard.name == "xCard" then -- A collection of vCards - local t = {}; - local vCards = vCard.tags; - for i=1,#vCards do - t[i] = from_xep54_vCard(vCards[i]); - end - return t - elseif vCard.name == "vCard" then -- A single vCard - return from_xep54_vCard(vCard) - end -end - -local vcard4 = { } - -function vcard4:text(node, params, value) -- luacheck: ignore 212/params - self:tag(node:lower()) - -- FIXME params - if type(value) == "string" then - self:text_tag("text", value); - elseif vcard4[node] then - vcard4[node](value); - end - self:up(); -end - -function vcard4.N(value) - for i, k in ipairs(vCard_dtd.N.values) do - value:text_tag(k, value[i]); - end -end - -local xmlns_vcard4 = "urn:ietf:params:xml:ns:vcard-4.0" - -local function item_to_vcard4(item) - local typ = item.name:lower(); - local t = st.stanza(typ, { xmlns = xmlns_vcard4 }); - - local prop_def = vCard4_dtd[typ]; - if prop_def == "text" then - t:text_tag("text", item[1]); - elseif prop_def == "uri" then - if item.ENCODING and item.ENCODING[1] == 'b' then - t:text_tag("uri", "data:;base64," .. item[1]); - else - t:text_tag("uri", item[1]); - end - elseif type(prop_def) == "table" then - if prop_def.values then - for i, v in ipairs(prop_def.values) do - t:text_tag(v:lower(), item[i]); - end - else - t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"}) - end - else - t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"}) - end - return t; -end - -local function vcard_to_vcard4xml(vCard) - local t = st.stanza("vcard", { xmlns = xmlns_vcard4 }); - for i=1,#vCard do - t:add_child(item_to_vcard4(vCard[i])); - end - return t; -end - -local function vcards_to_vcard4xml(vCards) - if not vCards[1] or vCards[1].name then - return vcard_to_vcard4xml(vCards) - else - local t = st.stanza("vcards", { xmlns = xmlns_vcard4 }); - for i=1,#vCards do - t:add_child(vcard_to_vcard4xml(vCards[i])); - end - return t; - end -end - --- This was adapted from http://xmpp.org/extensions/xep-0054.html#dtd -vCard_dtd = { - VERSION = "text", --MUST be 3.0, so parsing is redundant - FN = "text", - N = { - values = { - "FAMILY", - "GIVEN", - "MIDDLE", - "PREFIX", - "SUFFIX", - }, - }, - NICKNAME = "text", - PHOTO = { - props_verbatim = { ENCODING = { "b" } }, - props = { "TYPE" }, - value = "BINVAL", --{ "EXTVAL", }, - }, - BDAY = "text", - ADR = { - types = { - "HOME", - "WORK", - "POSTAL", - "PARCEL", - "DOM", - "INTL", - "PREF", - }, - values = { - "POBOX", - "EXTADD", - "STREET", - "LOCALITY", - "REGION", - "PCODE", - "CTRY", - } - }, - LABEL = { - types = { - "HOME", - "WORK", - "POSTAL", - "PARCEL", - "DOM", - "INTL", - "PREF", - }, - value = "LINE", - }, - TEL = { - types = { - "HOME", - "WORK", - "VOICE", - "FAX", - "PAGER", - "MSG", - "CELL", - "VIDEO", - "BBS", - "MODEM", - "ISDN", - "PCS", - "PREF", - }, - value = "NUMBER", - }, - EMAIL = { - types = { - "HOME", - "WORK", - "INTERNET", - "PREF", - "X400", - }, - value = "USERID", - }, - JABBERID = "text", - MAILER = "text", - TZ = "text", - GEO = { - values = { - "LAT", - "LON", - }, - }, - TITLE = "text", - ROLE = "text", - LOGO = "copy of PHOTO", - AGENT = "text", - ORG = { - values = { - behaviour = "repeat-last", - "ORGNAME", - "ORGUNIT", - } - }, - CATEGORIES = { - values = "KEYWORD", - }, - NOTE = "text", - PRODID = "text", - REV = "text", - SORTSTRING = "text", - SOUND = "copy of PHOTO", - UID = "text", - URL = "text", - CLASS = { - names = { -- The item.name is the value if it's one of these. - "PUBLIC", - "PRIVATE", - "CONFIDENTIAL", - }, - }, - KEY = { - props = { "TYPE" }, - value = "CRED", - }, - DESC = "text", -}; -vCard_dtd.LOGO = vCard_dtd.PHOTO; -vCard_dtd.SOUND = vCard_dtd.PHOTO; - -vCard4_dtd = { - source = "uri", - kind = "text", - xml = "text", - fn = "text", - n = { - values = { - "family", - "given", - "middle", - "prefix", - "suffix", - }, - }, - nickname = "text", - photo = "uri", - bday = "date-and-or-time", - anniversary = "date-and-or-time", - gender = "text", - adr = { - values = { - "pobox", - "ext", - "street", - "locality", - "region", - "code", - "country", - } - }, - tel = "text", - email = "text", - impp = "uri", - lang = "language-tag", - tz = "text", - geo = "uri", - title = "text", - role = "text", - logo = "uri", - org = "text", - member = "uri", - related = "uri", - categories = "text", - note = "text", - prodid = "text", - rev = "timestamp", - sound = "uri", - uid = "uri", - clientpidmap = "number, uuid", - url = "uri", - version = "text", - key = "uri", - fburl = "uri", - caladruri = "uri", - caluri = "uri", -}; - -return { - from_text = from_text; - to_text = to_text; - - from_xep54 = from_xep54; - to_xep54 = to_xep54; - - to_vcard4 = vcards_to_vcard4xml; -}; diff --git a/util/watchdog.lua b/util/watchdog.lua index 516e60e4..407028a5 100644 --- a/util/watchdog.lua +++ b/util/watchdog.lua @@ -1,6 +1,5 @@ local timer = require "util.timer"; local setmetatable = setmetatable; -local os_time = os.time; local _ENV = nil; -- luacheck: std none @@ -9,27 +8,35 @@ local watchdog_methods = {}; local watchdog_mt = { __index = watchdog_methods }; local function new(timeout, callback) - local watchdog = setmetatable({ timeout = timeout, last_reset = os_time(), callback = callback }, watchdog_mt); - timer.add_task(timeout+1, function (current_time) - local last_reset = watchdog.last_reset; - if not last_reset then - return; - end - local time_left = (last_reset + timeout) - current_time; - if time_left < 0 then - return watchdog:callback(); - end - return time_left + 1; - end); + local watchdog = setmetatable({ + timeout = timeout; + callback = callback; + timer_id = nil; + }, watchdog_mt); + + watchdog:reset(); -- Kick things off + return watchdog; end -function watchdog_methods:reset() - self.last_reset = os_time(); +function watchdog_methods:reset(new_timeout) + if new_timeout then + self.timeout = new_timeout; + end + if self.timer_id then + timer.reschedule(self.timer_id, self.timeout+1); + else + self.timer_id = timer.add_task(self.timeout+1, function () + return self:callback(); + end); + end end function watchdog_methods:cancel() - self.last_reset = nil; + if self.timer_id then + timer.stop(self.timer_id); + self.timer_id = nil; + end end return { diff --git a/util/x509.lua b/util/x509.lua index 76b50076..51ca3c96 100644 --- a/util/x509.lua +++ b/util/x509.lua @@ -11,12 +11,12 @@ -- IDN libraries complicate that. --- [TLS-CERTS] - http://tools.ietf.org/html/rfc6125 --- [XMPP-CORE] - http://tools.ietf.org/html/rfc6120 --- [SRV-ID] - http://tools.ietf.org/html/rfc4985 --- [IDNA] - http://tools.ietf.org/html/rfc5890 --- [LDAP] - http://tools.ietf.org/html/rfc4519 --- [PKIX] - http://tools.ietf.org/html/rfc5280 +-- [TLS-CERTS] - https://www.rfc-editor.org/rfc/rfc6125.html +-- [XMPP-CORE] - https://www.rfc-editor.org/rfc/rfc6120.html +-- [SRV-ID] - https://www.rfc-editor.org/rfc/rfc4985.html +-- [IDNA] - https://www.rfc-editor.org/rfc/rfc5890.html +-- [LDAP] - https://www.rfc-editor.org/rfc/rfc4519.html +-- [PKIX] - https://www.rfc-editor.org/rfc/rfc5280.html local nameprep = require "util.encodings".stringprep.nameprep; local idna_to_ascii = require "util.encodings".idna.to_ascii; |