diff options
Diffstat (limited to 'util')
50 files changed, 2196 insertions, 358 deletions
diff --git a/util/adhoc.lua b/util/adhoc.lua index 17c9eee5..d81b8242 100644 --- a/util/adhoc.lua +++ b/util/adhoc.lua @@ -1,3 +1,5 @@ +-- luacheck: ignore 212/self + local function new_simple_form(form, result_handler) return function(self, data, state) if state then diff --git a/util/array.lua b/util/array.lua index 150b4355..1a8ffec7 100644 --- a/util/array.lua +++ b/util/array.lua @@ -19,7 +19,7 @@ local type = type; local array = {}; local array_base = {}; local array_methods = {}; -local array_mt = { __index = array_methods, __tostring = function (self) return "{"..self:concat(", ").."}"; end }; +local array_mt = { __index = array_methods, __name = "array", __tostring = function (self) return "{"..self:concat(", ").."}"; end }; local function new_array(self, t, _s, _var) if type(t) == "function" then -- Assume iterator diff --git a/util/async.lua b/util/async.lua new file mode 100644 index 00000000..46a64ecb --- /dev/null +++ b/util/async.lua @@ -0,0 +1,252 @@ +local logger = require "util.logger"; +local log = logger.init("util.async"); +local new_id = require "util.id".short; + +local function checkthread() + local thread, main = coroutine.running(); + if not thread or main then + error("Not running in an async context, see https://prosody.im/doc/developers/util/async"); + end + return thread; +end + +local function runner_from_thread(thread) + local level = 0; + -- Find the 'level' of the top-most function (0 == current level, 1 == caller, ...) + while debug.getinfo(thread, level, "") do level = level + 1; end + local name, runner = debug.getlocal(thread, level-1, 1); + if name ~= "self" or type(runner) ~= "table" or runner.thread ~= thread then + return nil; + end + return runner; +end + +local function call_watcher(runner, watcher_name, ...) + local watcher = runner.watchers[watcher_name]; + if not watcher then + return false; + end + runner:log("debug", "Calling '%s' watcher", watcher_name); + local ok, err = pcall(watcher, runner, ...); -- COMPAT: Switch to xpcall after Lua 5.1 + if not ok then + runner:log("error", "Error in '%s' watcher: %s", watcher_name, err); + return nil, err; + end + return true; +end + +local function runner_continue(thread) + -- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure) + if coroutine.status(thread) ~= "suspended" then -- This should suffice + log("error", "unexpected async state: thread not suspended"); + return false; + end + local ok, state, runner = coroutine.resume(thread); + if not ok then + local err = state; + -- Running the coroutine failed, which means we have to find the runner manually, + -- in order to inform the error handler + runner = runner_from_thread(thread); + if not runner then + log("error", "unexpected async state: unable to locate runner during error handling"); + return false; + end + call_watcher(runner, "error", debug.traceback(thread, err)); + runner.state, runner.thread = "ready", nil; + return runner:run(); + elseif state == "ready" then + -- If state is 'ready', it is our responsibility to update runner.state from 'waiting'. + -- We also have to :run(), because the queue might have further items that will not be + -- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer). + runner.state = "ready"; + runner:run(); + end + return true; +end + +local function waiter(num) + local thread = checkthread(); + num = num or 1; + local waiting; + return function () + if num == 0 then return; end -- already done + waiting = true; + coroutine.yield("wait"); + end, function () + num = num - 1; + if num == 0 and waiting then + runner_continue(thread); + elseif num < 0 then + error("done() called too many times"); + end + end; +end + +local function guarder() + local guards = {}; + local default_id = {}; + return function (id, func) + id = id or default_id; + local thread = checkthread(); + local guard = guards[id]; + if not guard then + guard = {}; + guards[id] = guard; + log("debug", "New guard!"); + else + table.insert(guard, thread); + log("debug", "Guarded. %d threads waiting.", #guard) + coroutine.yield("wait"); + end + local function exit() + local next_waiting = table.remove(guard, 1); + if next_waiting then + log("debug", "guard: Executing next waiting thread (%d left)", #guard) + runner_continue(next_waiting); + else + log("debug", "Guard off duty.") + guards[id] = nil; + end + end + if func then + func(); + exit(); + return; + end + return exit; + end; +end + +local runner_mt = {}; +runner_mt.__index = runner_mt; + +local function runner_create_thread(func, self) + local thread = coroutine.create(function (self) -- luacheck: ignore 432/self + while true do + func(coroutine.yield("ready", self)); + end + end); + assert(coroutine.resume(thread, self)); -- Start it up, it will return instantly to wait for the first input + return thread; +end + +local function default_error_watcher(runner, err) + runner:log("error", "Encountered error: %s", err); + error(err); +end +local function default_func(f) f(); end +local function runner(func, watchers, data) + local id = new_id(); + local _log = logger.init("runner" .. id); + return setmetatable({ func = func or default_func, thread = false, state = "ready", notified_state = "ready", + queue = {}, watchers = watchers or { error = default_error_watcher }, data = data, id = id, _log = _log; } + , runner_mt); +end + +-- Add a task item for the runner to process +function runner_mt:run(input) + if input ~= nil then + table.insert(self.queue, input); + self:log("debug", "queued new work item, %d items queued", #self.queue); + end + if self.state ~= "ready" then + -- The runner is busy. Indicate that the task item has been + -- queued, and return information about the current runner state + return true, self.state, #self.queue; + end + + local q, thread = self.queue, self.thread; + if not thread or coroutine.status(thread) == "dead" then + self:log("debug", "creating new coroutine"); + -- Create a new coroutine for this runner + thread = runner_create_thread(self.func, self); + self.thread = thread; + end + + -- Process task item(s) while the queue is not empty, and we're not blocked + local n, state, err = #q, self.state, nil; + self.state = "running"; + self:log("debug", "running main loop"); + while n > 0 and state == "ready" and not err do + local consumed; + -- Loop through queue items, and attempt to run them + for i = 1,n do + local queued_input = q[i]; + local ok, new_state = coroutine.resume(thread, queued_input); + if not ok then + -- There was an error running the coroutine, save the error, mark runner as ready to begin again + consumed, state, err = i, "ready", debug.traceback(thread, new_state); + self.thread = nil; + break; + elseif new_state == "wait" then + -- Runner is blocked on waiting for a task item to complete + consumed, state = i, "waiting"; + break; + end + end + -- Loop ended - either queue empty because all tasks passed without blocking (consumed == nil) + -- or runner is blocked/errored, and consumed will contain the number of tasks processed so far + if not consumed then consumed = n; end + -- Remove consumed items from the queue array + if q[n+1] ~= nil then + n = #q; + end + for i = 1, n do + q[i] = q[consumed+i]; + end + n = #q; + end + -- Runner processed all items it can, so save current runner state + self.state = state; + if err or state ~= self.notified_state then + self:log("debug", "changed state from %s to %s", self.notified_state, err and ("error ("..state..")") or state); + if err then + state = "error" + else + self.notified_state = state; + end + local handler = self.watchers[state]; + if handler then handler(self, err); end + end + if n > 0 then + return self:run(); + end + return true, state, n; +end + +-- Add a task item to the queue without invoking the runner, even if it is idle +function runner_mt:enqueue(input) + table.insert(self.queue, input); + self:log("debug", "queued new work item, %d items queued", #self.queue); + return self; +end + +function runner_mt:log(level, fmt, ...) + return self._log(level, fmt, ...); +end + +function runner_mt:onready(f) + self.watchers.ready = f; + return self; +end + +function runner_mt:onwaiting(f) + self.watchers.waiting = f; + return self; +end + +function runner_mt:onerror(f) + self.watchers.error = f; + return self; +end + +local function ready() + return pcall(checkthread); +end + +return { + ready = ready; + waiter = waiter; + guarder = guarder; + runner = runner; +}; diff --git a/util/cache.lua b/util/cache.lua index 9c141bb6..a5fd5e6d 100644 --- a/util/cache.lua +++ b/util/cache.lua @@ -116,6 +116,25 @@ function cache_methods:tail() return tail.key, tail.value; end +function cache_methods:resize(new_size) + new_size = assert(tonumber(new_size), "cache size must be a number"); + new_size = math.floor(new_size); + assert(new_size > 0, "cache size must be greater than zero"); + local on_evict = self._on_evict; + while self._count > new_size do + local tail = self._tail; + local evicted_key, evicted_value = tail.key, tail.value; + if on_evict ~= nil and (on_evict == false or on_evict(evicted_key, evicted_value) == false) then + -- Cache is full, and we're not allowed to evict + return false; + end + _remove(self, tail); + self._data[evicted_key] = nil; + end + self.size = new_size; + return true; +end + function cache_methods:table() --luacheck: ignore 212/t if not self.proxy_table then @@ -139,6 +158,13 @@ function cache_methods:table() return self.proxy_table; end +function cache_methods:clear() + self._data = {}; + self._count = 0; + self._head = nil; + self._tail = nil; +end + local function new(size, on_evict) size = assert(tonumber(size), "cache size must be a number"); size = math.floor(size); diff --git a/util/caps.lua b/util/caps.lua index cd5ff9c0..de492edb 100644 --- a/util/caps.lua +++ b/util/caps.lua @@ -13,6 +13,7 @@ local t_insert, t_sort, t_concat = table.insert, table.sort, table.concat; local ipairs = ipairs; local _ENV = nil; +-- luacheck: std none local function calculate_hash(disco_info) local identities, features, extensions = {}, {}, {}; diff --git a/util/dataforms.lua b/util/dataforms.lua index 469ce976..7b832f78 100644 --- a/util/dataforms.lua +++ b/util/dataforms.lua @@ -8,12 +8,13 @@ local setmetatable = setmetatable; local ipairs = ipairs; -local tostring, type, next = tostring, type, next; +local type, next = type, next; local t_concat = table.concat; local st = require "util.stanza"; local jid_prep = require "util.jid".prep; local _ENV = nil; +-- luacheck: std none local xmlns_forms = 'jabber:x:data'; @@ -48,7 +49,7 @@ function form_t.form(layout, data, formtype) :add_child(value) :up(); else - form:tag("value"):text(tostring(value)):up(); + form:tag("value"):text(value):up(); end elseif field_type == "boolean" then form:tag("value"):text((value and "1") or "0"):up(); @@ -78,7 +79,7 @@ function form_t.form(layout, data, formtype) has_default = true; end else - form:tag("option", { label= val }):tag("value"):text(tostring(val)):up():up(); + form:tag("option", { label= val }):tag("value"):text(val):up():up(); end end end @@ -94,7 +95,7 @@ function form_t.form(layout, data, formtype) form:tag("value"):text(val.value):up(); end else - form:tag("option", { label= val }):tag("value"):text(tostring(val)):up():up(); + form:tag("option", { label= val }):tag("value"):text(val):up():up(); end end end @@ -248,8 +249,24 @@ field_readers["hidden"] = return field_tag:get_child_text("value"); end + +local function get_form_type(form) + if not st.is_stanza(form) then + return nil, "not a stanza object"; + elseif form.attr.xmlns ~= "jabber:x:data" or form.name ~= "x" then + return nil, "not a dataform element"; + end + for field in form:childtags("field") do + if field.attr.var == "FORM_TYPE" then + return field:get_child_text("value"); + end + end + return ""; +end + return { new = new; + get_type = get_form_type; }; diff --git a/util/datamanager.lua b/util/datamanager.lua index bd8fb7bb..cf96887b 100644 --- a/util/datamanager.lua +++ b/util/datamanager.lua @@ -40,9 +40,10 @@ pcall(function() end); local _ENV = nil; +-- luacheck: std none ---- utils ----- -local encode, decode; +local encode, decode, store_encode; do local urlcodes = setmetatable({}, { __index = function (t, k) t[k] = char(tonumber(k, 16)); return t[k]; end }); @@ -53,6 +54,12 @@ do encode = function (s) return s and (s:gsub("%W", function (c) return format("%%%02x", c:byte()); end)); end + + -- Special encode function for store names, which historically were unencoded. + -- All currently known stores use a-z and underscore, so this one preserves underscores. + store_encode = function (s) + return s and (s:gsub("[^_%w]", function (c) return format("%%%02x", c:byte()); end)); + end end if not atomic_append then @@ -119,6 +126,7 @@ local function getpath(username, host, datastore, ext, create) ext = ext or "dat"; host = (host and encode(host)) or "_global"; username = username and encode(username); + datastore = store_encode(datastore); if username then if create then mkdir(mkdir(mkdir(data_path).."/"..host).."/"..datastore); end return format("%s/%s/%s/%s.%s", data_path, host, datastore, username, ext); diff --git a/util/datetime.lua b/util/datetime.lua index abb4e867..06be9fc2 100644 --- a/util/datetime.lua +++ b/util/datetime.lua @@ -15,6 +15,7 @@ local os_difftime = os.difftime; local tonumber = tonumber; local _ENV = nil; +-- luacheck: std none local function date(t) return os_date("!%Y-%m-%d", t); diff --git a/util/debug.lua b/util/debug.lua index 00f476d0..9a28395a 100644 --- a/util/debug.lua +++ b/util/debug.lua @@ -47,6 +47,7 @@ local function get_upvalues_table(func) for upvalue_num = 1, math.huge do local name, value = debug.getupvalue(func, upvalue_num); if not name then break; end + if name == "" then name = ("[%d]"):format(upvalue_num); end table.insert(upvalues, { name = name, value = value }); end end @@ -112,7 +113,9 @@ end local function build_source_boundary_marker(last_source_desc) local padding = string.rep("-", math.floor(((optimal_line_length - 6) - #last_source_desc)/2)); - return getstring(styles.boundary_padding, "v"..padding).." "..getstring(styles.filename, last_source_desc).." "..getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v ")); + return getstring(styles.boundary_padding, "v"..padding).." ".. + getstring(styles.filename, last_source_desc).." ".. + getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v ")); end local function _traceback(thread, message, level) @@ -142,9 +145,9 @@ local function _traceback(thread, message, level) local last_source_desc; local lines = {}; - for nlevel, level in ipairs(levels) do - local info = level.info; - local line = "..."; + for nlevel, current_level in ipairs(levels) do + local info = current_level.info; + local line; local func_type = info.namewhat.." "; local source_desc = (info.short_src == "[C]" and "C code") or info.short_src or "Unknown"; if func_type == " " then func_type = ""; end; @@ -160,7 +163,9 @@ local function _traceback(thread, message, level) if func_type == "global " or func_type == "local " then func_type = func_type.."function "; end - line = "[Lua] "..getstring(styles.location, info.short_src.." line "..info.currentline).." in "..func_type..getstring(styles.funcname, name).." (defined on line "..info.linedefined..")"; + line = "[Lua] "..getstring(styles.location, info.short_src.." line ".. + info.currentline).." in "..func_type..getstring(styles.funcname, name).. + " (defined on line "..info.linedefined..")"; end if source_desc ~= last_source_desc then -- Venturing into a new source, add marker for previous last_source_desc = source_desc; @@ -169,13 +174,13 @@ local function _traceback(thread, message, level) nlevel = nlevel-1; table.insert(lines, "\t"..(nlevel==0 and ">" or " ")..getstring(styles.level_num, "("..nlevel..") ")..line); local npadding = (" "):rep(#tostring(nlevel)); - if level.locals then - local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding); + if current_level.locals then + local locals_str = string_from_var_table(current_level.locals, optimal_line_length, "\t "..npadding); if locals_str then table.insert(lines, "\t "..npadding.."Locals: "..locals_str); end end - local upvalues_str = string_from_var_table(level.upvalues, optimal_line_length, "\t "..npadding); + local upvalues_str = string_from_var_table(current_level.upvalues, optimal_line_length, "\t "..npadding); if upvalues_str then table.insert(lines, "\t "..npadding.."Upvals: "..upvalues_str); end diff --git a/util/dependencies.lua b/util/dependencies.lua index de840241..9b0afd77 100644 --- a/util/dependencies.lua +++ b/util/dependencies.lua @@ -28,7 +28,7 @@ local function missingdep(name, sources, msg) end print(""); print(msg or (name.." is required for Prosody to run, so we will now exit.")); - print("More help can be found on our website, at http://prosody.im/doc/depends"); + print("More help can be found on our website, at https://prosody.im/doc/depends"); print("**************************"); print(""); end @@ -40,7 +40,7 @@ end package.preload["util.ztact"] = function () if not package.loaded["core.loggingmanager"] then error("util.ztact has been removed from Prosody and you need to fix your config " - .."file. More information can be found at http://prosody.im/doc/packagers#ztact", 0); + .."file. More information can be found at https://prosody.im/doc/packagers#ztact", 0); else error("module 'util.ztact' has been deprecated in Prosody 0.8."); end @@ -156,7 +156,7 @@ local function log_warnings() if ssl then local major, minor, veryminor, patched = ssl._VERSION:match("(%d+)%.(%d+)%.?(%d*)(M?)"); if not major or ((tonumber(major) == 0 and (tonumber(minor) or 0) <= 3 and (tonumber(veryminor) or 0) <= 2) and patched ~= "M") then - prosody.log("error", "This version of LuaSec contains a known bug that causes disconnects, see http://prosody.im/doc/depends"); + prosody.log("error", "This version of LuaSec contains a known bug that causes disconnects, see https://prosody.im/doc/depends"); end end local lxp = softreq"lxp"; @@ -165,7 +165,7 @@ local function log_warnings() prosody.log("error", "The version of LuaExpat on your system leaves Prosody " .."vulnerable to denial-of-service attacks. You should upgrade to " .."LuaExpat 1.3.0 or higher as soon as possible. See " - .."http://prosody.im/doc/depends#luaexpat for more information."); + .."https://prosody.im/doc/depends#luaexpat for more information."); end if not lxp.new({}).getcurrentbytecount then prosody.log("error", "The version of LuaExpat on your system does not support " @@ -173,7 +173,7 @@ local function log_warnings() .."networks (e.g. the internet) vulnerable to denial-of-service " .."attacks. You should upgrade to LuaExpat 1.3.0 or higher as " .."soon as possible. See " - .."http://prosody.im/doc/depends#luaexpat for more information."); + .."https://prosody.im/doc/depends#luaexpat for more information."); end end end diff --git a/util/envload.lua b/util/envload.lua index 926f20c0..6182a1f9 100644 --- a/util/envload.lua +++ b/util/envload.lua @@ -4,7 +4,7 @@ -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- --- luacheck: ignore 113/setfenv +-- luacheck: ignore 113/setfenv 113/loadstring local load, loadstring, setfenv = load, loadstring, setfenv; local io_open = io.open; diff --git a/util/events.lua b/util/events.lua index 6e13619c..0bf0ddcb 100644 --- a/util/events.lua +++ b/util/events.lua @@ -15,6 +15,7 @@ local setmetatable = setmetatable; local next = next; local _ENV = nil; +-- luacheck: std none local function new() -- Map event name to ordered list of handlers (lazily built): handlers[event_name] = array_of_handler_functions @@ -26,7 +27,7 @@ local function new() -- Event map: event_map[handler_function] = priority_number local event_map = {}; -- Called on-demand to build handlers entries - local function _rebuild_index(handlers, event) + local function _rebuild_index(self, event) local _handlers = event_map[event]; if not _handlers or next(_handlers) == nil then return; end local index = {}; @@ -34,7 +35,7 @@ local function new() t_insert(index, handler); end t_sort(index, function(a, b) return _handlers[a] > _handlers[b]; end); - handlers[event] = index; + self[event] = index; return index; end; setmetatable(handlers, { __index = _rebuild_index }); @@ -61,13 +62,13 @@ local function new() local function get_handlers(event) return handlers[event]; end; - local function add_handlers(handlers) - for event, handler in pairs(handlers) do + local function add_handlers(self) + for event, handler in pairs(self) do add_handler(event, handler); end end; - local function remove_handlers(handlers) - for event, handler in pairs(handlers) do + local function remove_handlers(self) + for event, handler in pairs(self) do remove_handler(event, handler); end end; @@ -81,6 +82,7 @@ local function new() end end; local function fire_event(event_name, event_data) + -- luacheck: ignore 432/event_name 432/event_data local w = wrappers[event_name] or global_wrappers; if w then local curr_wrapper = #w; diff --git a/util/filters.lua b/util/filters.lua index f405c0bd..f30dfd9c 100644 --- a/util/filters.lua +++ b/util/filters.lua @@ -9,6 +9,7 @@ local t_insert, t_remove = table.insert, table.remove; local _ENV = nil; +-- luacheck: std none local new_filter_hooks = {}; diff --git a/util/format.lua b/util/format.lua index 5f2b12be..c5e513fa 100644 --- a/util/format.lua +++ b/util/format.lua @@ -4,11 +4,10 @@ local tostring = tostring; local select = select; -local assert = assert; -local unpack = unpack; +local unpack = table.unpack or unpack; -- luacheck: ignore 113/unpack local type = type; -local function format(format, ...) +local function format(formatstring, ...) local args, args_length = { ... }, select('#', ...); -- format specifier spec: @@ -25,7 +24,7 @@ local function format(format, ...) -- process each format specifier local i = 0; - format = format:gsub("%%[^cdiouxXaAeEfgGqs%%]*[cdiouxXaAeEfgGqs%%]", function(spec) + formatstring = formatstring:gsub("%%[^cdiouxXaAeEfgGqs%%]*[cdiouxXaAeEfgGqs%%]", function(spec) if spec ~= "%%" then i = i + 1; local arg = args[i]; @@ -54,21 +53,12 @@ local function format(format, ...) else args[i] = tostring(arg); end - format = format .. " [%s]" + formatstring = formatstring .. " [%s]" end - return format:format(unpack(args)); -end - -local function test() - assert(format("%s", "hello") == "hello"); - assert(format("%s") == "<nil>"); - assert(format("%s", true) == "true"); - assert(format("%d", true) == "[true]"); - assert(format("%%", true) == "% [true]"); + return formatstring:format(unpack(args)); end return { format = format; - test = test; }; diff --git a/util/import.lua b/util/import.lua index c2b9dce1..8ecfe43c 100644 --- a/util/import.lua +++ b/util/import.lua @@ -8,9 +8,9 @@ -local unpack = table.unpack or unpack; --luacheck: ignore 113 +local unpack = table.unpack or unpack; --luacheck: ignore 113 143 local t_insert = table.insert; -function import(module, ...) +function _G.import(module, ...) local m = package.loaded[module] or require(module); if type(m) == "table" and ... then local ret = {}; diff --git a/util/indexedbheap.lua b/util/indexedbheap.lua new file mode 100644 index 00000000..7f193d54 --- /dev/null +++ b/util/indexedbheap.lua @@ -0,0 +1,157 @@ + +local setmetatable = setmetatable; +local math_floor = math.floor; +local t_remove = table.remove; + +local function _heap_insert(self, item, sync, item2, index) + local pos = #self + 1; + while true do + local half_pos = math_floor(pos / 2); + if half_pos == 0 or item > self[half_pos] then break; end + self[pos] = self[half_pos]; + sync[pos] = sync[half_pos]; + index[sync[pos]] = pos; + pos = half_pos; + end + self[pos] = item; + sync[pos] = item2; + index[item2] = pos; +end + +local function _percolate_up(self, k, sync, index) + local tmp = self[k]; + local tmp_sync = sync[k]; + while k ~= 1 do + local parent = math_floor(k/2); + if tmp < self[parent] then break; end + self[k] = self[parent]; + sync[k] = sync[parent]; + index[sync[k]] = k; + k = parent; + end + self[k] = tmp; + sync[k] = tmp_sync; + index[tmp_sync] = k; + return k; +end + +local function _percolate_down(self, k, sync, index) + local tmp = self[k]; + local tmp_sync = sync[k]; + local size = #self; + local child = 2*k; + while 2*k <= size do + if child ~= size and self[child] > self[child + 1] then + child = child + 1; + end + if tmp > self[child] then + self[k] = self[child]; + sync[k] = sync[child]; + index[sync[k]] = k; + else + break; + end + + k = child; + child = 2*k; + end + self[k] = tmp; + sync[k] = tmp_sync; + index[tmp_sync] = k; + return k; +end + +local function _heap_pop(self, sync, index) + local size = #self; + if size == 0 then return nil; end + + local result = self[1]; + local result_sync = sync[1]; + index[result_sync] = nil; + if size == 1 then + self[1] = nil; + sync[1] = nil; + return result, result_sync; + end + self[1] = t_remove(self); + sync[1] = t_remove(sync); + index[sync[1]] = 1; + + _percolate_down(self, 1, sync, index); + + return result, result_sync; +end + +local indexed_heap = {}; + +function indexed_heap:insert(item, priority, id) + if id == nil then + id = self.current_id; + self.current_id = id + 1; + end + self.items[id] = item; + _heap_insert(self.priorities, priority, self.ids, id, self.index); + return id; +end +function indexed_heap:pop() + local priority, id = _heap_pop(self.priorities, self.ids, self.index); + if id then + local item = self.items[id]; + self.items[id] = nil; + return priority, item, id; + end +end +function indexed_heap:peek() + return self.priorities[1]; +end +function indexed_heap:reprioritize(id, priority) + local k = self.index[id]; + if k == nil then return; end + self.priorities[k] = priority; + + k = _percolate_up(self.priorities, k, self.ids, self.index); + _percolate_down(self.priorities, k, self.ids, self.index); +end +function indexed_heap:remove_index(k) + local result = self.priorities[k]; + if result == nil then return; end + + local result_sync = self.ids[k]; + local item = self.items[result_sync]; + local size = #self.priorities; + + self.priorities[k] = self.priorities[size]; + self.ids[k] = self.ids[size]; + self.index[self.ids[k]] = k; + + t_remove(self.priorities); + t_remove(self.ids); + + self.index[result_sync] = nil; + self.items[result_sync] = nil; + + if size > k then + k = _percolate_up(self.priorities, k, self.ids, self.index); + _percolate_down(self.priorities, k, self.ids, self.index); + end + + return result, item, result_sync; +end +function indexed_heap:remove(id) + return self:remove_index(self.index[id]); +end + +local mt = { __index = indexed_heap }; + +local _M = { + create = function() + return setmetatable({ + ids = {}; -- heap of ids, sync'd with priorities + items = {}; -- map id->items + priorities = {}; -- heap of priorities + index = {}; -- map of id->index of id in ids + current_id = 1.5 + }, mt); + end +}; +return _M; diff --git a/util/ip.lua b/util/ip.lua index 81a98ef7..0ec9e297 100644 --- a/util/ip.lua +++ b/util/ip.lua @@ -5,69 +5,76 @@ -- COPYING file in the source package for more information. -- +local net = require "util.net"; +local hex = require "util.hex"; + local ip_methods = {}; -local ip_mt = { __index = function (ip, key) return (ip_methods[key])(ip); end, - __tostring = function (ip) return ip.addr; end, - __eq = function (ipA, ipB) return ipA.addr == ipB.addr; end}; -local hex2bits = { ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111", ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011", ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111" }; + +local ip_mt = { + __index = function (ip, key) + local method = ip_methods[key]; + if not method then return nil; end + local ret = method(ip); + ip[key] = ret; + return ret; + end, + __tostring = function (ip) return ip.addr; end, + __eq = function (ipA, ipB) return ipA.packed == ipB.packed; end +}; + +local hex2bits = { + ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", + ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111", + ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011", + ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111", +}; local function new_ip(ipStr, proto) - if not proto then - local sep = ipStr:match("^%x+(.)"); - if sep == ":" or (not(sep) and ipStr:sub(1,1) == ":") then - proto = "IPv6" - elseif sep == "." then - proto = "IPv4" - end - if not proto then - return nil, "invalid address"; - end - elseif proto ~= "IPv4" and proto ~= "IPv6" then - return nil, "invalid protocol"; - end local zone; - if proto == "IPv6" and ipStr:find('%', 1, true) then + if (not proto or proto == "IPv6") and ipStr:find('%', 1, true) then ipStr, zone = ipStr:match("^(.-)%%(.*)"); end - if proto == "IPv6" and ipStr:find('.', 1, true) then - local changed; - ipStr, changed = ipStr:gsub(":(%d+)%.(%d+)%.(%d+)%.(%d+)$", function(a,b,c,d) - return (":%04X:%04X"):format(a*256+b,c*256+d); - end); - if changed ~= 1 then return nil, "invalid-address"; end + + local packed, err = net.pton(ipStr); + if not packed then return packed, err end + if proto == "IPv6" and #packed ~= 16 then + return nil, "invalid-ipv6"; + elseif proto == "IPv4" and #packed ~= 4 then + return nil, "invalid-ipv4"; + elseif not proto then + if #packed == 16 then + proto = "IPv6"; + elseif #packed == 4 then + proto = "IPv4"; + else + return nil, "unknown protocol"; + end + elseif proto ~= "IPv6" and proto ~= "IPv4" then + return nil, "invalid protocol"; end - return setmetatable({ addr = ipStr, proto = proto, zone = zone }, ip_mt); + return setmetatable({ addr = ipStr, packed = packed, proto = proto, zone = zone }, ip_mt); +end + +function ip_methods:normal() + return net.ntop(self.packed); end -local function toBits(ip) - local result = ""; - local fields = {}; +function ip_methods.bits(ip) + return hex.to(ip.packed):upper():gsub(".", hex2bits); +end + +function ip_methods.bits_full(ip) if ip.proto == "IPv4" then ip = ip.toV4mapped; end - ip = (ip.addr):upper(); - ip:gsub("([^:]*):?", function (c) fields[#fields + 1] = c end); - if not ip:match(":$") then fields[#fields] = nil; end - for i, field in ipairs(fields) do - if field:len() == 0 and i ~= 1 and i ~= #fields then - for _ = 1, 16 * (9 - #fields) do - result = result .. "0"; - end - else - for _ = 1, 4 - field:len() do - result = result .. "0000"; - end - for j = 1, field:len() do - result = result .. hex2bits[field:sub(j, j)]; - end - end - end - return result; + return ip.bits; end +local match; + local function commonPrefixLength(ipA, ipB) - ipA, ipB = toBits(ipA), toBits(ipB); + ipA, ipB = ipA.bits_full, ipB.bits_full; for i = 1, 128 do if ipA:sub(i,i) ~= ipB:sub(i,i) then return i-1; @@ -76,56 +83,60 @@ local function commonPrefixLength(ipA, ipB) return 128; end +-- Instantiate once +local loopback = new_ip("::1"); +local loopback4 = new_ip("127.0.0.0"); +local sixtofour = new_ip("2002::"); +local teredo = new_ip("2001::"); +local linklocal = new_ip("fe80::"); +local linklocal4 = new_ip("169.254.0.0"); +local uniquelocal = new_ip("fc00::"); +local sitelocal = new_ip("fec0::"); +local sixbone = new_ip("3ffe::"); +local defaultunicast = new_ip("::"); +local multicast = new_ip("ff00::"); +local ipv6mapped = new_ip("::ffff:0:0"); + local function v4scope(ip) - local fields = {}; - ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end); - -- Loopback: - if fields[1] == 127 then + if match(ip, loopback4, 8) then return 0x2; - -- Link-local unicast: - elseif fields[1] == 169 and fields[2] == 254 then + elseif match(ip, linklocal4) then return 0x2; - -- Global unicast: - else + else -- Global unicast return 0xE; end end local function v6scope(ip) - -- Loopback: - if ip:match("^[0:]*1$") then + if ip == loopback then return 0x2; - -- Link-local unicast: - elseif ip:match("^[Ff][Ee][89ABab]") then + elseif match(ip, linklocal, 10) then return 0x2; - -- Site-local unicast: - elseif ip:match("^[Ff][Ee][CcDdEeFf]") then + elseif match(ip, sitelocal, 10) then return 0x5; - -- Multicast: - elseif ip:match("^[Ff][Ff]") then - return tonumber("0x"..ip:sub(4,4)); - -- Global unicast: - else + elseif match(ip, multicast, 10) then + return ip.packed:byte(2) % 0x10; + else -- Global unicast return 0xE; end end local function label(ip) - if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then + if ip == loopback then return 0; - elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then + elseif match(ip, sixtofour, 16) then return 2; - elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then + elseif match(ip, teredo, 32) then return 5; - elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then + elseif match(ip, uniquelocal, 7) then return 13; - elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then + elseif match(ip, sitelocal, 10) then return 11; - elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then + elseif match(ip, sixbone, 16) then return 12; - elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then + elseif match(ip, defaultunicast, 96) then return 3; - elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then + elseif match(ip, ipv6mapped, 96) then return 4; else return 1; @@ -133,91 +144,67 @@ local function label(ip) end local function precedence(ip) - if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then + if ip == loopback then return 50; - elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then + elseif match(ip, sixtofour, 16) then return 30; - elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then + elseif match(ip, teredo, 32) then return 5; - elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then + elseif match(ip, uniquelocal, 7) then return 3; - elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then + elseif match(ip, sitelocal, 10) then return 1; - elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then + elseif match(ip, sixbone, 16) then return 1; - elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then + elseif match(ip, defaultunicast, 96) then return 1; - elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then + elseif match(ip, ipv6mapped, 96) then return 35; else return 40; end end -local function toV4mapped(ip) - local fields = {}; - local ret = "::ffff:"; - ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end); - ret = ret .. ("%02x"):format(fields[1]); - ret = ret .. ("%02x"):format(fields[2]); - ret = ret .. ":" - ret = ret .. ("%02x"):format(fields[3]); - ret = ret .. ("%02x"):format(fields[4]); - return new_ip(ret, "IPv6"); -end - function ip_methods:toV4mapped() if self.proto ~= "IPv4" then return nil, "No IPv4 address" end - local value = toV4mapped(self.addr); - self.toV4mapped = value; + local value = new_ip("::ffff:" .. self.normal); return value; end function ip_methods:label() - local value; if self.proto == "IPv4" then - value = label(self.toV4mapped); + return label(self.toV4mapped); else - value = label(self); + return label(self); end - self.label = value; - return value; end function ip_methods:precedence() - local value; if self.proto == "IPv4" then - value = precedence(self.toV4mapped); + return precedence(self.toV4mapped); else - value = precedence(self); + return precedence(self); end - self.precedence = value; - return value; end function ip_methods:scope() - local value; if self.proto == "IPv4" then - value = v4scope(self.addr); + return v4scope(self); else - value = v6scope(self.addr); + return v6scope(self); end - self.scope = value; - return value; end +local rfc1918_8 = new_ip("10.0.0.0"); +local rfc1918_12 = new_ip("172.16.0.0"); +local rfc1918_16 = new_ip("192.168.0.0"); +local rfc6598 = new_ip("100.64.0.0"); + function ip_methods:private() local private = self.scope ~= 0xE; if not private and self.proto == "IPv4" then - local ip = self.addr; - local fields = {}; - ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end); - if fields[1] == 127 or fields[1] == 10 or (fields[1] == 192 and fields[2] == 168) - or (fields[1] == 172 and (fields[2] >= 16 or fields[2] <= 32)) then - private = true; - end + return match(self, rfc1918_8, 8) or match(self, rfc1918_12, 12) or match(self, rfc1918_16) or match(self, rfc6598, 10); end - self.private = private; return private; end @@ -231,15 +218,26 @@ local function parse_cidr(cidr) return new_ip(cidr), bits; end -local function match(ipA, ipB, bits) - local common_bits = commonPrefixLength(ipA, ipB); - if bits and ipB.proto == "IPv4" then - common_bits = common_bits - 96; -- v6 mapped addresses always share these bits +function match(ipA, ipB, bits) + if not bits or bits >= 128 or ipB.proto == "IPv4" and bits >= 32 then + return ipA == ipB; + elseif bits < 1 then + return true; + end + if ipA.proto ~= ipB.proto then + if ipA.proto == "IPv4" then + ipA = ipA.toV4mapped; + elseif ipB.proto == "IPv4" then + ipB = ipB.toV4mapped; + bits = bits + (128 - 32); + end end - return common_bits >= (bits or 128); + return ipA.bits:sub(1, bits) == ipB.bits:sub(1, bits); end -return {new_ip = new_ip, +return { + new_ip = new_ip, commonPrefixLength = commonPrefixLength, parse_cidr = parse_cidr, - match=match}; + match = match, +}; diff --git a/util/iterators.lua b/util/iterators.lua index bd150ff2..5d16d8c1 100644 --- a/util/iterators.lua +++ b/util/iterators.lua @@ -12,8 +12,13 @@ local it = {}; local t_insert = table.insert; local select, next = select, next; -local unpack = table.unpack or unpack; --luacheck: ignore 113 -local pack = table.pack or function (...) return { n = select("#", ...), ... }; end +local unpack = table.unpack or unpack; --luacheck: ignore 113 143 +local pack = table.pack or function (...) return { n = select("#", ...), ... }; end -- luacheck: ignore 143 +local type = type; +local table, setmetatable = table, setmetatable; + +local _ENV = nil; +--luacheck: std none -- Reverse an iterator function it.reverse(f, s, var) @@ -184,4 +189,45 @@ function it.to_table(f, s, var) return t; end +local function _join_iter(j_s, j_var) + local iterators, current_idx = j_s[1], j_s[2]; + local f, s, var = unpack(iterators[current_idx], 1, 3); + if j_var ~= nil then + var = j_var; + end + local ret = pack(f(s, var)); + local var1 = ret[1]; + if var1 == nil then + -- End of this iterator, advance to next + if current_idx == #iterators then + -- No more iterators, return nil + return; + end + j_s[2] = current_idx + 1; + return _join_iter(j_s); + end + return unpack(ret, 1, ret.n); +end +local join_methods = {}; +local join_mt = { + __index = join_methods; + __call = function (t, s, var) --luacheck: ignore 212/t + return _join_iter(s, var); + end; +}; + +function join_methods:append(f, s, var) + table.insert(self, { f, s, var }); + return self, { self, 1 }; +end + +function join_methods:prepend(f, s, var) + table.insert(self, { f, s, var }, 1); + return self, { self, 1 }; +end + +function it.join(f, s, var) + return setmetatable({ {f, s, var} }, join_mt); +end + return it; diff --git a/util/jid.lua b/util/jid.lua index f402b7f4..37c48193 100644 --- a/util/jid.lua +++ b/util/jid.lua @@ -25,6 +25,7 @@ local unescapes = {}; for k,v in pairs(escapes) do unescapes[v] = k; end local _ENV = nil; +-- luacheck: std none local function split(jid) if not jid then return; end diff --git a/util/json.lua b/util/json.lua index cba54e8e..05af709a 100644 --- a/util/json.lua +++ b/util/json.lua @@ -27,9 +27,6 @@ module.null = null; local escapes = { ["\""] = "\\\"", ["\\"] = "\\\\", ["\b"] = "\\b", ["\f"] = "\\f", ["\n"] = "\\n", ["\r"] = "\\r", ["\t"] = "\\t"}; -local unescapes = { - ["\""] = "\"", ["\\"] = "\\", ["/"] = "/", - b = "\b", f = "\f", n = "\n", r = "\r", t = "\t"}; for i=0,31 do local ch = s_char(i); if not escapes[ch] then escapes[ch] = ("\\u%.4X"):format(i); end @@ -263,8 +260,9 @@ end local function _unescape_func(x) x = x:match("%x%x%x%x", 3); if x then - --if x >= 0xD800 and x <= 0xDFFF then _unescape_error = true; end -- bad surrogate pair - return codepoint_to_utf8(tonumber(x, 16)); + local codepoint = tonumber(x, 16) + if codepoint >= 0xD800 and codepoint <= 0xDFFF then _unescape_error = true; end -- bad surrogate pair + return codepoint_to_utf8(codepoint); end _unescape_error = true; end @@ -276,7 +274,7 @@ function _readstring(json, index) --if s:find("[%z-\31]") then return nil, "control char in string"; end -- FIXME handle control characters _unescape_error = nil; - --s = s:gsub("\\u[dD][89abAB]%x%x\\u[dD][cdefCDEF]%x%x", _unescape_surrogate_func); + s = s:gsub("\\u[dD][89abAB]%x%x\\u[dD][cdefCDEF]%x%x", _unescape_surrogate_func); -- FIXME handle escapes beyond BMP s = s:gsub("\\u.?.?.?.?", _unescape_func); if _unescape_error then return nil, "invalid escape"; end diff --git a/util/logger.lua b/util/logger.lua index e72b29bc..20a5cef2 100644 --- a/util/logger.lua +++ b/util/logger.lua @@ -8,8 +8,11 @@ -- luacheck: ignore 213/level local pairs = pairs; +local ipairs = ipairs; +local require = require; local _ENV = nil; +-- luacheck: std none local level_sinks = {}; @@ -67,10 +70,21 @@ local function add_level_sink(level, sink_function) end end +local function add_simple_sink(simple_sink_function, levels) + local format = require "util.format".format; + local function sink_function(name, level, msg, ...) + return simple_sink_function(name, level, format(msg, ...)); + end + for _, level in ipairs(levels or {"debug", "info", "warn", "error"}) do + add_level_sink(level, sink_function); + end +end + return { init = init; make_logger = make_logger; reset = reset; add_level_sink = add_level_sink; + add_simple_sink = add_simple_sink; new = make_logger; }; diff --git a/util/multitable.lua b/util/multitable.lua index e4321d3d..8d32ed8a 100644 --- a/util/multitable.lua +++ b/util/multitable.lua @@ -9,9 +9,10 @@ local select = select; local t_insert = table.insert; local pairs, next, type = pairs, next, type; -local unpack = table.unpack or unpack; --luacheck: ignore 113 +local unpack = table.unpack or unpack; --luacheck: ignore 113 143 local _ENV = nil; +-- luacheck: std none local function get(self, ...) local t = self.data; @@ -132,7 +133,7 @@ local function iter(self, ...) local maxdepth = select("#", ...); local stack = { self.data }; local keys = { }; - local function it(self) + local function it(self) -- luacheck: ignore 432/self local depth = #stack; local key = next(stack[depth], keys[depth]); if key == nil then -- Go up the stack diff --git a/util/openssl.lua b/util/openssl.lua index 703c6d15..32b5aea7 100644 --- a/util/openssl.lua +++ b/util/openssl.lua @@ -114,7 +114,7 @@ function ssl_config:add_xmppAddr(host) s_format("%s;%s", oid_xmppaddr, utf8string(host))); end -function ssl_config:from_prosody(hosts, config, certhosts) +function ssl_config:from_prosody(hosts, config, certhosts) -- luacheck: ignore 431/config -- TODO Decide if this should go elsewhere local found_matching_hosts = false; for i = 1, #certhosts do diff --git a/util/pluginloader.lua b/util/pluginloader.lua index 004855f0..9ab8f245 100644 --- a/util/pluginloader.lua +++ b/util/pluginloader.lua @@ -5,6 +5,7 @@ -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- +-- luacheck: ignore 113/CFG_PLUGINDIR local dir_sep, path_sep = package.config:match("^(%S+)%s(%S+)"); local plugin_dir = {}; diff --git a/util/presence.lua b/util/presence.lua index f6370354..8d1ae2d9 100644 --- a/util/presence.lua +++ b/util/presence.lua @@ -13,7 +13,6 @@ local function select_top_resources(user) local recipients = {}; for _, session in pairs(user.sessions) do -- find resource with greatest priority if session.presence then - -- TODO check active privacy list for session local p = session.priority; if p > priority then priority = p; diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua index 8ae051ae..5f0c4d12 100644 --- a/util/prosodyctl.lua +++ b/util/prosodyctl.lua @@ -24,8 +24,6 @@ local io, os = io, os; local print = print; local tonumber = tonumber; -local CFG_SOURCEDIR = _G.CFG_SOURCEDIR; - local _G = _G; local prosody = prosody; @@ -66,7 +64,10 @@ local function getline() end local function getpass() - local stty_ret = os.execute("stty -echo 2>/dev/null"); + local stty_ret, _, status_code = os.execute("stty -echo 2>/dev/null"); + if status_code then -- COMPAT w/ Lua 5.1 + stty_ret = status_code; + end if stty_ret ~= 0 then io.write("\027[08m"); -- ANSI 'hidden' text attribute end @@ -189,8 +190,8 @@ local function getpid() pidfile = config.resolve_relative_path(prosody.paths.data, pidfile); - local modules_enabled = set.new(config.get("*", "modules_disabled")); - if prosody.platform ~= "posix" or modules_enabled:contains("posix") then + local modules_disabled = set.new(config.get("*", "modules_disabled")); + if prosody.platform ~= "posix" or modules_disabled:contains("posix") then return false, "no-posix"; end @@ -228,7 +229,7 @@ local function isrunning() return true, signal.kill(pid, 0) == 0; end -local function start() +local function start(source_dir) local ok, ret = isrunning(); if not ok then return ok, ret; @@ -236,10 +237,10 @@ local function start() if ret then return false, "already-running"; end - if not CFG_SOURCEDIR then + if not source_dir then os.execute("./prosody"); else - os.execute(CFG_SOURCEDIR.."/../../bin/prosody"); + os.execute(source_dir.."/../../bin/prosody"); end return true; end diff --git a/util/pubsub.lua b/util/pubsub.lua index 1db917d8..86dd8e14 100644 --- a/util/pubsub.lua +++ b/util/pubsub.lua @@ -1,32 +1,71 @@ local events = require "util.events"; local cache = require "util.cache"; -local service = {}; -local service_mt = { __index = service }; +local service_mt = {}; -local default_config = { __index = { - itemstore = function (config) return cache.new(tonumber(config["pubsub#max_items"])) end; +local default_config = { + itemstore = function (config, _) return cache.new(config["max_items"]) end; broadcaster = function () end; + itemcheck = function () return true; end; get_affiliation = function () end; + normalize_jid = function (jid) return jid; end; capabilities = {}; -} }; -local default_node_config = { __index = { - ["pubsub#max_items"] = "20"; -} }; +}; +local default_config_mt = { __index = default_config }; + +local default_node_config = { + ["persist_items"] = false; + ["max_items"] = 20; +}; +local default_node_config_mt = { __index = default_node_config }; + +-- Storage helper functions + +local function load_node_from_store(nodestore, node_name) + local node = nodestore:get(node_name); + node.config = setmetatable(node.config or {}, default_node_config_mt); + return node; +end +local function save_node_to_store(nodestore, node) + return nodestore:set(node.name, { + name = node.name; + config = node.config; + subscribers = node.subscribers; + affiliations = node.affiliations; + }); +end + +-- Create and return a new service object local function new(config) config = config or {}; - return setmetatable({ - config = setmetatable(config, default_config); - node_defaults = setmetatable(config.node_defaults or {}, default_node_config); + + local service = setmetatable({ + config = setmetatable(config, default_config_mt); + node_defaults = setmetatable(config.node_defaults or {}, default_node_config_mt); affiliations = {}; subscriptions = {}; nodes = {}; data = {}; events = events.new(); }, service_mt); + + -- Load nodes from storage, if we have a store and it supports iterating over stored items + if config.nodestore and config.nodestore.users then + for node_name in config.nodestore:users() do + service.nodes[node_name] = load_node_from_store(config.nodestore, node_name); + service.data[node_name] = config.itemstore(service.nodes[node_name].config, node_name); + end + end + + return service; end +--- Service methods + +local service = {}; +service_mt.__index = service; + function service:jids_equal(jid1, jid2) local normalize = self.config.normalize_jid; return normalize(jid1) == normalize(jid2); @@ -36,7 +75,8 @@ function service:may(node, actor, action) if actor == true then return true; end local node_obj = self.nodes[node]; - local node_aff = node_obj and node_obj.affiliations[actor]; + local node_aff = node_obj and (node_obj.affiliations[actor] + or node_obj.affiliations[self.config.normalize_jid(actor)]); local service_aff = self.affiliations[actor] or self.config.get_affiliation(actor, node, action) or "none"; @@ -176,18 +216,6 @@ function service:remove_subscription(node, actor, jid) return true; end -function service:remove_all_subscriptions(actor, jid) - local normal_jid = self.config.normalize_jid(jid); - local subs = self.subscriptions[normal_jid] - subs = subs and subs[jid]; - if subs then - for node in pairs(subs) do - self:remove_subscription(node, true, jid); - end - end - return true; -end - function service:get_subscription(node, actor, jid) -- Access checking local cap; @@ -223,14 +251,27 @@ function service:create(node, actor, options) config = setmetatable(options or {}, {__index=self.node_defaults}); affiliations = {}; }; - self.data[node] = self.config.itemstore(self.nodes[node].config); + + if self.config.nodestore then + local ok, err = save_node_to_store(self.config.nodestore, self.nodes[node]); + if not ok then + self.nodes[node] = nil; + return ok, err; + end + end + + self.data[node] = self.config.itemstore(self.nodes[node].config, node); self.events.fire_event("node-created", { node = node, actor = actor }); - local ok, err = self:set_affiliation(node, true, actor, "owner"); - if not ok then - self.nodes[node] = nil; - self.data[node] = nil; + if actor ~= true then + local ok, err = self:set_affiliation(node, true, actor, "owner"); + if not ok then + self.nodes[node] = nil; + self.data[node] = nil; + return ok, err; + end end - return ok, err; + + return true; end function service:delete(node, actor) @@ -244,9 +285,12 @@ function service:delete(node, actor) return false, "item-not-found"; end self.nodes[node] = nil; + if self.data[node] and self.data[node].clear then + self.data[node]:clear(); + end self.data[node] = nil; self.events.fire_event("node-deleted", { node = node, actor = actor }); - self.config.broadcaster("delete", node, node_obj.subscribers); + self.config.broadcaster("delete", node, node_obj.subscribers, nil, actor, node_obj, self); return true; end @@ -267,13 +311,17 @@ function service:publish(node, actor, id, item) end node_obj = self.nodes[node]; end + if not self.config.itemcheck(item) then + return nil, "internal-server-error"; + end local node_data = self.data[node]; local ok = node_data:set(id, item); if not ok then return nil, "internal-server-error"; end + if type(ok) == "string" then id = ok; end self.events.fire_event("item-published", { node = node, actor = actor, id = id, item = item }); - self.config.broadcaster("items", node, node_obj.subscribers, item, actor); + self.config.broadcaster("items", node, node_obj.subscribers, item, actor, node_obj, self); return true; end @@ -293,7 +341,7 @@ function service:retract(node, actor, id, retract) end self.events.fire_event("item-retracted", { node = node, actor = actor, id = id }); if retract then - self.config.broadcaster("items", node, node_obj.subscribers, retract); + self.config.broadcaster("items", node, node_obj.subscribers, retract, actor, node_obj, self); end return true end @@ -308,10 +356,14 @@ function service:purge(node, actor, notify) if not node_obj then return false, "item-not-found"; end - self.data[node] = self.config.itemstore(self.nodes[node].config); + if self.data[node] and self.data[node].clear then + self.data[node]:clear() + else + self.data[node] = self.config.itemstore(self.nodes[node].config, node); + end self.events.fire_event("node-purged", { node = node, actor = actor }); if notify then - self.config.broadcaster("purge", node, node_obj.subscribers); + self.config.broadcaster("purge", node, node_obj.subscribers, nil, actor, node_obj, self); end return true end @@ -327,7 +379,11 @@ function service:get_items(node, actor, id) return false, "item-not-found"; end if id then -- Restrict results to a single specific item - return true, { id, [id] = self.data[node]:get(id) }; + local with_id = self.data[node]:get(id); + if not with_id then + return true, { }; + end + return true, { id, [id] = with_id }; else local data = {} for key, value in self.data[node]:items() do @@ -338,6 +394,15 @@ function service:get_items(node, actor, id) end end +function service:get_last_item(node, actor) + -- Access checking + if not self:may(node, actor, "get_items") then + return false, "forbidden"; + end + -- + return true, self.data[node]:tail(); +end + function service:get_nodes(actor) -- Access checking if not self:may(nil, actor, "get_nodes") then @@ -421,14 +486,14 @@ function service:set_node_config(node, actor, new_config) return false, "item-not-found"; end - for k,v in pairs(new_config) do - node_obj.config[k] = v; + if new_config["persist_items"] ~= node_obj.config["persist_items"] then + self.data[node] = self.config.itemstore(self.nodes[node].config, node); + elseif new_config["max_items"] ~= node_obj.config["max_items"] then + self.data[node]:resize(new_config["max_items"]); end - local new_data = self.config.itemstore(self.nodes[node].config); - for key, value in self.data[node]:items() do - new_data:set(key, value); - end - self.data[node] = new_data; + + node_obj.config = setmetatable(new_config, {__index=self.node_defaults}); + return true; end diff --git a/util/random.lua b/util/random.lua index b2d0000d..d8a84514 100644 --- a/util/random.lua +++ b/util/random.lua @@ -11,9 +11,6 @@ if ok then return crand; end local urandom, urandom_err = io.open("/dev/urandom", "r"); -local function seed() -end - local function bytes(n) return urandom:read(n); end @@ -25,7 +22,6 @@ if not urandom then end return { - seed = seed; bytes = bytes; _source = "/dev/urandom"; }; diff --git a/util/sasl.lua b/util/sasl.lua index 5845f34a..50851405 100644 --- a/util/sasl.lua +++ b/util/sasl.lua @@ -20,6 +20,7 @@ local assert = assert; local require = require; local _ENV = nil; +-- luacheck: std none --[[ Authentication Backend Prototypes: @@ -42,7 +43,7 @@ Example: local method = {}; method.__index = method; -local mechanisms = {}; +local registered_mechanisms = {}; local backend_mechanism = {}; local mechanism_channelbindings = {}; @@ -52,7 +53,7 @@ local function registerMechanism(name, backends, f, cb_backends) assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table."); assert(type(f) == "function", "Parameter f MUST be a function."); if cb_backends then assert(type(cb_backends) == "table"); end - mechanisms[name] = f + registered_mechanisms[name] = f if cb_backends then mechanism_channelbindings[name] = {}; for _, cb_name in ipairs(cb_backends) do @@ -70,7 +71,7 @@ local function new(realm, profile) local mechanisms = profile.mechanisms; if not mechanisms then mechanisms = {}; - for backend, f in pairs(profile) do + for backend in pairs(profile) do if backend_mechanism[backend] then for _, mechanism in ipairs(backend_mechanism[backend]) do mechanisms[mechanism] = true; @@ -128,7 +129,7 @@ end -- feed new messages to process into the library function method:process(message) --if message == "" or message == nil then return "failure", "malformed-request" end - return mechanisms[self.selected](self, message); + return registered_mechanisms[self.selected](self, message); end -- load the mechanisms diff --git a/util/sasl/anonymous.lua b/util/sasl/anonymous.lua index 6201db32..de98a5e2 100644 --- a/util/sasl/anonymous.lua +++ b/util/sasl/anonymous.lua @@ -12,9 +12,10 @@ -- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -local generate_uuid = require "util.uuid".generate; +local generate_random_id = require "util.id".medium; local _ENV = nil; +-- luacheck: std none --========================= --SASL ANONYMOUS according to RFC 4505 @@ -28,10 +29,10 @@ anonymous: end ]] -local function anonymous(self, message) +local function anonymous(self, message) -- luacheck: ignore 212/message local username; repeat - username = generate_uuid(); + username = generate_random_id():lower(); until self.profile.anonymous(self, username, self.realm); self.username = username; return "success" diff --git a/util/sasl/digest-md5.lua b/util/sasl/digest-md5.lua index 695dd2a3..7542a037 100644 --- a/util/sasl/digest-md5.lua +++ b/util/sasl/digest-md5.lua @@ -26,6 +26,7 @@ local generate_uuid = require "util.uuid".generate; local nodeprep = require "util.encodings".stringprep.nodeprep; local _ENV = nil; +-- luacheck: std none --========================= --SASL DIGEST-MD5 according to RFC 2831 diff --git a/util/sasl/external.lua b/util/sasl/external.lua index 5ba90190..ce50743e 100644 --- a/util/sasl/external.lua +++ b/util/sasl/external.lua @@ -1,6 +1,7 @@ local saslprep = require "util.encodings".stringprep.saslprep; local _ENV = nil; +-- luacheck: std none local function external(self, message) message = saslprep(message); diff --git a/util/sasl/plain.lua b/util/sasl/plain.lua index cd59b1ac..00c6bd20 100644 --- a/util/sasl/plain.lua +++ b/util/sasl/plain.lua @@ -17,6 +17,7 @@ local nodeprep = require "util.encodings".stringprep.nodeprep; local log = require "util.logger".init("sasl"); local _ENV = nil; +-- luacheck: std none -- ================================ -- SASL PLAIN according to RFC 4616 diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua index 4e20dbb9..043f328b 100644 --- a/util/sasl/scram.lua +++ b/util/sasl/scram.lua @@ -26,6 +26,7 @@ local char = string.char; local byte = string.byte; local _ENV = nil; +-- luacheck: std none --========================= --SASL SCRAM-SHA-1 according to RFC 5802 @@ -46,7 +47,18 @@ Supported Channel Binding Backends local default_i = 4096 -local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;}; +local xor_map = { + 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,1,0,3,2,5,4,7,6,9,8,11,10, + 13,12,15,14,2,3,0,1,6,7,4,5,10,11,8,9,14,15,12,13,3,2,1,0,7,6,5, + 4,11,10,9,8,15,14,13,12,4,5,6,7,0,1,2,3,12,13,14,15,8,9,10,11,5, + 4,7,6,1,0,3,2,13,12,15,14,9,8,11,10,6,7,4,5,2,3,0,1,14,15,12,13, + 10,11,8,9,7,6,5,4,3,2,1,0,15,14,13,12,11,10,9,8,8,9,10,11,12,13, + 14,15,0,1,2,3,4,5,6,7,9,8,11,10,13,12,15,14,1,0,3,2,5,4,7,6,10, + 11,8,9,14,15,12,13,2,3,0,1,6,7,4,5,11,10,9,8,15,14,13,12,3,2,1, + 0,7,6,5,4,12,13,14,15,8,9,10,11,4,5,6,7,0,1,2,3,13,12,15,14,9,8, + 11,10,5,4,7,6,1,0,3,2,14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1,15, + 14,13,12,11,10,9,8,7,6,5,4,3,2,1,0, +}; local result = {}; local function binaryXOR( a, b ) @@ -148,7 +160,7 @@ local function scram_gen(hash_name, H_f, HMAC_f) end self.username = username; - -- retreive credentials + -- retrieve credentials local stored_key, server_key, salt, iteration_count; if self.profile.plain then local password, status = self.profile.plain(self, username, self.realm) @@ -237,10 +249,14 @@ end local function init(registerMechanism) local function registerSCRAMMechanism(hash_name, hash, hmac_hash) - registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash)); + registerMechanism("SCRAM-"..hash_name, + {"plain", "scram_"..(hashprep(hash_name))}, + scram_gen(hash_name:lower(), hash, hmac_hash)); -- register channel binding equivalent - registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"}); + registerMechanism("SCRAM-"..hash_name.."-PLUS", + {"plain", "scram_"..(hashprep(hash_name))}, + scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"}); end registerSCRAMMechanism("SHA-1", sha1, hmac_sha1); diff --git a/util/sasl_cyrus.lua b/util/sasl_cyrus.lua index 4e9a4af5..a6bd0628 100644 --- a/util/sasl_cyrus.lua +++ b/util/sasl_cyrus.lua @@ -61,6 +61,7 @@ local sasl_errstring = { setmetatable(sasl_errstring, { __index = function() return "undefined error!" end }); local _ENV = nil; +-- luacheck: std none local method = {}; method.__index = method; diff --git a/util/serialization.lua b/util/serialization.lua index 206f5fbb..54c8110f 100644 --- a/util/serialization.lua +++ b/util/serialization.lua @@ -21,6 +21,7 @@ local log = require "util.logger".init("serialization"); local envload = require"util.envload".envload; local _ENV = nil; +-- luacheck: std none local indent = function(i) return string_rep("\t", i); diff --git a/util/set.lua b/util/set.lua index c136a522..a4f20138 100644 --- a/util/set.lua +++ b/util/set.lua @@ -11,8 +11,9 @@ local ipairs, pairs, setmetatable, next, tostring = local t_concat = table.concat; local _ENV = nil; +-- luacheck: std none -local set_mt = {}; +local set_mt = { __name = "set" }; function set_mt.__call(set, _, k) return next(set._items, k); end diff --git a/util/sql.lua b/util/sql.lua index d964025e..67a5d683 100644 --- a/util/sql.lua +++ b/util/sql.lua @@ -1,11 +1,10 @@ local setmetatable, getmetatable = setmetatable, getmetatable; -local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113 -local tonumber, tostring = tonumber, tostring; +local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113 143 +local tostring = tostring; local type = type; local assert, pcall, xpcall, debug_traceback = assert, pcall, xpcall, debug.traceback; local t_concat = table.concat; -local s_char = string.char; local log = require "util.logger".init("sql"); local DBI = require "DBI"; @@ -15,6 +14,7 @@ DBI.Drivers(); local build_url = require "socket.url".build; local _ENV = nil; +-- luacheck: std none local column_mt = {}; local table_mt = {}; @@ -58,9 +58,6 @@ table_mt.__index = {}; function table_mt.__index:create(engine) return engine:_create_table(self); end -function table_mt:__call(...) - -- TODO -end function column_mt:__tostring() return 'Column{ name="'..self.name..'", type="'..self.type..'" }' end @@ -71,31 +68,6 @@ function index_mt:__tostring() -- return 'Index{ name="'..self.name..'", type="'..self.type..'" }' end -local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end -local function parse_url(url) - local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)"); - assert(scheme, "Invalid URL format"); - local username, password, host, port; - local authpart, hostpart = secondpart:match("([^@]+)@([^@+])"); - if not authpart then hostpart = secondpart; end - if authpart then - username, password = authpart:match("([^:]*):(.*)"); - username = username or authpart; - password = password and urldecode(password); - end - if hostpart then - host, port = hostpart:match("([^:]*):(.*)"); - host = host or hostpart; - port = port and assert(tonumber(port), "Invalid URL format"); - end - return { - scheme = scheme:lower(); - username = username; password = password; - host = host; port = port; - database = #database > 0 and database or nil; - }; -end - local engine = {}; function engine:connect() if self.conn then return true; end @@ -123,7 +95,7 @@ function engine:connect() end return true; end -function engine:onconnect() +function engine:onconnect() -- luacheck: ignore 212/self -- Override from create_engine() end @@ -148,6 +120,7 @@ function engine:execute(sql, ...) prepared[sql] = stmt; end + -- luacheck: ignore 411/success local success, err = stmt:execute(...); if not success then return success, err; end return stmt; @@ -161,14 +134,14 @@ local result_mt = { __index = { local function debugquery(where, sql, ...) local i = 0; local a = {...} sql = sql:gsub("\n?\t+", " "); - log("debug", "[%s] %s", where, sql:gsub("%?", function () + log("debug", "[%s] %s", where, (sql:gsub("%?", function () i = i + 1; local v = a[i]; if type(v) == "string" then v = ("'%s'"):format(v:gsub("'", "''")); end return tostring(v); - end)); + end))); end function engine:execute_query(sql, ...) @@ -335,7 +308,12 @@ function engine:set_encoding() -- to UTF-8 local charset = "utf8"; if driver == "MySQL" then self:transaction(function() - for row in self:select"SELECT \"CHARACTER_SET_NAME\" FROM \"information_schema\".\"CHARACTER_SETS\" WHERE \"CHARACTER_SET_NAME\" LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;" do + for row in self:select[[ + SELECT "CHARACTER_SET_NAME" + FROM "information_schema"."CHARACTER_SETS" + WHERE "CHARACTER_SET_NAME" LIKE 'utf8%' + ORDER BY MAXLEN DESC LIMIT 1; + ]] do charset = row and row[1] or charset; end end); @@ -379,7 +357,7 @@ local function db2uri(params) }; end -local function create_engine(self, params, onconnect) +local function create_engine(_, params, onconnect) return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt); end diff --git a/util/sslconfig.lua b/util/sslconfig.lua index 4c4e1d48..5c685f7d 100644 --- a/util/sslconfig.lua +++ b/util/sslconfig.lua @@ -8,6 +8,7 @@ local t_insert = table.insert; local setmetatable = setmetatable; local _ENV = nil; +-- luacheck: std none local handlers = { }; local finalisers = { }; diff --git a/util/stanza.lua b/util/stanza.lua index 2191fa8e..4f91d5e9 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -7,6 +7,7 @@ -- +local error = error; local t_insert = table.insert; local t_remove = table.remove; local t_concat = table.concat; @@ -23,6 +24,8 @@ local s_sub = string.sub; local s_find = string.find; local os = os; +local valid_utf8 = require "util.encodings".utf8.valid; + local do_pretty_printing = not os.getenv("WINDIR"); local getstyle, getstring; if do_pretty_printing then @@ -37,12 +40,52 @@ end local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas"; local _ENV = nil; +-- luacheck: std none -local stanza_mt = { __type = "stanza" }; +local stanza_mt = { __name = "stanza" }; stanza_mt.__index = stanza_mt; -local function new_stanza(name, attr) - local stanza = { name = name, attr = attr or {}, tags = {} }; +local function check_name(name, name_type) + if type(name) ~= "string" then + error("invalid "..name_type.." name: expected string, got "..type(name)); + elseif #name == 0 then + error("invalid "..name_type.." name: empty string"); + elseif s_find(name, "[<>& '\"]") then + error("invalid "..name_type.." name: contains invalid characters"); + elseif not valid_utf8(name) then + error("invalid "..name_type.." name: contains invalid utf8"); + end +end + +local function check_text(text, text_type) + if type(text) ~= "string" then + error("invalid "..text_type.." value: expected string, got "..type(text)); + elseif not valid_utf8(text) then + error("invalid "..text_type.." value: contains invalid utf8"); + end +end + +local function check_attr(attr) + if attr ~= nil then + if type(attr) ~= "table" then + error("invalid attributes, expected table got "..type(attr)); + end + for k, v in pairs(attr) do + check_name(k, "attribute"); + check_text(v, "attribute"); + if type(v) ~= "string" then + error("invalid attribute value for '"..k.."': expected string, got "..type(v)); + elseif not valid_utf8(v) then + error("invalid attribute value for '"..k.."': contains invalid utf8"); + end + end + end +end + +local function new_stanza(name, attr, namespaces) + check_name(name, "tag"); + check_attr(attr); + local stanza = { name = name, attr = attr or {}, namespaces = namespaces, tags = {} }; return setmetatable(stanza, stanza_mt); end @@ -58,8 +101,12 @@ function stanza_mt:body(text, attr) return self:tag("body", attr):text(text); end -function stanza_mt:tag(name, attrs) - local s = new_stanza(name, attrs); +function stanza_mt:text_tag(name, text, attr, namespaces) + return self:tag(name, attr, namespaces):text(text):up(); +end + +function stanza_mt:tag(name, attr, namespaces) + local s = new_stanza(name, attr, namespaces); local last_add = self.last_add; if not last_add then last_add = {}; self.last_add = last_add; end (last_add[#last_add] or self):add_direct_child(s); @@ -68,8 +115,11 @@ function stanza_mt:tag(name, attrs) end function stanza_mt:text(text) - local last_add = self.last_add; - (last_add and last_add[#last_add] or self):add_direct_child(text); + if text ~= nil and text ~= "" then + check_text(text, "text"); + local last_add = self.last_add; + (last_add and last_add[#last_add] or self):add_direct_child(text); + end return self; end @@ -337,7 +387,12 @@ end local function clone(stanza) local attr, tags = {}, {}; for k,v in pairs(stanza.attr) do attr[k] = v; end - local new = { name = stanza.name, attr = attr, tags = tags }; + local old_namespaces, namespaces = stanza.namespaces; + if old_namespaces then + namespaces = {}; + for k,v in pairs(old_namespaces) do namespaces[k] = v; end + end + local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags }; for i=1,#stanza do local child = stanza[i]; if child.name then @@ -362,7 +417,13 @@ local function iq(attr) end local function reply(orig) - return new_stanza(orig.name, orig.attr and { to = orig.attr.from, from = orig.attr.to, id = orig.attr.id, type = ((orig.name == "iq" and "result") or orig.attr.type) }); + return new_stanza(orig.name, + orig.attr and { + to = orig.attr.from, + from = orig.attr.to, + id = orig.attr.id, + type = ((orig.name == "iq" and "result") or orig.attr.type) + }); end local xmpp_stanzas_attr = { xmlns = xmlns_stanzas }; diff --git a/util/startup.lua b/util/startup.lua new file mode 100644 index 00000000..451a0587 --- /dev/null +++ b/util/startup.lua @@ -0,0 +1,550 @@ +-- Ignore the CFG_* variables +-- luacheck: ignore 113/CFG_CONFIGDIR 113/CFG_SOURCEDIR 113/CFG_DATADIR 113/CFG_PLUGINDIR +local startup = {}; + +local prosody = { events = require "util.events".new() }; +local logger = require "util.logger"; +local log = logger.init("startup"); + +local config = require "core.configmanager"; + +local dependencies = require "util.dependencies"; + +local original_logging_config; + +function startup.read_config() + local filenames = {}; + + local filename; + if arg[1] == "--config" and arg[2] then + table.insert(filenames, arg[2]); + if CFG_CONFIGDIR then + table.insert(filenames, CFG_CONFIGDIR.."/"..arg[2]); + end + table.remove(arg, 1); table.remove(arg, 1); + elseif os.getenv("PROSODY_CONFIG") then -- Passed by prosodyctl + table.insert(filenames, os.getenv("PROSODY_CONFIG")); + else + table.insert(filenames, (CFG_CONFIGDIR or ".").."/prosody.cfg.lua"); + end + for _,_filename in ipairs(filenames) do + filename = _filename; + local file = io.open(filename); + if file then + file:close(); + prosody.config_file = filename; + CFG_CONFIGDIR = filename:match("^(.*)[\\/][^\\/]*$"); -- luacheck: ignore 111 + break; + end + end + prosody.config_file = filename + local ok, level, err = config.load(filename); + if not ok then + print("\n"); + print("**************************"); + if level == "parser" then + print("A problem occurred while reading the config file "..filename); + print(""); + local err_line, err_message = tostring(err):match("%[string .-%]:(%d*): (.*)"); + if err:match("chunk has too many syntax levels$") then + print("An Include statement in a config file is including an already-included"); + print("file and causing an infinite loop. An Include statement in a config file is..."); + else + print("Error"..(err_line and (" on line "..err_line) or "")..": "..(err_message or tostring(err))); + end + print(""); + elseif level == "file" then + print("Prosody was unable to find the configuration file."); + print("We looked for: "..filename); + print("A sample config file is included in the Prosody download called prosody.cfg.lua.dist"); + print("Copy or rename it to prosody.cfg.lua and edit as necessary."); + end + print("More help on configuring Prosody can be found at https://prosody.im/doc/configure"); + print("Good luck!"); + print("**************************"); + print(""); + os.exit(1); + end +end + +function startup.check_dependencies() + if not dependencies.check_dependencies() then + os.exit(1); + end +end + +-- luacheck: globals socket server + +function startup.load_libraries() + -- Load socket framework + -- luacheck: ignore 111/server 111/socket + socket = require "socket"; + server = require "net.server" +end + +function startup.init_logging() + -- Initialize logging + local loggingmanager = require "core.loggingmanager" + loggingmanager.reload_logging(); + prosody.events.add_handler("reopen-log-files", function () + loggingmanager.reload_logging(); + prosody.events.fire_event("logging-reloaded"); + end); +end + +function startup.log_dependency_warnings() + dependencies.log_warnings(); +end + +function startup.sanity_check() + for host, host_config in pairs(config.getconfig()) do + if host ~= "*" + and host_config.enabled ~= false + and not host_config.component_module then + return; + end + end + log("error", "No enabled VirtualHost entries found in the config file."); + log("error", "At least one active host is required for Prosody to function. Exiting..."); + os.exit(1); +end + +function startup.sandbox_require() + -- Replace require() with one that doesn't pollute _G, required + -- for neat sandboxing of modules + -- luacheck: ignore 113/getfenv 111/require + local _realG = _G; + local _real_require = require; + local getfenv = getfenv or function (f) + -- FIXME: This is a hack to replace getfenv() in Lua 5.2 + local name, env = debug.getupvalue(debug.getinfo(f or 1).func, 1); + if name == "_ENV" then + return env; + end + end + function require(...) -- luacheck: ignore 121 + local curr_env = getfenv(2); + local curr_env_mt = getmetatable(curr_env); + local _realG_mt = getmetatable(_realG); + if curr_env_mt and curr_env_mt.__index and not curr_env_mt.__newindex and _realG_mt then + local old_newindex, old_index; + old_newindex, _realG_mt.__newindex = _realG_mt.__newindex, curr_env; + old_index, _realG_mt.__index = _realG_mt.__index, function (_G, k) -- luacheck: ignore 212/_G + return rawget(curr_env, k); + end; + local ret = _real_require(...); + _realG_mt.__newindex = old_newindex; + _realG_mt.__index = old_index; + return ret; + end + return _real_require(...); + end +end + +function startup.set_function_metatable() + local mt = {}; + function mt.__index(f, upvalue) + local i, name, value = 0; + repeat + i = i + 1; + name, value = debug.getupvalue(f, i); + until name == upvalue or name == nil; + return value; + end + function mt.__newindex(f, upvalue, value) + local i, name = 0; + repeat + i = i + 1; + name = debug.getupvalue(f, i); + until name == upvalue or name == nil; + if name then + debug.setupvalue(f, i, value); + end + end + function mt.__tostring(f) + local info = debug.getinfo(f); + return ("function(%s:%d)"):format(info.short_src:match("[^\\/]*$"), info.linedefined); + end + debug.setmetatable(function() end, mt); +end + +function startup.detect_platform() + prosody.platform = "unknown"; + if os.getenv("WINDIR") then + prosody.platform = "windows"; + elseif package.config:sub(1,1) == "/" then + prosody.platform = "posix"; + end +end + +function startup.detect_installed() + prosody.installed = nil; + if CFG_SOURCEDIR and (prosody.platform == "windows" or CFG_SOURCEDIR:match("^/")) then + prosody.installed = true; + end +end + +function startup.init_global_state() + -- luacheck: ignore 121 + prosody.bare_sessions = {}; + prosody.full_sessions = {}; + prosody.hosts = {}; + + -- COMPAT: These globals are deprecated + -- luacheck: ignore 111/bare_sessions 111/full_sessions 111/hosts + bare_sessions = prosody.bare_sessions; + full_sessions = prosody.full_sessions; + hosts = prosody.hosts; + + prosody.paths = { source = CFG_SOURCEDIR, config = CFG_CONFIGDIR or ".", + plugins = CFG_PLUGINDIR or "plugins", data = "data" }; + + prosody.arg = _G.arg; + + _G.log = logger.init("general"); + prosody.log = logger.init("general"); + + startup.detect_platform(); + startup.detect_installed(); + _G.prosody = prosody; +end + +function startup.setup_datadir() + prosody.paths.data = config.get("*", "data_path") or CFG_DATADIR or "data"; +end + +function startup.setup_plugindir() + local custom_plugin_paths = config.get("*", "plugin_paths"); + if custom_plugin_paths then + local path_sep = package.config:sub(3,3); + -- path1;path2;path3;defaultpath... + -- luacheck: ignore 111 + CFG_PLUGINDIR = table.concat(custom_plugin_paths, path_sep)..path_sep..(CFG_PLUGINDIR or "plugins"); + prosody.paths.plugins = CFG_PLUGINDIR; + end +end + +function startup.chdir() + if prosody.installed then + -- Change working directory to data path. + require "lfs".chdir(prosody.paths.data); + end +end + +function startup.add_global_prosody_functions() + -- Function to reload the config file + function prosody.reload_config() + log("info", "Reloading configuration file"); + prosody.events.fire_event("reloading-config"); + local ok, level, err = config.load(prosody.config_file); + if not ok then + if level == "parser" then + log("error", "There was an error parsing the configuration file: %s", tostring(err)); + elseif level == "file" then + log("error", "Couldn't read the config file when trying to reload: %s", tostring(err)); + end + else + prosody.events.fire_event("config-reloaded", { + filename = prosody.config_file, + config = config.getconfig(), + }); + end + return ok, (err and tostring(level)..": "..tostring(err)) or nil; + end + + -- Function to reopen logfiles + function prosody.reopen_logfiles() + log("info", "Re-opening log files"); + prosody.events.fire_event("reopen-log-files"); + end + + -- Function to initiate prosody shutdown + function prosody.shutdown(reason, code) + log("info", "Shutting down: %s", reason or "unknown reason"); + prosody.shutdown_reason = reason; + prosody.shutdown_code = code; + prosody.events.fire_event("server-stopping", { + reason = reason; + code = code; + }); + server.setquitting(true); + end +end + +function startup.load_secondary_libraries() + --- Load and initialise core modules + require "util.import" + require "util.xmppstream" + require "core.stanza_router" + require "core.statsmanager" + require "core.hostmanager" + require "core.portmanager" + require "core.modulemanager" + require "core.usermanager" + require "core.rostermanager" + require "core.sessionmanager" + package.loaded['core.componentmanager'] = setmetatable({},{__index=function() + log("warn", "componentmanager is deprecated: %s", debug.traceback():match("\n[^\n]*\n[ \t]*([^\n]*)")); + return function() end + end}); + + require "util.array" + require "util.datetime" + require "util.iterators" + require "util.timer" + require "util.helpers" + + pcall(require, "util.signal") -- Not on Windows + + -- Commented to protect us from + -- the second kind of people + --[[ + pcall(require, "remdebug.engine"); + if remdebug then remdebug.engine.start() end + ]] + + require "util.stanza" + require "util.jid" +end + +function startup.init_http_client() + local http = require "net.http" + local config_ssl = config.get("*", "ssl") or {} + local https_client = config.get("*", "client_https_ssl") + http.default.options.sslctx = require "core.certmanager".create_context("client_https port 0", "client", + { capath = config_ssl.capath, cafile = config_ssl.cafile, verify = "peer", }, https_client); +end + +function startup.init_data_store() + require "core.storagemanager"; +end + +function startup.prepare_to_start() + log("info", "Prosody is using the %s backend for connection handling", server.get_backend()); + -- Signal to modules that we are ready to start + prosody.events.fire_event("server-starting"); + prosody.start_time = os.time(); +end + +function startup.init_global_protection() + -- Catch global accesses + -- luacheck: ignore 212/t + local locked_globals_mt = { + __index = function (t, k) log("warn", "%s", debug.traceback("Attempt to read a non-existent global '"..tostring(k).."'", 2)); end; + __newindex = function (t, k, v) error("Attempt to set a global: "..tostring(k).." = "..tostring(v), 2); end; + }; + + function prosody.unlock_globals() + setmetatable(_G, nil); + end + + function prosody.lock_globals() + setmetatable(_G, locked_globals_mt); + end + + -- And lock now... + prosody.lock_globals(); +end + +function startup.read_version() + -- Try to determine version + local version_file = io.open((CFG_SOURCEDIR or ".").."/prosody.version"); + prosody.version = "unknown"; + if version_file then + prosody.version = version_file:read("*a"):gsub("%s*$", ""); + version_file:close(); + if #prosody.version == 12 and prosody.version:match("^[a-f0-9]+$") then + prosody.version = "hg:"..prosody.version; + end + else + local hg = require"util.mercurial"; + local hgid = hg.check_id(CFG_SOURCEDIR or "."); + if hgid then prosody.version = "hg:" .. hgid; end + end +end + +function startup.log_greeting() + log("info", "Hello and welcome to Prosody version %s", prosody.version); +end + +function startup.notify_started() + prosody.events.fire_event("server-started"); +end + +-- Override logging config (used by prosodyctl) +function startup.force_console_logging() + original_logging_config = config.get("*", "log"); + config.set("*", "log", { { levels = { min = os.getenv("PROSODYCTL_LOG_LEVEL") or "info" }, to = "console" } }); +end + +function startup.switch_user() + -- Switch away from root and into the prosody user -- + -- NOTE: This function is only used by prosodyctl. + -- The prosody process is built with the assumption that + -- it is already started as the appropriate user. + + local want_pposix_version = "0.4.0"; + local have_pposix, pposix = pcall(require, "util.pposix"); + + if have_pposix and pposix then + if pposix._VERSION ~= want_pposix_version then + print(string.format("Unknown version (%s) of binary pposix module, expected %s", + tostring(pposix._VERSION), want_pposix_version)); + os.exit(1); + end + prosody.current_uid = pposix.getuid(); + local arg_root = arg[1] == "--root"; + if arg_root then table.remove(arg, 1); end + if prosody.current_uid == 0 and config.get("*", "run_as_root") ~= true and not arg_root then + -- We haz root! + local desired_user = config.get("*", "prosody_user") or "prosody"; + local desired_group = config.get("*", "prosody_group") or desired_user; + local ok, err = pposix.setgid(desired_group); + if ok then + ok, err = pposix.initgroups(desired_user); + end + if ok then + ok, err = pposix.setuid(desired_user); + if ok then + -- Yay! + prosody.switched_user = true; + end + end + if not prosody.switched_user then + -- Boo! + print("Warning: Couldn't switch to Prosody user/group '"..tostring(desired_user).."'/'"..tostring(desired_group).."': "..tostring(err)); + else + -- Make sure the Prosody user can read the config + local conf, err, errno = io.open(prosody.config_file); + if conf then + conf:close(); + else + print("The config file is not readable by the '"..desired_user.."' user."); + print("Prosody will not be able to read it."); + print("Error was "..err); + os.exit(1); + end + end + end + + -- Set our umask to protect data files + pposix.umask(config.get("*", "umask") or "027"); + pposix.setenv("HOME", prosody.paths.data); + pposix.setenv("PROSODY_CONFIG", prosody.config_file); + else + print("Error: Unable to load pposix module. Check that Prosody is installed correctly.") + print("For more help send the below error to us through https://prosody.im/discuss"); + print(tostring(pposix)) + os.exit(1); + end +end + +function startup.check_unwriteable() + local function test_writeable(filename) + local f, err = io.open(filename, "a"); + if not f then + return false, err; + end + f:close(); + return true; + end + + local unwriteable_files = {}; + if type(original_logging_config) == "string" and original_logging_config:sub(1,1) ~= "*" then + local ok, err = test_writeable(original_logging_config); + if not ok then + table.insert(unwriteable_files, err); + end + elseif type(original_logging_config) == "table" then + for _, rule in ipairs(original_logging_config) do + if rule.filename then + local ok, err = test_writeable(rule.filename); + if not ok then + table.insert(unwriteable_files, err); + end + end + end + end + + if #unwriteable_files > 0 then + print("One of more of the Prosody log files are not"); + print("writeable, please correct the errors and try"); + print("starting prosodyctl again."); + print(""); + for _, err in ipairs(unwriteable_files) do + print(err); + end + print(""); + os.exit(1); + end +end + +function startup.make_host(hostname) + return { + type = "local", + events = prosody.events, + modules = {}, + sessions = {}, + users = require "core.usermanager".new_null_provider(hostname) + }; +end + +function startup.make_dummy_hosts() + -- When running under prosodyctl, we don't want to + -- fully initialize the server, so we populate prosody.hosts + -- with just enough things for most code to work correctly + -- luacheck: ignore 122/hosts + prosody.core_post_stanza = function () end; -- TODO: mod_router! + + for hostname in pairs(config.getconfig()) do + prosody.hosts[hostname] = startup.make_host(hostname); + end +end + +-- prosodyctl only +function startup.prosodyctl() + startup.init_global_state(); + startup.read_config(); + startup.force_console_logging(); + startup.init_logging(); + startup.setup_plugindir(); + startup.setup_datadir(); + startup.chdir(); + startup.read_version(); + startup.switch_user(); + startup.check_dependencies(); + startup.log_dependency_warnings(); + startup.check_unwriteable(); + startup.load_libraries(); + startup.init_http_client(); + startup.make_dummy_hosts(); +end + +function startup.prosody() + -- These actions are in a strict order, as many depend on + -- previous steps to have already been performed + startup.init_global_state(); + startup.read_config(); + startup.init_logging(); + startup.sanity_check(); + startup.sandbox_require(); + startup.set_function_metatable(); + startup.check_dependencies(); + startup.init_logging(); + startup.load_libraries(); + startup.setup_plugindir(); + startup.setup_datadir(); + startup.chdir(); + startup.add_global_prosody_functions(); + startup.read_version(); + startup.log_greeting(); + startup.log_dependency_warnings(); + startup.load_secondary_libraries(); + startup.init_http_client(); + startup.init_data_store(); + startup.init_global_protection(); + startup.prepare_to_start(); + startup.notify_started(); +end + +return startup; diff --git a/util/template.lua b/util/template.lua index 04ebb93d..c11037c5 100644 --- a/util/template.lua +++ b/util/template.lua @@ -4,12 +4,13 @@ local setmetatable = setmetatable; local pairs = pairs; local ipairs = ipairs; local error = error; -local loadstring = loadstring; +local envload = require "util.envload".envload; local debug = debug; local t_remove = table.remove; local parse_xml = require "util.xml".parse; local _ENV = nil; +-- luacheck: std none local function trim_xml(stanza) for i=#stanza,1,-1 do @@ -72,7 +73,7 @@ local function create_cloner(stanza, chunkname) src = src.."local _"..i.."="..lookup[i]..";"; end src = src.."return "..name..";end"; - local f,err = loadstring(src, chunkname); + local f,err = envload(src, chunkname); if not f then error(err); end return f(setmetatable, stanza_mt); end diff --git a/util/termcolours.lua b/util/termcolours.lua index 23c9156b..829d84af 100644 --- a/util/termcolours.lua +++ b/util/termcolours.lua @@ -26,6 +26,7 @@ end local orig_color = windows and windows.get_consolecolor and windows.get_consolecolor(); local _ENV = nil; +-- luacheck: std none local stylemap = { reset = 0; bright = 1, dim = 2, underscore = 4, blink = 5, reverse = 7, hidden = 8; diff --git a/util/throttle.lua b/util/throttle.lua index 1012f78a..d2036e9e 100644 --- a/util/throttle.lua +++ b/util/throttle.lua @@ -3,6 +3,7 @@ local gettime = require "util.time".now local setmetatable = setmetatable; local _ENV = nil; +-- luacheck: std none local throttle = {}; local throttle_mt = { __index = throttle }; diff --git a/util/timer.lua b/util/timer.lua index 7e2e9414..424d44fa 100644 --- a/util/timer.lua +++ b/util/timer.lua @@ -6,78 +6,114 @@ -- COPYING file in the source package for more information. -- +local indexedbheap = require "util.indexedbheap"; +local log = require "util.logger".init("timer"); local server = require "net.server"; -local math_min = math.min -local math_huge = math.huge local get_time = require "util.time".now -local t_insert = table.insert; -local pairs = pairs; +local async = require "util.async"; local type = type; - -local data = {}; -local new_data = {}; +local debug_traceback = debug.traceback; +local tostring = tostring; +local xpcall = xpcall; +local math_max = math.max; local _ENV = nil; +-- luacheck: std none -local _add_task; -if not server.event then - function _add_task(delay, callback) - local current_time = get_time(); - delay = delay + current_time; - if delay >= current_time then - t_insert(new_data, {delay, callback}); - else - local r = callback(current_time); - if r and type(r) == "number" then - return _add_task(r, callback); - end +local _add_task = server.add_task; + +local _server_timer; +local _active_timers = 0; +local h = indexedbheap.create(); +local params = {}; +local next_time = nil; +local _id, _callback, _now, _param; +local function _call() return _callback(_now, _id, _param); end +local function _traceback_handler(err) log("error", "Traceback[timer]: %s", debug_traceback(tostring(err), 2)); end +local function _on_timer(now) + local peek; + while true do + peek = h:peek(); + if peek == nil or peek > now then break; end + local _; + _, _callback, _id = h:pop(); + _now = now; + _param = params[_id]; + params[_id] = nil; + --item(now, id, _param); -- FIXME pcall + local success, err = xpcall(_call, _traceback_handler); + if success and type(err) == "number" then + h:insert(_callback, err + now, _id); -- re-add + params[_id] = _param; end end - server._addtimer(function() - local current_time = get_time(); - if #new_data > 0 then - for _, d in pairs(new_data) do - t_insert(data, d); - end - new_data = {}; - end + if peek ~= nil and _active_timers > 1 and peek == next_time then + -- Another instance of _on_timer already set next_time to the same value, + -- so it should be safe to not renew this timer event + peek = nil; + else + next_time = peek; + end - local next_time = math_huge; - for i, d in pairs(data) do - local t, callback = d[1], d[2]; - if t <= current_time then - data[i] = nil; - local r = callback(current_time); - if type(r) == "number" then - _add_task(r, callback); - next_time = math_min(next_time, r); - end - else - next_time = math_min(next_time, t - current_time); - end - end - return next_time; - end); -else - local event = server.event; - local event_base = server.event_base; - local EVENT_LEAVE = (event.core and event.core.LEAVE) or -1; + if peek then + -- peek is the time of the next event + return peek - now; + end + _active_timers = _active_timers - 1; +end +local function add_task(delay, callback, param) + local current_time = get_time(); + local event_time = current_time + delay; - function _add_task(delay, callback) - local event_handle; - event_handle = event_base:addevent(nil, 0, function () - local ret = callback(get_time()); - if ret then - return 0, ret; - elseif event_handle then - return EVENT_LEAVE; - end + local id = h:insert(callback, event_time); + params[id] = param; + if next_time == nil or event_time < next_time then + next_time = event_time; + if _server_timer then + _server_timer:close(); + _server_timer = nil; + else + _active_timers = _active_timers + 1; end - , delay); + _server_timer = _add_task(next_time - current_time, _on_timer); end + return id; +end +local function stop(id) + params[id] = nil; + local result, item, result_sync = h:remove(id); + local peek = h:peek(); + if peek ~= next_time and _server_timer then + next_time = peek; + _server_timer:close(); + if next_time ~= nil then + _server_timer = _add_task(math_max(next_time - get_time(), 0), _on_timer); + end + end + return result, item, result_sync; +end +local function reschedule(id, delay) + local current_time = get_time(); + local event_time = current_time + delay; + h:reprioritize(id, delay); + if next_time == nil or event_time < next_time then + next_time = event_time; + _add_task(next_time - current_time, _on_timer); + end + return id; +end + +local function sleep(s) + local wait, done = async.waiter(); + add_task(s, done); + wait(); end return { - add_task = _add_task; + add_task = add_task; + stop = stop; + reschedule = reschedule; + sleep = sleep; }; + diff --git a/util/vcard.lua b/util/vcard.lua new file mode 100644 index 00000000..51758c41 --- /dev/null +++ b/util/vcard.lua @@ -0,0 +1,572 @@ +-- Copyright (C) 2011-2014 Kim Alvefur +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +-- TODO +-- Fix folding. + +local st = require "util.stanza"; +local t_insert, t_concat = table.insert, table.concat; +local type = type; +local pairs, ipairs = pairs, ipairs; + +local from_text, to_text, from_xep54, to_xep54; + +local line_sep = "\n"; + +local vCard_dtd; -- See end of file +local vCard4_dtd; + +local function vCard_esc(s) + return s:gsub("[,:;\\]", "\\%1"):gsub("\n","\\n"); +end + +local function vCard_unesc(s) + return s:gsub("\\?[\\nt:;,]", { + ["\\\\"] = "\\", + ["\\n"] = "\n", + ["\\r"] = "\r", + ["\\t"] = "\t", + ["\\:"] = ":", -- FIXME Shouldn't need to espace : in values, just params + ["\\;"] = ";", + ["\\,"] = ",", + [":"] = "\29", + [";"] = "\30", + [","] = "\31", + }); +end + +local function item_to_xep54(item) + local t = st.stanza(item.name, { xmlns = "vcard-temp" }); + + local prop_def = vCard_dtd[item.name]; + if prop_def == "text" then + t:text(item[1]); + elseif type(prop_def) == "table" then + if prop_def.types and item.TYPE then + if type(item.TYPE) == "table" then + for _,v in pairs(prop_def.types) do + for _,typ in pairs(item.TYPE) do + if typ:upper() == v then + t:tag(v):up(); + break; + end + end + end + else + t:tag(item.TYPE:upper()):up(); + end + end + + if prop_def.props then + for _,v in pairs(prop_def.props) do + if item[v] then + t:tag(v):up(); + end + end + end + + if prop_def.value then + t:tag(prop_def.value):text(item[1]):up(); + elseif prop_def.values then + local prop_def_values = prop_def.values; + local repeat_last = prop_def_values.behaviour == "repeat-last" and prop_def_values[#prop_def_values]; + for i=1,#item do + t:tag(prop_def.values[i] or repeat_last):text(item[i]):up(); + end + end + end + + return t; +end + +local function vcard_to_xep54(vCard) + local t = st.stanza("vCard", { xmlns = "vcard-temp" }); + for i=1,#vCard do + t:add_child(item_to_xep54(vCard[i])); + end + return t; +end + +function to_xep54(vCards) + if not vCards[1] or vCards[1].name then + return vcard_to_xep54(vCards) + else + local t = st.stanza("xCard", { xmlns = "vcard-temp" }); + for i=1,#vCards do + t:add_child(vcard_to_xep54(vCards[i])); + end + return t; + end +end + +function from_text(data) + data = data -- unfold and remove empty lines + :gsub("\r\n","\n") + :gsub("\n ", "") + :gsub("\n\n+","\n"); + local vCards = {}; + local current; + for line in data:gmatch("[^\n]+") do + line = vCard_unesc(line); + local name, params, value = line:match("^([-%a]+)(\30?[^\29]*)\29(.*)$"); + value = value:gsub("\29",":"); + if #params > 0 then + local _params = {}; + for k,isval,v in params:gmatch("\30([^=]+)(=?)([^\30]*)") do + k = k:upper(); + local _vt = {}; + for _p in v:gmatch("[^\31]+") do + _vt[#_vt+1]=_p + _vt[_p]=true; + end + if isval == "=" then + _params[k]=_vt; + else + _params[k]=true; + end + end + params = _params; + end + if name == "BEGIN" and value == "VCARD" then + current = {}; + vCards[#vCards+1] = current; + elseif name == "END" and value == "VCARD" then + current = nil; + elseif current and vCard_dtd[name] then + local dtd = vCard_dtd[name]; + local item = { name = name }; + t_insert(current, item); + local up = current; + current = item; + if dtd.types then + for _, t in ipairs(dtd.types) do + t = t:lower(); + if ( params.TYPE and params.TYPE[t] == true) + or params[t] == true then + current.TYPE=t; + end + end + end + if dtd.props then + for _, p in ipairs(dtd.props) do + if params[p] then + if params[p] == true then + current[p]=true; + else + for _, prop in ipairs(params[p]) do + current[p]=prop; + end + end + end + end + end + if dtd == "text" or dtd.value then + t_insert(current, value); + elseif dtd.values then + for p in ("\30"..value):gmatch("\30([^\30]*)") do + t_insert(current, p); + end + end + current = up; + end + end + return vCards; +end + +local function item_to_text(item) + local value = {}; + for i=1,#item do + value[i] = vCard_esc(item[i]); + end + value = t_concat(value, ";"); + + local params = ""; + for k,v in pairs(item) do + if type(k) == "string" and k ~= "name" then + params = params .. (";%s=%s"):format(k, type(v) == "table" and t_concat(v,",") or v); + end + end + + return ("%s%s:%s"):format(item.name, params, value) +end + +local function vcard_to_text(vcard) + local t={}; + t_insert(t, "BEGIN:VCARD") + for i=1,#vcard do + t_insert(t, item_to_text(vcard[i])); + end + t_insert(t, "END:VCARD") + return t_concat(t, line_sep); +end + +function to_text(vCards) + if vCards[1] and vCards[1].name then + return vcard_to_text(vCards) + else + local t = {}; + for i=1,#vCards do + t[i]=vcard_to_text(vCards[i]); + end + return t_concat(t, line_sep); + end +end + +local function from_xep54_item(item) + local prop_name = item.name; + local prop_def = vCard_dtd[prop_name]; + + local prop = { name = prop_name }; + + if prop_def == "text" then + prop[1] = item:get_text(); + elseif type(prop_def) == "table" then + if prop_def.value then --single item + prop[1] = item:get_child_text(prop_def.value) or ""; + elseif prop_def.values then --array + local value_names = prop_def.values; + if value_names.behaviour == "repeat-last" then + for i=1,#item.tags do + t_insert(prop, item.tags[i]:get_text() or ""); + end + else + for i=1,#value_names do + t_insert(prop, item:get_child_text(value_names[i]) or ""); + end + end + elseif prop_def.names then + local names = prop_def.names; + for i=1,#names do + if item:get_child(names[i]) then + prop[1] = names[i]; + break; + end + end + end + + if prop_def.props_verbatim then + for k,v in pairs(prop_def.props_verbatim) do + prop[k] = v; + end + end + + if prop_def.types then + local types = prop_def.types; + prop.TYPE = {}; + for i=1,#types do + if item:get_child(types[i]) then + t_insert(prop.TYPE, types[i]:lower()); + end + end + if #prop.TYPE == 0 then + prop.TYPE = nil; + end + end + + -- A key-value pair, within a key-value pair? + if prop_def.props then + local params = prop_def.props; + for i=1,#params do + local name = params[i] + local data = item:get_child_text(name); + if data then + prop[name] = prop[name] or {}; + t_insert(prop[name], data); + end + end + end + else + return nil + end + + return prop; +end + +local function from_xep54_vCard(vCard) + local tags = vCard.tags; + local t = {}; + for i=1,#tags do + t_insert(t, from_xep54_item(tags[i])); + end + return t +end + +function from_xep54(vCard) + if vCard.attr.xmlns ~= "vcard-temp" then + return nil, "wrong-xmlns"; + end + if vCard.name == "xCard" then -- A collection of vCards + local t = {}; + local vCards = vCard.tags; + for i=1,#vCards do + t[i] = from_xep54_vCard(vCards[i]); + end + return t + elseif vCard.name == "vCard" then -- A single vCard + return from_xep54_vCard(vCard) + end +end + +local vcard4 = { } + +function vcard4:text(node, params, value) -- luacheck: ignore 212/params + self:tag(node:lower()) + -- FIXME params + if type(value) == "string" then + self:tag("text"):text(value):up() + elseif vcard4[node] then + vcard4[node](value); + end + self:up(); +end + +function vcard4.N(value) + for i, k in ipairs(vCard_dtd.N.values) do + value:tag(k):text(value[i]):up(); + end +end + +local xmlns_vcard4 = "urn:ietf:params:xml:ns:vcard-4.0" + +local function item_to_vcard4(item) + local typ = item.name:lower(); + local t = st.stanza(typ, { xmlns = xmlns_vcard4 }); + + local prop_def = vCard4_dtd[typ]; + if prop_def == "text" then + t:tag("text"):text(item[1]):up(); + elseif prop_def == "uri" then + if item.ENCODING and item.ENCODING[1] == 'b' then + t:tag("uri"):text("data:;base64,"):text(item[1]):up(); + else + t:tag("uri"):text(item[1]):up(); + end + elseif type(prop_def) == "table" then + if prop_def.values then + for i, v in ipairs(prop_def.values) do + t:tag(v:lower()):text(item[i] or ""):up(); + end + else + t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"}) + end + else + t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"}) + end + return t; +end + +local function vcard_to_vcard4xml(vCard) + local t = st.stanza("vcard", { xmlns = xmlns_vcard4 }); + for i=1,#vCard do + t:add_child(item_to_vcard4(vCard[i])); + end + return t; +end + +local function vcards_to_vcard4xml(vCards) + if not vCards[1] or vCards[1].name then + return vcard_to_vcard4xml(vCards) + else + local t = st.stanza("vcards", { xmlns = xmlns_vcard4 }); + for i=1,#vCards do + t:add_child(vcard_to_vcard4xml(vCards[i])); + end + return t; + end +end + +-- This was adapted from http://xmpp.org/extensions/xep-0054.html#dtd +vCard_dtd = { + VERSION = "text", --MUST be 3.0, so parsing is redundant + FN = "text", + N = { + values = { + "FAMILY", + "GIVEN", + "MIDDLE", + "PREFIX", + "SUFFIX", + }, + }, + NICKNAME = "text", + PHOTO = { + props_verbatim = { ENCODING = { "b" } }, + props = { "TYPE" }, + value = "BINVAL", --{ "EXTVAL", }, + }, + BDAY = "text", + ADR = { + types = { + "HOME", + "WORK", + "POSTAL", + "PARCEL", + "DOM", + "INTL", + "PREF", + }, + values = { + "POBOX", + "EXTADD", + "STREET", + "LOCALITY", + "REGION", + "PCODE", + "CTRY", + } + }, + LABEL = { + types = { + "HOME", + "WORK", + "POSTAL", + "PARCEL", + "DOM", + "INTL", + "PREF", + }, + value = "LINE", + }, + TEL = { + types = { + "HOME", + "WORK", + "VOICE", + "FAX", + "PAGER", + "MSG", + "CELL", + "VIDEO", + "BBS", + "MODEM", + "ISDN", + "PCS", + "PREF", + }, + value = "NUMBER", + }, + EMAIL = { + types = { + "HOME", + "WORK", + "INTERNET", + "PREF", + "X400", + }, + value = "USERID", + }, + JABBERID = "text", + MAILER = "text", + TZ = "text", + GEO = { + values = { + "LAT", + "LON", + }, + }, + TITLE = "text", + ROLE = "text", + LOGO = "copy of PHOTO", + AGENT = "text", + ORG = { + values = { + behaviour = "repeat-last", + "ORGNAME", + "ORGUNIT", + } + }, + CATEGORIES = { + values = "KEYWORD", + }, + NOTE = "text", + PRODID = "text", + REV = "text", + SORTSTRING = "text", + SOUND = "copy of PHOTO", + UID = "text", + URL = "text", + CLASS = { + names = { -- The item.name is the value if it's one of these. + "PUBLIC", + "PRIVATE", + "CONFIDENTIAL", + }, + }, + KEY = { + props = { "TYPE" }, + value = "CRED", + }, + DESC = "text", +}; +vCard_dtd.LOGO = vCard_dtd.PHOTO; +vCard_dtd.SOUND = vCard_dtd.PHOTO; + +vCard4_dtd = { + source = "uri", + kind = "text", + xml = "text", + fn = "text", + n = { + values = { + "family", + "given", + "middle", + "prefix", + "suffix", + }, + }, + nickname = "text", + photo = "uri", + bday = "date-and-or-time", + anniversary = "date-and-or-time", + gender = "text", + adr = { + values = { + "pobox", + "ext", + "street", + "locality", + "region", + "code", + "country", + } + }, + tel = "text", + email = "text", + impp = "uri", + lang = "language-tag", + tz = "text", + geo = "uri", + title = "text", + role = "text", + logo = "uri", + org = "text", + member = "uri", + related = "uri", + categories = "text", + note = "text", + prodid = "text", + rev = "timestamp", + sound = "uri", + uid = "uri", + clientpidmap = "number, uuid", + url = "uri", + version = "text", + key = "uri", + fburl = "uri", + caladruri = "uri", + caluri = "uri", +}; + +return { + from_text = from_text; + to_text = to_text; + + from_xep54 = from_xep54; + to_xep54 = to_xep54; + + to_vcard4 = vcards_to_vcard4xml; +}; diff --git a/util/watchdog.lua b/util/watchdog.lua index aa8c6486..516e60e4 100644 --- a/util/watchdog.lua +++ b/util/watchdog.lua @@ -3,6 +3,7 @@ local setmetatable = setmetatable; local os_time = os.time; local _ENV = nil; +-- luacheck: std none local watchdog_methods = {}; local watchdog_mt = { __index = watchdog_methods }; diff --git a/util/x509.lua b/util/x509.lua index f228b201..15cc4d3c 100644 --- a/util/x509.lua +++ b/util/x509.lua @@ -25,6 +25,7 @@ local log = require "util.logger".init("x509"); local s_format = string.format; local _ENV = nil; +-- luacheck: std none local oid_commonname = "2.5.4.3"; -- [LDAP] 2.3 local oid_subjectaltname = "2.5.29.17"; -- [PKIX] 4.2.1.6 diff --git a/util/xml.lua b/util/xml.lua index 733d821a..dac3f6fe 100644 --- a/util/xml.lua +++ b/util/xml.lua @@ -1,8 +1,11 @@ local st = require "util.stanza"; local lxp = require "lxp"; +local t_insert = table.insert; +local t_remove = table.remove; local _ENV = nil; +-- luacheck: std none local parse_xml = (function() local ns_prefixes = { @@ -14,6 +17,21 @@ local parse_xml = (function() --luacheck: ignore 212/self local handler = {}; local stanza = st.stanza("root"); + local namespaces = {}; + local prefixes = {}; + function handler:StartNamespaceDecl(prefix, url) + if prefix ~= nil then + t_insert(namespaces, url); + t_insert(prefixes, prefix); + end + end + function handler:EndNamespaceDecl(prefix) + if prefix ~= nil then + -- we depend on each StartNamespaceDecl having a paired EndNamespaceDecl + t_remove(namespaces); + t_remove(prefixes); + end + end function handler:StartElement(tagname, attr) local curr_ns,name = tagname:match(ns_pattern); if name == "" then @@ -34,7 +52,11 @@ local parse_xml = (function() end end end - stanza:tag(name, attr); + local n = {} + for i=1,#namespaces do + n[prefixes[i]] = namespaces[i]; + end + stanza:tag(name, attr, n); end function handler:CharacterData(data) stanza:text(data); diff --git a/util/xmppstream.lua b/util/xmppstream.lua index 7be63285..8c7851a5 100644 --- a/util/xmppstream.lua +++ b/util/xmppstream.lua @@ -25,6 +25,7 @@ local lxp_supports_bytecount = not not lxp.new({}).getcurrentbytecount; local default_stanza_size_limit = 1024*1024*10; -- 10MB local _ENV = nil; +-- luacheck: std none local new_parser = lxp.new; @@ -47,7 +48,10 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress) local cb_streamopened = stream_callbacks.streamopened; local cb_streamclosed = stream_callbacks.streamclosed; - local cb_error = stream_callbacks.error or function(session, e, stanza) error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); end; + local cb_error = stream_callbacks.error or + function(_, e, stanza) + error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); + end; local cb_handlestanza = stream_callbacks.handlestanza; cb_handleprogress = cb_handleprogress or dummy_cb; @@ -128,6 +132,9 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress) end if lxp_supports_xmldecl then function xml_handlers:XmlDecl(version, encoding, standalone) + session.xml_version = version; + session.xml_encoding = encoding; + session.xml_standalone = standalone; if lxp_supports_bytecount then cb_handleprogress(self:getcurrentbytecount()); end @@ -214,7 +221,7 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress) stack = {}; end - local function set_session(stream, new_session) + local function set_session(stream, new_session) -- luacheck: ignore 212/stream session = new_session; end @@ -238,7 +245,7 @@ local function new(session, stream_callbacks, stanza_size_limit) local parser = new_parser(handlers, ns_separator, false); local parse = parser.parse; - function session.open_stream(session, from, to) + function session.open_stream(session, from, to) -- luacheck: ignore 432/session local send = session.sends2s or session.send; local attr = { @@ -264,7 +271,7 @@ local function new(session, stream_callbacks, stanza_size_limit) n_outstanding_bytes = 0; meta.reset(); end, - feed = function (self, data) + feed = function (self, data) -- luacheck: ignore 212/self if lxp_supports_bytecount then n_outstanding_bytes = n_outstanding_bytes + #data; end |