diff options
Diffstat (limited to 'util')
54 files changed, 3138 insertions, 576 deletions
diff --git a/util/adhoc.lua b/util/adhoc.lua index 17c9eee5..d81b8242 100644 --- a/util/adhoc.lua +++ b/util/adhoc.lua @@ -1,3 +1,5 @@ +-- luacheck: ignore 212/self + local function new_simple_form(form, result_handler) return function(self, data, state) if state then diff --git a/util/array.lua b/util/array.lua index 150b4355..0b60a4fd 100644 --- a/util/array.lua +++ b/util/array.lua @@ -19,7 +19,13 @@ local type = type; local array = {}; local array_base = {}; local array_methods = {}; -local array_mt = { __index = array_methods, __tostring = function (self) return "{"..self:concat(", ").."}"; end }; +local array_mt = { + __index = array_methods; + __name = "array"; + __tostring = function (self) return "{"..self:concat(", ").."}"; end; +}; + +function array_mt:__freeze() return self; end local function new_array(self, t, _s, _var) if type(t) == "function" then -- Assume iterator @@ -46,6 +52,19 @@ function array_mt.__eq(a, b) return true; end +function array_mt.__div(a1, func) + local a2 = new_array(); + local o = 0; + for i = 1, #a1 do + local new_value = func(a1[i]); + if new_value ~= nil then + o = o + 1; + a2[o] = new_value; + end + end + return a2; +end + setmetatable(array, { __call = new_array }); -- Read-only methods @@ -53,6 +72,12 @@ function array_methods:random() return self[math_random(1, #self)]; end +-- Return a random value excluding the one at idx +function array_methods:random_other(idx) + local max = #self; + return self[((math.random(1, max-1)+(idx-1))%max)+1]; +end + -- These methods can be called two ways: -- array.method(existing_array, [params [, ...]]) -- Create new array for result -- existing_array:method([params, ...]) -- Transform existing array into result diff --git a/util/async.lua b/util/async.lua new file mode 100644 index 00000000..20397785 --- /dev/null +++ b/util/async.lua @@ -0,0 +1,254 @@ +local logger = require "util.logger"; +local log = logger.init("util.async"); +local new_id = require "util.id".short; +local xpcall = require "util.xpcall".xpcall; + +local function checkthread() + local thread, main = coroutine.running(); + if not thread or main then + error("Not running in an async context, see https://prosody.im/doc/developers/util/async"); + end + return thread; +end + +local function runner_from_thread(thread) + local level = 0; + -- Find the 'level' of the top-most function (0 == current level, 1 == caller, ...) + while debug.getinfo(thread, level, "") do level = level + 1; end + local name, runner = debug.getlocal(thread, level-1, 1); + if name ~= "self" or type(runner) ~= "table" or runner.thread ~= thread then + return nil; + end + return runner; +end + +local function call_watcher(runner, watcher_name, ...) + local watcher = runner.watchers[watcher_name]; + if not watcher then + return false; + end + runner:log("debug", "Calling '%s' watcher", watcher_name); + local ok, err = xpcall(watcher, debug.traceback, runner, ...); + if not ok then + runner:log("error", "Error in '%s' watcher: %s", watcher_name, err); + return nil, err; + end + return true; +end + +local function runner_continue(thread) + -- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure) + if coroutine.status(thread) ~= "suspended" then -- This should suffice + log("error", "unexpected async state: thread not suspended"); + return false; + end + local ok, state, runner = coroutine.resume(thread); + if not ok then + local err = state; + -- Running the coroutine failed, which means we have to find the runner manually, + -- in order to inform the error handler + runner = runner_from_thread(thread); + if not runner then + log("error", "unexpected async state: unable to locate runner during error handling"); + return false; + end + call_watcher(runner, "error", debug.traceback(thread, err)); + runner.state, runner.thread = "ready", nil; + return runner:run(); + elseif state == "ready" then + -- If state is 'ready', it is our responsibility to update runner.state from 'waiting'. + -- We also have to :run(), because the queue might have further items that will not be + -- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer). + runner.state = "ready"; + runner:run(); + end + return true; +end + +local function waiter(num) + local thread = checkthread(); + num = num or 1; + local waiting; + return function () + if num == 0 then return; end -- already done + waiting = true; + coroutine.yield("wait"); + end, function () + num = num - 1; + if num == 0 and waiting then + runner_continue(thread); + elseif num < 0 then + error("done() called too many times"); + end + end; +end + +local function guarder() + local guards = {}; + local default_id = {}; + return function (id, func) + id = id or default_id; + local thread = checkthread(); + local guard = guards[id]; + if not guard then + guard = {}; + guards[id] = guard; + log("debug", "New guard!"); + else + table.insert(guard, thread); + log("debug", "Guarded. %d threads waiting.", #guard) + coroutine.yield("wait"); + end + local function exit() + local next_waiting = table.remove(guard, 1); + if next_waiting then + log("debug", "guard: Executing next waiting thread (%d left)", #guard) + runner_continue(next_waiting); + else + log("debug", "Guard off duty.") + guards[id] = nil; + end + end + if func then + func(); + exit(); + return; + end + return exit; + end; +end + +local runner_mt = {}; +runner_mt.__index = runner_mt; + +local function runner_create_thread(func, self) + local thread = coroutine.create(function (self) -- luacheck: ignore 432/self + while true do + func(coroutine.yield("ready", self)); + end + end); + debug.sethook(thread, debug.gethook()); + assert(coroutine.resume(thread, self)); -- Start it up, it will return instantly to wait for the first input + return thread; +end + +local function default_error_watcher(runner, err) + runner:log("error", "Encountered error: %s", err); + error(err); +end +local function default_func(f) f(); end +local function runner(func, watchers, data) + local id = new_id(); + local _log = logger.init("runner" .. id); + return setmetatable({ func = func or default_func, thread = false, state = "ready", notified_state = "ready", + queue = {}, watchers = watchers or { error = default_error_watcher }, data = data, id = id, _log = _log; } + , runner_mt); +end + +-- Add a task item for the runner to process +function runner_mt:run(input) + if input ~= nil then + table.insert(self.queue, input); + --self:log("debug", "queued new work item, %d items queued", #self.queue); + end + if self.state ~= "ready" then + -- The runner is busy. Indicate that the task item has been + -- queued, and return information about the current runner state + return true, self.state, #self.queue; + end + + local q, thread = self.queue, self.thread; + if not thread or coroutine.status(thread) == "dead" then + self:log("debug", "creating new coroutine"); + -- Create a new coroutine for this runner + thread = runner_create_thread(self.func, self); + self.thread = thread; + end + + -- Process task item(s) while the queue is not empty, and we're not blocked + local n, state, err = #q, self.state, nil; + self.state = "running"; + --self:log("debug", "running main loop"); + while n > 0 and state == "ready" and not err do + local consumed; + -- Loop through queue items, and attempt to run them + for i = 1,n do + local queued_input = q[i]; + local ok, new_state = coroutine.resume(thread, queued_input); + if not ok then + -- There was an error running the coroutine, save the error, mark runner as ready to begin again + consumed, state, err = i, "ready", debug.traceback(thread, new_state); + self.thread = nil; + break; + elseif new_state == "wait" then + -- Runner is blocked on waiting for a task item to complete + consumed, state = i, "waiting"; + break; + end + end + -- Loop ended - either queue empty because all tasks passed without blocking (consumed == nil) + -- or runner is blocked/errored, and consumed will contain the number of tasks processed so far + if not consumed then consumed = n; end + -- Remove consumed items from the queue array + if q[n+1] ~= nil then + n = #q; + end + for i = 1, n do + q[i] = q[consumed+i]; + end + n = #q; + end + -- Runner processed all items it can, so save current runner state + self.state = state; + if err or state ~= self.notified_state then + self:log("debug", "changed state from %s to %s", self.notified_state, err and ("error ("..state..")") or state); + if err then + state = "error" + else + self.notified_state = state; + end + local handler = self.watchers[state]; + if handler then handler(self, err); end + end + if n > 0 then + return self:run(); + end + return true, state, n; +end + +-- Add a task item to the queue without invoking the runner, even if it is idle +function runner_mt:enqueue(input) + table.insert(self.queue, input); + self:log("debug", "queued new work item, %d items queued", #self.queue); + return self; +end + +function runner_mt:log(level, fmt, ...) + return self._log(level, fmt, ...); +end + +function runner_mt:onready(f) + self.watchers.ready = f; + return self; +end + +function runner_mt:onwaiting(f) + self.watchers.waiting = f; + return self; +end + +function runner_mt:onerror(f) + self.watchers.error = f; + return self; +end + +local function ready() + return pcall(checkthread); +end + +return { + ready = ready; + waiter = waiter; + guarder = guarder; + runner = runner; +}; diff --git a/util/cache.lua b/util/cache.lua index 9c141bb6..a5fd5e6d 100644 --- a/util/cache.lua +++ b/util/cache.lua @@ -116,6 +116,25 @@ function cache_methods:tail() return tail.key, tail.value; end +function cache_methods:resize(new_size) + new_size = assert(tonumber(new_size), "cache size must be a number"); + new_size = math.floor(new_size); + assert(new_size > 0, "cache size must be greater than zero"); + local on_evict = self._on_evict; + while self._count > new_size do + local tail = self._tail; + local evicted_key, evicted_value = tail.key, tail.value; + if on_evict ~= nil and (on_evict == false or on_evict(evicted_key, evicted_value) == false) then + -- Cache is full, and we're not allowed to evict + return false; + end + _remove(self, tail); + self._data[evicted_key] = nil; + end + self.size = new_size; + return true; +end + function cache_methods:table() --luacheck: ignore 212/t if not self.proxy_table then @@ -139,6 +158,13 @@ function cache_methods:table() return self.proxy_table; end +function cache_methods:clear() + self._data = {}; + self._count = 0; + self._head = nil; + self._tail = nil; +end + local function new(size, on_evict) size = assert(tonumber(size), "cache size must be a number"); size = math.floor(size); diff --git a/util/caps.lua b/util/caps.lua index cd5ff9c0..de492edb 100644 --- a/util/caps.lua +++ b/util/caps.lua @@ -13,6 +13,7 @@ local t_insert, t_sort, t_concat = table.insert, table.sort, table.concat; local ipairs = ipairs; local _ENV = nil; +-- luacheck: std none local function calculate_hash(disco_info) local identities, features, extensions = {}, {}, {}; diff --git a/util/dataforms.lua b/util/dataforms.lua index 469ce976..052d6a55 100644 --- a/util/dataforms.lua +++ b/util/dataforms.lua @@ -8,14 +8,17 @@ local setmetatable = setmetatable; local ipairs = ipairs; -local tostring, type, next = tostring, type, next; +local type, next = type, next; +local tonumber = tonumber; local t_concat = table.concat; local st = require "util.stanza"; local jid_prep = require "util.jid".prep; local _ENV = nil; +-- luacheck: std none local xmlns_forms = 'jabber:x:data'; +local xmlns_validate = 'http://jabber.org/protocol/xdata-validate'; local form_t = {}; local form_mt = { __index = form_t }; @@ -25,21 +28,76 @@ local function new(layout) end function form_t.form(layout, data, formtype) - local form = st.stanza("x", { xmlns = xmlns_forms, type = formtype or "form" }); - if layout.title then - form:tag("title"):text(layout.title):up(); + if not formtype then formtype = "form" end + local form = st.stanza("x", { xmlns = xmlns_forms, type = formtype }); + if formtype == "cancel" then + return form; end - if layout.instructions then - form:tag("instructions"):text(layout.instructions):up(); + if formtype ~= "submit" then + if layout.title then + form:tag("title"):text(layout.title):up(); + end + if layout.instructions then + form:tag("instructions"):text(layout.instructions):up(); + end end for _, field in ipairs(layout) do local field_type = field.type or "text-single"; -- Add field tag - form:tag("field", { type = field_type, var = field.name, label = field.label }); + form:tag("field", { type = field_type, var = field.var or field.name, label = formtype ~= "submit" and field.label or nil }); + + if formtype ~= "submit" then + if field.desc then + form:text_tag("desc", field.desc); + end + end + + if formtype == "form" and field.datatype then + form:tag("validate", { xmlns = xmlns_validate, datatype = field.datatype }); + -- <basic/> assumed + form:up(); + end + + + local value = field.value; + local options = field.options; + + if data and data[field.name] ~= nil then + value = data[field.name]; - local value = (data and data[field.name]) or field.value; + if formtype == "form" and type(value) == "table" + and (field_type == "list-single" or field_type == "list-multi") then + -- Allow passing dynamically generated options as values + options, value = value, nil; + end + end + + if formtype == "form" and options then + local defaults = {}; + for _, val in ipairs(options) do + if type(val) == "table" then + form:tag("option", { label = val.label }):tag("value"):text(val.value):up():up(); + if val.default then + defaults[#defaults+1] = val.value; + end + else + form:tag("option", { label= val }):tag("value"):text(val):up():up(); + end + end + if not value then + if field_type == "list-single" then + value = defaults[1]; + elseif field_type == "list-multi" then + value = defaults; + end + end + end - if value then + if value ~= nil then + if type(value) == "number" then + -- TODO validate that this is ok somehow, eg check field.datatype + value = ("%g"):format(value); + end -- Add value, depending on type if field_type == "hidden" then if type(value) == "table" then @@ -48,7 +106,7 @@ function form_t.form(layout, data, formtype) :add_child(value) :up(); else - form:tag("value"):text(tostring(value)):up(); + form:tag("value"):text(value):up(); end elseif field_type == "boolean" then form:tag("value"):text((value and "1") or "0"):up(); @@ -68,40 +126,10 @@ function form_t.form(layout, data, formtype) form:tag("value"):text(line):up(); end elseif field_type == "list-single" then - if formtype ~= "result" then - local has_default = false; - for _, val in ipairs(field.options or value) do - if type(val) == "table" then - form:tag("option", { label = val.label }):tag("value"):text(val.value):up():up(); - if value == val.value or val.default and (not has_default) then - form:tag("value"):text(val.value):up(); - has_default = true; - end - else - form:tag("option", { label= val }):tag("value"):text(tostring(val)):up():up(); - end - end - end - if (field.options or formtype == "result") and value then - form:tag("value"):text(value):up(); - end + form:tag("value"):text(value):up(); elseif field_type == "list-multi" then - if formtype ~= "result" then - for _, val in ipairs(field.options or value) do - if type(val) == "table" then - form:tag("option", { label = val.label }):tag("value"):text(val.value):up():up(); - if not field.options and val.default then - form:tag("value"):text(val.value):up(); - end - else - form:tag("option", { label= val }):tag("value"):text(tostring(val)):up():up(); - end - end - end - if (field.options or formtype == "result") and value then - for _, val in ipairs(value) do - form:tag("value"):text(val):up(); - end + for _, val in ipairs(value) do + form:tag("value"):text(val):up(); end end end @@ -115,7 +143,7 @@ function form_t.form(layout, data, formtype) form:up(); end - if field.required then + if formtype == "form" and field.required then form:tag("required"):up(); end @@ -126,8 +154,9 @@ function form_t.form(layout, data, formtype) end local field_readers = {}; +local data_validators = {}; -function form_t.data(layout, stanza) +function form_t.data(layout, stanza, current) local data = {}; local errors = {}; local present = {}; @@ -135,21 +164,33 @@ function form_t.data(layout, stanza) for _, field in ipairs(layout) do local tag; for field_tag in stanza:childtags("field") do - if field.name == field_tag.attr.var then + if (field.var or field.name) == field_tag.attr.var then tag = field_tag; break; end end if not tag then - if field.required then + if current and current[field.name] ~= nil then + data[field.name] = current[field.name]; + elseif field.required then errors[field.name] = "Required value missing"; end - else + elseif field.name then present[field.name] = true; local reader = field_readers[field.type]; if reader then - data[field.name], errors[field.name] = reader(tag, field.required); + local value, err = reader(tag, field.required); + local validator = field.datatype and data_validators[field.datatype]; + if value ~= nil and validator then + local valid, ret = validator(value, field); + if valid then + value = ret; + else + value, err = nil, ret or ("Invalid value for data of type " .. field.datatype); + end + end + data[field.name], errors[field.name] = value, err; end end end @@ -248,8 +289,35 @@ field_readers["hidden"] = return field_tag:get_child_text("value"); end +data_validators["xs:integer"] = + function (data) + local n = tonumber(data); + if not n then + return false, "not a number"; + elseif n % 1 ~= 0 then + return false, "not an integer"; + end + return true, n; + end + + +local function get_form_type(form) + if not st.is_stanza(form) then + return nil, "not a stanza object"; + elseif form.attr.xmlns ~= "jabber:x:data" or form.name ~= "x" then + return nil, "not a dataform element"; + end + for field in form:childtags("field") do + if field.attr.var == "FORM_TYPE" then + return field:get_child_text("value"); + end + end + return ""; +end + return { new = new; + get_type = get_form_type; }; diff --git a/util/datamanager.lua b/util/datamanager.lua index bd8fb7bb..cf96887b 100644 --- a/util/datamanager.lua +++ b/util/datamanager.lua @@ -40,9 +40,10 @@ pcall(function() end); local _ENV = nil; +-- luacheck: std none ---- utils ----- -local encode, decode; +local encode, decode, store_encode; do local urlcodes = setmetatable({}, { __index = function (t, k) t[k] = char(tonumber(k, 16)); return t[k]; end }); @@ -53,6 +54,12 @@ do encode = function (s) return s and (s:gsub("%W", function (c) return format("%%%02x", c:byte()); end)); end + + -- Special encode function for store names, which historically were unencoded. + -- All currently known stores use a-z and underscore, so this one preserves underscores. + store_encode = function (s) + return s and (s:gsub("[^_%w]", function (c) return format("%%%02x", c:byte()); end)); + end end if not atomic_append then @@ -119,6 +126,7 @@ local function getpath(username, host, datastore, ext, create) ext = ext or "dat"; host = (host and encode(host)) or "_global"; username = username and encode(username); + datastore = store_encode(datastore); if username then if create then mkdir(mkdir(mkdir(data_path).."/"..host).."/"..datastore); end return format("%s/%s/%s/%s.%s", data_path, host, datastore, username, ext); diff --git a/util/datetime.lua b/util/datetime.lua index abb4e867..06be9fc2 100644 --- a/util/datetime.lua +++ b/util/datetime.lua @@ -15,6 +15,7 @@ local os_difftime = os.difftime; local tonumber = tonumber; local _ENV = nil; +-- luacheck: std none local function date(t) return os_date("!%Y-%m-%d", t); diff --git a/util/debug.lua b/util/debug.lua index 00f476d0..9a28395a 100644 --- a/util/debug.lua +++ b/util/debug.lua @@ -47,6 +47,7 @@ local function get_upvalues_table(func) for upvalue_num = 1, math.huge do local name, value = debug.getupvalue(func, upvalue_num); if not name then break; end + if name == "" then name = ("[%d]"):format(upvalue_num); end table.insert(upvalues, { name = name, value = value }); end end @@ -112,7 +113,9 @@ end local function build_source_boundary_marker(last_source_desc) local padding = string.rep("-", math.floor(((optimal_line_length - 6) - #last_source_desc)/2)); - return getstring(styles.boundary_padding, "v"..padding).." "..getstring(styles.filename, last_source_desc).." "..getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v ")); + return getstring(styles.boundary_padding, "v"..padding).." ".. + getstring(styles.filename, last_source_desc).." ".. + getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v ")); end local function _traceback(thread, message, level) @@ -142,9 +145,9 @@ local function _traceback(thread, message, level) local last_source_desc; local lines = {}; - for nlevel, level in ipairs(levels) do - local info = level.info; - local line = "..."; + for nlevel, current_level in ipairs(levels) do + local info = current_level.info; + local line; local func_type = info.namewhat.." "; local source_desc = (info.short_src == "[C]" and "C code") or info.short_src or "Unknown"; if func_type == " " then func_type = ""; end; @@ -160,7 +163,9 @@ local function _traceback(thread, message, level) if func_type == "global " or func_type == "local " then func_type = func_type.."function "; end - line = "[Lua] "..getstring(styles.location, info.short_src.." line "..info.currentline).." in "..func_type..getstring(styles.funcname, name).." (defined on line "..info.linedefined..")"; + line = "[Lua] "..getstring(styles.location, info.short_src.." line ".. + info.currentline).." in "..func_type..getstring(styles.funcname, name).. + " (defined on line "..info.linedefined..")"; end if source_desc ~= last_source_desc then -- Venturing into a new source, add marker for previous last_source_desc = source_desc; @@ -169,13 +174,13 @@ local function _traceback(thread, message, level) nlevel = nlevel-1; table.insert(lines, "\t"..(nlevel==0 and ">" or " ")..getstring(styles.level_num, "("..nlevel..") ")..line); local npadding = (" "):rep(#tostring(nlevel)); - if level.locals then - local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding); + if current_level.locals then + local locals_str = string_from_var_table(current_level.locals, optimal_line_length, "\t "..npadding); if locals_str then table.insert(lines, "\t "..npadding.."Locals: "..locals_str); end end - local upvalues_str = string_from_var_table(level.upvalues, optimal_line_length, "\t "..npadding); + local upvalues_str = string_from_var_table(current_level.upvalues, optimal_line_length, "\t "..npadding); if upvalues_str then table.insert(lines, "\t "..npadding.."Upvals: "..upvalues_str); end diff --git a/util/dependencies.lua b/util/dependencies.lua index de840241..7c7b938e 100644 --- a/util/dependencies.lua +++ b/util/dependencies.lua @@ -28,24 +28,11 @@ local function missingdep(name, sources, msg) end print(""); print(msg or (name.." is required for Prosody to run, so we will now exit.")); - print("More help can be found on our website, at http://prosody.im/doc/depends"); + print("More help can be found on our website, at https://prosody.im/doc/depends"); print("**************************"); print(""); end --- COMPAT w/pre-0.8 Debian: The Debian config file used to use --- util.ztact, which has been removed from Prosody in 0.8. This --- is to log an error for people who still use it, so they can --- update their configs. -package.preload["util.ztact"] = function () - if not package.loaded["core.loggingmanager"] then - error("util.ztact has been removed from Prosody and you need to fix your config " - .."file. More information can be found at http://prosody.im/doc/packagers#ztact", 0); - else - error("module 'util.ztact' has been deprecated in Prosody 0.8."); - end -end; - local function check_dependencies() if _VERSION < "Lua 5.1" then print "***********************************" @@ -77,6 +64,10 @@ local function check_dependencies() ["Source"] = "http://www.tecgraf.puc-rio.br/~diego/professional/luasocket/"; }); fatal = true; + elseif not socket.tcp4 then + -- COMPAT LuaSocket before being IP-version agnostic + socket.tcp4 = socket.tcp; + socket.udp4 = socket.udp; end local lfs, err = softreq "lfs" @@ -156,7 +147,7 @@ local function log_warnings() if ssl then local major, minor, veryminor, patched = ssl._VERSION:match("(%d+)%.(%d+)%.?(%d*)(M?)"); if not major or ((tonumber(major) == 0 and (tonumber(minor) or 0) <= 3 and (tonumber(veryminor) or 0) <= 2) and patched ~= "M") then - prosody.log("error", "This version of LuaSec contains a known bug that causes disconnects, see http://prosody.im/doc/depends"); + prosody.log("error", "This version of LuaSec contains a known bug that causes disconnects, see https://prosody.im/doc/depends"); end end local lxp = softreq"lxp"; @@ -165,7 +156,7 @@ local function log_warnings() prosody.log("error", "The version of LuaExpat on your system leaves Prosody " .."vulnerable to denial-of-service attacks. You should upgrade to " .."LuaExpat 1.3.0 or higher as soon as possible. See " - .."http://prosody.im/doc/depends#luaexpat for more information."); + .."https://prosody.im/doc/depends#luaexpat for more information."); end if not lxp.new({}).getcurrentbytecount then prosody.log("error", "The version of LuaExpat on your system does not support " @@ -173,7 +164,7 @@ local function log_warnings() .."networks (e.g. the internet) vulnerable to denial-of-service " .."attacks. You should upgrade to LuaExpat 1.3.0 or higher as " .."soon as possible. See " - .."http://prosody.im/doc/depends#luaexpat for more information."); + .."https://prosody.im/doc/depends#luaexpat for more information."); end end end diff --git a/util/envload.lua b/util/envload.lua index 926f20c0..6182a1f9 100644 --- a/util/envload.lua +++ b/util/envload.lua @@ -4,7 +4,7 @@ -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- --- luacheck: ignore 113/setfenv +-- luacheck: ignore 113/setfenv 113/loadstring local load, loadstring, setfenv = load, loadstring, setfenv; local io_open = io.open; diff --git a/util/events.lua b/util/events.lua index 6e13619c..0bf0ddcb 100644 --- a/util/events.lua +++ b/util/events.lua @@ -15,6 +15,7 @@ local setmetatable = setmetatable; local next = next; local _ENV = nil; +-- luacheck: std none local function new() -- Map event name to ordered list of handlers (lazily built): handlers[event_name] = array_of_handler_functions @@ -26,7 +27,7 @@ local function new() -- Event map: event_map[handler_function] = priority_number local event_map = {}; -- Called on-demand to build handlers entries - local function _rebuild_index(handlers, event) + local function _rebuild_index(self, event) local _handlers = event_map[event]; if not _handlers or next(_handlers) == nil then return; end local index = {}; @@ -34,7 +35,7 @@ local function new() t_insert(index, handler); end t_sort(index, function(a, b) return _handlers[a] > _handlers[b]; end); - handlers[event] = index; + self[event] = index; return index; end; setmetatable(handlers, { __index = _rebuild_index }); @@ -61,13 +62,13 @@ local function new() local function get_handlers(event) return handlers[event]; end; - local function add_handlers(handlers) - for event, handler in pairs(handlers) do + local function add_handlers(self) + for event, handler in pairs(self) do add_handler(event, handler); end end; - local function remove_handlers(handlers) - for event, handler in pairs(handlers) do + local function remove_handlers(self) + for event, handler in pairs(self) do remove_handler(event, handler); end end; @@ -81,6 +82,7 @@ local function new() end end; local function fire_event(event_name, event_data) + -- luacheck: ignore 432/event_name 432/event_data local w = wrappers[event_name] or global_wrappers; if w then local curr_wrapper = #w; diff --git a/util/filters.lua b/util/filters.lua index f405c0bd..f30dfd9c 100644 --- a/util/filters.lua +++ b/util/filters.lua @@ -9,6 +9,7 @@ local t_insert, t_remove = table.insert, table.remove; local _ENV = nil; +-- luacheck: std none local new_filter_hooks = {}; diff --git a/util/format.lua b/util/format.lua index 5f2b12be..c5e513fa 100644 --- a/util/format.lua +++ b/util/format.lua @@ -4,11 +4,10 @@ local tostring = tostring; local select = select; -local assert = assert; -local unpack = unpack; +local unpack = table.unpack or unpack; -- luacheck: ignore 113/unpack local type = type; -local function format(format, ...) +local function format(formatstring, ...) local args, args_length = { ... }, select('#', ...); -- format specifier spec: @@ -25,7 +24,7 @@ local function format(format, ...) -- process each format specifier local i = 0; - format = format:gsub("%%[^cdiouxXaAeEfgGqs%%]*[cdiouxXaAeEfgGqs%%]", function(spec) + formatstring = formatstring:gsub("%%[^cdiouxXaAeEfgGqs%%]*[cdiouxXaAeEfgGqs%%]", function(spec) if spec ~= "%%" then i = i + 1; local arg = args[i]; @@ -54,21 +53,12 @@ local function format(format, ...) else args[i] = tostring(arg); end - format = format .. " [%s]" + formatstring = formatstring .. " [%s]" end - return format:format(unpack(args)); -end - -local function test() - assert(format("%s", "hello") == "hello"); - assert(format("%s") == "<nil>"); - assert(format("%s", true) == "true"); - assert(format("%d", true) == "[true]"); - assert(format("%%", true) == "% [true]"); + return formatstring:format(unpack(args)); end return { format = format; - test = test; }; diff --git a/util/http.lua b/util/http.lua index f7259920..cfb89193 100644 --- a/util/http.lua +++ b/util/http.lua @@ -57,8 +57,19 @@ local function contains_token(field, token) return field:find(","..token:lower()..",", 1, true) ~= nil; end +local function normalize_path(path, is_dir) + if is_dir then + if path:sub(-1,-1) ~= "/" then path = path.."/"; end + else + if path:sub(-1,-1) == "/" then path = path:sub(1, -2); end + end + if path:sub(1,1) ~= "/" then path = "/"..path; end + return path; +end + return { urlencode = urlencode, urldecode = urldecode; formencode = formencode, formdecode = formdecode; contains_token = contains_token; + normalize_path = normalize_path; }; diff --git a/util/import.lua b/util/import.lua index c2b9dce1..8ecfe43c 100644 --- a/util/import.lua +++ b/util/import.lua @@ -8,9 +8,9 @@ -local unpack = table.unpack or unpack; --luacheck: ignore 113 +local unpack = table.unpack or unpack; --luacheck: ignore 113 143 local t_insert = table.insert; -function import(module, ...) +function _G.import(module, ...) local m = package.loaded[module] or require(module); if type(m) == "table" and ... then local ret = {}; diff --git a/util/indexedbheap.lua b/util/indexedbheap.lua new file mode 100644 index 00000000..7f193d54 --- /dev/null +++ b/util/indexedbheap.lua @@ -0,0 +1,157 @@ + +local setmetatable = setmetatable; +local math_floor = math.floor; +local t_remove = table.remove; + +local function _heap_insert(self, item, sync, item2, index) + local pos = #self + 1; + while true do + local half_pos = math_floor(pos / 2); + if half_pos == 0 or item > self[half_pos] then break; end + self[pos] = self[half_pos]; + sync[pos] = sync[half_pos]; + index[sync[pos]] = pos; + pos = half_pos; + end + self[pos] = item; + sync[pos] = item2; + index[item2] = pos; +end + +local function _percolate_up(self, k, sync, index) + local tmp = self[k]; + local tmp_sync = sync[k]; + while k ~= 1 do + local parent = math_floor(k/2); + if tmp < self[parent] then break; end + self[k] = self[parent]; + sync[k] = sync[parent]; + index[sync[k]] = k; + k = parent; + end + self[k] = tmp; + sync[k] = tmp_sync; + index[tmp_sync] = k; + return k; +end + +local function _percolate_down(self, k, sync, index) + local tmp = self[k]; + local tmp_sync = sync[k]; + local size = #self; + local child = 2*k; + while 2*k <= size do + if child ~= size and self[child] > self[child + 1] then + child = child + 1; + end + if tmp > self[child] then + self[k] = self[child]; + sync[k] = sync[child]; + index[sync[k]] = k; + else + break; + end + + k = child; + child = 2*k; + end + self[k] = tmp; + sync[k] = tmp_sync; + index[tmp_sync] = k; + return k; +end + +local function _heap_pop(self, sync, index) + local size = #self; + if size == 0 then return nil; end + + local result = self[1]; + local result_sync = sync[1]; + index[result_sync] = nil; + if size == 1 then + self[1] = nil; + sync[1] = nil; + return result, result_sync; + end + self[1] = t_remove(self); + sync[1] = t_remove(sync); + index[sync[1]] = 1; + + _percolate_down(self, 1, sync, index); + + return result, result_sync; +end + +local indexed_heap = {}; + +function indexed_heap:insert(item, priority, id) + if id == nil then + id = self.current_id; + self.current_id = id + 1; + end + self.items[id] = item; + _heap_insert(self.priorities, priority, self.ids, id, self.index); + return id; +end +function indexed_heap:pop() + local priority, id = _heap_pop(self.priorities, self.ids, self.index); + if id then + local item = self.items[id]; + self.items[id] = nil; + return priority, item, id; + end +end +function indexed_heap:peek() + return self.priorities[1]; +end +function indexed_heap:reprioritize(id, priority) + local k = self.index[id]; + if k == nil then return; end + self.priorities[k] = priority; + + k = _percolate_up(self.priorities, k, self.ids, self.index); + _percolate_down(self.priorities, k, self.ids, self.index); +end +function indexed_heap:remove_index(k) + local result = self.priorities[k]; + if result == nil then return; end + + local result_sync = self.ids[k]; + local item = self.items[result_sync]; + local size = #self.priorities; + + self.priorities[k] = self.priorities[size]; + self.ids[k] = self.ids[size]; + self.index[self.ids[k]] = k; + + t_remove(self.priorities); + t_remove(self.ids); + + self.index[result_sync] = nil; + self.items[result_sync] = nil; + + if size > k then + k = _percolate_up(self.priorities, k, self.ids, self.index); + _percolate_down(self.priorities, k, self.ids, self.index); + end + + return result, item, result_sync; +end +function indexed_heap:remove(id) + return self:remove_index(self.index[id]); +end + +local mt = { __index = indexed_heap }; + +local _M = { + create = function() + return setmetatable({ + ids = {}; -- heap of ids, sync'd with priorities + items = {}; -- map id->items + priorities = {}; -- heap of priorities + index = {}; -- map of id->index of id in ids + current_id = 1.5 + }, mt); + end +}; +return _M; diff --git a/util/ip.lua b/util/ip.lua index 81a98ef7..0ec9e297 100644 --- a/util/ip.lua +++ b/util/ip.lua @@ -5,69 +5,76 @@ -- COPYING file in the source package for more information. -- +local net = require "util.net"; +local hex = require "util.hex"; + local ip_methods = {}; -local ip_mt = { __index = function (ip, key) return (ip_methods[key])(ip); end, - __tostring = function (ip) return ip.addr; end, - __eq = function (ipA, ipB) return ipA.addr == ipB.addr; end}; -local hex2bits = { ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111", ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011", ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111" }; + +local ip_mt = { + __index = function (ip, key) + local method = ip_methods[key]; + if not method then return nil; end + local ret = method(ip); + ip[key] = ret; + return ret; + end, + __tostring = function (ip) return ip.addr; end, + __eq = function (ipA, ipB) return ipA.packed == ipB.packed; end +}; + +local hex2bits = { + ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", + ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111", + ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011", + ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111", +}; local function new_ip(ipStr, proto) - if not proto then - local sep = ipStr:match("^%x+(.)"); - if sep == ":" or (not(sep) and ipStr:sub(1,1) == ":") then - proto = "IPv6" - elseif sep == "." then - proto = "IPv4" - end - if not proto then - return nil, "invalid address"; - end - elseif proto ~= "IPv4" and proto ~= "IPv6" then - return nil, "invalid protocol"; - end local zone; - if proto == "IPv6" and ipStr:find('%', 1, true) then + if (not proto or proto == "IPv6") and ipStr:find('%', 1, true) then ipStr, zone = ipStr:match("^(.-)%%(.*)"); end - if proto == "IPv6" and ipStr:find('.', 1, true) then - local changed; - ipStr, changed = ipStr:gsub(":(%d+)%.(%d+)%.(%d+)%.(%d+)$", function(a,b,c,d) - return (":%04X:%04X"):format(a*256+b,c*256+d); - end); - if changed ~= 1 then return nil, "invalid-address"; end + + local packed, err = net.pton(ipStr); + if not packed then return packed, err end + if proto == "IPv6" and #packed ~= 16 then + return nil, "invalid-ipv6"; + elseif proto == "IPv4" and #packed ~= 4 then + return nil, "invalid-ipv4"; + elseif not proto then + if #packed == 16 then + proto = "IPv6"; + elseif #packed == 4 then + proto = "IPv4"; + else + return nil, "unknown protocol"; + end + elseif proto ~= "IPv6" and proto ~= "IPv4" then + return nil, "invalid protocol"; end - return setmetatable({ addr = ipStr, proto = proto, zone = zone }, ip_mt); + return setmetatable({ addr = ipStr, packed = packed, proto = proto, zone = zone }, ip_mt); +end + +function ip_methods:normal() + return net.ntop(self.packed); end -local function toBits(ip) - local result = ""; - local fields = {}; +function ip_methods.bits(ip) + return hex.to(ip.packed):upper():gsub(".", hex2bits); +end + +function ip_methods.bits_full(ip) if ip.proto == "IPv4" then ip = ip.toV4mapped; end - ip = (ip.addr):upper(); - ip:gsub("([^:]*):?", function (c) fields[#fields + 1] = c end); - if not ip:match(":$") then fields[#fields] = nil; end - for i, field in ipairs(fields) do - if field:len() == 0 and i ~= 1 and i ~= #fields then - for _ = 1, 16 * (9 - #fields) do - result = result .. "0"; - end - else - for _ = 1, 4 - field:len() do - result = result .. "0000"; - end - for j = 1, field:len() do - result = result .. hex2bits[field:sub(j, j)]; - end - end - end - return result; + return ip.bits; end +local match; + local function commonPrefixLength(ipA, ipB) - ipA, ipB = toBits(ipA), toBits(ipB); + ipA, ipB = ipA.bits_full, ipB.bits_full; for i = 1, 128 do if ipA:sub(i,i) ~= ipB:sub(i,i) then return i-1; @@ -76,56 +83,60 @@ local function commonPrefixLength(ipA, ipB) return 128; end +-- Instantiate once +local loopback = new_ip("::1"); +local loopback4 = new_ip("127.0.0.0"); +local sixtofour = new_ip("2002::"); +local teredo = new_ip("2001::"); +local linklocal = new_ip("fe80::"); +local linklocal4 = new_ip("169.254.0.0"); +local uniquelocal = new_ip("fc00::"); +local sitelocal = new_ip("fec0::"); +local sixbone = new_ip("3ffe::"); +local defaultunicast = new_ip("::"); +local multicast = new_ip("ff00::"); +local ipv6mapped = new_ip("::ffff:0:0"); + local function v4scope(ip) - local fields = {}; - ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end); - -- Loopback: - if fields[1] == 127 then + if match(ip, loopback4, 8) then return 0x2; - -- Link-local unicast: - elseif fields[1] == 169 and fields[2] == 254 then + elseif match(ip, linklocal4) then return 0x2; - -- Global unicast: - else + else -- Global unicast return 0xE; end end local function v6scope(ip) - -- Loopback: - if ip:match("^[0:]*1$") then + if ip == loopback then return 0x2; - -- Link-local unicast: - elseif ip:match("^[Ff][Ee][89ABab]") then + elseif match(ip, linklocal, 10) then return 0x2; - -- Site-local unicast: - elseif ip:match("^[Ff][Ee][CcDdEeFf]") then + elseif match(ip, sitelocal, 10) then return 0x5; - -- Multicast: - elseif ip:match("^[Ff][Ff]") then - return tonumber("0x"..ip:sub(4,4)); - -- Global unicast: - else + elseif match(ip, multicast, 10) then + return ip.packed:byte(2) % 0x10; + else -- Global unicast return 0xE; end end local function label(ip) - if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then + if ip == loopback then return 0; - elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then + elseif match(ip, sixtofour, 16) then return 2; - elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then + elseif match(ip, teredo, 32) then return 5; - elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then + elseif match(ip, uniquelocal, 7) then return 13; - elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then + elseif match(ip, sitelocal, 10) then return 11; - elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then + elseif match(ip, sixbone, 16) then return 12; - elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then + elseif match(ip, defaultunicast, 96) then return 3; - elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then + elseif match(ip, ipv6mapped, 96) then return 4; else return 1; @@ -133,91 +144,67 @@ local function label(ip) end local function precedence(ip) - if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then + if ip == loopback then return 50; - elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then + elseif match(ip, sixtofour, 16) then return 30; - elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then + elseif match(ip, teredo, 32) then return 5; - elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then + elseif match(ip, uniquelocal, 7) then return 3; - elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then + elseif match(ip, sitelocal, 10) then return 1; - elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then + elseif match(ip, sixbone, 16) then return 1; - elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then + elseif match(ip, defaultunicast, 96) then return 1; - elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then + elseif match(ip, ipv6mapped, 96) then return 35; else return 40; end end -local function toV4mapped(ip) - local fields = {}; - local ret = "::ffff:"; - ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end); - ret = ret .. ("%02x"):format(fields[1]); - ret = ret .. ("%02x"):format(fields[2]); - ret = ret .. ":" - ret = ret .. ("%02x"):format(fields[3]); - ret = ret .. ("%02x"):format(fields[4]); - return new_ip(ret, "IPv6"); -end - function ip_methods:toV4mapped() if self.proto ~= "IPv4" then return nil, "No IPv4 address" end - local value = toV4mapped(self.addr); - self.toV4mapped = value; + local value = new_ip("::ffff:" .. self.normal); return value; end function ip_methods:label() - local value; if self.proto == "IPv4" then - value = label(self.toV4mapped); + return label(self.toV4mapped); else - value = label(self); + return label(self); end - self.label = value; - return value; end function ip_methods:precedence() - local value; if self.proto == "IPv4" then - value = precedence(self.toV4mapped); + return precedence(self.toV4mapped); else - value = precedence(self); + return precedence(self); end - self.precedence = value; - return value; end function ip_methods:scope() - local value; if self.proto == "IPv4" then - value = v4scope(self.addr); + return v4scope(self); else - value = v6scope(self.addr); + return v6scope(self); end - self.scope = value; - return value; end +local rfc1918_8 = new_ip("10.0.0.0"); +local rfc1918_12 = new_ip("172.16.0.0"); +local rfc1918_16 = new_ip("192.168.0.0"); +local rfc6598 = new_ip("100.64.0.0"); + function ip_methods:private() local private = self.scope ~= 0xE; if not private and self.proto == "IPv4" then - local ip = self.addr; - local fields = {}; - ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end); - if fields[1] == 127 or fields[1] == 10 or (fields[1] == 192 and fields[2] == 168) - or (fields[1] == 172 and (fields[2] >= 16 or fields[2] <= 32)) then - private = true; - end + return match(self, rfc1918_8, 8) or match(self, rfc1918_12, 12) or match(self, rfc1918_16) or match(self, rfc6598, 10); end - self.private = private; return private; end @@ -231,15 +218,26 @@ local function parse_cidr(cidr) return new_ip(cidr), bits; end -local function match(ipA, ipB, bits) - local common_bits = commonPrefixLength(ipA, ipB); - if bits and ipB.proto == "IPv4" then - common_bits = common_bits - 96; -- v6 mapped addresses always share these bits +function match(ipA, ipB, bits) + if not bits or bits >= 128 or ipB.proto == "IPv4" and bits >= 32 then + return ipA == ipB; + elseif bits < 1 then + return true; + end + if ipA.proto ~= ipB.proto then + if ipA.proto == "IPv4" then + ipA = ipA.toV4mapped; + elseif ipB.proto == "IPv4" then + ipB = ipB.toV4mapped; + bits = bits + (128 - 32); + end end - return common_bits >= (bits or 128); + return ipA.bits:sub(1, bits) == ipB.bits:sub(1, bits); end -return {new_ip = new_ip, +return { + new_ip = new_ip, commonPrefixLength = commonPrefixLength, parse_cidr = parse_cidr, - match=match}; + match = match, +}; diff --git a/util/iterators.lua b/util/iterators.lua index bd150ff2..302cca36 100644 --- a/util/iterators.lua +++ b/util/iterators.lua @@ -12,8 +12,13 @@ local it = {}; local t_insert = table.insert; local select, next = select, next; -local unpack = table.unpack or unpack; --luacheck: ignore 113 -local pack = table.pack or function (...) return { n = select("#", ...), ... }; end +local unpack = table.unpack or unpack; --luacheck: ignore 113 143 +local pack = table.pack or function (...) return { n = select("#", ...), ... }; end -- luacheck: ignore 143 +local type = type; +local table, setmetatable = table, setmetatable; + +local _ENV = nil; +--luacheck: std none -- Reverse an iterator function it.reverse(f, s, var) @@ -172,6 +177,19 @@ function it.to_array(f, s, var) return t; end +function it.sorted_pairs(t, sort_func) + local keys = it.to_array(it.keys(t)); + table.sort(keys, sort_func); + local i = 0; + return function () + i = i + 1; + local key = keys[i]; + if key ~= nil then + return key, t[key]; + end + end; +end + -- Treat the return of an iterator as key,value pairs, -- and build a table function it.to_table(f, s, var) @@ -184,4 +202,45 @@ function it.to_table(f, s, var) return t; end +local function _join_iter(j_s, j_var) + local iterators, current_idx = j_s[1], j_s[2]; + local f, s, var = unpack(iterators[current_idx], 1, 3); + if j_var ~= nil then + var = j_var; + end + local ret = pack(f(s, var)); + local var1 = ret[1]; + if var1 == nil then + -- End of this iterator, advance to next + if current_idx == #iterators then + -- No more iterators, return nil + return; + end + j_s[2] = current_idx + 1; + return _join_iter(j_s); + end + return unpack(ret, 1, ret.n); +end +local join_methods = {}; +local join_mt = { + __index = join_methods; + __call = function (t, s, var) --luacheck: ignore 212/t + return _join_iter(s, var); + end; +}; + +function join_methods:append(f, s, var) + table.insert(self, { f, s, var }); + return self, { self, 1 }; +end + +function join_methods:prepend(f, s, var) + table.insert(self, { f, s, var }, 1); + return self, { self, 1 }; +end + +function it.join(f, s, var) + return setmetatable({ {f, s, var} }, join_mt); +end + return it; diff --git a/util/jid.lua b/util/jid.lua index f402b7f4..ec31f180 100644 --- a/util/jid.lua +++ b/util/jid.lua @@ -25,11 +25,12 @@ local unescapes = {}; for k,v in pairs(escapes) do unescapes[v] = k; end local _ENV = nil; +-- luacheck: std none local function split(jid) if not jid then return; end local node, nodepos = match(jid, "^([^@/]+)@()"); - local host, hostpos = match(jid, "^([^@/]+)()", nodepos) + local host, hostpos = match(jid, "^([^@/]+)()", nodepos); if node and not host then return nil, nil, nil; end local resource = match(jid, "^/(.+)$", hostpos); if (not host) or ((not resource) and #jid >= hostpos) then return nil, nil, nil; end diff --git a/util/json.lua b/util/json.lua index cba54e8e..a750da2e 100644 --- a/util/json.lua +++ b/util/json.lua @@ -7,10 +7,10 @@ -- local type = type; -local t_insert, t_concat, t_remove, t_sort = table.insert, table.concat, table.remove, table.sort; +local t_insert, t_concat, t_remove = table.insert, table.concat, table.remove; local s_char = string.char; local tostring, tonumber = tostring, tonumber; -local pairs, ipairs = pairs, ipairs; +local pairs, ipairs, spairs = pairs, ipairs, require "util.iterators".sorted_pairs; local next = next; local getmetatable, setmetatable = getmetatable, setmetatable; local print = print; @@ -27,9 +27,6 @@ module.null = null; local escapes = { ["\""] = "\\\"", ["\\"] = "\\\\", ["\b"] = "\\b", ["\f"] = "\\f", ["\n"] = "\\n", ["\r"] = "\\r", ["\t"] = "\\t"}; -local unescapes = { - ["\""] = "\"", ["\\"] = "\\", ["/"] = "/", - b = "\b", f = "\f", n = "\n", r = "\r", t = "\t"}; for i=0,31 do local ch = s_char(i); if not escapes[ch] then escapes[ch] = ("\\u%.4X"):format(i); end @@ -98,25 +95,12 @@ function tablesave(o, buffer) if next(__hash) ~= nil or next(hash) ~= nil or next(__array) == nil then t_insert(buffer, "{"); local mark = #buffer; - if buffer.ordered then - local keys = {}; - for k in pairs(hash) do - t_insert(keys, k); - end - t_sort(keys); - for _,k in ipairs(keys) do - stringsave(k, buffer); - t_insert(buffer, ":"); - simplesave(hash[k], buffer); - t_insert(buffer, ","); - end - else - for k,v in pairs(hash) do - stringsave(k, buffer); - t_insert(buffer, ":"); - simplesave(v, buffer); - t_insert(buffer, ","); - end + local _pairs = buffer.ordered and spairs or pairs; + for k,v in _pairs(hash) do + stringsave(k, buffer); + t_insert(buffer, ":"); + simplesave(v, buffer); + t_insert(buffer, ","); end if next(__hash) ~= nil then t_insert(buffer, "\"__hash\":["); @@ -263,8 +247,9 @@ end local function _unescape_func(x) x = x:match("%x%x%x%x", 3); if x then - --if x >= 0xD800 and x <= 0xDFFF then _unescape_error = true; end -- bad surrogate pair - return codepoint_to_utf8(tonumber(x, 16)); + local codepoint = tonumber(x, 16) + if codepoint >= 0xD800 and codepoint <= 0xDFFF then _unescape_error = true; end -- bad surrogate pair + return codepoint_to_utf8(codepoint); end _unescape_error = true; end @@ -276,7 +261,7 @@ function _readstring(json, index) --if s:find("[%z-\31]") then return nil, "control char in string"; end -- FIXME handle control characters _unescape_error = nil; - --s = s:gsub("\\u[dD][89abAB]%x%x\\u[dD][cdefCDEF]%x%x", _unescape_surrogate_func); + s = s:gsub("\\u[dD][89abAB]%x%x\\u[dD][cdefCDEF]%x%x", _unescape_surrogate_func); -- FIXME handle escapes beyond BMP s = s:gsub("\\u.?.?.?.?", _unescape_func); if _unescape_error then return nil, "invalid escape"; end diff --git a/util/logger.lua b/util/logger.lua index e72b29bc..20a5cef2 100644 --- a/util/logger.lua +++ b/util/logger.lua @@ -8,8 +8,11 @@ -- luacheck: ignore 213/level local pairs = pairs; +local ipairs = ipairs; +local require = require; local _ENV = nil; +-- luacheck: std none local level_sinks = {}; @@ -67,10 +70,21 @@ local function add_level_sink(level, sink_function) end end +local function add_simple_sink(simple_sink_function, levels) + local format = require "util.format".format; + local function sink_function(name, level, msg, ...) + return simple_sink_function(name, level, format(msg, ...)); + end + for _, level in ipairs(levels or {"debug", "info", "warn", "error"}) do + add_level_sink(level, sink_function); + end +end + return { init = init; make_logger = make_logger; reset = reset; add_level_sink = add_level_sink; + add_simple_sink = add_simple_sink; new = make_logger; }; diff --git a/util/multitable.lua b/util/multitable.lua index e4321d3d..8d32ed8a 100644 --- a/util/multitable.lua +++ b/util/multitable.lua @@ -9,9 +9,10 @@ local select = select; local t_insert = table.insert; local pairs, next, type = pairs, next, type; -local unpack = table.unpack or unpack; --luacheck: ignore 113 +local unpack = table.unpack or unpack; --luacheck: ignore 113 143 local _ENV = nil; +-- luacheck: std none local function get(self, ...) local t = self.data; @@ -132,7 +133,7 @@ local function iter(self, ...) local maxdepth = select("#", ...); local stack = { self.data }; local keys = { }; - local function it(self) + local function it(self) -- luacheck: ignore 432/self local depth = #stack; local key = next(stack[depth], keys[depth]); if key == nil then -- Go up the stack diff --git a/util/openssl.lua b/util/openssl.lua index 703c6d15..32b5aea7 100644 --- a/util/openssl.lua +++ b/util/openssl.lua @@ -114,7 +114,7 @@ function ssl_config:add_xmppAddr(host) s_format("%s;%s", oid_xmppaddr, utf8string(host))); end -function ssl_config:from_prosody(hosts, config, certhosts) +function ssl_config:from_prosody(hosts, config, certhosts) -- luacheck: ignore 431/config -- TODO Decide if this should go elsewhere local found_matching_hosts = false; for i = 1, #certhosts do diff --git a/util/pluginloader.lua b/util/pluginloader.lua index 004855f0..9ab8f245 100644 --- a/util/pluginloader.lua +++ b/util/pluginloader.lua @@ -5,6 +5,7 @@ -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- +-- luacheck: ignore 113/CFG_PLUGINDIR local dir_sep, path_sep = package.config:match("^(%S+)%s(%S+)"); local plugin_dir = {}; diff --git a/util/presence.lua b/util/presence.lua index f6370354..8d1ae2d9 100644 --- a/util/presence.lua +++ b/util/presence.lua @@ -13,7 +13,6 @@ local function select_top_resources(user) local recipients = {}; for _, session in pairs(user.sessions) do -- find resource with greatest priority if session.presence then - -- TODO check active privacy list for session local p = session.priority; if p > priority then priority = p; diff --git a/util/promise.lua b/util/promise.lua new file mode 100644 index 00000000..07c9c4dc --- /dev/null +++ b/util/promise.lua @@ -0,0 +1,152 @@ +local promise_methods = {}; +local promise_mt = { __name = "promise", __index = promise_methods }; + +local xpcall = require "util.xpcall".xpcall; + +function promise_mt:__tostring() + return "promise (" .. (self._state or "invalid") .. ")"; +end + +local function is_promise(o) + local mt = getmetatable(o); + return mt == promise_mt; +end + +local function wrap_handler(f, resolve, reject, default) + if not f then + return default; + end + return function (param) + local ok, ret = xpcall(f, debug.traceback, param); + if ok then + resolve(ret); + else + reject(ret); + end + return true; + end; +end + +local function next_pending(self, on_fulfilled, on_rejected, resolve, reject) + table.insert(self._pending_on_fulfilled, wrap_handler(on_fulfilled, resolve, reject, resolve)); + table.insert(self._pending_on_rejected, wrap_handler(on_rejected, resolve, reject, reject)); +end + +local function next_fulfilled(promise, on_fulfilled, on_rejected, resolve, reject) -- luacheck: ignore 212/on_rejected + wrap_handler(on_fulfilled, resolve, reject, resolve)(promise.value); +end + +local function next_rejected(promise, on_fulfilled, on_rejected, resolve, reject) -- luacheck: ignore 212/on_fulfilled + wrap_handler(on_rejected, resolve, reject, reject)(promise.reason); +end + +local function promise_settle(promise, new_state, new_next, cbs, value) + if promise._state ~= "pending" then + return; + end + promise._state = new_state; + promise._next = new_next; + for _, cb in ipairs(cbs) do + cb(value); + end + return true; +end + +local function new_resolve_functions(p) + local resolved = false; + local function _resolve(v) + if resolved then return; end + resolved = true; + if is_promise(v) then + v:next(new_resolve_functions(p)); + elseif promise_settle(p, "fulfilled", next_fulfilled, p._pending_on_fulfilled, v) then + p.value = v; + end + + end + local function _reject(e) + if resolved then return; end + resolved = true; + if promise_settle(p, "rejected", next_rejected, p._pending_on_rejected, e) then + p.reason = e; + end + end + return _resolve, _reject; +end + +local function new(f) + local p = setmetatable({ _state = "pending", _next = next_pending, _pending_on_fulfilled = {}, _pending_on_rejected = {} }, promise_mt); + if f then + local resolve, reject = new_resolve_functions(p); + local ok, ret = pcall(f, resolve, reject); + if not ok and p._state == "pending" then + reject(ret); + end + end + return p; +end + +local function all(promises) + return new(function (resolve, reject) + local count, total, results = 0, #promises, {}; + for i = 1, total do + promises[i]:next(function (v) + results[i] = v; + count = count + 1; + if count == total then + resolve(results); + end + end, reject); + end + end); +end + +local function race(promises) + return new(function (resolve, reject) + for i = 1, #promises do + promises[i]:next(resolve, reject); + end + end); +end + +local function resolve(v) + return new(function (_resolve) + _resolve(v); + end); +end + +local function reject(v) + return new(function (_, _reject) + _reject(v); + end); +end + +local function try(f) + return resolve():next(function () return f(); end); +end + +function promise_methods:next(on_fulfilled, on_rejected) + return new(function (resolve, reject) --luacheck: ignore 431/resolve 431/reject + self:_next(on_fulfilled, on_rejected, resolve, reject); + end); +end + +function promise_methods:catch(on_rejected) + return self:next(nil, on_rejected); +end + +function promise_methods:finally(on_finally) + local function _on_finally(value) on_finally(); return value; end + local function _on_catch_finally(err) on_finally(); return reject(err); end + return self:next(_on_finally, _on_catch_finally); +end + +return { + new = new; + resolve = resolve; + reject = reject; + all = all; + race = race; + try = try; + is_promise = is_promise; +} diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua index 8ae051ae..5f0c4d12 100644 --- a/util/prosodyctl.lua +++ b/util/prosodyctl.lua @@ -24,8 +24,6 @@ local io, os = io, os; local print = print; local tonumber = tonumber; -local CFG_SOURCEDIR = _G.CFG_SOURCEDIR; - local _G = _G; local prosody = prosody; @@ -66,7 +64,10 @@ local function getline() end local function getpass() - local stty_ret = os.execute("stty -echo 2>/dev/null"); + local stty_ret, _, status_code = os.execute("stty -echo 2>/dev/null"); + if status_code then -- COMPAT w/ Lua 5.1 + stty_ret = status_code; + end if stty_ret ~= 0 then io.write("\027[08m"); -- ANSI 'hidden' text attribute end @@ -189,8 +190,8 @@ local function getpid() pidfile = config.resolve_relative_path(prosody.paths.data, pidfile); - local modules_enabled = set.new(config.get("*", "modules_disabled")); - if prosody.platform ~= "posix" or modules_enabled:contains("posix") then + local modules_disabled = set.new(config.get("*", "modules_disabled")); + if prosody.platform ~= "posix" or modules_disabled:contains("posix") then return false, "no-posix"; end @@ -228,7 +229,7 @@ local function isrunning() return true, signal.kill(pid, 0) == 0; end -local function start() +local function start(source_dir) local ok, ret = isrunning(); if not ok then return ok, ret; @@ -236,10 +237,10 @@ local function start() if ret then return false, "already-running"; end - if not CFG_SOURCEDIR then + if not source_dir then os.execute("./prosody"); else - os.execute(CFG_SOURCEDIR.."/../../bin/prosody"); + os.execute(source_dir.."/../../bin/prosody"); end return true; end diff --git a/util/pubsub.lua b/util/pubsub.lua index 1db917d8..fafae50a 100644 --- a/util/pubsub.lua +++ b/util/pubsub.lua @@ -1,50 +1,212 @@ local events = require "util.events"; local cache = require "util.cache"; -local service = {}; -local service_mt = { __index = service }; +local service_mt = {}; -local default_config = { __index = { - itemstore = function (config) return cache.new(tonumber(config["pubsub#max_items"])) end; +local default_config = { + itemstore = function (config, _) return cache.new(config["max_items"]) end; broadcaster = function () end; + itemcheck = function () return true; end; get_affiliation = function () end; - capabilities = {}; -} }; -local default_node_config = { __index = { - ["pubsub#max_items"] = "20"; -} }; + normalize_jid = function (jid) return jid; end; + capabilities = { + outcast = { + create = false; + publish = false; + retract = false; + get_nodes = false; + + subscribe = false; + unsubscribe = false; + get_subscription = true; + get_subscriptions = true; + get_items = false; + + subscribe_other = false; + unsubscribe_other = false; + get_subscription_other = false; + get_subscriptions_other = false; + + be_subscribed = false; + be_unsubscribed = true; + + set_affiliation = false; + }; + none = { + create = false; + publish = false; + retract = false; + get_nodes = true; + + subscribe = true; + unsubscribe = true; + get_subscription = true; + get_subscriptions = true; + get_items = false; + + subscribe_other = false; + unsubscribe_other = false; + get_subscription_other = false; + get_subscriptions_other = false; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = false; + }; + member = { + create = false; + publish = false; + retract = false; + get_nodes = true; + + subscribe = true; + unsubscribe = true; + get_subscription = true; + get_subscriptions = true; + get_items = true; + + subscribe_other = false; + unsubscribe_other = false; + get_subscription_other = false; + get_subscriptions_other = false; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = false; + }; + publisher = { + create = false; + publish = true; + retract = true; + get_nodes = true; + get_configuration = true; + + subscribe = true; + unsubscribe = true; + get_subscription = true; + get_subscriptions = true; + get_items = true; + + subscribe_other = false; + unsubscribe_other = false; + get_subscription_other = false; + get_subscriptions_other = false; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = false; + }; + owner = { + create = true; + publish = true; + retract = true; + delete = true; + get_nodes = true; + configure = true; + get_configuration = true; + + subscribe = true; + unsubscribe = true; + get_subscription = true; + get_subscriptions = true; + get_items = true; + + subscribe_other = true; + unsubscribe_other = true; + get_subscription_other = true; + get_subscriptions_other = true; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = true; + }; + }; +}; +local default_config_mt = { __index = default_config }; + +local default_node_config = { + ["persist_items"] = false; + ["max_items"] = 20; + ["access_model"] = "open"; + ["publish_model"] = "publishers"; +}; +local default_node_config_mt = { __index = default_node_config }; + +-- Storage helper functions + +local function load_node_from_store(service, node_name) + local node = service.config.nodestore:get(node_name); + node.config = setmetatable(node.config or {}, {__index=service.node_defaults}); + return node; +end + +local function save_node_to_store(service, node) + return service.config.nodestore:set(node.name, { + name = node.name; + config = node.config; + subscribers = node.subscribers; + affiliations = node.affiliations; + }); +end + +local function delete_node_in_store(service, node_name) + return service.config.nodestore:set(node_name, nil); +end + +-- Create and return a new service object local function new(config) config = config or {}; - return setmetatable({ - config = setmetatable(config, default_config); - node_defaults = setmetatable(config.node_defaults or {}, default_node_config); + + local service = setmetatable({ + config = setmetatable(config, default_config_mt); + node_defaults = setmetatable(config.node_defaults or {}, default_node_config_mt); affiliations = {}; subscriptions = {}; nodes = {}; data = {}; events = events.new(); }, service_mt); + + -- Load nodes from storage, if we have a store and it supports iterating over stored items + if config.nodestore and config.nodestore.users then + for node_name in config.nodestore:users() do + service.nodes[node_name] = load_node_from_store(service, node_name); + service.data[node_name] = config.itemstore(service.nodes[node_name].config, node_name); + end + end + + return service; end -function service:jids_equal(jid1, jid2) +--- Service methods + +local service = {}; +service_mt.__index = service; + +function service:jids_equal(jid1, jid2) --> boolean local normalize = self.config.normalize_jid; return normalize(jid1) == normalize(jid2); end -function service:may(node, actor, action) +function service:may(node, actor, action) --> boolean if actor == true then return true; end local node_obj = self.nodes[node]; - local node_aff = node_obj and node_obj.affiliations[actor]; + local node_aff = node_obj and (node_obj.affiliations[actor] + or node_obj.affiliations[self.config.normalize_jid(actor)]); local service_aff = self.affiliations[actor] - or self.config.get_affiliation(actor, node, action) - or "none"; + or self.config.get_affiliation(actor, node, action); + local default_aff = self:get_default_affiliation(node, actor) or "none"; -- Check if node allows/forbids it local node_capabilities = node_obj and node_obj.capabilities; if node_capabilities then - local caps = node_capabilities[node_aff or service_aff]; + local caps = node_capabilities[node_aff or service_aff or default_aff]; if caps then local can = caps[action]; if can ~= nil then @@ -55,7 +217,7 @@ function service:may(node, actor, action) -- Check service-wide capabilities instead local service_capabilities = self.config.capabilities; - local caps = service_capabilities[node_aff or service_aff]; + local caps = service_capabilities[node_aff or service_aff or default_aff]; if caps then local can = caps[action]; if can ~= nil then @@ -66,7 +228,29 @@ function service:may(node, actor, action) return false; end -function service:set_affiliation(node, actor, jid, affiliation) +function service:get_default_affiliation(node, actor) --> affiliation + local node_obj = self.nodes[node]; + local access_model = node_obj and node_obj.config.access_model + or self.node_defaults.access_model; + + if access_model == "open" then + return "member"; + elseif access_model == "whitelist" then + return "outcast"; + end + + if self.config.access_models then + local check = self.config.access_models[access_model]; + if check then + local aff = check(actor); + if aff then + return aff; + end + end + end +end + +function service:set_affiliation(node, actor, jid, affiliation) --> ok, err -- Access checking if not self:may(node, actor, "set_affiliation") then return false, "forbidden"; @@ -76,7 +260,18 @@ function service:set_affiliation(node, actor, jid, affiliation) if not node_obj then return false, "item-not-found"; end + jid = self.config.normalize_jid(jid); + local old_affiliation = node_obj.affiliations[jid]; node_obj.affiliations[jid] = affiliation; + + if self.config.nodestore then + local ok, err = save_node_to_store(self, node_obj); + if not ok then + node_obj.affiliations[jid] = old_affiliation; + return ok, "internal-server-error"; + end + end + local _, jid_sub = self:get_subscription(node, true, jid); if not jid_sub and not self:may(node, jid, "be_unsubscribed") then local ok, err = self:add_subscription(node, true, jid); @@ -92,7 +287,7 @@ function service:set_affiliation(node, actor, jid, affiliation) return true; end -function service:add_subscription(node, actor, jid, options) +function service:add_subscription(node, actor, jid, options) --> ok, err -- Access checking local cap; if actor == true or jid == actor or self:jids_equal(actor, jid) then @@ -119,6 +314,7 @@ function service:add_subscription(node, actor, jid, options) node_obj = self.nodes[node]; end end + local old_subscription = node_obj.subscribers[jid]; node_obj.subscribers[jid] = options or true; local normal_jid = self.config.normalize_jid(jid); local subs = self.subscriptions[normal_jid]; @@ -131,11 +327,21 @@ function service:add_subscription(node, actor, jid, options) else self.subscriptions[normal_jid] = { [jid] = { [node] = true } }; end - self.events.fire_event("subscription-added", { node = node, jid = jid, normalized_jid = normal_jid, options = options }); + + if self.config.nodestore then + local ok, err = save_node_to_store(self, node_obj); + if not ok then + node_obj.subscribers[jid] = old_subscription; + self.subscriptions[normal_jid][jid][node] = old_subscription and true or nil; + return ok, "internal-server-error"; + end + end + + self.events.fire_event("subscription-added", { service = self, node = node, jid = jid, normalized_jid = normal_jid, options = options }); return true; end -function service:remove_subscription(node, actor, jid) +function service:remove_subscription(node, actor, jid) --> ok, err -- Access checking local cap; if actor == true or jid == actor or self:jids_equal(actor, jid) then @@ -157,6 +363,7 @@ function service:remove_subscription(node, actor, jid) if not node_obj.subscribers[jid] then return false, "not-subscribed"; end + local old_subscription = node_obj.subscribers[jid]; node_obj.subscribers[jid] = nil; local normal_jid = self.config.normalize_jid(jid); local subs = self.subscriptions[normal_jid]; @@ -172,23 +379,21 @@ function service:remove_subscription(node, actor, jid) self.subscriptions[normal_jid] = nil; end end - self.events.fire_event("subscription-removed", { node = node, jid = jid, normalized_jid = normal_jid }); - return true; -end -function service:remove_all_subscriptions(actor, jid) - local normal_jid = self.config.normalize_jid(jid); - local subs = self.subscriptions[normal_jid] - subs = subs and subs[jid]; - if subs then - for node in pairs(subs) do - self:remove_subscription(node, true, jid); + if self.config.nodestore then + local ok, err = save_node_to_store(self, node_obj); + if not ok then + node_obj.subscribers[jid] = old_subscription; + self.subscriptions[normal_jid][jid][node] = old_subscription and true or nil; + return ok, "internal-server-error"; end end + + self.events.fire_event("subscription-removed", { service = self, node = node, jid = jid, normalized_jid = normal_jid }); return true; end -function service:get_subscription(node, actor, jid) +function service:get_subscription(node, actor, jid) --> (true, subscription) or (false, err) -- Access checking local cap; if actor == true or jid == actor or self:jids_equal(actor, jid) then @@ -207,7 +412,7 @@ function service:get_subscription(node, actor, jid) return true, node_obj.subscribers[jid]; end -function service:create(node, actor, options) +function service:create(node, actor, options) --> ok, err -- Access checking if not self:may(node, actor, "create") then return false, "forbidden"; @@ -223,17 +428,30 @@ function service:create(node, actor, options) config = setmetatable(options or {}, {__index=self.node_defaults}); affiliations = {}; }; - self.data[node] = self.config.itemstore(self.nodes[node].config); - self.events.fire_event("node-created", { node = node, actor = actor }); - local ok, err = self:set_affiliation(node, true, actor, "owner"); - if not ok then - self.nodes[node] = nil; - self.data[node] = nil; + + if self.config.nodestore then + local ok, err = save_node_to_store(self, self.nodes[node]); + if not ok then + self.nodes[node] = nil; + return ok, "internal-server-error"; + end end - return ok, err; + + self.data[node] = self.config.itemstore(self.nodes[node].config, node); + self.events.fire_event("node-created", { service = self, node = node, actor = actor }); + if actor ~= true then + local ok, err = self:set_affiliation(node, true, actor, "owner"); + if not ok then + self.nodes[node] = nil; + self.data[node] = nil; + return ok, err; + end + end + + return true; end -function service:delete(node, actor) +function service:delete(node, actor) --> ok, err -- Access checking if not self:may(node, actor, "delete") then return false, "forbidden"; @@ -244,15 +462,52 @@ function service:delete(node, actor) return false, "item-not-found"; end self.nodes[node] = nil; + if self.data[node] and self.data[node].clear then + self.data[node]:clear(); + end self.data[node] = nil; - self.events.fire_event("node-deleted", { node = node, actor = actor }); - self.config.broadcaster("delete", node, node_obj.subscribers); + + if self.config.nodestore then + local ok, err = delete_node_in_store(self, node); + if not ok then + self.nodes[node] = nil; + return ok, err; + end + end + + self.events.fire_event("node-deleted", { service = self, node = node, actor = actor }); + self.config.broadcaster("delete", node, node_obj.subscribers, nil, actor, node_obj, self); return true; end -function service:publish(node, actor, id, item) +-- Used to check that the config of a node is as expected (i.e. 'publish-options') +local function check_preconditions(node_config, required_config) + if not (node_config and required_config) then + return false; + end + for config_field, value in pairs(required_config) do + if node_config[config_field] ~= value then + return false; + end + end + return true; +end + +function service:publish(node, actor, id, item, requested_config) --> ok, err -- Access checking - if not self:may(node, actor, "publish") then + local may_publish = false; + + if self:may(node, actor, "publish") then + may_publish = true; + else + local node_obj = self.nodes[node]; + local publish_model = node_obj and node_obj.config.publish_model; + if publish_model == "open" + or (publish_model == "subscribers" and node_obj.subscribers[actor]) then + may_publish = true; + end + end + if not may_publish then return false, "forbidden"; end -- @@ -261,23 +516,34 @@ function service:publish(node, actor, id, item) if not self.config.autocreate_on_publish then return false, "item-not-found"; end - local ok, err = self:create(node, true); + local ok, err = self:create(node, true, requested_config); if not ok then return ok, err; end node_obj = self.nodes[node]; + elseif requested_config and not requested_config._defaults_only then + -- Check that node has the requested config before we publish + if not check_preconditions(node_obj.config, requested_config) then + return false, "precondition-not-met"; + end + end + if not self.config.itemcheck(item) then + return nil, "invalid-item"; end local node_data = self.data[node]; local ok = node_data:set(id, item); if not ok then return nil, "internal-server-error"; end - self.events.fire_event("item-published", { node = node, actor = actor, id = id, item = item }); - self.config.broadcaster("items", node, node_obj.subscribers, item, actor); + if type(ok) == "string" then id = ok; end + local event_data = { service = self, node = node, actor = actor, id = id, item = item }; + self.events.fire_event("item-published/"..node, event_data); + self.events.fire_event("item-published", event_data); + self.config.broadcaster("items", node, node_obj.subscribers, item, actor, node_obj, self); return true; end -function service:retract(node, actor, id, retract) +function service:retract(node, actor, id, retract) --> ok, err -- Access checking if not self:may(node, actor, "retract") then return false, "forbidden"; @@ -291,14 +557,14 @@ function service:retract(node, actor, id, retract) if not ok then return nil, "internal-server-error"; end - self.events.fire_event("item-retracted", { node = node, actor = actor, id = id }); + self.events.fire_event("item-retracted", { service = self, node = node, actor = actor, id = id }); if retract then - self.config.broadcaster("items", node, node_obj.subscribers, retract); + self.config.broadcaster("retract", node, node_obj.subscribers, retract, actor, node_obj, self); end return true end -function service:purge(node, actor, notify) +function service:purge(node, actor, notify) --> ok, err -- Access checking if not self:may(node, actor, "retract") then return false, "forbidden"; @@ -308,15 +574,19 @@ function service:purge(node, actor, notify) if not node_obj then return false, "item-not-found"; end - self.data[node] = self.config.itemstore(self.nodes[node].config); - self.events.fire_event("node-purged", { node = node, actor = actor }); + if self.data[node] and self.data[node].clear then + self.data[node]:clear() + else + self.data[node] = self.config.itemstore(self.nodes[node].config, node); + end + self.events.fire_event("node-purged", { service = self, node = node, actor = actor }); if notify then - self.config.broadcaster("purge", node, node_obj.subscribers); + self.config.broadcaster("purge", node, node_obj.subscribers, nil, actor, node_obj, self); end return true end -function service:get_items(node, actor, id) +function service:get_items(node, actor, id) --> (true, { id, [id] = node }) or (false, err) -- Access checking if not self:may(node, actor, "get_items") then return false, "forbidden"; @@ -327,7 +597,11 @@ function service:get_items(node, actor, id) return false, "item-not-found"; end if id then -- Restrict results to a single specific item - return true, { id, [id] = self.data[node]:get(id) }; + local with_id = self.data[node]:get(id); + if not with_id then + return true, { }; + end + return true, { id, [id] = with_id }; else local data = {} for key, value in self.data[node]:items() do @@ -338,7 +612,23 @@ function service:get_items(node, actor, id) end end -function service:get_nodes(actor) +function service:get_last_item(node, actor) --> (true, id, node) or (false, err) + -- Access checking + if not self:may(node, actor, "get_items") then + return false, "forbidden"; + end + -- + + -- Check node exists + if not self.nodes[node] then + return false, "item-not-found"; + end + + -- Returns success, id, item + return true, self.data[node]:head(); +end + +function service:get_nodes(actor) --> (true, map) or (false, err) -- Access checking if not self:may(nil, actor, "get_nodes") then return false, "forbidden"; @@ -347,7 +637,30 @@ function service:get_nodes(actor) return true, self.nodes; end -function service:get_subscriptions(node, actor, jid) +local function flatten_subscriptions(ret, serv, subs, node, node_obj) + for subscribed_jid, subscribed_nodes in pairs(subs) do + if node then -- Return only subscriptions to this node + if subscribed_nodes[node] then + ret[#ret+1] = { + node = node; + jid = subscribed_jid; + subscription = node_obj.subscribers[subscribed_jid]; + }; + end + else -- Return subscriptions to all nodes + local nodes = serv.nodes; + for subscribed_node in pairs(subscribed_nodes) do + ret[#ret+1] = { + node = subscribed_node; + jid = subscribed_jid; + subscription = nodes[subscribed_node].subscribers[subscribed_jid]; + }; + end + end + end +end + +function service:get_subscriptions(node, actor, jid) --> (true, array) or (false, err) -- Access checking local cap; if actor == true or jid == actor or self:jids_equal(actor, jid) then @@ -366,38 +679,25 @@ function service:get_subscriptions(node, actor, jid) return false, "item-not-found"; end end + local ret = {}; + if jid == nil then + for _, subs in pairs(self.subscriptions) do + flatten_subscriptions(ret, self, subs, node, node_obj) + end + return true, ret; + end local normal_jid = self.config.normalize_jid(jid); local subs = self.subscriptions[normal_jid]; -- We return the subscription object from the node to save -- a get_subscription() call for each node. - local ret = {}; if subs then - for subscribed_jid, subscribed_nodes in pairs(subs) do - if node then -- Return only subscriptions to this node - if subscribed_nodes[node] then - ret[#ret+1] = { - node = node; - jid = subscribed_jid; - subscription = node_obj.subscribers[subscribed_jid]; - }; - end - else -- Return subscriptions to all nodes - local nodes = self.nodes; - for subscribed_node in pairs(subscribed_nodes) do - ret[#ret+1] = { - node = subscribed_node; - jid = subscribed_jid; - subscription = nodes[subscribed_node].subscribers[subscribed_jid]; - }; - end - end - end + flatten_subscriptions(ret, self, subs, node, node_obj) end return true, ret; end -- Access models only affect 'none' affiliation caps, service/default access level... -function service:set_node_capabilities(node, actor, capabilities) +function service:set_node_capabilities(node, actor, capabilities) --> ok, err -- Access checking if not self:may(node, actor, "configure") then return false, "forbidden"; @@ -411,7 +711,7 @@ function service:set_node_capabilities(node, actor, capabilities) return true; end -function service:set_node_config(node, actor, new_config) +function service:set_node_config(node, actor, new_config) --> ok, err if not self:may(node, actor, "configure") then return false, "forbidden"; end @@ -421,17 +721,71 @@ function service:set_node_config(node, actor, new_config) return false, "item-not-found"; end - for k,v in pairs(new_config) do - node_obj.config[k] = v; + setmetatable(new_config, {__index=self.node_defaults}) + + if self.config.check_node_config then + local ok = self.config.check_node_config(node, actor, new_config); + if not ok then + return false, "not-acceptable"; + end + end + + local old_config = node_obj.config; + node_obj.config = new_config; + + if self.config.nodestore then + local ok, err = save_node_to_store(self, node_obj); + if not ok then + node_obj.config = old_config; + return ok, "internal-server-error"; + end + end + + if old_config["access_model"] ~= node_obj.config["access_model"] then + for subscriber in pairs(node_obj.subscribers) do + if not self:may(node, subscriber, "be_subscribed") then + local ok, err = self:remove_subscription(node, true, subscriber); + if not ok then + node_obj.config = old_config; + return ok, err; + end + end + end end - local new_data = self.config.itemstore(self.nodes[node].config); - for key, value in self.data[node]:items() do - new_data:set(key, value); + + if old_config["persist_items"] ~= node_obj.config["persist_items"] then + self.data[node] = self.config.itemstore(self.nodes[node].config, node); + elseif old_config["max_items"] ~= node_obj.config["max_items"] then + self.data[node]:resize(self.nodes[node].config["max_items"]); end - self.data[node] = new_data; + return true; end +function service:get_node_config(node, actor) --> (true, config) or (false, err) + if not self:may(node, actor, "get_configuration") then + return false, "forbidden"; + end + + local node_obj = self.nodes[node]; + if not node_obj then + return false, "item-not-found"; + end + + local config_table = {}; + for k, v in pairs(default_node_config) do + config_table[k] = v; + end + for k, v in pairs(self.node_defaults) do + config_table[k] = v; + end + for k, v in pairs(node_obj.config) do + config_table[k] = v; + end + + return true, config_table; +end + return { new = new; }; diff --git a/util/random.lua b/util/random.lua index b2d0000d..d8a84514 100644 --- a/util/random.lua +++ b/util/random.lua @@ -11,9 +11,6 @@ if ok then return crand; end local urandom, urandom_err = io.open("/dev/urandom", "r"); -local function seed() -end - local function bytes(n) return urandom:read(n); end @@ -25,7 +22,6 @@ if not urandom then end return { - seed = seed; bytes = bytes; _source = "/dev/urandom"; }; diff --git a/util/sasl.lua b/util/sasl.lua index 5845f34a..50851405 100644 --- a/util/sasl.lua +++ b/util/sasl.lua @@ -20,6 +20,7 @@ local assert = assert; local require = require; local _ENV = nil; +-- luacheck: std none --[[ Authentication Backend Prototypes: @@ -42,7 +43,7 @@ Example: local method = {}; method.__index = method; -local mechanisms = {}; +local registered_mechanisms = {}; local backend_mechanism = {}; local mechanism_channelbindings = {}; @@ -52,7 +53,7 @@ local function registerMechanism(name, backends, f, cb_backends) assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table."); assert(type(f) == "function", "Parameter f MUST be a function."); if cb_backends then assert(type(cb_backends) == "table"); end - mechanisms[name] = f + registered_mechanisms[name] = f if cb_backends then mechanism_channelbindings[name] = {}; for _, cb_name in ipairs(cb_backends) do @@ -70,7 +71,7 @@ local function new(realm, profile) local mechanisms = profile.mechanisms; if not mechanisms then mechanisms = {}; - for backend, f in pairs(profile) do + for backend in pairs(profile) do if backend_mechanism[backend] then for _, mechanism in ipairs(backend_mechanism[backend]) do mechanisms[mechanism] = true; @@ -128,7 +129,7 @@ end -- feed new messages to process into the library function method:process(message) --if message == "" or message == nil then return "failure", "malformed-request" end - return mechanisms[self.selected](self, message); + return registered_mechanisms[self.selected](self, message); end -- load the mechanisms diff --git a/util/sasl/anonymous.lua b/util/sasl/anonymous.lua index 6201db32..de98a5e2 100644 --- a/util/sasl/anonymous.lua +++ b/util/sasl/anonymous.lua @@ -12,9 +12,10 @@ -- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -local generate_uuid = require "util.uuid".generate; +local generate_random_id = require "util.id".medium; local _ENV = nil; +-- luacheck: std none --========================= --SASL ANONYMOUS according to RFC 4505 @@ -28,10 +29,10 @@ anonymous: end ]] -local function anonymous(self, message) +local function anonymous(self, message) -- luacheck: ignore 212/message local username; repeat - username = generate_uuid(); + username = generate_random_id():lower(); until self.profile.anonymous(self, username, self.realm); self.username = username; return "success" diff --git a/util/sasl/digest-md5.lua b/util/sasl/digest-md5.lua index 695dd2a3..7542a037 100644 --- a/util/sasl/digest-md5.lua +++ b/util/sasl/digest-md5.lua @@ -26,6 +26,7 @@ local generate_uuid = require "util.uuid".generate; local nodeprep = require "util.encodings".stringprep.nodeprep; local _ENV = nil; +-- luacheck: std none --========================= --SASL DIGEST-MD5 according to RFC 2831 diff --git a/util/sasl/external.lua b/util/sasl/external.lua index 5ba90190..ce50743e 100644 --- a/util/sasl/external.lua +++ b/util/sasl/external.lua @@ -1,6 +1,7 @@ local saslprep = require "util.encodings".stringprep.saslprep; local _ENV = nil; +-- luacheck: std none local function external(self, message) message = saslprep(message); diff --git a/util/sasl/plain.lua b/util/sasl/plain.lua index cd59b1ac..00c6bd20 100644 --- a/util/sasl/plain.lua +++ b/util/sasl/plain.lua @@ -17,6 +17,7 @@ local nodeprep = require "util.encodings".stringprep.nodeprep; local log = require "util.logger".init("sasl"); local _ENV = nil; +-- luacheck: std none -- ================================ -- SASL PLAIN according to RFC 4616 diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua index 4e20dbb9..043f328b 100644 --- a/util/sasl/scram.lua +++ b/util/sasl/scram.lua @@ -26,6 +26,7 @@ local char = string.char; local byte = string.byte; local _ENV = nil; +-- luacheck: std none --========================= --SASL SCRAM-SHA-1 according to RFC 5802 @@ -46,7 +47,18 @@ Supported Channel Binding Backends local default_i = 4096 -local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;}; +local xor_map = { + 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,1,0,3,2,5,4,7,6,9,8,11,10, + 13,12,15,14,2,3,0,1,6,7,4,5,10,11,8,9,14,15,12,13,3,2,1,0,7,6,5, + 4,11,10,9,8,15,14,13,12,4,5,6,7,0,1,2,3,12,13,14,15,8,9,10,11,5, + 4,7,6,1,0,3,2,13,12,15,14,9,8,11,10,6,7,4,5,2,3,0,1,14,15,12,13, + 10,11,8,9,7,6,5,4,3,2,1,0,15,14,13,12,11,10,9,8,8,9,10,11,12,13, + 14,15,0,1,2,3,4,5,6,7,9,8,11,10,13,12,15,14,1,0,3,2,5,4,7,6,10, + 11,8,9,14,15,12,13,2,3,0,1,6,7,4,5,11,10,9,8,15,14,13,12,3,2,1, + 0,7,6,5,4,12,13,14,15,8,9,10,11,4,5,6,7,0,1,2,3,13,12,15,14,9,8, + 11,10,5,4,7,6,1,0,3,2,14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1,15, + 14,13,12,11,10,9,8,7,6,5,4,3,2,1,0, +}; local result = {}; local function binaryXOR( a, b ) @@ -148,7 +160,7 @@ local function scram_gen(hash_name, H_f, HMAC_f) end self.username = username; - -- retreive credentials + -- retrieve credentials local stored_key, server_key, salt, iteration_count; if self.profile.plain then local password, status = self.profile.plain(self, username, self.realm) @@ -237,10 +249,14 @@ end local function init(registerMechanism) local function registerSCRAMMechanism(hash_name, hash, hmac_hash) - registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash)); + registerMechanism("SCRAM-"..hash_name, + {"plain", "scram_"..(hashprep(hash_name))}, + scram_gen(hash_name:lower(), hash, hmac_hash)); -- register channel binding equivalent - registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"}); + registerMechanism("SCRAM-"..hash_name.."-PLUS", + {"plain", "scram_"..(hashprep(hash_name))}, + scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"}); end registerSCRAMMechanism("SHA-1", sha1, hmac_sha1); diff --git a/util/sasl_cyrus.lua b/util/sasl_cyrus.lua index 4e9a4af5..a6bd0628 100644 --- a/util/sasl_cyrus.lua +++ b/util/sasl_cyrus.lua @@ -61,6 +61,7 @@ local sasl_errstring = { setmetatable(sasl_errstring, { __index = function() return "undefined error!" end }); local _ENV = nil; +-- luacheck: std none local method = {}; method.__index = method; diff --git a/util/serialization.lua b/util/serialization.lua index 206f5fbb..dd6a2a2b 100644 --- a/util/serialization.lua +++ b/util/serialization.lua @@ -1,89 +1,271 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain +-- Copyright (C) 2018 Kim Alvefur -- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- -local string_rep = string.rep; -local type = type; -local tostring = tostring; -local t_insert = table.insert; +local getmetatable = getmetatable; +local next, type = next, type; +local s_format = string.format; +local s_gsub = string.gsub; +local s_rep = string.rep; +local s_char = string.char; +local s_match = string.match; local t_concat = table.concat; -local pairs = pairs; -local next = next; local pcall = pcall; - -local debug_traceback = debug.traceback; -local log = require "util.logger".init("serialization"); local envload = require"util.envload".envload; -local _ENV = nil; +local pos_inf, neg_inf = math.huge, -math.huge; +-- luacheck: ignore 143/math +local m_type = math.type or function (n) + return n % 1 == 0 and n <= 9007199254740992 and n >= -9007199254740992 and "integer" or "float"; +end; + +local char_to_hex = {}; +for i = 0,255 do + char_to_hex[s_char(i)] = s_format("%02x", i); +end + +local function to_hex(s) + return (s_gsub(s, ".", char_to_hex)); +end + +local function fatal_error(obj, why) + error("Can't serialize "..type(obj) .. (why and ": ".. why or "")); +end -local indent = function(i) - return string_rep("\t", i); +local function nonfatal_fallback(x, why) + return s_format("{__type=%q,__error=%q}", type(x), why or "fail"); end -local function basicSerialize (o) - if type(o) == "number" or type(o) == "boolean" then - -- no need to check for NaN, as that's not a valid table index - if o == 1/0 then return "(1/0)"; - elseif o == -1/0 then return "(-1/0)"; - else return tostring(o); end - else -- assume it is a string -- FIXME make sure it's a string. throw an error otherwise. - return (("%q"):format(tostring(o)):gsub("\\\n", "\\n")); + +local string_escapes = { + ['\a'] = [[\a]]; ['\b'] = [[\b]]; + ['\f'] = [[\f]]; ['\n'] = [[\n]]; + ['\r'] = [[\r]]; ['\t'] = [[\t]]; + ['\v'] = [[\v]]; ['\\'] = [[\\]]; + ['\"'] = [[\"]]; ['\''] = [[\']]; +} + +for i = 0, 255 do + local c = s_char(i); + if not string_escapes[c] then + string_escapes[c] = s_format("\\%03d", i); end end -local function _simplesave(o, ind, t, func) - if type(o) == "number" then - if o ~= o then func(t, "(0/0)"); - elseif o == 1/0 then func(t, "(1/0)"); - elseif o == -1/0 then func(t, "(-1/0)"); - else func(t, tostring(o)); end - elseif type(o) == "string" then - func(t, (("%q"):format(o):gsub("\\\n", "\\n"))); - elseif type(o) == "table" then - if next(o) ~= nil then - func(t, "{\n"); - for k,v in pairs(o) do - func(t, indent(ind)); - func(t, "["); - func(t, basicSerialize(k)); - func(t, "] = "); - if ind == 0 then - _simplesave(v, 0, t, func); + +local default_keywords = { + ["do"] = true; ["and"] = true; ["else"] = true; ["break"] = true; + ["if"] = true; ["end"] = true; ["goto"] = true; ["false"] = true; + ["in"] = true; ["for"] = true; ["then"] = true; ["local"] = true; + ["or"] = true; ["nil"] = true; ["true"] = true; ["until"] = true; + ["elseif"] = true; ["function"] = true; ["not"] = true; + ["repeat"] = true; ["return"] = true; ["while"] = true; +}; + +local function new(opt) + if type(opt) ~= "table" then + opt = { preset = opt }; + end + + local types = { + table = true; + string = true; + number = true; + boolean = true; + ["nil"] = true; + }; + + -- presets + if opt.preset == "debug" then + opt.preset = "oneline"; + opt.freeze = true; + opt.fatal = false; + opt.fallback = nonfatal_fallback; + opt.unquoted = true; + end + if opt.preset == "oneline" then + opt.indentwith = opt.indentwith or ""; + opt.itemstart = opt.itemstart or " "; + opt.itemlast = opt.itemlast or ""; + opt.tend = opt.tend or " }"; + elseif opt.preset == "compact" then + opt.indentwith = opt.indentwith or ""; + opt.itemstart = opt.itemstart or ""; + opt.itemlast = opt.itemlast or ""; + opt.equals = opt.equals or "="; + opt.unquoted = true; + end + + local fallback = opt.fallback or opt.fatal == false and nonfatal_fallback or fatal_error; + + local function ser(v) + return (types[type(v)] or fallback)(v); + end + + local keywords = opt.keywords or default_keywords; + + -- indented + local indentwith = opt.indentwith or "\t"; + local itemstart = opt.itemstart or "\n"; + local itemsep = opt.itemsep or ";"; + local itemlast = opt.itemlast or ";\n"; + local tstart = opt.tstart or "{"; + local tend = opt.tend or "}"; + local kstart = opt.kstart or "["; + local kend = opt.kend or "]"; + local equals = opt.equals or " = "; + local unquoted = opt.unquoted == true and "^[%a_][%w_]*$" or opt.unquoted; + local hex = opt.hex; + local freeze = opt.freeze; + local maxdepth = opt.maxdepth or 127; + local multirefs = opt.multiref; + + -- serialize one table, recursively + -- t - table being serialized + -- o - array where tokens are added, concatenate to get final result + -- - also used to detect cycles + -- l - position in o of where to insert next token + -- d - depth, used for indentation + local function serialize_table(t, o, l, d) + if o[t] then + o[l], l = fallback(t, "table has multiple references"), l + 1; + return l; + elseif d > maxdepth then + o[l], l = fallback(t, "max table depth reached"), l + 1; + return l; + end + + -- Keep track of table loops + local ot = t; -- reference pre-freeze + o[t] = true; + o[ot] = true; + + if freeze == true then + -- opportunity to do pre-serialization + local mt = getmetatable(t); + if type(mt) == "table" then + local tag = mt.__name; + local fr = mt.__freeze; + + if type(fr) == "function" then + t = fr(t); + if type(tag) == "string" then + o[l], l = tag, l + 1; + end + end + end + end + + o[l], l = tstart, l + 1; + local indent = s_rep(indentwith, d); + local numkey = 1; + local ktyp, vtyp; + for k,v in next,t do + o[l], l = itemstart, l + 1; + o[l], l = indent, l + 1; + ktyp, vtyp = type(k), type(v); + if k == numkey then + -- next index in array part + -- assuming that these are found in order + numkey = numkey + 1; + elseif unquoted and ktyp == "string" and + not keywords[k] and s_match(k, unquoted) then + -- unquoted keys + o[l], l = k, l + 1; + o[l], l = equals, l + 1; + else + -- quoted keys + o[l], l = kstart, l + 1; + if ktyp == "table" then + l = serialize_table(k, o, l, d+1); else - _simplesave(v, ind+1, t, func); + o[l], l = ser(k), l + 1; end - func(t, ";\n"); + -- = + o[l], o[l+1], l = kend, equals, l + 2; + end + + -- the value + if vtyp == "table" then + l = serialize_table(v, o, l, d+1); + else + o[l], l = ser(v), l + 1; + end + -- last item? + if next(t, k) ~= nil then + o[l], l = itemsep, l + 1; + else + o[l], l = itemlast, l + 1; end - func(t, indent(ind-1)); - func(t, "}"); - else - func(t, "{}"); end - elseif type(o) == "boolean" then - func(t, (o and "true" or "false")); + if next(t) ~= nil then + o[l], l = s_rep(indentwith, d-1), l + 1; + end + o[l], l = tend, l +1; + + if multirefs then + o[t] = nil; + o[ot] = nil; + end + + return l; + end + + function types.table(t) + local o = {}; + serialize_table(t, o, 1, 1); + return t_concat(o); + end + + local function serialize_string(s) + return '"' .. s_gsub(s, "[%z\1-\31\"\'\\\127-\255]", string_escapes) .. '"'; + end + + if type(hex) == "string" then + function types.string(s) + local esc = serialize_string(s); + if #esc > (#s*2+2+#hex) then + return hex .. '"' .. to_hex(s) .. '"'; + end + return esc; + end else - log("error", "cannot serialize a %s: %s", type(o), debug_traceback()) - func(t, "nil"); + types.string = serialize_string; end -end -local function append(t, o) - _simplesave(o, 1, t, t.write or t_insert); - return t; -end + function types.number(t) + if m_type(t) == "integer" then + return s_format("%d", t); + elseif t == pos_inf then + return "(1/0)"; + elseif t == neg_inf then + return "(-1/0)"; + elseif t ~= t then + return "(0/0)"; + end + return s_format("%.18g", t); + end + + -- Are these faster than tostring? + types["nil"] = function() + return "nil"; + end + + function types.boolean(t) + return t and "true" or "false"; + end -local function serialize(o) - return t_concat(append({}, o)); + return ser; end local function deserialize(str) if type(str) ~= "string" then return nil; end str = "return "..str; - local f, err = envload(str, "@data", {}); + local f, err = envload(str, "=serialized data", {}); if not f then return nil, err; end local success, ret = pcall(f); if not success then return nil, ret; end @@ -91,7 +273,9 @@ local function deserialize(str) end return { - append = append; - serialize = serialize; + new = new; + serialize = function (x, opt) + return new(opt)(x); + end; deserialize = deserialize; }; diff --git a/util/set.lua b/util/set.lua index c136a522..02fabc6a 100644 --- a/util/set.lua +++ b/util/set.lua @@ -11,8 +11,9 @@ local ipairs, pairs, setmetatable, next, tostring = local t_concat = table.concat; local _ENV = nil; +-- luacheck: std none -local set_mt = {}; +local set_mt = { __name = "set" }; function set_mt.__call(set, _, k) return next(set._items, k); end @@ -22,6 +23,14 @@ function items_mt.__call(items, _, k) return next(items, k); end +function set_mt:__freeze() + local a, i = {}, 1; + for item in self._items do + a[i], i = item, i+1; + end + return a; +end + local function new(list) local items = setmetatable({}, items_mt); local set = { _items = items }; diff --git a/util/sql.lua b/util/sql.lua index d964025e..47900102 100644 --- a/util/sql.lua +++ b/util/sql.lua @@ -1,11 +1,11 @@ local setmetatable, getmetatable = setmetatable, getmetatable; -local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113 -local tonumber, tostring = tonumber, tostring; +local ipairs = ipairs; +local tostring = tostring; local type = type; -local assert, pcall, xpcall, debug_traceback = assert, pcall, xpcall, debug.traceback; +local assert, pcall, debug_traceback = assert, pcall, debug.traceback; +local xpcall = require "util.xpcall".xpcall; local t_concat = table.concat; -local s_char = string.char; local log = require "util.logger".init("sql"); local DBI = require "DBI"; @@ -15,6 +15,7 @@ DBI.Drivers(); local build_url = require "socket.url".build; local _ENV = nil; +-- luacheck: std none local column_mt = {}; local table_mt = {}; @@ -58,9 +59,6 @@ table_mt.__index = {}; function table_mt.__index:create(engine) return engine:_create_table(self); end -function table_mt:__call(...) - -- TODO -end function column_mt:__tostring() return 'Column{ name="'..self.name..'", type="'..self.type..'" }' end @@ -71,31 +69,6 @@ function index_mt:__tostring() -- return 'Index{ name="'..self.name..'", type="'..self.type..'" }' end -local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end -local function parse_url(url) - local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)"); - assert(scheme, "Invalid URL format"); - local username, password, host, port; - local authpart, hostpart = secondpart:match("([^@]+)@([^@+])"); - if not authpart then hostpart = secondpart; end - if authpart then - username, password = authpart:match("([^:]*):(.*)"); - username = username or authpart; - password = password and urldecode(password); - end - if hostpart then - host, port = hostpart:match("([^:]*):(.*)"); - host = host or hostpart; - port = port and assert(tonumber(port), "Invalid URL format"); - end - return { - scheme = scheme:lower(); - username = username; password = password; - host = host; port = port; - database = #database > 0 and database or nil; - }; -end - local engine = {}; function engine:connect() if self.conn then return true; end @@ -123,7 +96,7 @@ function engine:connect() end return true; end -function engine:onconnect() +function engine:onconnect() -- luacheck: ignore 212/self -- Override from create_engine() end @@ -148,6 +121,7 @@ function engine:execute(sql, ...) prepared[sql] = stmt; end + -- luacheck: ignore 411/success local success, err = stmt:execute(...); if not success then return success, err; end return stmt; @@ -161,14 +135,14 @@ local result_mt = { __index = { local function debugquery(where, sql, ...) local i = 0; local a = {...} sql = sql:gsub("\n?\t+", " "); - log("debug", "[%s] %s", where, sql:gsub("%?", function () + log("debug", "[%s] %s", where, (sql:gsub("%?", function () i = i + 1; local v = a[i]; if type(v) == "string" then v = ("'%s'"):format(v:gsub("'", "''")); end return tostring(v); - end)); + end))); end function engine:execute_query(sql, ...) @@ -227,11 +201,9 @@ function engine:_transaction(func, ...) if not ok then return ok, err; end end --assert(not self.__transaction, "Recursive transactions not allowed"); - local args, n_args = {...}, select("#", ...); - local function f() return func(unpack(args, 1, n_args)); end log("debug", "SQL transaction begin [%s]", tostring(func)); self.__transaction = true; - local success, a, b, c = xpcall(f, handleerr); + local success, a, b, c = xpcall(func, handleerr, ...); self.__transaction = nil; if success then log("debug", "SQL transaction success [%s]", tostring(func)); @@ -335,7 +307,12 @@ function engine:set_encoding() -- to UTF-8 local charset = "utf8"; if driver == "MySQL" then self:transaction(function() - for row in self:select"SELECT \"CHARACTER_SET_NAME\" FROM \"information_schema\".\"CHARACTER_SETS\" WHERE \"CHARACTER_SET_NAME\" LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;" do + for row in self:select[[ + SELECT "CHARACTER_SET_NAME" + FROM "information_schema"."CHARACTER_SETS" + WHERE "CHARACTER_SET_NAME" LIKE 'utf8%' + ORDER BY MAXLEN DESC LIMIT 1; + ]] do charset = row and row[1] or charset; end end); @@ -379,7 +356,7 @@ local function db2uri(params) }; end -local function create_engine(self, params, onconnect) +local function create_engine(_, params, onconnect) return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt); end diff --git a/util/sslconfig.lua b/util/sslconfig.lua index 4c4e1d48..a5827a76 100644 --- a/util/sslconfig.lua +++ b/util/sslconfig.lua @@ -8,6 +8,7 @@ local t_insert = table.insert; local setmetatable = setmetatable; local _ENV = nil; +-- luacheck: std none local handlers = { }; local finalisers = { }; @@ -69,7 +70,7 @@ finalisers.curveslist = finalisers.ciphers; -- protocol = "x" should enable only that protocol -- protocol = "x+" should enable x and later versions -local protocols = { "sslv2", "sslv3", "tlsv1", "tlsv1_1", "tlsv1_2" }; +local protocols = { "sslv2", "sslv3", "tlsv1", "tlsv1_1", "tlsv1_2", "tlsv1_3" }; for i = 1, #protocols do protocols[protocols[i] .. "+"] = i - 1; end -- this interacts with ssl.options as well to add no_x diff --git a/util/stanza.lua b/util/stanza.lua index 07365144..8d199912 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -7,6 +7,7 @@ -- +local error = error; local t_insert = table.insert; local t_remove = table.remove; local t_concat = table.concat; @@ -23,6 +24,8 @@ local s_sub = string.sub; local s_find = string.find; local os = os; +local valid_utf8 = require "util.encodings".utf8.valid; + local do_pretty_printing = not os.getenv("WINDIR"); local getstyle, getstring; if do_pretty_printing then @@ -37,12 +40,52 @@ end local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas"; local _ENV = nil; +-- luacheck: std none -local stanza_mt = { __type = "stanza" }; +local stanza_mt = { __name = "stanza" }; stanza_mt.__index = stanza_mt; -local function new_stanza(name, attr) - local stanza = { name = name, attr = attr or {}, tags = {} }; +local function check_name(name, name_type) + if type(name) ~= "string" then + error("invalid "..name_type.." name: expected string, got "..type(name)); + elseif #name == 0 then + error("invalid "..name_type.." name: empty string"); + elseif s_find(name, "[<>& '\"]") then + error("invalid "..name_type.." name: contains invalid characters"); + elseif not valid_utf8(name) then + error("invalid "..name_type.." name: contains invalid utf8"); + end +end + +local function check_text(text, text_type) + if type(text) ~= "string" then + error("invalid "..text_type.." value: expected string, got "..type(text)); + elseif not valid_utf8(text) then + error("invalid "..text_type.." value: contains invalid utf8"); + end +end + +local function check_attr(attr) + if attr ~= nil then + if type(attr) ~= "table" then + error("invalid attributes, expected table got "..type(attr)); + end + for k, v in pairs(attr) do + check_name(k, "attribute"); + check_text(v, "attribute"); + if type(v) ~= "string" then + error("invalid attribute value for '"..k.."': expected string, got "..type(v)); + elseif not valid_utf8(v) then + error("invalid attribute value for '"..k.."': contains invalid utf8"); + end + end + end +end + +local function new_stanza(name, attr, namespaces) + check_name(name, "tag"); + check_attr(attr); + local stanza = { name = name, attr = attr or {}, namespaces = namespaces, tags = {} }; return setmetatable(stanza, stanza_mt); end @@ -58,8 +101,12 @@ function stanza_mt:body(text, attr) return self:tag("body", attr):text(text); end -function stanza_mt:tag(name, attrs) - local s = new_stanza(name, attrs); +function stanza_mt:text_tag(name, text, attr, namespaces) + return self:tag(name, attr, namespaces):text(text):up(); +end + +function stanza_mt:tag(name, attr, namespaces) + local s = new_stanza(name, attr, namespaces); local last_add = self.last_add; if not last_add then last_add = {}; self.last_add = last_add; end (last_add[#last_add] or self):add_direct_child(s); @@ -68,8 +115,10 @@ function stanza_mt:tag(name, attrs) end function stanza_mt:text(text) - local last_add = self.last_add; - (last_add and last_add[#last_add] or self):add_direct_child(text); + if text ~= nil and text ~= "" then + local last_add = self.last_add; + (last_add and last_add[#last_add] or self):add_direct_child(text); + end return self; end @@ -85,10 +134,13 @@ function stanza_mt:reset() end function stanza_mt:add_direct_child(child) - if type(child) == "table" then + if is_stanza(child) then t_insert(self.tags, child); + t_insert(self, child); + else + check_text(child, "text"); + t_insert(self, child); end - t_insert(self, child); end function stanza_mt:add_child(child) @@ -165,6 +217,7 @@ end function stanza_mt:maptags(callback) local tags, curr_tag = self.tags, 1; local n_children, n_tags = #self, #tags; + local max_iterations = n_children + 1; local i = 1; while curr_tag <= n_tags and n_tags > 0 do @@ -184,6 +237,11 @@ function stanza_mt:maptags(callback) curr_tag = curr_tag + 1; end i = i + 1; + if i > max_iterations then + -- COMPAT: Hopefully temporary guard against #981 while we + -- figure out the root cause + error("Invalid stanza state! Please report this error."); + end end return self; end @@ -289,12 +347,6 @@ function stanza_mt.get_error(stanza) return error_type, condition or "undefined-condition", text; end -local id = 0; -local function new_id() - id = id + 1; - return "lx"..id; -end - local function preserialize(stanza) local s = { name = stanza.name, attr = stanza.attr }; for _, child in ipairs(stanza) do @@ -307,6 +359,8 @@ local function preserialize(stanza) return s; end +stanza_mt.__freeze = preserialize; + local function deserialize(stanza) -- Set metatable if stanza then @@ -344,14 +398,19 @@ local function deserialize(stanza) return stanza; end -local function clone(stanza) +local function _clone(stanza) local attr, tags = {}, {}; for k,v in pairs(stanza.attr) do attr[k] = v; end - local new = { name = stanza.name, attr = attr, tags = tags }; + local old_namespaces, namespaces = stanza.namespaces; + if old_namespaces then + namespaces = {}; + for k,v in pairs(old_namespaces) do namespaces[k] = v; end + end + local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags }; for i=1,#stanza do local child = stanza[i]; if child.name then - child = clone(child); + child = _clone(child); t_insert(tags, child); end t_insert(new, child); @@ -359,6 +418,13 @@ local function clone(stanza) return setmetatable(new, stanza_mt); end +local function clone(stanza) + if not is_stanza(stanza) then + error("bad argument to clone: expected stanza, got "..type(stanza)); + end + return _clone(stanza); +end + local function message(attr, body) if not body then return new_stanza("message", attr); @@ -367,12 +433,20 @@ local function message(attr, body) end end local function iq(attr) - if attr and not attr.id then attr.id = new_id(); end - return new_stanza("iq", attr or { id = new_id() }); + if not (attr and attr.id) then + error("iq stanzas require an id attribute"); + end + return new_stanza("iq", attr); end local function reply(orig) - return new_stanza(orig.name, orig.attr and { to = orig.attr.from, from = orig.attr.to, id = orig.attr.id, type = ((orig.name == "iq" and "result") or orig.attr.type) }); + return new_stanza(orig.name, + orig.attr and { + to = orig.attr.from, + from = orig.attr.to, + id = orig.attr.id, + type = ((orig.name == "iq" and "result") or orig.attr.type) + }); end local xmpp_stanzas_attr = { xmlns = xmlns_stanzas }; @@ -433,7 +507,6 @@ return { stanza_mt = stanza_mt; stanza = new_stanza; is_stanza = is_stanza; - new_id = new_id; preserialize = preserialize; deserialize = deserialize; clone = clone; diff --git a/util/startup.lua b/util/startup.lua new file mode 100644 index 00000000..e92867dc --- /dev/null +++ b/util/startup.lua @@ -0,0 +1,552 @@ +-- Ignore the CFG_* variables +-- luacheck: ignore 113/CFG_CONFIGDIR 113/CFG_SOURCEDIR 113/CFG_DATADIR 113/CFG_PLUGINDIR +local startup = {}; + +local prosody = { events = require "util.events".new() }; +local logger = require "util.logger"; +local log = logger.init("startup"); + +local config = require "core.configmanager"; + +local dependencies = require "util.dependencies"; + +local original_logging_config; + +function startup.read_config() + local filenames = {}; + + local filename; + if arg[1] == "--config" and arg[2] then + table.insert(filenames, arg[2]); + if CFG_CONFIGDIR then + table.insert(filenames, CFG_CONFIGDIR.."/"..arg[2]); + end + table.remove(arg, 1); table.remove(arg, 1); + elseif os.getenv("PROSODY_CONFIG") then -- Passed by prosodyctl + table.insert(filenames, os.getenv("PROSODY_CONFIG")); + else + table.insert(filenames, (CFG_CONFIGDIR or ".").."/prosody.cfg.lua"); + end + for _,_filename in ipairs(filenames) do + filename = _filename; + local file = io.open(filename); + if file then + file:close(); + prosody.config_file = filename; + CFG_CONFIGDIR = filename:match("^(.*)[\\/][^\\/]*$"); -- luacheck: ignore 111 + break; + end + end + prosody.config_file = filename + local ok, level, err = config.load(filename); + if not ok then + print("\n"); + print("**************************"); + if level == "parser" then + print("A problem occurred while reading the config file "..filename); + print(""); + local err_line, err_message = tostring(err):match("%[string .-%]:(%d*): (.*)"); + if err:match("chunk has too many syntax levels$") then + print("An Include statement in a config file is including an already-included"); + print("file and causing an infinite loop. An Include statement in a config file is..."); + else + print("Error"..(err_line and (" on line "..err_line) or "")..": "..(err_message or tostring(err))); + end + print(""); + elseif level == "file" then + print("Prosody was unable to find the configuration file."); + print("We looked for: "..filename); + print("A sample config file is included in the Prosody download called prosody.cfg.lua.dist"); + print("Copy or rename it to prosody.cfg.lua and edit as necessary."); + end + print("More help on configuring Prosody can be found at https://prosody.im/doc/configure"); + print("Good luck!"); + print("**************************"); + print(""); + os.exit(1); + end + prosody.config_loaded = true; +end + +function startup.check_dependencies() + if not dependencies.check_dependencies() then + os.exit(1); + end +end + +-- luacheck: globals socket server + +function startup.load_libraries() + -- Load socket framework + -- luacheck: ignore 111/server 111/socket + socket = require "socket"; + server = require "net.server" +end + +function startup.init_logging() + -- Initialize logging + local loggingmanager = require "core.loggingmanager" + loggingmanager.reload_logging(); + prosody.events.add_handler("reopen-log-files", function () + loggingmanager.reload_logging(); + prosody.events.fire_event("logging-reloaded"); + end); +end + +function startup.log_dependency_warnings() + dependencies.log_warnings(); +end + +function startup.sanity_check() + for host, host_config in pairs(config.getconfig()) do + if host ~= "*" + and host_config.enabled ~= false + and not host_config.component_module then + return; + end + end + log("error", "No enabled VirtualHost entries found in the config file."); + log("error", "At least one active host is required for Prosody to function. Exiting..."); + os.exit(1); +end + +function startup.sandbox_require() + -- Replace require() with one that doesn't pollute _G, required + -- for neat sandboxing of modules + -- luacheck: ignore 113/getfenv 111/require + local _realG = _G; + local _real_require = require; + local getfenv = getfenv or function (f) + -- FIXME: This is a hack to replace getfenv() in Lua 5.2 + local name, env = debug.getupvalue(debug.getinfo(f or 1).func, 1); + if name == "_ENV" then + return env; + end + end + function require(...) -- luacheck: ignore 121 + local curr_env = getfenv(2); + local curr_env_mt = getmetatable(curr_env); + local _realG_mt = getmetatable(_realG); + if curr_env_mt and curr_env_mt.__index and not curr_env_mt.__newindex and _realG_mt then + local old_newindex, old_index; + old_newindex, _realG_mt.__newindex = _realG_mt.__newindex, curr_env; + old_index, _realG_mt.__index = _realG_mt.__index, function (_G, k) -- luacheck: ignore 212/_G + return rawget(curr_env, k); + end; + local ret = _real_require(...); + _realG_mt.__newindex = old_newindex; + _realG_mt.__index = old_index; + return ret; + end + return _real_require(...); + end +end + +function startup.set_function_metatable() + local mt = {}; + function mt.__index(f, upvalue) + local i, name, value = 0; + repeat + i = i + 1; + name, value = debug.getupvalue(f, i); + until name == upvalue or name == nil; + return value; + end + function mt.__newindex(f, upvalue, value) + local i, name = 0; + repeat + i = i + 1; + name = debug.getupvalue(f, i); + until name == upvalue or name == nil; + if name then + debug.setupvalue(f, i, value); + end + end + function mt.__tostring(f) + local info = debug.getinfo(f); + return ("function(%s:%d)"):format(info.short_src:match("[^\\/]*$"), info.linedefined); + end + debug.setmetatable(function() end, mt); +end + +function startup.detect_platform() + prosody.platform = "unknown"; + if os.getenv("WINDIR") then + prosody.platform = "windows"; + elseif package.config:sub(1,1) == "/" then + prosody.platform = "posix"; + end +end + +function startup.detect_installed() + prosody.installed = nil; + if CFG_SOURCEDIR and (prosody.platform == "windows" or CFG_SOURCEDIR:match("^/")) then + prosody.installed = true; + end +end + +function startup.init_global_state() + -- luacheck: ignore 121 + prosody.bare_sessions = {}; + prosody.full_sessions = {}; + prosody.hosts = {}; + + -- COMPAT: These globals are deprecated + -- luacheck: ignore 111/bare_sessions 111/full_sessions 111/hosts + bare_sessions = prosody.bare_sessions; + full_sessions = prosody.full_sessions; + hosts = prosody.hosts; + + prosody.paths = { source = CFG_SOURCEDIR, config = CFG_CONFIGDIR or ".", + plugins = CFG_PLUGINDIR or "plugins", data = "data" }; + + prosody.arg = _G.arg; + + _G.log = logger.init("general"); + prosody.log = logger.init("general"); + + startup.detect_platform(); + startup.detect_installed(); + _G.prosody = prosody; +end + +function startup.setup_datadir() + prosody.paths.data = config.get("*", "data_path") or CFG_DATADIR or "data"; +end + +function startup.setup_plugindir() + local custom_plugin_paths = config.get("*", "plugin_paths"); + if custom_plugin_paths then + local path_sep = package.config:sub(3,3); + -- path1;path2;path3;defaultpath... + -- luacheck: ignore 111 + CFG_PLUGINDIR = table.concat(custom_plugin_paths, path_sep)..path_sep..(CFG_PLUGINDIR or "plugins"); + prosody.paths.plugins = CFG_PLUGINDIR; + end +end + +function startup.chdir() + if prosody.installed then + -- Change working directory to data path. + require "lfs".chdir(prosody.paths.data); + end +end + +function startup.add_global_prosody_functions() + -- Function to reload the config file + function prosody.reload_config() + log("info", "Reloading configuration file"); + prosody.events.fire_event("reloading-config"); + local ok, level, err = config.load(prosody.config_file); + if not ok then + if level == "parser" then + log("error", "There was an error parsing the configuration file: %s", tostring(err)); + elseif level == "file" then + log("error", "Couldn't read the config file when trying to reload: %s", tostring(err)); + end + else + prosody.events.fire_event("config-reloaded", { + filename = prosody.config_file, + config = config.getconfig(), + }); + end + return ok, (err and tostring(level)..": "..tostring(err)) or nil; + end + + -- Function to reopen logfiles + function prosody.reopen_logfiles() + log("info", "Re-opening log files"); + prosody.events.fire_event("reopen-log-files"); + end + + -- Function to initiate prosody shutdown + function prosody.shutdown(reason, code) + log("info", "Shutting down: %s", reason or "unknown reason"); + prosody.shutdown_reason = reason; + prosody.shutdown_code = code; + prosody.events.fire_event("server-stopping", { + reason = reason; + code = code; + }); + server.setquitting(true); + end +end + +function startup.load_secondary_libraries() + --- Load and initialise core modules + require "util.import" + require "util.xmppstream" + require "core.stanza_router" + require "core.statsmanager" + require "core.hostmanager" + require "core.portmanager" + require "core.modulemanager" + require "core.usermanager" + require "core.rostermanager" + require "core.sessionmanager" + package.loaded['core.componentmanager'] = setmetatable({},{__index=function() + -- COMPAT which version? + log("warn", "componentmanager is deprecated: %s", debug.traceback():match("\n[^\n]*\n[ \t]*([^\n]*)")); + return function() end + end}); + + require "util.array" + require "util.datetime" + require "util.iterators" + require "util.timer" + require "util.helpers" + + pcall(require, "util.signal") -- Not on Windows + + -- Commented to protect us from + -- the second kind of people + --[[ + pcall(require, "remdebug.engine"); + if remdebug then remdebug.engine.start() end + ]] + + require "util.stanza" + require "util.jid" +end + +function startup.init_http_client() + local http = require "net.http" + local config_ssl = config.get("*", "ssl") or {} + local https_client = config.get("*", "client_https_ssl") + http.default.options.sslctx = require "core.certmanager".create_context("client_https port 0", "client", + { capath = config_ssl.capath, cafile = config_ssl.cafile, verify = "peer", }, https_client); +end + +function startup.init_data_store() + require "core.storagemanager"; +end + +function startup.prepare_to_start() + log("info", "Prosody is using the %s backend for connection handling", server.get_backend()); + -- Signal to modules that we are ready to start + prosody.events.fire_event("server-starting"); + prosody.start_time = os.time(); +end + +function startup.init_global_protection() + -- Catch global accesses + -- luacheck: ignore 212/t + local locked_globals_mt = { + __index = function (t, k) log("warn", "%s", debug.traceback("Attempt to read a non-existent global '"..tostring(k).."'", 2)); end; + __newindex = function (t, k, v) error("Attempt to set a global: "..tostring(k).." = "..tostring(v), 2); end; + }; + + function prosody.unlock_globals() + setmetatable(_G, nil); + end + + function prosody.lock_globals() + setmetatable(_G, locked_globals_mt); + end + + -- And lock now... + prosody.lock_globals(); +end + +function startup.read_version() + -- Try to determine version + local version_file = io.open((CFG_SOURCEDIR or ".").."/prosody.version"); + prosody.version = "unknown"; + if version_file then + prosody.version = version_file:read("*a"):gsub("%s*$", ""); + version_file:close(); + if #prosody.version == 12 and prosody.version:match("^[a-f0-9]+$") then + prosody.version = "hg:"..prosody.version; + end + else + local hg = require"util.mercurial"; + local hgid = hg.check_id(CFG_SOURCEDIR or "."); + if hgid then prosody.version = "hg:" .. hgid; end + end +end + +function startup.log_greeting() + log("info", "Hello and welcome to Prosody version %s", prosody.version); +end + +function startup.notify_started() + prosody.events.fire_event("server-started"); +end + +-- Override logging config (used by prosodyctl) +function startup.force_console_logging() + original_logging_config = config.get("*", "log"); + config.set("*", "log", { { levels = { min = os.getenv("PROSODYCTL_LOG_LEVEL") or "info" }, to = "console" } }); +end + +function startup.switch_user() + -- Switch away from root and into the prosody user -- + -- NOTE: This function is only used by prosodyctl. + -- The prosody process is built with the assumption that + -- it is already started as the appropriate user. + + local want_pposix_version = "0.4.0"; + local have_pposix, pposix = pcall(require, "util.pposix"); + + if have_pposix and pposix then + if pposix._VERSION ~= want_pposix_version then + print(string.format("Unknown version (%s) of binary pposix module, expected %s", + tostring(pposix._VERSION), want_pposix_version)); + os.exit(1); + end + prosody.current_uid = pposix.getuid(); + local arg_root = arg[1] == "--root"; + if arg_root then table.remove(arg, 1); end + if prosody.current_uid == 0 and config.get("*", "run_as_root") ~= true and not arg_root then + -- We haz root! + local desired_user = config.get("*", "prosody_user") or "prosody"; + local desired_group = config.get("*", "prosody_group") or desired_user; + local ok, err = pposix.setgid(desired_group); + if ok then + ok, err = pposix.initgroups(desired_user); + end + if ok then + ok, err = pposix.setuid(desired_user); + if ok then + -- Yay! + prosody.switched_user = true; + end + end + if not prosody.switched_user then + -- Boo! + print("Warning: Couldn't switch to Prosody user/group '"..tostring(desired_user).."'/'"..tostring(desired_group).."': "..tostring(err)); + else + -- Make sure the Prosody user can read the config + local conf, err, errno = io.open(prosody.config_file); + if conf then + conf:close(); + else + print("The config file is not readable by the '"..desired_user.."' user."); + print("Prosody will not be able to read it."); + print("Error was "..err); + os.exit(1); + end + end + end + + -- Set our umask to protect data files + pposix.umask(config.get("*", "umask") or "027"); + pposix.setenv("HOME", prosody.paths.data); + pposix.setenv("PROSODY_CONFIG", prosody.config_file); + else + print("Error: Unable to load pposix module. Check that Prosody is installed correctly.") + print("For more help send the below error to us through https://prosody.im/discuss"); + print(tostring(pposix)) + os.exit(1); + end +end + +function startup.check_unwriteable() + local function test_writeable(filename) + local f, err = io.open(filename, "a"); + if not f then + return false, err; + end + f:close(); + return true; + end + + local unwriteable_files = {}; + if type(original_logging_config) == "string" and original_logging_config:sub(1,1) ~= "*" then + local ok, err = test_writeable(original_logging_config); + if not ok then + table.insert(unwriteable_files, err); + end + elseif type(original_logging_config) == "table" then + for _, rule in ipairs(original_logging_config) do + if rule.filename then + local ok, err = test_writeable(rule.filename); + if not ok then + table.insert(unwriteable_files, err); + end + end + end + end + + if #unwriteable_files > 0 then + print("One of more of the Prosody log files are not"); + print("writeable, please correct the errors and try"); + print("starting prosodyctl again."); + print(""); + for _, err in ipairs(unwriteable_files) do + print(err); + end + print(""); + os.exit(1); + end +end + +function startup.make_host(hostname) + return { + type = "local", + events = prosody.events, + modules = {}, + sessions = {}, + users = require "core.usermanager".new_null_provider(hostname) + }; +end + +function startup.make_dummy_hosts() + -- When running under prosodyctl, we don't want to + -- fully initialize the server, so we populate prosody.hosts + -- with just enough things for most code to work correctly + -- luacheck: ignore 122/hosts + prosody.core_post_stanza = function () end; -- TODO: mod_router! + + for hostname in pairs(config.getconfig()) do + prosody.hosts[hostname] = startup.make_host(hostname); + end +end + +-- prosodyctl only +function startup.prosodyctl() + startup.init_global_state(); + startup.read_config(); + startup.force_console_logging(); + startup.init_logging(); + startup.setup_plugindir(); + startup.setup_datadir(); + startup.chdir(); + startup.read_version(); + startup.switch_user(); + startup.check_dependencies(); + startup.log_dependency_warnings(); + startup.check_unwriteable(); + startup.load_libraries(); + startup.init_http_client(); + startup.make_dummy_hosts(); +end + +function startup.prosody() + -- These actions are in a strict order, as many depend on + -- previous steps to have already been performed + startup.init_global_state(); + startup.read_config(); + startup.init_logging(); + startup.sanity_check(); + startup.sandbox_require(); + startup.set_function_metatable(); + startup.check_dependencies(); + startup.init_logging(); + startup.load_libraries(); + startup.setup_plugindir(); + startup.setup_datadir(); + startup.chdir(); + startup.add_global_prosody_functions(); + startup.read_version(); + startup.log_greeting(); + startup.log_dependency_warnings(); + startup.load_secondary_libraries(); + startup.init_http_client(); + startup.init_data_store(); + startup.init_global_protection(); + startup.prepare_to_start(); + startup.notify_started(); +end + +return startup; diff --git a/util/template.lua b/util/template.lua index 04ebb93d..c11037c5 100644 --- a/util/template.lua +++ b/util/template.lua @@ -4,12 +4,13 @@ local setmetatable = setmetatable; local pairs = pairs; local ipairs = ipairs; local error = error; -local loadstring = loadstring; +local envload = require "util.envload".envload; local debug = debug; local t_remove = table.remove; local parse_xml = require "util.xml".parse; local _ENV = nil; +-- luacheck: std none local function trim_xml(stanza) for i=#stanza,1,-1 do @@ -72,7 +73,7 @@ local function create_cloner(stanza, chunkname) src = src.."local _"..i.."="..lookup[i]..";"; end src = src.."return "..name..";end"; - local f,err = loadstring(src, chunkname); + local f,err = envload(src, chunkname); if not f then error(err); end return f(setmetatable, stanza_mt); end diff --git a/util/termcolours.lua b/util/termcolours.lua index 23c9156b..829d84af 100644 --- a/util/termcolours.lua +++ b/util/termcolours.lua @@ -26,6 +26,7 @@ end local orig_color = windows and windows.get_consolecolor and windows.get_consolecolor(); local _ENV = nil; +-- luacheck: std none local stylemap = { reset = 0; bright = 1, dim = 2, underscore = 4, blink = 5, reverse = 7, hidden = 8; diff --git a/util/throttle.lua b/util/throttle.lua index 1012f78a..d2036e9e 100644 --- a/util/throttle.lua +++ b/util/throttle.lua @@ -3,6 +3,7 @@ local gettime = require "util.time".now local setmetatable = setmetatable; local _ENV = nil; +-- luacheck: std none local throttle = {}; local throttle_mt = { __index = throttle }; diff --git a/util/time.lua b/util/time.lua deleted file mode 100644 index 84cff877..00000000 --- a/util/time.lua +++ /dev/null @@ -1,8 +0,0 @@ --- Import gettime() from LuaSocket, as a way to access high-resolution time --- in a platform-independent way - -local socket_gettime = require "socket".gettime; - -return { - now = socket_gettime; -} diff --git a/util/timer.lua b/util/timer.lua index 7e2e9414..4670e196 100644 --- a/util/timer.lua +++ b/util/timer.lua @@ -6,78 +6,102 @@ -- COPYING file in the source package for more information. -- +local indexedbheap = require "util.indexedbheap"; +local log = require "util.logger".init("timer"); local server = require "net.server"; -local math_min = math.min -local math_huge = math.huge local get_time = require "util.time".now -local t_insert = table.insert; -local pairs = pairs; local type = type; - -local data = {}; -local new_data = {}; +local debug_traceback = debug.traceback; +local tostring = tostring; +local xpcall = require "util.xpcall".xpcall; +local math_max = math.max; local _ENV = nil; +-- luacheck: std none -local _add_task; -if not server.event then - function _add_task(delay, callback) - local current_time = get_time(); - delay = delay + current_time; - if delay >= current_time then - t_insert(new_data, {delay, callback}); - else - local r = callback(current_time); - if r and type(r) == "number" then - return _add_task(r, callback); - end +local _add_task = server.add_task; + +local _server_timer; +local _active_timers = 0; +local h = indexedbheap.create(); +local params = {}; +local next_time = nil; +local function _traceback_handler(err) log("error", "Traceback[timer]: %s", debug_traceback(tostring(err), 2)); end +local function _on_timer(now) + local peek; + while true do + peek = h:peek(); + if peek == nil or peek > now then break; end + local _, callback, id = h:pop(); + local param = params[id]; + params[id] = nil; + --item(now, id, _param); + local success, err = xpcall(callback, _traceback_handler, now, id, param); + if success and type(err) == "number" then + h:insert(callback, err + now, id); -- re-add + params[id] = param; end end - server._addtimer(function() - local current_time = get_time(); - if #new_data > 0 then - for _, d in pairs(new_data) do - t_insert(data, d); - end - new_data = {}; - end + if peek ~= nil and _active_timers > 1 and peek == next_time then + -- Another instance of _on_timer already set next_time to the same value, + -- so it should be safe to not renew this timer event + peek = nil; + else + next_time = peek; + end - local next_time = math_huge; - for i, d in pairs(data) do - local t, callback = d[1], d[2]; - if t <= current_time then - data[i] = nil; - local r = callback(current_time); - if type(r) == "number" then - _add_task(r, callback); - next_time = math_min(next_time, r); - end - else - next_time = math_min(next_time, t - current_time); - end - end - return next_time; - end); -else - local event = server.event; - local event_base = server.event_base; - local EVENT_LEAVE = (event.core and event.core.LEAVE) or -1; + if peek then + -- peek is the time of the next event + return peek - now; + end + _active_timers = _active_timers - 1; +end +local function add_task(delay, callback, param) + local current_time = get_time(); + local event_time = current_time + delay; - function _add_task(delay, callback) - local event_handle; - event_handle = event_base:addevent(nil, 0, function () - local ret = callback(get_time()); - if ret then - return 0, ret; - elseif event_handle then - return EVENT_LEAVE; - end + local id = h:insert(callback, event_time); + params[id] = param; + if next_time == nil or event_time < next_time then + next_time = event_time; + if _server_timer then + _server_timer:close(); + _server_timer = nil; + else + _active_timers = _active_timers + 1; + end + _server_timer = _add_task(next_time - current_time, _on_timer); + end + return id; +end +local function stop(id) + params[id] = nil; + local result, item, result_sync = h:remove(id); + local peek = h:peek(); + if peek ~= next_time and _server_timer then + next_time = peek; + _server_timer:close(); + if next_time ~= nil then + _server_timer = _add_task(math_max(next_time - get_time(), 0), _on_timer); end - , delay); end + return result, item, result_sync; +end +local function reschedule(id, delay) + local current_time = get_time(); + local event_time = current_time + delay; + h:reprioritize(id, delay); + if next_time == nil or event_time < next_time then + next_time = event_time; + _add_task(next_time - current_time, _on_timer); + end + return id; end return { - add_task = _add_task; + add_task = add_task; + stop = stop; + reschedule = reschedule; }; + diff --git a/util/vcard.lua b/util/vcard.lua new file mode 100644 index 00000000..bb299fab --- /dev/null +++ b/util/vcard.lua @@ -0,0 +1,574 @@ +-- Copyright (C) 2011-2014 Kim Alvefur +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +-- TODO +-- Fix folding. + +local st = require "util.stanza"; +local t_insert, t_concat = table.insert, table.concat; +local type = type; +local pairs, ipairs = pairs, ipairs; + +local from_text, to_text, from_xep54, to_xep54; + +local line_sep = "\n"; + +local vCard_dtd; -- See end of file +local vCard4_dtd; + +local function vCard_esc(s) + return s:gsub("[,:;\\]", "\\%1"):gsub("\n","\\n"); +end + +local function vCard_unesc(s) + return s:gsub("\\?[\\nt:;,]", { + ["\\\\"] = "\\", + ["\\n"] = "\n", + ["\\r"] = "\r", + ["\\t"] = "\t", + ["\\:"] = ":", -- FIXME Shouldn't need to espace : in values, just params + ["\\;"] = ";", + ["\\,"] = ",", + [":"] = "\29", + [";"] = "\30", + [","] = "\31", + }); +end + +local function item_to_xep54(item) + local t = st.stanza(item.name, { xmlns = "vcard-temp" }); + + local prop_def = vCard_dtd[item.name]; + if prop_def == "text" then + t:text(item[1]); + elseif type(prop_def) == "table" then + if prop_def.types and item.TYPE then + if type(item.TYPE) == "table" then + for _,v in pairs(prop_def.types) do + for _,typ in pairs(item.TYPE) do + if typ:upper() == v then + t:tag(v):up(); + break; + end + end + end + else + t:tag(item.TYPE:upper()):up(); + end + end + + if prop_def.props then + for _,prop in pairs(prop_def.props) do + if item[prop] then + for _, v in ipairs(item[prop]) do + t:text_tag(prop, v); + end + end + end + end + + if prop_def.value then + t:text_tag(prop_def.value, item[1]); + elseif prop_def.values then + local prop_def_values = prop_def.values; + local repeat_last = prop_def_values.behaviour == "repeat-last" and prop_def_values[#prop_def_values]; + for i=1,#item do + t:text_tag(prop_def.values[i] or repeat_last, item[i]); + end + end + end + + return t; +end + +local function vcard_to_xep54(vCard) + local t = st.stanza("vCard", { xmlns = "vcard-temp" }); + for i=1,#vCard do + t:add_child(item_to_xep54(vCard[i])); + end + return t; +end + +function to_xep54(vCards) + if not vCards[1] or vCards[1].name then + return vcard_to_xep54(vCards) + else + local t = st.stanza("xCard", { xmlns = "vcard-temp" }); + for i=1,#vCards do + t:add_child(vcard_to_xep54(vCards[i])); + end + return t; + end +end + +function from_text(data) + data = data -- unfold and remove empty lines + :gsub("\r\n","\n") + :gsub("\n ", "") + :gsub("\n\n+","\n"); + local vCards = {}; + local current; + for line in data:gmatch("[^\n]+") do + line = vCard_unesc(line); + local name, params, value = line:match("^([-%a]+)(\30?[^\29]*)\29(.*)$"); + value = value:gsub("\29",":"); + if #params > 0 then + local _params = {}; + for k,isval,v in params:gmatch("\30([^=]+)(=?)([^\30]*)") do + k = k:upper(); + local _vt = {}; + for _p in v:gmatch("[^\31]+") do + _vt[#_vt+1]=_p + _vt[_p]=true; + end + if isval == "=" then + _params[k]=_vt; + else + _params[k]=true; + end + end + params = _params; + end + if name == "BEGIN" and value == "VCARD" then + current = {}; + vCards[#vCards+1] = current; + elseif name == "END" and value == "VCARD" then + current = nil; + elseif current and vCard_dtd[name] then + local dtd = vCard_dtd[name]; + local item = { name = name }; + t_insert(current, item); + local up = current; + current = item; + if dtd.types then + for _, t in ipairs(dtd.types) do + t = t:lower(); + if ( params.TYPE and params.TYPE[t] == true) + or params[t] == true then + current.TYPE=t; + end + end + end + if dtd.props then + for _, p in ipairs(dtd.props) do + if params[p] then + if params[p] == true then + current[p]=true; + else + for _, prop in ipairs(params[p]) do + current[p]=prop; + end + end + end + end + end + if dtd == "text" or dtd.value then + t_insert(current, value); + elseif dtd.values then + for p in ("\30"..value):gmatch("\30([^\30]*)") do + t_insert(current, p); + end + end + current = up; + end + end + return vCards; +end + +local function item_to_text(item) + local value = {}; + for i=1,#item do + value[i] = vCard_esc(item[i]); + end + value = t_concat(value, ";"); + + local params = ""; + for k,v in pairs(item) do + if type(k) == "string" and k ~= "name" then + params = params .. (";%s=%s"):format(k, type(v) == "table" and t_concat(v,",") or v); + end + end + + return ("%s%s:%s"):format(item.name, params, value) +end + +local function vcard_to_text(vcard) + local t={}; + t_insert(t, "BEGIN:VCARD") + for i=1,#vcard do + t_insert(t, item_to_text(vcard[i])); + end + t_insert(t, "END:VCARD") + return t_concat(t, line_sep); +end + +function to_text(vCards) + if vCards[1] and vCards[1].name then + return vcard_to_text(vCards) + else + local t = {}; + for i=1,#vCards do + t[i]=vcard_to_text(vCards[i]); + end + return t_concat(t, line_sep); + end +end + +local function from_xep54_item(item) + local prop_name = item.name; + local prop_def = vCard_dtd[prop_name]; + + local prop = { name = prop_name }; + + if prop_def == "text" then + prop[1] = item:get_text(); + elseif type(prop_def) == "table" then + if prop_def.value then --single item + prop[1] = item:get_child_text(prop_def.value) or ""; + elseif prop_def.values then --array + local value_names = prop_def.values; + if value_names.behaviour == "repeat-last" then + for i=1,#item.tags do + t_insert(prop, item.tags[i]:get_text() or ""); + end + else + for i=1,#value_names do + t_insert(prop, item:get_child_text(value_names[i]) or ""); + end + end + elseif prop_def.names then + local names = prop_def.names; + for i=1,#names do + if item:get_child(names[i]) then + prop[1] = names[i]; + break; + end + end + end + + if prop_def.props_verbatim then + for k,v in pairs(prop_def.props_verbatim) do + prop[k] = v; + end + end + + if prop_def.types then + local types = prop_def.types; + prop.TYPE = {}; + for i=1,#types do + if item:get_child(types[i]) then + t_insert(prop.TYPE, types[i]:lower()); + end + end + if #prop.TYPE == 0 then + prop.TYPE = nil; + end + end + + -- A key-value pair, within a key-value pair? + if prop_def.props then + local params = prop_def.props; + for i=1,#params do + local name = params[i] + local data = item:get_child_text(name); + if data then + prop[name] = prop[name] or {}; + t_insert(prop[name], data); + end + end + end + else + return nil + end + + return prop; +end + +local function from_xep54_vCard(vCard) + local tags = vCard.tags; + local t = {}; + for i=1,#tags do + t_insert(t, from_xep54_item(tags[i])); + end + return t +end + +function from_xep54(vCard) + if vCard.attr.xmlns ~= "vcard-temp" then + return nil, "wrong-xmlns"; + end + if vCard.name == "xCard" then -- A collection of vCards + local t = {}; + local vCards = vCard.tags; + for i=1,#vCards do + t[i] = from_xep54_vCard(vCards[i]); + end + return t + elseif vCard.name == "vCard" then -- A single vCard + return from_xep54_vCard(vCard) + end +end + +local vcard4 = { } + +function vcard4:text(node, params, value) -- luacheck: ignore 212/params + self:tag(node:lower()) + -- FIXME params + if type(value) == "string" then + self:text_tag("text", value); + elseif vcard4[node] then + vcard4[node](value); + end + self:up(); +end + +function vcard4.N(value) + for i, k in ipairs(vCard_dtd.N.values) do + value:text_tag(k, value[i]); + end +end + +local xmlns_vcard4 = "urn:ietf:params:xml:ns:vcard-4.0" + +local function item_to_vcard4(item) + local typ = item.name:lower(); + local t = st.stanza(typ, { xmlns = xmlns_vcard4 }); + + local prop_def = vCard4_dtd[typ]; + if prop_def == "text" then + t:text_tag("text", item[1]); + elseif prop_def == "uri" then + if item.ENCODING and item.ENCODING[1] == 'b' then + t:text_tag("uri", "data:;base64," .. item[1]); + else + t:text_tag("uri", item[1]); + end + elseif type(prop_def) == "table" then + if prop_def.values then + for i, v in ipairs(prop_def.values) do + t:text_tag(v:lower(), item[i]); + end + else + t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"}) + end + else + t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"}) + end + return t; +end + +local function vcard_to_vcard4xml(vCard) + local t = st.stanza("vcard", { xmlns = xmlns_vcard4 }); + for i=1,#vCard do + t:add_child(item_to_vcard4(vCard[i])); + end + return t; +end + +local function vcards_to_vcard4xml(vCards) + if not vCards[1] or vCards[1].name then + return vcard_to_vcard4xml(vCards) + else + local t = st.stanza("vcards", { xmlns = xmlns_vcard4 }); + for i=1,#vCards do + t:add_child(vcard_to_vcard4xml(vCards[i])); + end + return t; + end +end + +-- This was adapted from http://xmpp.org/extensions/xep-0054.html#dtd +vCard_dtd = { + VERSION = "text", --MUST be 3.0, so parsing is redundant + FN = "text", + N = { + values = { + "FAMILY", + "GIVEN", + "MIDDLE", + "PREFIX", + "SUFFIX", + }, + }, + NICKNAME = "text", + PHOTO = { + props_verbatim = { ENCODING = { "b" } }, + props = { "TYPE" }, + value = "BINVAL", --{ "EXTVAL", }, + }, + BDAY = "text", + ADR = { + types = { + "HOME", + "WORK", + "POSTAL", + "PARCEL", + "DOM", + "INTL", + "PREF", + }, + values = { + "POBOX", + "EXTADD", + "STREET", + "LOCALITY", + "REGION", + "PCODE", + "CTRY", + } + }, + LABEL = { + types = { + "HOME", + "WORK", + "POSTAL", + "PARCEL", + "DOM", + "INTL", + "PREF", + }, + value = "LINE", + }, + TEL = { + types = { + "HOME", + "WORK", + "VOICE", + "FAX", + "PAGER", + "MSG", + "CELL", + "VIDEO", + "BBS", + "MODEM", + "ISDN", + "PCS", + "PREF", + }, + value = "NUMBER", + }, + EMAIL = { + types = { + "HOME", + "WORK", + "INTERNET", + "PREF", + "X400", + }, + value = "USERID", + }, + JABBERID = "text", + MAILER = "text", + TZ = "text", + GEO = { + values = { + "LAT", + "LON", + }, + }, + TITLE = "text", + ROLE = "text", + LOGO = "copy of PHOTO", + AGENT = "text", + ORG = { + values = { + behaviour = "repeat-last", + "ORGNAME", + "ORGUNIT", + } + }, + CATEGORIES = { + values = "KEYWORD", + }, + NOTE = "text", + PRODID = "text", + REV = "text", + SORTSTRING = "text", + SOUND = "copy of PHOTO", + UID = "text", + URL = "text", + CLASS = { + names = { -- The item.name is the value if it's one of these. + "PUBLIC", + "PRIVATE", + "CONFIDENTIAL", + }, + }, + KEY = { + props = { "TYPE" }, + value = "CRED", + }, + DESC = "text", +}; +vCard_dtd.LOGO = vCard_dtd.PHOTO; +vCard_dtd.SOUND = vCard_dtd.PHOTO; + +vCard4_dtd = { + source = "uri", + kind = "text", + xml = "text", + fn = "text", + n = { + values = { + "family", + "given", + "middle", + "prefix", + "suffix", + }, + }, + nickname = "text", + photo = "uri", + bday = "date-and-or-time", + anniversary = "date-and-or-time", + gender = "text", + adr = { + values = { + "pobox", + "ext", + "street", + "locality", + "region", + "code", + "country", + } + }, + tel = "text", + email = "text", + impp = "uri", + lang = "language-tag", + tz = "text", + geo = "uri", + title = "text", + role = "text", + logo = "uri", + org = "text", + member = "uri", + related = "uri", + categories = "text", + note = "text", + prodid = "text", + rev = "timestamp", + sound = "uri", + uid = "uri", + clientpidmap = "number, uuid", + url = "uri", + version = "text", + key = "uri", + fburl = "uri", + caladruri = "uri", + caluri = "uri", +}; + +return { + from_text = from_text; + to_text = to_text; + + from_xep54 = from_xep54; + to_xep54 = to_xep54; + + to_vcard4 = vcards_to_vcard4xml; +}; diff --git a/util/watchdog.lua b/util/watchdog.lua index aa8c6486..516e60e4 100644 --- a/util/watchdog.lua +++ b/util/watchdog.lua @@ -3,6 +3,7 @@ local setmetatable = setmetatable; local os_time = os.time; local _ENV = nil; +-- luacheck: std none local watchdog_methods = {}; local watchdog_mt = { __index = watchdog_methods }; diff --git a/util/x509.lua b/util/x509.lua index f228b201..15cc4d3c 100644 --- a/util/x509.lua +++ b/util/x509.lua @@ -25,6 +25,7 @@ local log = require "util.logger".init("x509"); local s_format = string.format; local _ENV = nil; +-- luacheck: std none local oid_commonname = "2.5.4.3"; -- [LDAP] 2.3 local oid_subjectaltname = "2.5.29.17"; -- [PKIX] 4.2.1.6 diff --git a/util/xml.lua b/util/xml.lua index 733d821a..dac3f6fe 100644 --- a/util/xml.lua +++ b/util/xml.lua @@ -1,8 +1,11 @@ local st = require "util.stanza"; local lxp = require "lxp"; +local t_insert = table.insert; +local t_remove = table.remove; local _ENV = nil; +-- luacheck: std none local parse_xml = (function() local ns_prefixes = { @@ -14,6 +17,21 @@ local parse_xml = (function() --luacheck: ignore 212/self local handler = {}; local stanza = st.stanza("root"); + local namespaces = {}; + local prefixes = {}; + function handler:StartNamespaceDecl(prefix, url) + if prefix ~= nil then + t_insert(namespaces, url); + t_insert(prefixes, prefix); + end + end + function handler:EndNamespaceDecl(prefix) + if prefix ~= nil then + -- we depend on each StartNamespaceDecl having a paired EndNamespaceDecl + t_remove(namespaces); + t_remove(prefixes); + end + end function handler:StartElement(tagname, attr) local curr_ns,name = tagname:match(ns_pattern); if name == "" then @@ -34,7 +52,11 @@ local parse_xml = (function() end end end - stanza:tag(name, attr); + local n = {} + for i=1,#namespaces do + n[prefixes[i]] = namespaces[i]; + end + stanza:tag(name, attr, n); end function handler:CharacterData(data) stanza:text(data); diff --git a/util/xmppstream.lua b/util/xmppstream.lua index 7be63285..58cbd18e 100644 --- a/util/xmppstream.lua +++ b/util/xmppstream.lua @@ -25,6 +25,7 @@ local lxp_supports_bytecount = not not lxp.new({}).getcurrentbytecount; local default_stanza_size_limit = 1024*1024*10; -- 10MB local _ENV = nil; +-- luacheck: std none local new_parser = lxp.new; @@ -47,7 +48,10 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress) local cb_streamopened = stream_callbacks.streamopened; local cb_streamclosed = stream_callbacks.streamclosed; - local cb_error = stream_callbacks.error or function(session, e, stanza) error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); end; + local cb_error = stream_callbacks.error or + function(_, e, stanza) + error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); + end; local cb_handlestanza = stream_callbacks.handlestanza; cb_handleprogress = cb_handleprogress or dummy_cb; @@ -126,13 +130,7 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress) t_insert(oldstanza.tags, stanza); end end - if lxp_supports_xmldecl then - function xml_handlers:XmlDecl(version, encoding, standalone) - if lxp_supports_bytecount then - cb_handleprogress(self:getcurrentbytecount()); - end - end - end + function xml_handlers:StartCdataSection() if lxp_supports_bytecount then if stanza then @@ -203,6 +201,18 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress) end end + if lxp_supports_xmldecl then + function xml_handlers:XmlDecl(version, encoding, standalone) + if lxp_supports_bytecount then + cb_handleprogress(self:getcurrentbytecount()); + end + if (encoding and encoding:lower() ~= "utf-8") + or (standalone == "no") + or (version and version ~= "1.0") then + return restricted_handler(self); + end + end + end if lxp_supports_doctype then xml_handlers.StartDoctypeDecl = restricted_handler; end @@ -214,7 +224,7 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress) stack = {}; end - local function set_session(stream, new_session) + local function set_session(stream, new_session) -- luacheck: ignore 212/stream session = new_session; end @@ -238,7 +248,7 @@ local function new(session, stream_callbacks, stanza_size_limit) local parser = new_parser(handlers, ns_separator, false); local parse = parser.parse; - function session.open_stream(session, from, to) + function session.open_stream(session, from, to) -- luacheck: ignore 432/session local send = session.sends2s or session.send; local attr = { @@ -264,14 +274,19 @@ local function new(session, stream_callbacks, stanza_size_limit) n_outstanding_bytes = 0; meta.reset(); end, - feed = function (self, data) + feed = function (self, data) -- luacheck: ignore 212/self if lxp_supports_bytecount then n_outstanding_bytes = n_outstanding_bytes + #data; end - local ok, err = parse(parser, data); + local _parser = parser; + local ok, err = parse(_parser, data); if lxp_supports_bytecount and n_outstanding_bytes > stanza_size_limit then return nil, "stanza-too-large"; end + if parser ~= _parser then + _parser:parse(); + _parser:close(); + end return ok, err; end, set_session = meta.set_session; diff --git a/util/xpcall.lua b/util/xpcall.lua new file mode 100644 index 00000000..d2fc5011 --- /dev/null +++ b/util/xpcall.lua @@ -0,0 +1,9 @@ +local xpcall = xpcall; + +if select(2, xpcall(function (x) return x end, function () end, "test")) ~= "test" then + xpcall = require"util.compat".xpcall; +end + +return { + xpcall = xpcall; +}; |