diff options
Diffstat (limited to 'util')
50 files changed, 1857 insertions, 721 deletions
diff --git a/util/array.lua b/util/array.lua index 2d58e7fb..d4ab1771 100644 --- a/util/array.lua +++ b/util/array.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -11,8 +11,10 @@ local t_insert, t_sort, t_remove, t_concat local setmetatable = setmetatable; local math_random = math.random; +local math_floor = math.floor; local pairs, ipairs = pairs, ipairs; local tostring = tostring; +local type = type; local array = {}; local array_base = {}; @@ -59,13 +61,13 @@ function array_base.filter(outa, ina, func) write = write + 1; end end - + if inplace and write <= start_length then for i=write,start_length do outa[i] = nil; end end - + return outa; end @@ -84,6 +86,25 @@ function array_base.pluck(outa, ina, key) return outa; end +function array_base.reverse(outa, ina) + local len = #ina; + if ina == outa then + local middle = math_floor(len/2); + len = len + 1; + local o; -- opposite + for i = 1, middle do + o = len - i; + outa[i], outa[o] = outa[o], outa[i]; + end + else + local off = len + 1; + for i = 1, len do + outa[i] = ina[off - i]; + end + end + return outa; +end + --- These methods only mutate the array function array_methods:shuffle(outa, ina) local len = #self; @@ -94,15 +115,6 @@ function array_methods:shuffle(outa, ina) return self; end -function array_methods:reverse() - local len = #self-1; - for i=len,1,-1 do - self:push(self[i]); - self:pop(i); - end - return self; -end - function array_methods:append(array) local len,len2 = #self, #array; for i=1,len2 do @@ -157,7 +169,4 @@ for method, f in pairs(array_base) do end end -_G.array = array; -module("array"); - return array; diff --git a/util/async.lua b/util/async.lua new file mode 100644 index 00000000..968ec804 --- /dev/null +++ b/util/async.lua @@ -0,0 +1,158 @@ +local log = require "util.logger".init("util.async"); + +local function runner_continue(thread) + -- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure) + if coroutine.status(thread) ~= "suspended" then -- This should suffice + return false; + end + local ok, state, runner = coroutine.resume(thread); + if not ok then + local level = 0; + while debug.getinfo(thread, level, "") do level = level + 1; end + ok, runner = debug.getlocal(thread, level-1, 1); + local error_handler = runner.watchers.error; + if error_handler then error_handler(runner, debug.traceback(thread, state)); end + elseif state == "ready" then + -- If state is 'ready', it is our responsibility to update runner.state from 'waiting'. + -- We also have to :run(), because the queue might have further items that will not be + -- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer). + runner.state = "ready"; + runner:run(); + end + return true; +end + +local function waiter(num) + local thread = coroutine.running(); + if not thread then + error("Not running in an async context, see http://prosody.im/doc/developers/async"); + end + num = num or 1; + local waiting; + return function () + if num == 0 then return; end -- already done + waiting = true; + coroutine.yield("wait"); + end, function () + num = num - 1; + if num == 0 and waiting then + runner_continue(thread); + elseif num < 0 then + error("done() called too many times"); + end + end; +end + +local function guarder() + local guards = {}; + return function (id, func) + local thread = coroutine.running(); + if not thread then + error("Not running in an async context, see http://prosody.im/doc/developers/async"); + end + local guard = guards[id]; + if not guard then + guard = {}; + guards[id] = guard; + log("debug", "New guard!"); + else + table.insert(guard, thread); + log("debug", "Guarded. %d threads waiting.", #guard) + coroutine.yield("wait"); + end + local function exit() + local next_waiting = table.remove(guard, 1); + if next_waiting then + log("debug", "guard: Executing next waiting thread (%d left)", #guard) + runner_continue(next_waiting); + else + log("debug", "Guard off duty.") + guards[id] = nil; + end + end + if func then + func(); + exit(); + return; + end + return exit; + end; +end + +local runner_mt = {}; +runner_mt.__index = runner_mt; + +local function runner_create_thread(func, self) + local thread = coroutine.create(function (self) + while true do + func(coroutine.yield("ready", self)); + end + end); + assert(coroutine.resume(thread, self)); -- Start it up, it will return instantly to wait for the first input + return thread; +end + +local empty_watchers = {}; +local function runner(func, watchers, data) + return setmetatable({ func = func, thread = false, state = "ready", notified_state = "ready", + queue = {}, watchers = watchers or empty_watchers, data = data } + , runner_mt); +end + +function runner_mt:run(input) + if input ~= nil then + table.insert(self.queue, input); + end + if self.state ~= "ready" then + return true, self.state, #self.queue; + end + + local q, thread = self.queue, self.thread; + if not thread or coroutine.status(thread) == "dead" then + thread = runner_create_thread(self.func, self); + self.thread = thread; + end + + local n, state, err = #q, self.state, nil; + self.state = "running"; + while n > 0 and state == "ready" do + local consumed; + for i = 1,n do + local input = q[i]; + local ok, new_state = coroutine.resume(thread, input); + if not ok then + consumed, state, err = i, "ready", debug.traceback(thread, new_state); + self.thread = nil; + break; + elseif new_state == "wait" then + consumed, state = i, "waiting"; + break; + end + end + if not consumed then consumed = n; end + if q[n+1] ~= nil then + n = #q; + end + for i = 1, n do + q[i] = q[consumed+i]; + end + n = #q; + end + self.state = state; + if err or state ~= self.notified_state then + if err then + state = "error" + else + self.notified_state = state; + end + local handler = self.watchers[state]; + if handler then handler(self, err); end + end + return true, state, n; +end + +function runner_mt:enqueue(input) + table.insert(self.queue, input); +end + +return { waiter = waiter, guarder = guarder, runner = runner }; diff --git a/util/caps.lua b/util/caps.lua index a61e7403..cd5ff9c0 100644 --- a/util/caps.lua +++ b/util/caps.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -12,9 +12,9 @@ local sha1 = require "util.hashes".sha1; local t_insert, t_sort, t_concat = table.insert, table.sort, table.concat; local ipairs = ipairs; -module "caps" +local _ENV = nil; -function calculate_hash(disco_info) +local function calculate_hash(disco_info) local identities, features, extensions = {}, {}, {}; for _, tag in ipairs(disco_info) do if tag.name == "identity" then @@ -58,4 +58,6 @@ function calculate_hash(disco_info) return ver, S; end -return _M; +return { + calculate_hash = calculate_hash; +}; diff --git a/util/dataforms.lua b/util/dataforms.lua index ee37157a..05846ab3 100644 --- a/util/dataforms.lua +++ b/util/dataforms.lua @@ -1,26 +1,26 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- local setmetatable = setmetatable; -local pairs, ipairs = pairs, ipairs; +local ipairs = ipairs; local tostring, type, next = tostring, type, next; local t_concat = table.concat; local st = require "util.stanza"; local jid_prep = require "util.jid".prep; -module "dataforms" +local _ENV = nil; local xmlns_forms = 'jabber:x:data'; local form_t = {}; local form_mt = { __index = form_t }; -function new(layout) +local function new(layout) return setmetatable(layout, form_mt); end @@ -32,13 +32,13 @@ function form_t.form(layout, data, formtype) if layout.instructions then form:tag("instructions"):text(layout.instructions):up(); end - for n, field in ipairs(layout) do + for _, field in ipairs(layout) do local field_type = field.type or "text-single"; -- Add field tag form:tag("field", { type = field_type, var = field.name, label = field.label }); local value = (data and data[field.name]) or field.value; - + if value then -- Add value, depending on type if field_type == "hidden" then @@ -102,11 +102,11 @@ function form_t.form(layout, data, formtype) end form:up(); end - + if field.required then form:tag("required"):up(); end - + -- Jump back up to list of fields form:up(); end @@ -145,30 +145,29 @@ function form_t.data(layout, stanza) return data; end -field_readers["text-single"] = - function (field_tag, required) - local data = field_tag:get_child_text("value"); - if data and #data > 0 then - return data - elseif required then - return nil, "Required value missing"; - end +local function simple_text(field_tag, required) + local data = field_tag:get_child_text("value"); + -- XEP-0004 does not say if an empty string is acceptable for a required value + -- so we will follow HTML5 which says that empty string means missing + if required and (data == nil or data == "") then + return nil, "Required value missing"; end + return data; -- Return whatever get_child_text returned, even if empty string +end -field_readers["text-private"] = - field_readers["text-single"]; +field_readers["text-single"] = simple_text; + +field_readers["text-private"] = simple_text; field_readers["jid-single"] = function (field_tag, required) - local raw_data = field_tag:get_child_text("value") + local raw_data, err = simple_text(field_tag, required); + if not raw_data then return raw_data, err; end local data = jid_prep(raw_data); - if data and #data > 0 then - return data - elseif raw_data then + if not data then return nil, "Invalid JID: " .. raw_data; - elseif required then - return nil, "Required value missing"; end + return data; end field_readers["jid-multi"] = @@ -212,8 +211,7 @@ field_readers["text-multi"] = return data, err; end -field_readers["list-single"] = - field_readers["text-single"]; +field_readers["list-single"] = simple_text; local boolean_values = { ["1"] = true, ["true"] = true, @@ -222,15 +220,13 @@ local boolean_values = { field_readers["boolean"] = function (field_tag, required) - local raw_value = field_tag:get_child_text("value"); - local value = boolean_values[raw_value ~= nil and raw_value]; - if value ~= nil then - return value; - elseif raw_value then - return nil, "Invalid boolean representation"; - elseif required then - return nil, "Required value missing"; + local raw_value, err = simple_text(field_tag, required); + if not raw_value then return raw_value, err; end + local value = boolean_values[raw_value]; + if value == nil then + return nil, "Invalid boolean representation:" .. raw_value; end + return value; end field_readers["hidden"] = @@ -238,7 +234,9 @@ field_readers["hidden"] = return field_tag:get_child_text("value"); end -return _M; +return { + new = new; +}; --[=[ diff --git a/util/datamanager.lua b/util/datamanager.lua index a107d95c..4b722851 100644 --- a/util/datamanager.lua +++ b/util/datamanager.lua @@ -43,7 +43,7 @@ pcall(function() fallocate = pposix.fallocate or fallocate; end); -module "datamanager" +local _ENV = nil; ---- utils ----- local encode, decode; @@ -74,7 +74,7 @@ local callbacks = {}; ------- API ------------- -function set_data_path(path) +local function set_data_path(path) log("debug", "Setting data path to: %s", path); data_path = path; end @@ -87,14 +87,14 @@ local function callback(username, host, datastore, data) return username, host, datastore, data; end -function add_callback(func) +local function add_callback(func) if not callbacks[func] then -- Would you really want to set the same callback more than once? callbacks[func] = true; callbacks[#callbacks+1] = func; return true; end end -function remove_callback(func) +local function remove_callback(func) if callbacks[func] then for i, f in ipairs(callbacks) do if f == func then @@ -106,7 +106,7 @@ function remove_callback(func) end end -function getpath(username, host, datastore, ext, create) +local function getpath(username, host, datastore, ext, create) ext = ext or "dat"; host = (host and encode(host)) or "_global"; username = username and encode(username); @@ -119,7 +119,7 @@ function getpath(username, host, datastore, ext, create) end end -function load(username, host, datastore) +local function load(username, host, datastore) local data, ret = envloadfile(getpath(username, host, datastore), {}); if not data then local mode = lfs.attributes(getpath(username, host, datastore), "mode"); @@ -175,7 +175,7 @@ if prosody and prosody.platform ~= "posix" then end end -function store(username, host, datastore, data) +local function store(username, host, datastore, data) if not data then data = {}; end @@ -209,7 +209,7 @@ function store(username, host, datastore, data) return true; end -function list_append(username, host, datastore, data) +local function list_append(username, host, datastore, data) if not data then return; end if callback(username, host, datastore) == false then return true; end -- save the datastore @@ -235,7 +235,7 @@ function list_append(username, host, datastore, data) return true; end -function list_store(username, host, datastore, data) +local function list_store(username, host, datastore, data) if not data then data = {}; end @@ -259,7 +259,7 @@ function list_store(username, host, datastore, data) return true; end -function list_load(username, host, datastore) +local function list_load(username, host, datastore) local items = {}; local data, ret = envloadfile(getpath(username, host, datastore, "list"), {item = function(i) t_insert(items, i); end}); if not data then @@ -287,7 +287,7 @@ local type_map = { list = "list"; } -function users(host, store, typ) +local function users(host, store, typ) typ = type_map[typ or "keyval"]; local store_dir = format("%s/%s/%s", data_path, encode(host), store); @@ -306,7 +306,7 @@ function users(host, store, typ) end, state; end -function stores(username, host, typ) +local function stores(username, host, typ) typ = type_map[typ or "keyval"]; local store_dir = format("%s/%s/", data_path, encode(host)); @@ -346,7 +346,7 @@ local function do_remove(path) return true end -function purge(username, host) +local function purge(username, host) local host_dir = format("%s/%s/", data_path, encode(host)); local ok, iter, state, var = pcall(lfs.dir, host_dir); if not ok then @@ -366,6 +366,19 @@ function purge(username, host) return #errs == 0, t_concat(errs, ", "); end -_M.path_decode = decode; -_M.path_encode = encode; -return _M; +return { + set_data_path = set_data_path; + add_callback = add_callback; + remove_callback = remove_callback; + getpath = getpath; + load = load; + store = store; + list_append = list_append; + list_store = list_store; + list_load = list_load; + users = users; + stores = stores; + purge = purge; + path_decode = decode; + path_encode = encode; +}; diff --git a/util/datetime.lua b/util/datetime.lua index a1f62a48..27f28ace 100644 --- a/util/datetime.lua +++ b/util/datetime.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -15,25 +15,25 @@ local os_difftime = os.difftime; local error = error; local tonumber = tonumber; -module "datetime" +local _ENV = nil; -function date(t) +local function date(t) return os_date("!%Y-%m-%d", t); end -function datetime(t) +local function datetime(t) return os_date("!%Y-%m-%dT%H:%M:%SZ", t); end -function time(t) +local function time(t) return os_date("!%H:%M:%S", t); end -function legacy(t) +local function legacy(t) return os_date("!%Y%m%dT%H:%M:%S", t); end -function parse(s) +local function parse(s) if s then local year, month, day, hour, min, sec, tzd; year, month, day, hour, min, sec, tzd = s:match("^(%d%d%d%d)%-?(%d%d)%-?(%d%d)T(%d%d):(%d%d):(%d%d)%.?%d*([Z+%-]?.*)$"); @@ -54,4 +54,10 @@ function parse(s) end end -return _M; +return { + date = date; + datetime = datetime; + time = time; + legacy = legacy; + parse = parse; +}; diff --git a/util/debug.lua b/util/debug.lua index bff0e347..a78524a9 100644 --- a/util/debug.lua +++ b/util/debug.lua @@ -13,7 +13,7 @@ local termcolours = require "util.termcolours"; local getstring = termcolours.getstring; local styles; do - _ = termcolours.getstyle; + local _ = termcolours.getstyle; styles = { boundary_padding = _("bright"); filename = _("bright", "blue"); @@ -22,20 +22,23 @@ do location = _("yellow"); }; end -module("debugx", package.seeall); -function get_locals_table(level) - level = level + 1; -- Skip this function itself +local function get_locals_table(thread, level) local locals = {}; for local_num = 1, math.huge do - local name, value = debug.getlocal(level, local_num); + local name, value; + if thread then + name, value = debug.getlocal(thread, level, local_num); + else + name, value = debug.getlocal(level+1, local_num); + end if not name then break; end table.insert(locals, { name = name, value = value }); end return locals; end -function get_upvalues_table(func) +local function get_upvalues_table(func) local upvalues = {}; if func then for upvalue_num = 1, math.huge do @@ -47,7 +50,7 @@ function get_upvalues_table(func) return upvalues; end -function string_from_var_table(var_table, max_line_len, indent_str) +local function string_from_var_table(var_table, max_line_len, indent_str) local var_string = {}; local col_pos = 0; max_line_len = max_line_len or math.huge; @@ -83,33 +86,25 @@ function string_from_var_table(var_table, max_line_len, indent_str) end end -function get_traceback_table(thread, start_level) +local function get_traceback_table(thread, start_level) local levels = {}; for level = start_level, math.huge do local info; if thread then - info = debug.getinfo(thread, level+1); + info = debug.getinfo(thread, level); else info = debug.getinfo(level+1); end if not info then break; end - + levels[(level-start_level)+1] = { level = level; info = info; - locals = get_locals_table(level+1); + locals = get_locals_table(thread, level+(thread and 0 or 1)); upvalues = get_upvalues_table(info.func); }; - end - return levels; -end - -function traceback(...) - local ok, ret = pcall(_traceback, ...); - if not ok then - return "Error in error handling: "..ret; end - return ret; + return levels; end local function build_source_boundary_marker(last_source_desc) @@ -117,7 +112,7 @@ local function build_source_boundary_marker(last_source_desc) return getstring(styles.boundary_padding, "v"..padding).." "..getstring(styles.filename, last_source_desc).." "..getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v ")); end -function _traceback(thread, message, level) +local function _traceback(thread, message, level) -- Lua manual says: debug.traceback ([thread,] [message [, level]]) -- I fathom this to mean one of: @@ -134,15 +129,15 @@ function _traceback(thread, message, level) return nil; -- debug.traceback() does this end - level = level or 1; + level = level or 0; message = message and (message.."\n") or ""; - - -- +3 counts for this function, and the pcall() and wrapper above us - local levels = get_traceback_table(thread, level+3); - + + -- +3 counts for this function, and the pcall() and wrapper above us, the +1... I don't know. + local levels = get_traceback_table(thread, level+(thread == nil and 4 or 0)); + local last_source_desc; - + local lines = {}; for nlevel, level in ipairs(levels) do local info = level.info; @@ -171,9 +166,11 @@ function _traceback(thread, message, level) nlevel = nlevel-1; table.insert(lines, "\t"..(nlevel==0 and ">" or " ")..getstring(styles.level_num, "("..nlevel..") ")..line); local npadding = (" "):rep(#tostring(nlevel)); - local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding); - if locals_str then - table.insert(lines, "\t "..npadding.."Locals: "..locals_str); + if level.locals then + local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding); + if locals_str then + table.insert(lines, "\t "..npadding.."Locals: "..locals_str); + end end local upvalues_str = string_from_var_table(level.upvalues, optimal_line_length, "\t "..npadding); if upvalues_str then @@ -186,8 +183,23 @@ function _traceback(thread, message, level) return message.."stack traceback:\n"..table.concat(lines, "\n"); end -function use() +local function traceback(...) + local ok, ret = pcall(_traceback, ...); + if not ok then + return "Error in error handling: "..ret; + end + return ret; +end + +local function use() debug.traceback = traceback; end -return _M; +return { + get_locals_table = get_locals_table; + get_upvalues_table = get_upvalues_table; + string_from_var_table = string_from_var_table; + get_traceback_table = get_traceback_table; + traceback = traceback; + use = use; +}; diff --git a/util/dependencies.lua b/util/dependencies.lua index 4d50cf63..5e3b03d8 100644 --- a/util/dependencies.lua +++ b/util/dependencies.lua @@ -1,21 +1,19 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- -module("dependencies", package.seeall) - -function softreq(...) local ok, lib = pcall(require, ...); if ok then return lib; else return nil, lib; end end +local function softreq(...) local ok, lib = pcall(require, ...); if ok then return lib; else return nil, lib; end end -- Required to be able to find packages installed with luarocks if not softreq "luarocks.loader" then -- LuaRocks 2.x softreq "luarocks.require"; -- LuaRocks <1.x end -function missingdep(name, sources, msg) +local function missingdep(name, sources, msg) print(""); print("**************************"); print("Prosody was unable to find "..tostring(name)); @@ -35,7 +33,7 @@ function missingdep(name, sources, msg) print(""); end --- COMPAT w/pre-0.8 Debian: The Debian config file used to use +-- COMPAT w/pre-0.8 Debian: The Debian config file used to use -- util.ztact, which has been removed from Prosody in 0.8. This -- is to log an error for people who still use it, so they can -- update their configs. @@ -48,19 +46,19 @@ package.preload["util.ztact"] = function () end end; -function check_dependencies() - if _VERSION ~= "Lua 5.1" then +local function check_dependencies() + if _VERSION < "Lua 5.1" then print "***********************************" print("Unsupported Lua version: ".._VERSION); - print("Only Lua 5.1 is supported."); + print("At least Lua 5.1 is required."); print "***********************************" return false; end local fatal; - + local lxp = softreq "lxp" - + if not lxp then missingdep("luaexpat", { ["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-expat0"; @@ -69,9 +67,9 @@ function check_dependencies() }); fatal = true; end - + local socket = softreq "socket" - + if not socket then missingdep("luasocket", { ["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-socket2"; @@ -80,7 +78,7 @@ function check_dependencies() }); fatal = true; end - + local lfs, err = softreq "lfs" if not lfs then missingdep("luafilesystem", { @@ -90,9 +88,9 @@ function check_dependencies() }); fatal = true; end - + local ssl = softreq "ssl" - + if not ssl then missingdep("LuaSec", { ["Debian/Ubuntu"] = "http://prosody.im/download/start#debian_and_ubuntu"; @@ -100,7 +98,7 @@ function check_dependencies() ["Source"] = "http://www.inf.puc-rio.br/~brunoos/luasec/"; }, "SSL/TLS support will not be available"); end - + local encodings, err = softreq "util.encodings" if not encodings then if err:match("not found") then @@ -137,13 +135,18 @@ function check_dependencies() return not fatal; end -function log_warnings() +local function log_warnings() + if _VERSION > "Lua 5.1" then + log("warn", "Support for %s is experimental, please report any issues", _VERSION); + end + local ssl = softreq"ssl"; if ssl then local major, minor, veryminor, patched = ssl._VERSION:match("(%d+)%.(%d+)%.?(%d*)(M?)"); if not major or ((tonumber(major) == 0 and (tonumber(minor) or 0) <= 3 and (tonumber(veryminor) or 0) <= 2) and patched ~= "M") then log("error", "This version of LuaSec contains a known bug that causes disconnects, see http://prosody.im/doc/depends"); end end + local lxp = softreq"lxp"; if lxp then if not pcall(lxp.new, { StartDoctypeDecl = false }) then log("error", "The version of LuaExpat on your system leaves Prosody " @@ -162,4 +165,9 @@ function log_warnings() end end -return _M; +return { + softreq = softreq; + missingdep = missingdep; + check_dependencies = check_dependencies; + log_warnings = log_warnings; +}; diff --git a/util/events.lua b/util/events.lua index 412acccd..825ffb19 100644 --- a/util/events.lua +++ b/util/events.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -9,14 +9,17 @@ local pairs = pairs; local t_insert = table.insert; +local t_remove = table.remove; local t_sort = table.sort; local setmetatable = setmetatable; local next = next; -module "events" +local _ENV = nil; -function new() +local function new() local handlers = {}; + local global_wrappers; + local wrappers = {}; local event_map = {}; local function _rebuild_index(handlers, event) local _handlers = event_map[event]; @@ -50,6 +53,9 @@ function new() end end end; + local function get_handlers(event) + return handlers[event]; + end; local function add_handlers(handlers) for event, handler in pairs(handlers) do add_handler(event, handler); @@ -60,24 +66,91 @@ function new() remove_handler(event, handler); end end; - local function fire_event(event, ...) - local h = handlers[event]; + local function _fire_event(event_name, event_data) + local h = handlers[event_name]; if h then for i=1,#h do - local ret = h[i](...); + local ret = h[i](event_data); if ret ~= nil then return ret; end end end end; + local function fire_event(event_name, event_data) + local w = wrappers[event_name] or global_wrappers; + if w then + local curr_wrapper = #w; + local function c(event_name, event_data) + curr_wrapper = curr_wrapper - 1; + if curr_wrapper == 0 then + if global_wrappers == nil or w == global_wrappers then + return _fire_event(event_name, event_data); + end + w, curr_wrapper = global_wrappers, #global_wrappers; + return w[curr_wrapper](c, event_name, event_data); + else + return w[curr_wrapper](c, event_name, event_data); + end + end + return w[curr_wrapper](c, event_name, event_data); + end + return _fire_event(event_name, event_data); + end + local function add_wrapper(event_name, wrapper) + local w; + if event_name == false then + w = global_wrappers; + if not w then + w = {}; + global_wrappers = w; + end + else + w = wrappers[event_name]; + if not w then + w = {}; + wrappers[event_name] = w; + end + end + w[#w+1] = wrapper; + end + local function remove_wrapper(event_name, wrapper) + local w; + if event_name == false then + w = global_wrappers; + else + w = wrappers[event_name]; + end + if not w then return; end + for i = #w, 1 do + if w[i] == wrapper then + t_remove(w, i); + end + end + if #w == 0 then + if event_name == nil then + global_wrappers = nil; + else + wrappers[event_name] = nil; + end + end + end return { add_handler = add_handler; remove_handler = remove_handler; add_handlers = add_handlers; remove_handlers = remove_handlers; + get_handlers = get_handlers; + wrappers = { + add_handler = add_wrapper; + remove_handler = remove_wrapper; + }; + add_wrapper = add_wrapper; + remove_wrapper = remove_wrapper; fire_event = fire_event; _handlers = handlers; _event_map = event_map; }; end -return _M; +return { + new = new; +}; diff --git a/util/filters.lua b/util/filters.lua index 6290e53b..f405c0bd 100644 --- a/util/filters.lua +++ b/util/filters.lua @@ -1,22 +1,22 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- local t_insert, t_remove = table.insert, table.remove; -module "filters" +local _ENV = nil; local new_filter_hooks = {}; -function initialize(session) +local function initialize(session) if not session.filters then local filters = {}; session.filters = filters; - + function session.filter(type, data) local filter_list = filters[type]; if filter_list then @@ -28,19 +28,19 @@ function initialize(session) return data; end end - + for i=1,#new_filter_hooks do new_filter_hooks[i](session); end - + return session.filter; end -function add_filter(session, type, callback, priority) +local function add_filter(session, type, callback, priority) if not session.filters then initialize(session); end - + local filter_list = session.filters[type]; if not filter_list then filter_list = {}; @@ -48,19 +48,19 @@ function add_filter(session, type, callback, priority) elseif filter_list[callback] then return; -- Filter already added end - + priority = priority or 0; - + local i = 0; repeat i = i + 1; until not filter_list[i] or filter_list[filter_list[i]] < priority; - + t_insert(filter_list, i, callback); filter_list[callback] = priority; end -function remove_filter(session, type, callback) +local function remove_filter(session, type, callback) if not session.filters then return; end local filter_list = session.filters[type]; if filter_list and filter_list[callback] then @@ -74,11 +74,11 @@ function remove_filter(session, type, callback) end end -function add_filter_hook(callback) +local function add_filter_hook(callback) t_insert(new_filter_hooks, callback); end -function remove_filter_hook(callback) +local function remove_filter_hook(callback) for i=1,#new_filter_hooks do if new_filter_hooks[i] == callback then t_remove(new_filter_hooks, i); @@ -86,4 +86,10 @@ function remove_filter_hook(callback) end end -return _M; +return { + initialize = initialize; + add_filter = add_filter; + remove_filter = remove_filter; + add_filter_hook = add_filter_hook; + remove_filter_hook = remove_filter_hook; +}; diff --git a/util/helpers.lua b/util/helpers.lua index 08b86a7c..bf76d258 100644 --- a/util/helpers.lua +++ b/util/helpers.lua @@ -1,28 +1,18 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- local debug = require "util.debug"; -module("helpers", package.seeall); - -- Helper functions for debugging local log = require "util.logger".init("util.debug"); -function log_host_events(host) - return log_events(prosody.hosts[host].events, host); -end - -function revert_log_host_events(host) - return revert_log_events(prosody.hosts[host].events); -end - -function log_events(events, name, logger) +local function log_events(events, name, logger) local f = events.fire_event; if not f then error("Object does not appear to be a util.events object"); @@ -37,11 +27,19 @@ function log_events(events, name, logger) return events; end -function revert_log_events(events) +local function revert_log_events(events) events.fire_event, events[events.fire_event] = events[events.fire_event], nil; -- :)) end -function show_events(events, specific_event) +local function log_host_events(host) + return log_events(prosody.hosts[host].events, host); +end + +local function revert_log_host_events(host) + return revert_log_events(prosody.hosts[host].events); +end + +local function show_events(events, specific_event) local event_handlers = events._handlers; local events_array = {}; local event_handler_arrays = {}; @@ -70,7 +68,7 @@ function show_events(events, specific_event) return table.concat(events_array, "\n"); end -function get_upvalue(f, get_name) +local function get_upvalue(f, get_name) local i, name, value = 0; repeat i = i + 1; @@ -79,4 +77,11 @@ function get_upvalue(f, get_name) return value; end -return _M; +return { + log_host_events = log_host_events; + revert_log_host_events = revert_log_host_events; + log_events = log_events; + revert_log_events = revert_log_events; + show_events = show_events; + get_upvalue = get_upvalue; +}; diff --git a/util/hex.lua b/util/hex.lua new file mode 100644 index 00000000..4cc28d33 --- /dev/null +++ b/util/hex.lua @@ -0,0 +1,26 @@ +local s_char = string.char; +local s_format = string.format; +local s_gsub = string.gsub; +local s_lower = string.lower; + +local char_to_hex = {}; +local hex_to_char = {}; + +do + local char, hex; + for i = 0,255 do + char, hex = s_char(i), s_format("%02x", i); + char_to_hex[char] = hex; + hex_to_char[hex] = char; + end +end + +local function to(s) + return (s_gsub(s, ".", char_to_hex)); +end + +local function from(s) + return (s_gsub(s_lower(s), "%X*(%x%x)%X*", hex_to_char)); +end + +return { to = to, from = from } diff --git a/util/hmac.lua b/util/hmac.lua index 51211c7a..2c4cc6ef 100644 --- a/util/hmac.lua +++ b/util/hmac.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/import.lua b/util/import.lua index 81401e8b..174da0ca 100644 --- a/util/import.lua +++ b/util/import.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/interpolation.lua b/util/interpolation.lua new file mode 100644 index 00000000..315cc203 --- /dev/null +++ b/util/interpolation.lua @@ -0,0 +1,85 @@ +-- Simple template language +-- +-- The new() function takes a pattern and an escape function and returns +-- a render() function. Both are required. +-- +-- The function render() takes a string template and a table of values. +-- Sequences like {name} in the template string are substituted +-- with values from the table, optionally depending on a modifier +-- symbol. +-- +-- Variants are: +-- {name} is substituted for values["name"] and is escaped using the +-- second argument to new_render(). To disable the escaping, use {name!}. +-- {name.item} can be used to access table items. +-- To renter lists of items: {name# item number {idx} is {item} } +-- Or key-value pairs: {name% t[ {idx} ] = {item} } +-- To show a defaults for missing values {name? sub-template } can be used, +-- which renders a sub-template if values["name"] is false-ish. +-- {name& sub-template } does the opposite, the sub-template is rendered +-- if the selected value is anything but false or nil. + +local type, tostring = type, tostring; +local pairs, ipairs = pairs, ipairs; +local s_sub, s_gsub, s_match = string.sub, string.gsub, string.match; +local t_concat = table.concat; + +local function new_render(pat, escape, funcs) + -- assert(type(pat) == "string", "bad argument #1 to 'new_render' (string expected)"); + -- assert(type(escape) == "function", "bad argument #2 to 'new_render' (function expected)"); + local function render(template, values) + -- assert(type(template) == "string", "bad argument #1 to 'render' (string expected)"); + -- assert(type(values) == "table", "bad argument #2 to 'render' (table expected)"); + return (s_gsub(template, pat, function (block) + block = s_sub(block, 2, -2); + local name, opt, e = s_match(block, "^([%a_][%w_.]*)(%p?)()"); + if not name then return end + local value = values[name]; + if not value and name:find(".", 2, true) then + value = values; + for word in name:gmatch"[^.]+" do + value = value[word]; + if not value then break; end + end + end + if funcs then + while value ~= nil and opt == '|' do + local f; + f, opt, e = s_match(block, "^([%a_][%w_.]*)(%p?)()", e); + f = funcs[f]; + if f then value = f(value); end + end + end + if opt == '#' or opt == '%' then + if type(value) ~= "table" then return ""; end + local iter = opt == '#' and ipairs or pairs; + local out, i, subtpl = {}, 1, s_sub(block, e); + local subvalues = setmetatable({}, { __index = values }); + for idx, item in iter(value) do + subvalues.idx = idx; + subvalues.item = item; + out[i], i = render(subtpl, subvalues), i+1; + end + return t_concat(out); + elseif opt == '&' then + if not value then return ""; end + return render(s_sub(block, e), values); + elseif opt == '?' and not value then + return render(s_sub(block, e), values); + elseif value ~= nil then + if type(value) ~= "string" then + value = tostring(value); + end + if opt ~= '!' then + return escape(value); + end + return value; + end + end)); + end + return render; +end + +return { + new = new_render; +}; diff --git a/util/ip.lua b/util/ip.lua index 856bf034..d0ae07eb 100644 --- a/util/ip.lua +++ b/util/ip.lua @@ -12,7 +12,17 @@ local ip_mt = { __index = function (ip, key) return (ip_methods[key])(ip); end, local hex2bits = { ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111", ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011", ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111" }; local function new_ip(ipStr, proto) - if proto ~= "IPv4" and proto ~= "IPv6" then + if not proto then + local sep = ipStr:match("^%x+(.)"); + if sep == ":" or (not(sep) and ipStr:sub(1,1) == ":") then + proto = "IPv6" + elseif sep == "." then + proto = "IPv4" + end + if not proto then + return nil, "invalid address"; + end + elseif proto ~= "IPv4" and proto ~= "IPv6" then return nil, "invalid protocol"; end if proto == "IPv6" and ipStr:find('.', 1, true) then @@ -82,7 +92,7 @@ local function v6scope(ip) if ip:match("^[0:]*1$") then return 0x2; -- Link-local unicast: - elseif ip:match("^[Ff][Ee][89ABab]") then + elseif ip:match("^[Ff][Ee][89ABab]") then return 0x2; -- Site-local unicast: elseif ip:match("^[Ff][Ee][CcDdEeFf]") then @@ -192,5 +202,43 @@ function ip_methods:scope() return value; end +function ip_methods:private() + local private = self.scope ~= 0xE; + if not private and self.proto == "IPv4" then + local ip = self.addr; + local fields = {}; + ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end); + if fields[1] == 127 or fields[1] == 10 or (fields[1] == 192 and fields[2] == 168) + or (fields[1] == 172 and (fields[2] >= 16 or fields[2] <= 32)) then + private = true; + end + end + self.private = private; + return private; +end + +local function parse_cidr(cidr) + local bits; + local ip_len = cidr:find("/", 1, true); + if ip_len then + bits = tonumber(cidr:sub(ip_len+1, -1)); + cidr = cidr:sub(1, ip_len-1); + end + return new_ip(cidr), bits; +end + +local function match(ipA, ipB, bits) + local common_bits = commonPrefixLength(ipA, ipB); + if not bits then + return ipA == ipB; + end + if bits and ipB.proto == "IPv4" then + common_bits = common_bits - 96; -- v6 mapped addresses always share these bits + end + return common_bits >= bits; +end + return {new_ip = new_ip, - commonPrefixLength = commonPrefixLength}; + commonPrefixLength = commonPrefixLength, + parse_cidr = parse_cidr, + match=match}; diff --git a/util/iterators.lua b/util/iterators.lua index 1f6aacb8..aa9c3ec0 100644 --- a/util/iterators.lua +++ b/util/iterators.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -10,6 +10,10 @@ local it = {}; +local t_insert = table.insert; +local select, unpack, next = select, unpack, next; +local function pack(...) return { n = select("#", ...), ... }; end + -- Reverse an iterator function it.reverse(f, s, var) local results = {}; @@ -19,9 +23,9 @@ function it.reverse(f, s, var) local ret = { f(s, var) }; var = ret[1]; if var == nil then break; end - table.insert(results, 1, ret); + t_insert(results, 1, ret); end - + -- Then return our reverse one local i,max = 0, #results; return function (results) @@ -52,15 +56,15 @@ end -- Given an iterator, iterate only over unique items function it.unique(f, s, var) local set = {}; - + return function () while true do - local ret = { f(s, var) }; + local ret = pack(f(s, var)); var = ret[1]; if var == nil then break; end if not set[var] then set[var] = true; - return var; + return unpack(ret, 1, ret.n); end end end; @@ -69,14 +73,13 @@ end --[[ Return the number of items an iterator returns ]]-- function it.count(f, s, var) local x = 0; - + while true do - local ret = { f(s, var) }; - var = ret[1]; + var = f(s, var); if var == nil then break; end x = x + 1; end - + return x; end @@ -104,7 +107,7 @@ end function it.tail(n, f, s, var) local results, count = {}, 0; while true do - local ret = { f(s, var) }; + local ret = pack(f(s, var)); var = ret[1]; if var == nil then break; end results[(count%n)+1] = ret; @@ -117,9 +120,24 @@ function it.tail(n, f, s, var) return function () pos = pos + 1; if pos > n then return nil; end - return unpack(results[((count-1+pos)%n)+1]); + local ret = results[((count-1+pos)%n)+1]; + return unpack(ret, 1, ret.n); end - --return reverse(head(n, reverse(f, s, var))); + --return reverse(head(n, reverse(f, s, var))); -- ! +end + +function it.filter(filter, f, s, var) + if type(filter) ~= "function" then + local filter_value = filter; + function filter(x) return x ~= filter_value; end + end + return function (s, var) + local ret; + repeat ret = pack(f(s, var)); + var = ret[1]; + until var == nil or filter(unpack(ret, 1, ret.n)); + return unpack(ret, 1, ret.n); + end, s, var; end local function _ripairs_iter(t, key) if key > 1 then return key-1, t[key-1]; end end @@ -139,7 +157,7 @@ function it.to_array(f, s, var) while true do var = f(s, var); if var == nil then break; end - table.insert(t, var); + t_insert(t, var); end return t; end diff --git a/util/jid.lua b/util/jid.lua index 8e0a784c..696f51d8 100644 --- a/util/jid.lua +++ b/util/jid.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -23,9 +23,9 @@ local escapes = { local unescapes = {}; for k,v in pairs(escapes) do unescapes[v] = k; end -module "jid" +local _ENV = nil; -local function _split(jid) +local function split(jid) if not jid then return; end local node, nodepos = match(jid, "^([^@/]+)@()"); local host, hostpos = match(jid, "^([^@/]+)()", nodepos) @@ -34,18 +34,13 @@ local function _split(jid) if (not host) or ((not resource) and #jid >= hostpos) then return nil, nil, nil; end return node, host, resource; end -split = _split; -function bare(jid) - local node, host = _split(jid); - if node and host then - return node.."@"..host; - end - return host; +local function bare(jid) + return jid and match(jid, "^[^/]+"); end -local function _prepped_split(jid) - local node, host, resource = _split(jid); +local function prepped_split(jid) + local node, host, resource = split(jid); if host then if sub(host, -1, -1) == "." then -- Strip empty root label host = sub(host, 1, -2); @@ -63,39 +58,29 @@ local function _prepped_split(jid) return node, host, resource; end end -prepped_split = _prepped_split; -function prep(jid) - local node, host, resource = _prepped_split(jid); - if host then - if node then - host = node .. "@" .. host; - end - if resource then - host = host .. "/" .. resource; - end - end - return host; -end - -function join(node, host, resource) - if node and host and resource then +local function join(node, host, resource) + if not host then return end + if node and resource then return node.."@"..host.."/"..resource; - elseif node and host then + elseif node then return node.."@"..host; - elseif host and resource then + elseif resource then return host.."/"..resource; - elseif host then - return host; end - return nil; -- Invalid JID + return host; +end + +local function prep(jid) + local node, host, resource = prepped_split(jid); + return join(node, host, resource); end -function compare(jid, acl) +local function compare(jid, acl) -- compare jid to single acl rule -- TODO compare to table of rules? - local jid_node, jid_host, jid_resource = _split(jid); - local acl_node, acl_host, acl_resource = _split(acl); + local jid_node, jid_host, jid_resource = split(jid); + local acl_node, acl_host, acl_resource = split(acl); if ((acl_node ~= nil and acl_node == jid_node) or acl_node == nil) and ((acl_host ~= nil and acl_host == jid_host) or acl_host == nil) and ((acl_resource ~= nil and acl_resource == jid_resource) or acl_resource == nil) then @@ -104,7 +89,16 @@ function compare(jid, acl) return false end -function escape(s) return s and (s:gsub(".", escapes)); end -function unescape(s) return s and (s:gsub("\\%x%x", unescapes)); end +local function escape(s) return s and (s:gsub(".", escapes)); end +local function unescape(s) return s and (s:gsub("\\%x%x", unescapes)); end -return _M; +return { + split = split; + bare = bare; + prepped_split = prepped_split; + join = join; + prep = prep; + compare = compare; + escape = escape; + unescape = unescape; +}; diff --git a/util/json.lua b/util/json.lua index 82ebcc43..becd295d 100644 --- a/util/json.lua +++ b/util/json.lua @@ -13,7 +13,7 @@ local tostring, tonumber = tostring, tonumber; local pairs, ipairs = pairs, ipairs; local next = next; local error = error; -local newproxy, getmetatable, setmetatable = newproxy, getmetatable, setmetatable; +local getmetatable, setmetatable = getmetatable, setmetatable; local print = print; local has_array, array = pcall(require, "util.array"); @@ -22,10 +22,7 @@ local array_mt = has_array and getmetatable(array()) or {}; --module("json") local json = {}; -local null = newproxy and newproxy(true) or {}; -if getmetatable and getmetatable(null) then - getmetatable(null).__tostring = function() return "null"; end; -end +local null = setmetatable({}, { __tostring = function() return "null"; end; }); json.null = null; local escapes = { @@ -348,9 +345,9 @@ local first_escape = { function json.decode(json) json = json:gsub("\\.", first_escape) -- get rid of all escapes except \uXXXX, making string parsing much simpler --:gsub("[\r\n]", "\t"); -- \r\n\t are equivalent, we care about none of them, and none of them can be in strings - + -- TODO do encoding verification - + local val, index = _readvalue(json, 1); if val == nil then return val, index; end if json:find("[^ \t\r\n]", index) then return nil, "garbage at eof"; end diff --git a/util/logger.lua b/util/logger.lua index 26206d4d..3d1f1c8b 100644 --- a/util/logger.lua +++ b/util/logger.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -11,13 +11,13 @@ local pcall = pcall; local find = string.find; local ipairs, pairs, setmetatable = ipairs, pairs, setmetatable; -module "logger" +local _ENV = nil; local level_sinks = {}; local make_logger; -function init(name) +local function init(name) local log_debug = make_logger(name, "debug"); local log_info = make_logger(name, "info"); local log_warn = make_logger(name, "warn"); @@ -52,7 +52,7 @@ function make_logger(source_name, level) return logger; end -function reset() +local function reset() for level, handler_list in pairs(level_sinks) do -- Clear all handlers for this level for i = 1, #handler_list do @@ -61,7 +61,7 @@ function reset() end end -function add_level_sink(level, sink_function) +local function add_level_sink(level, sink_function) if not level_sinks[level] then level_sinks[level] = { sink_function }; else @@ -69,6 +69,10 @@ function add_level_sink(level, sink_function) end end -_M.new = make_logger; - -return _M; +return { + init = init; + make_logger = make_logger; + reset = reset; + add_level_sink = add_level_sink; + new = make_logger; +}; diff --git a/util/mercurial.lua b/util/mercurial.lua new file mode 100644 index 00000000..3f75c4c1 --- /dev/null +++ b/util/mercurial.lua @@ -0,0 +1,34 @@ + +local lfs = require"lfs"; + +local hg = { }; + +function hg.check_id(path) + if lfs.attributes(path, 'mode') ~= "directory" then + return nil, "not a directory"; + end + local hg_dirstate = io.open(path.."/.hg/dirstate"); + local hgid, hgrepo + if hg_dirstate then + hgid = ("%02x%02x%02x%02x%02x%02x"):format(hg_dirstate:read(6):byte(1, 6)); + hg_dirstate:close(); + local hg_changelog = io.open(path.."/.hg/store/00changelog.i"); + if hg_changelog then + hg_changelog:seek("set", 0x20); + hgrepo = ("%02x%02x%02x%02x%02x%02x"):format(hg_changelog:read(6):byte(1, 6)); + hg_changelog:close(); + end + else + local hg_archival,e = io.open(path.."/.hg_archival.txt"); + if hg_archival then + local repo = hg_archival:read("*l"); + local node = hg_archival:read("*l"); + hg_archival:close() + hgid = node and node:match("^node: (%x%x%x%x%x%x%x%x%x%x%x%x)") + hgrepo = repo and repo:match("^repo: (%x%x%x%x%x%x%x%x%x%x%x%x)") + end + end + return hgid, hgrepo; +end + +return hg; diff --git a/util/multitable.lua b/util/multitable.lua index dbf34d28..7a2d2b2a 100644 --- a/util/multitable.lua +++ b/util/multitable.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -10,7 +10,7 @@ local select = select; local t_insert = table.insert; local unpack, pairs, next, type = unpack, pairs, next, type; -module "multitable" +local _ENV = nil; local function get(self, ...) local t = self.data; @@ -126,7 +126,7 @@ local function search_add(self, results, ...) return results; end -function iter(self, ...) +local function iter(self, ...) local query = { ... }; local maxdepth = select("#", ...); local stack = { self.data }; @@ -161,7 +161,7 @@ function iter(self, ...) return it, self; end -function new() +local function new() return { data = {}; get = get; @@ -174,4 +174,7 @@ function new() }; end -return _M; +return { + iter = iter; + new = new; +}; diff --git a/util/paths.lua b/util/paths.lua new file mode 100644 index 00000000..89f4cad9 --- /dev/null +++ b/util/paths.lua @@ -0,0 +1,44 @@ +local t_concat = table.concat; + +local path_sep = package.config:sub(1,1); + +local path_util = {} + +-- Helper function to resolve relative paths (needed by config) +function path_util.resolve_relative_path(parent_path, path) + if path then + -- Some normalization + parent_path = parent_path:gsub("%"..path_sep.."+$", ""); + path = path:gsub("^%.%"..path_sep.."+", ""); + + local is_relative; + if path_sep == "/" and path:sub(1,1) ~= "/" then + is_relative = true; + elseif path_sep == "\\" and (path:sub(1,1) ~= "/" and (path:sub(2,3) ~= ":\\" and path:sub(2,3) ~= ":/")) then + is_relative = true; + end + if is_relative then + return parent_path..path_sep..path; + end + end + return path; +end + +-- Helper function to convert a glob to a Lua pattern +function path_util.glob_to_pattern(glob) + return "^"..glob:gsub("[%p*?]", function (c) + if c == "*" then + return ".*"; + elseif c == "?" then + return "."; + else + return "%"..c; + end + end).."$"; +end + +function path_util.join(...) + return t_concat({...}, path_sep); +end + +return path_util; diff --git a/util/pluginloader.lua b/util/pluginloader.lua index 112c0d52..0d7eafa7 100644 --- a/util/pluginloader.lua +++ b/util/pluginloader.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -17,9 +17,7 @@ end local io_open = io.open; local envload = require "util.envload".envload; -module "pluginloader" - -function load_file(names) +local function load_file(names) local file, err, path; for i=1,#plugin_dir do for j=1,#names do @@ -35,7 +33,7 @@ function load_file(names) return file, err; end -function load_resource(plugin, resource) +local function load_resource(plugin, resource) resource = resource or "mod_"..plugin..".lua"; local names = { @@ -48,7 +46,7 @@ function load_resource(plugin, resource) return load_file(names); end -function load_code(plugin, resource, env) +local function load_code(plugin, resource, env) local content, err = load_resource(plugin, resource); if not content then return content, err; end local path = err; @@ -57,4 +55,8 @@ function load_code(plugin, resource, env) return f, path; end -return _M; +return { + load_file = load_file; + load_resource = load_resource; + load_code = load_code; +}; diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua index c6fe1986..f8f28644 100644 --- a/util/prosodyctl.lua +++ b/util/prosodyctl.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -29,25 +29,19 @@ local CFG_SOURCEDIR = _G.CFG_SOURCEDIR; local _G = _G; local prosody = prosody; -module "prosodyctl" - -- UI helpers -function show_message(msg, ...) - print(msg:format(...)); -end - -function show_warning(msg, ...) +local function show_message(msg, ...) print(msg:format(...)); end -function show_usage(usage, desc) +local function show_usage(usage, desc) print("Usage: ".._G.arg[0].." "..usage); if desc then print(" "..desc); end end -function getchar(n) +local function getchar(n) local stty_ret = os.execute("stty raw -echo 2>/dev/null"); local ok, char; if stty_ret == 0 then @@ -64,14 +58,14 @@ function getchar(n) end end -function getline() +local function getline() local ok, line = pcall(io.read, "*l"); if ok then return line; end end -function getpass() +local function getpass() local stty_ret = os.execute("stty -echo 2>/dev/null"); if stty_ret ~= 0 then io.write("\027[08m"); -- ANSI 'hidden' text attribute @@ -88,7 +82,7 @@ function getpass() end end -function show_yesno(prompt) +local function show_yesno(prompt) io.write(prompt, " "); local choice = getchar():lower(); io.write("\n"); @@ -99,7 +93,7 @@ function show_yesno(prompt) return (choice == "y"); end -function read_password() +local function read_password() local password; while true do io.write("Enter new password: "); @@ -120,7 +114,7 @@ function read_password() return password; end -function show_prompt(prompt) +local function show_prompt(prompt) io.write(prompt, " "); local line = getline(); line = line and line:gsub("\n$",""); @@ -128,7 +122,7 @@ function show_prompt(prompt) end -- Server control -function adduser(params) +local function adduser(params) local user, host, password = nodeprep(params.user), nameprep(params.host), params.password; if not user then return false, "invalid-username"; @@ -146,15 +140,15 @@ function adduser(params) if not(provider) or provider.name == "null" then usermanager.initialize_host(host); end - + local ok, errmsg = usermanager.create_user(user, password, host); if not ok then - return false, errmsg; + return false, errmsg or "creating-user-failed"; end return true; end -function user_exists(params) +local function user_exists(params) local user, host, password = nodeprep(params.user), nameprep(params.host), params.password; storagemanager.initialize_host(host); @@ -162,28 +156,28 @@ function user_exists(params) if not(provider) or provider.name == "null" then usermanager.initialize_host(host); end - + return usermanager.user_exists(user, host); end -function passwd(params) - if not _M.user_exists(params) then +local function passwd(params) + if not user_exists(params) then return false, "no-such-user"; end - - return _M.adduser(params); + + return adduser(params); end -function deluser(params) - if not _M.user_exists(params) then +local function deluser(params) + if not user_exists(params) then return false, "no-such-user"; end local user, host = nodeprep(params.user), nameprep(params.host); - + return usermanager.delete_user(user, host); end -function getpid() +local function getpid() local pidfile = config.get("*", "pidfile"); if not pidfile then return false, "no-pidfile"; @@ -192,35 +186,35 @@ function getpid() if type(pidfile) ~= "string" then return false, "invalid-pidfile"; end - - local modules_enabled = set.new(config.get("*", "modules_enabled")); - if not modules_enabled:contains("posix") then + + local modules_enabled = set.new(config.get("*", "modules_disabled")); + if prosody.platform ~= "posix" or modules_enabled:contains("posix") then return false, "no-posix"; end - + local file, err = io.open(pidfile, "r+"); if not file then return false, "pidfile-read-failed", err; end - + local locked, err = lfs.lock(file, "w"); if locked then file:close(); return false, "pidfile-not-locked"; end - + local pid = tonumber(file:read("*a")); file:close(); - + if not pid then return false, "invalid-pid"; end - + return true, pid; end -function isrunning() - local ok, pid, err = _M.getpid(); +local function isrunning() + local ok, pid, err = getpid(); if not ok then if pid == "pidfile-read-failed" or pid == "pidfile-not-locked" then -- Report as not running, since we can't open the pidfile @@ -232,8 +226,8 @@ function isrunning() return true, signal.kill(pid, 0) == 0; end -function start() - local ok, ret = _M.isrunning(); +local function start() + local ok, ret = isrunning(); if not ok then return ok, ret; end @@ -248,36 +242,55 @@ function start() return true; end -function stop() - local ok, ret = _M.isrunning(); +local function stop() + local ok, ret = isrunning(); if not ok then return ok, ret; end if not ret then return false, "not-running"; end - - local ok, pid = _M.getpid() + + local ok, pid = getpid() if not ok then return false, pid; end - + signal.kill(pid, signal.SIGTERM); return true; end -function reload() - local ok, ret = _M.isrunning(); +local function reload() + local ok, ret = isrunning(); if not ok then return ok, ret; end if not ret then return false, "not-running"; end - - local ok, pid = _M.getpid() + + local ok, pid = getpid() if not ok then return false, pid; end - + signal.kill(pid, signal.SIGHUP); return true; end -return _M; +return { + show_message = show_message; + show_warning = show_message; + show_usage = show_usage; + getchar = getchar; + getline = getline; + getpass = getpass; + show_yesno = show_yesno; + read_password = read_password; + show_prompt = show_prompt; + adduser = adduser; + user_exists = user_exists; + passwd = passwd; + deluser = deluser; + getpid = getpid; + isrunning = isrunning; + start = start; + stop = stop; + reload = reload; +}; diff --git a/util/pubsub.lua b/util/pubsub.lua index e1418c62..6d12690a 100644 --- a/util/pubsub.lua +++ b/util/pubsub.lua @@ -1,23 +1,27 @@ local events = require "util.events"; - -module("pubsub", package.seeall); +local t_remove = table.remove; local service = {}; local service_mt = { __index = service }; -local default_config = { +local default_config = { __index = { broadcaster = function () end; get_affiliation = function () end; capabilities = {}; -}; +} }; +local default_node_config = { __index = { + ["pubsub#max_items"] = "20"; +} }; -function new(config) +local function new(config) config = config or {}; return setmetatable({ - config = setmetatable(config, { __index = default_config }); + config = setmetatable(config, default_config); + node_defaults = setmetatable(config.node_defaults or {}, default_node_config); affiliations = {}; subscriptions = {}; nodes = {}; + data = {}; events = events.new(); }, service_mt); end @@ -29,13 +33,13 @@ end function service:may(node, actor, action) if actor == true then return true; end - + local node_obj = self.nodes[node]; local node_aff = node_obj and node_obj.affiliations[actor]; local service_aff = self.affiliations[actor] or self.config.get_affiliation(actor, node, action) or "none"; - + -- Check if node allows/forbids it local node_capabilities = node_obj and node_obj.capabilities; if node_capabilities then @@ -47,7 +51,7 @@ function service:may(node, actor, action) end end end - + -- Check service-wide capabilities instead local service_capabilities = self.config.capabilities; local caps = service_capabilities[node_aff or service_aff]; @@ -57,7 +61,7 @@ function service:may(node, actor, action) return can; end end - + return false; end @@ -202,7 +206,7 @@ function service:get_subscription(node, actor, jid) return true, node_obj.subscribers[jid]; end -function service:create(node, actor) +function service:create(node, actor, options) -- Access checking if not self:may(node, actor, "create") then return false, "forbidden"; @@ -211,17 +215,20 @@ function service:create(node, actor) if self.nodes[node] then return false, "conflict"; end - + + self.data[node] = {}; self.nodes[node] = { name = node; subscribers = {}; - config = {}; - data = {}; + config = setmetatable(options or {}, {__index=self.node_defaults}); affiliations = {}; }; + setmetatable(self.nodes[node], { __index = { data = self.data[node] } }); -- COMPAT + self.events.fire_event("node-created", { node = node, actor = actor }); local ok, err = self:set_affiliation(node, true, actor, "owner"); if not ok then self.nodes[node] = nil; + self.data[node] = nil; end return ok, err; end @@ -237,10 +244,31 @@ function service:delete(node, actor) return false, "item-not-found"; end self.nodes[node] = nil; + self.data[node] = nil; + self.events.fire_event("node-deleted", { node = node, actor = actor }); self.config.broadcaster("delete", node, node_obj.subscribers); return true; end +local function remove_item_by_id(data, id) + if not data[id] then return end + data[id] = nil; + for i, _id in ipairs(data) do + if id == _id then + t_remove(data, i); + return i; + end + end +end + +local function trim_items(data, max) + max = tonumber(max); + if not max or #data <= max then return end + repeat + data[t_remove(data, 1)] = nil; + until #data <= max +end + function service:publish(node, actor, id, item) -- Access checking if not self:may(node, actor, "publish") then @@ -258,9 +286,13 @@ function service:publish(node, actor, id, item) end node_obj = self.nodes[node]; end - node_obj.data[id] = item; + local node_data = self.data[node]; + remove_item_by_id(node_data, id); + node_data[#node_data + 1] = id; + node_data[id] = item; + trim_items(node_data, node_obj.config["pubsub#max_items"]); self.events.fire_event("item-published", { node = node, actor = actor, id = id, item = item }); - self.config.broadcaster("items", node, node_obj.subscribers, item); + self.config.broadcaster("items", node, node_obj.subscribers, item, actor); return true; end @@ -271,10 +303,11 @@ function service:retract(node, actor, id, retract) end -- local node_obj = self.nodes[node]; - if (not node_obj) or (not node_obj.data[id]) then + if (not node_obj) or (not self.data[node][id]) then return false, "item-not-found"; end - node_obj.data[id] = nil; + self.events.fire_event("item-retracted", { node = node, actor = actor, id = id }); + remove_item_by_id(self.data[node], id); if retract then self.config.broadcaster("items", node, node_obj.subscribers, retract); end @@ -291,7 +324,8 @@ function service:purge(node, actor, notify) if not node_obj then return false, "item-not-found"; end - node_obj.data = {}; -- Purge + self.data[node] = {}; -- Purge + self.events.fire_event("node-purged", { node = node, actor = actor }); if notify then self.config.broadcaster("purge", node, node_obj.subscribers); end @@ -309,9 +343,9 @@ function service:get_items(node, actor, id) return false, "item-not-found"; end if id then -- Restrict results to a single specific item - return true, { [id] = node_obj.data[id] }; + return true, { id, [id] = self.data[node][id] }; else - return true, node_obj.data; + return true, self.data[node]; end end @@ -388,4 +422,24 @@ function service:set_node_capabilities(node, actor, capabilities) return true; end -return _M; +function service:set_node_config(node, actor, new_config) + if not self:may(node, actor, "configure") then + return false, "forbidden"; + end + + local node_obj = self.nodes[node]; + if not node_obj then + return false, "item-not-found"; + end + + for k,v in pairs(new_config) do + node_obj.config[k] = v; + end + trim_items(self.data[node], node_obj.config["pubsub#max_items"]); + + return true; +end + +return { + new = new; +}; diff --git a/util/queue.lua b/util/queue.lua new file mode 100644 index 00000000..203da0e3 --- /dev/null +++ b/util/queue.lua @@ -0,0 +1,59 @@ +-- Prosody IM +-- Copyright (C) 2008-2015 Matthew Wild +-- Copyright (C) 2008-2015 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +-- Small ringbuffer library (i.e. an efficient FIFO queue with a size limit) +-- (because unbounded dynamically-growing queues are a bad thing...) + +local have_utable, utable = pcall(require, "util.table"); -- For pre-allocation of table + +local function new(size, allow_wrapping) + -- Head is next insert, tail is next read + local head, tail = 1, 1; + local items = 0; -- Number of stored items + local t = have_utable and utable.create(size, 0) or {}; -- Table to hold items + + return { + size = size; + count = function (self) return items; end; + push = function (self, item) + if items >= size then + if allow_wrapping then + tail = (tail%size)+1; -- Advance to next oldest item + items = items - 1; + else + return nil, "queue full"; + end + end + t[head] = item; + items = items + 1; + head = (head%size)+1; + return true; + end; + pop = function (self) + if items == 0 then + return nil; + end + local item; + item, t[tail] = t[tail], 0; + tail = (tail%size)+1; + items = items - 1; + return item; + end; + peek = function (self) + if items == 0 then + return nil; + end + return t[tail]; + end; + }; +end + +return { + new = new; +}; + diff --git a/util/random.lua b/util/random.lua new file mode 100644 index 00000000..5938a94f --- /dev/null +++ b/util/random.lua @@ -0,0 +1,43 @@ +-- Prosody IM +-- Copyright (C) 2008-2014 Matthew Wild +-- Copyright (C) 2008-2014 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local tostring = tostring; +local os_time = os.time; +local os_clock = os.clock; +local ceil = math.ceil; +local H = require "util.hashes".sha512; + +local last_uniq_time = 0; +local function uniq_time() + local new_uniq_time = os_time(); + if last_uniq_time >= new_uniq_time then new_uniq_time = last_uniq_time + 1; end + last_uniq_time = new_uniq_time; + return new_uniq_time; +end + +local function new_random(x) + return H(x..os_clock()..tostring({})); +end + +local buffer = new_random(uniq_time()); + +local function seed(x) + buffer = new_random(buffer..x); +end + +local function bytes(n) + if #buffer < n+4 then seed(uniq_time()); end + local r = buffer:sub(1, n); + buffer = buffer:sub(n+1); + return r; +end + +return { + seed = seed; + bytes = bytes; +}; diff --git a/util/sasl.lua b/util/sasl.lua index afb3861b..5845f34a 100644 --- a/util/sasl.lua +++ b/util/sasl.lua @@ -19,7 +19,7 @@ local setmetatable = setmetatable; local assert = assert; local require = require; -module "sasl" +local _ENV = nil; --[[ Authentication Backend Prototypes: @@ -27,19 +27,38 @@ Authentication Backend Prototypes: state = false : disabled state = true : enabled state = nil : non-existant + +Channel Binding: + +To enable support of channel binding in some mechanisms you need to provide appropriate callbacks in a table +at profile.cb. + +Example: + profile.cb["tls-unique"] = function(self) + return self.user + end + ]] local method = {}; method.__index = method; local mechanisms = {}; local backend_mechanism = {}; +local mechanism_channelbindings = {}; -- register a new SASL mechanims -function registerMechanism(name, backends, f) +local function registerMechanism(name, backends, f, cb_backends) assert(type(name) == "string", "Parameter name MUST be a string."); assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table."); assert(type(f) == "function", "Parameter f MUST be a function."); + if cb_backends then assert(type(cb_backends) == "table"); end mechanisms[name] = f + if cb_backends then + mechanism_channelbindings[name] = {}; + for _, cb_name in ipairs(cb_backends) do + mechanism_channelbindings[name][cb_name] = true; + end + end for _, backend_name in ipairs(backends) do if backend_mechanism[backend_name] == nil then backend_mechanism[backend_name] = {}; end t_insert(backend_mechanism[backend_name], name); @@ -47,7 +66,7 @@ function registerMechanism(name, backends, f) end -- create a new SASL object which can be used to authenticate clients -function new(realm, profile) +local function new(realm, profile) local mechanisms = profile.mechanisms; if not mechanisms then mechanisms = {}; @@ -63,6 +82,15 @@ function new(realm, profile) return setmetatable({ profile = profile, realm = realm, mechs = mechanisms }, method); end +-- add a channel binding handler +function method:add_cb_handler(name, f) + if type(self.profile.cb) ~= "table" then + self.profile.cb = {}; + end + self.profile.cb[name] = f; + return self; +end + -- get a fresh clone with the same realm and profile function method:clean_clone() return new(self.realm, self.profile) @@ -70,7 +98,23 @@ end -- get a list of possible SASL mechanims to use function method:mechanisms() - return self.mechs; + local current_mechs = {}; + for mech, _ in pairs(self.mechs) do + if mechanism_channelbindings[mech] then + if self.profile.cb then + local ok = false; + for cb_name, _ in pairs(self.profile.cb) do + if mechanism_channelbindings[mech][cb_name] then + ok = true; + end + end + if ok == true then current_mechs[mech] = true; end + end + else + current_mechs[mech] = true; + end + end + return current_mechs; end -- select a mechanism to use @@ -92,5 +136,9 @@ require "util.sasl.plain" .init(registerMechanism); require "util.sasl.digest-md5".init(registerMechanism); require "util.sasl.anonymous" .init(registerMechanism); require "util.sasl.scram" .init(registerMechanism); +require "util.sasl.external" .init(registerMechanism); -return _M; +return { + registerMechanism = registerMechanism; + new = new; +}; diff --git a/util/sasl/anonymous.lua b/util/sasl/anonymous.lua index ca5fe404..af05c0e7 100644 --- a/util/sasl/anonymous.lua +++ b/util/sasl/anonymous.lua @@ -16,7 +16,7 @@ local s_match = string.match; local log = require "util.logger".init("sasl"); local generate_uuid = require "util.uuid".generate; -module "sasl.anonymous" +local _ENV = nil; --========================= --SASL ANONYMOUS according to RFC 4505 @@ -39,8 +39,10 @@ local function anonymous(self, message) return "success" end -function init(registerMechanism) +local function init(registerMechanism) registerMechanism("ANONYMOUS", {"anonymous"}, anonymous); end -return _M; +return { + init = init; +} diff --git a/util/sasl/digest-md5.lua b/util/sasl/digest-md5.lua index 591d8537..695dd2a3 100644 --- a/util/sasl/digest-md5.lua +++ b/util/sasl/digest-md5.lua @@ -25,7 +25,7 @@ local log = require "util.logger".init("sasl"); local generate_uuid = require "util.uuid".generate; local nodeprep = require "util.encodings".stringprep.nodeprep; -module "sasl.digest-md5" +local _ENV = nil; --========================= --SASL DIGEST-MD5 according to RFC 2831 @@ -241,8 +241,10 @@ local function digest(self, message) end end -function init(registerMechanism) +local function init(registerMechanism) registerMechanism("DIGEST-MD5", {"plain"}, digest); end -return _M; +return { + init = init; +} diff --git a/util/sasl/external.lua b/util/sasl/external.lua new file mode 100644 index 00000000..5ba90190 --- /dev/null +++ b/util/sasl/external.lua @@ -0,0 +1,27 @@ +local saslprep = require "util.encodings".stringprep.saslprep; + +local _ENV = nil; + +local function external(self, message) + message = saslprep(message); + local state + self.username, state = self.profile.external(message); + + if state == false then + return "failure", "account-disabled"; + elseif state == nil then + return "failure", "not-authorized"; + elseif state == "expired" then + return "false", "credentials-expired"; + end + + return "success"; +end + +local function init(registerMechanism) + registerMechanism("EXTERNAL", {"external"}, external); +end + +return { + init = init; +} diff --git a/util/sasl/plain.lua b/util/sasl/plain.lua index c9ec2911..26e65335 100644 --- a/util/sasl/plain.lua +++ b/util/sasl/plain.lua @@ -16,7 +16,7 @@ local saslprep = require "util.encodings".stringprep.saslprep; local nodeprep = require "util.encodings".stringprep.nodeprep; local log = require "util.logger".init("sasl"); -module "sasl.plain" +local _ENV = nil; -- ================================ -- SASL PLAIN according to RFC 4616 @@ -82,8 +82,10 @@ local function plain(self, message) return "success"; end -function init(registerMechanism) +local function init(registerMechanism) registerMechanism("PLAIN", {"plain", "plain_test"}, plain); end -return _M; +return { + init = init; +} diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua index cf2f0ede..a1b5117a 100644 --- a/util/sasl/scram.lua +++ b/util/sasl/scram.lua @@ -13,7 +13,6 @@ local s_match = string.match; local type = type -local string = string local base64 = require "util.encodings".base64; local hmac_sha1 = require "util.hashes".hmac_sha1; local sha1 = require "util.hashes".sha1; @@ -26,7 +25,7 @@ local t_concat = table.concat; local char = string.char; local byte = string.byte; -module "sasl.scram" +local _ENV = nil; --========================= --SASL SCRAM-SHA-1 according to RFC 5802 @@ -39,18 +38,14 @@ scram_{MECH}: function(username, realm) return stored_key, server_key, iteration_count, salt, state; end + +Supported Channel Binding Backends + +'tls-unique' according to RFC 5929 ]] local default_i = 4096 -local function bp( b ) - local result = "" - for i=1, b:len() do - result = result.."\\"..b:byte(i) - end - return result -end - local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;}; local result = {}; @@ -73,11 +68,11 @@ local function validate_username(username, _nodeprep) return false end end - + -- replace =2C with , and =3D with = username = username:gsub("=2C", ","); username = username:gsub("=3D", "="); - + -- apply SASLprep username = saslprep(username); @@ -92,7 +87,7 @@ local function hashprep(hashname) return hashname:lower():gsub("-", "_"); end -function getAuthenticationDatabaseSHA1(password, salt, iteration_count) +local function getAuthenticationDatabaseSHA1(password, salt, iteration_count) if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then return false, "inappropriate argument types" end @@ -106,96 +101,131 @@ function getAuthenticationDatabaseSHA1(password, salt, iteration_count) end local function scram_gen(hash_name, H_f, HMAC_f) + local profile_name = "scram_" .. hashprep(hash_name); local function scram_hash(self, message) - if not self.state then self["state"] = {} end - + local support_channel_binding = false; + if self.profile.cb then support_channel_binding = true; end + if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end - if not self.state.name then + local state = self.state; + if not state then -- we are processing client_first_message local client_first_message = message; - + -- TODO: fail if authzid is provided, since we don't support them yet - self.state["client_first_message"] = client_first_message; - self.state["gs2_cbind_flag"], self.state["authzid"], self.state["name"], self.state["clientnonce"] - = client_first_message:match("^(%a),(.*),n=(.*),r=([^,]*).*"); + local gs2_header, gs2_cbind_flag, gs2_cbind_name, authzid, client_first_message_bare, username, clientnonce + = s_match(client_first_message, "^(([pny])=?([^,]*),([^,]*),)(m?=?[^,]*,?n=([^,]*),r=([^,]*),?.*)$"); - -- we don't do any channel binding yet - if self.state.gs2_cbind_flag ~= "n" and self.state.gs2_cbind_flag ~= "y" then + if not gs2_cbind_flag then return "failure", "malformed-request"; end - if not self.state.name or not self.state.clientnonce then - return "failure", "malformed-request", "Channel binding isn't support at this time."; + if support_channel_binding and gs2_cbind_flag == "y" then + -- "y" -> client does support channel binding + -- but thinks the server does not. + return "failure", "malformed-request"; + end + + if gs2_cbind_flag == "n" then + -- "n" -> client doesn't support channel binding. + support_channel_binding = false; end - - self.state.name = validate_username(self.state.name, self.profile.nodeprep); - if not self.state.name then + + if support_channel_binding and gs2_cbind_flag == "p" then + -- check whether we support the proposed channel binding type + if not self.profile.cb[gs2_cbind_name] then + return "failure", "malformed-request", "Proposed channel binding type isn't supported."; + end + else + -- no channel binding, + gs2_cbind_name = nil; + end + + username = validate_username(username, self.profile.nodeprep); + if not username then log("debug", "Username violates either SASLprep or contains forbidden character sequences.") return "failure", "malformed-request", "Invalid username."; end - - self.state["servernonce"] = generate_uuid(); - + -- retreive credentials + local stored_key, server_key, salt, iteration_count; if self.profile.plain then - local password, state = self.profile.plain(self, self.state.name, self.realm) + local password, state = self.profile.plain(self, username, self.realm) if state == nil then return "failure", "not-authorized" elseif state == false then return "failure", "account-disabled" end - + password = saslprep(password); if not password then log("debug", "Password violates SASLprep."); return "failure", "not-authorized", "Invalid password." end - self.state.salt = generate_uuid(); - self.state.iteration_count = default_i; + salt = generate_uuid(); + iteration_count = default_i; local succ = false; - succ, self.state.stored_key, self.state.server_key = getAuthenticationDatabaseSHA1(password, self.state.salt, default_i, self.state.iteration_count); + succ, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count); if not succ then - log("error", "Generating authentication database failed. Reason: %s", self.state.stored_key); + log("error", "Generating authentication database failed. Reason: %s", stored_key); return "failure", "temporary-auth-failure"; end - elseif self.profile["scram_"..hashprep(hash_name)] then - local stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, self.state.name, self.realm); + elseif self.profile[profile_name] then + local state; + stored_key, server_key, iteration_count, salt, state = self.profile[profile_name](self, username, self.realm); if state == nil then return "failure", "not-authorized" elseif state == false then return "failure", "account-disabled" end - - self.state.stored_key = stored_key; - self.state.server_key = server_key; - self.state.iteration_count = iteration_count; - self.state.salt = salt end - - local server_first_message = "r="..self.state.clientnonce..self.state.servernonce..",s="..base64.encode(self.state.salt)..",i="..self.state.iteration_count; - self.state["server_first_message"] = server_first_message; + + local nonce = clientnonce .. generate_uuid(); + local server_first_message = "r="..nonce..",s="..base64.encode(salt)..",i="..iteration_count; + self.state = { + gs2_header = gs2_header; + gs2_cbind_name = gs2_cbind_name; + username = username; + nonce = nonce; + + server_key = server_key; + stored_key = stored_key; + client_first_message_bare = client_first_message_bare; + server_first_message = server_first_message; + } return "challenge", server_first_message else -- we are processing client_final_message local client_final_message = message; - - self.state["channelbinding"], self.state["nonce"], self.state["proof"] = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)"); - - if not self.state.proof or not self.state.nonce or not self.state.channelbinding then + + local client_final_message_without_proof, channelbinding, nonce, proof + = s_match(client_final_message, "(c=([^,]*),r=([^,]*),?.-),p=(.*)$"); + + if not proof or not nonce or not channelbinding then return "failure", "malformed-request", "Missing an attribute(p, r or c) in SASL message."; end - if self.state.nonce ~= self.state.clientnonce..self.state.servernonce then + local client_gs2_header = base64.decode(channelbinding) + local our_client_gs2_header = state["gs2_header"] + if state.gs2_cbind_name then + -- we support channelbinding, so check if the value is valid + our_client_gs2_header = our_client_gs2_header .. self.profile.cb[state.gs2_cbind_name](self); + end + if client_gs2_header ~= our_client_gs2_header then + return "failure", "malformed-request", "Invalid channel binding value."; + end + + if nonce ~= state.nonce then return "failure", "malformed-request", "Wrong nonce in client-final-message."; end - - local ServerKey = self.state.server_key; - local StoredKey = self.state.stored_key; - - local AuthMessage = "n=" .. s_match(self.state.client_first_message,"n=(.+)") .. "," .. self.state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+") + + local ServerKey = state.server_key; + local StoredKey = state.stored_key; + + local AuthMessage = state.client_first_message_bare .. "," .. state.server_first_message .. "," .. client_final_message_without_proof local ClientSignature = HMAC_f(StoredKey, AuthMessage) - local ClientKey = binaryXOR(ClientSignature, base64.decode(self.state.proof)) + local ClientKey = binaryXOR(ClientSignature, base64.decode(proof)) local ServerSignature = HMAC_f(ServerKey, AuthMessage) if StoredKey == H_f(ClientKey) then local server_final_message = "v="..base64.encode(ServerSignature); - self["username"] = self.state.name; + self["username"] = state.username; return "success", server_final_message; else return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated."; @@ -205,12 +235,18 @@ local function scram_gen(hash_name, H_f, HMAC_f) return scram_hash; end -function init(registerMechanism) +local function init(registerMechanism) local function registerSCRAMMechanism(hash_name, hash, hmac_hash) registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash)); + + -- register channel binding equivalent + registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"}); end registerSCRAMMechanism("SHA-1", sha1, hmac_sha1); end -return _M; +return { + getAuthenticationDatabaseSHA1 = getAuthenticationDatabaseSHA1; + init = init; +} diff --git a/util/sasl_cyrus.lua b/util/sasl_cyrus.lua index 19684587..4e9a4af5 100644 --- a/util/sasl_cyrus.lua +++ b/util/sasl_cyrus.lua @@ -60,7 +60,7 @@ local sasl_errstring = { }; setmetatable(sasl_errstring, { __index = function() return "undefined error!" end }); -module "sasl_cyrus" +local _ENV = nil; local method = {}; method.__index = method; @@ -78,11 +78,11 @@ local function init(service_name) end -- create a new SASL object which can be used to authenticate clients --- host_fqdn may be nil in which case gethostname() gives the value. +-- host_fqdn may be nil in which case gethostname() gives the value. -- For GSSAPI, this determines the hostname in the service ticket (after -- reverse DNS canonicalization, only if [libdefaults] rdns = true which --- is the default). -function new(realm, service_name, app_name, host_fqdn) +-- is the default). +local function new(realm, service_name, app_name, host_fqdn) init(app_name or service_name); @@ -163,4 +163,6 @@ function method:process(message) end end -return _M; +return { + new = new; +}; diff --git a/util/serialization.lua b/util/serialization.lua index 8a259184..206f5fbb 100644 --- a/util/serialization.lua +++ b/util/serialization.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -11,18 +11,16 @@ local type = type; local tostring = tostring; local t_insert = table.insert; local t_concat = table.concat; -local error = error; local pairs = pairs; local next = next; -local loadstring = loadstring; local pcall = pcall; local debug_traceback = debug.traceback; local log = require "util.logger".init("serialization"); local envload = require"util.envload".envload; -module "serialization" +local _ENV = nil; local indent = function(i) return string_rep("\t", i); @@ -73,16 +71,16 @@ local function _simplesave(o, ind, t, func) end end -function append(t, o) +local function append(t, o) _simplesave(o, 1, t, t.write or t_insert); return t; end -function serialize(o) +local function serialize(o) return t_concat(append({}, o)); end -function deserialize(str) +local function deserialize(str) if type(str) ~= "string" then return nil; end str = "return "..str; local f, err = envload(str, "@data", {}); @@ -92,4 +90,8 @@ function deserialize(str) return ret; end -return _M; +return { + append = append; + serialize = serialize; + deserialize = deserialize; +}; diff --git a/util/set.lua b/util/set.lua index fa065a9c..c136a522 100644 --- a/util/set.lua +++ b/util/set.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -10,86 +10,49 @@ local ipairs, pairs, setmetatable, next, tostring = ipairs, pairs, setmetatable, next, tostring; local t_concat = table.concat; -module "set" +local _ENV = nil; local set_mt = {}; function set_mt.__call(set, _, k) return next(set._items, k); end -function set_mt.__add(set1, set2) - return _M.union(set1, set2); -end -function set_mt.__sub(set1, set2) - return _M.difference(set1, set2); -end -function set_mt.__div(set, func) - local new_set = _M.new(); - local items, new_items = set._items, new_set._items; - for item in pairs(items) do - local new_item = func(item); - if new_item ~= nil then - new_items[new_item] = true; - end - end - return new_set; -end -function set_mt.__eq(set1, set2) - local set1, set2 = set1._items, set2._items; - for item in pairs(set1) do - if not set2[item] then - return false; - end - end - - for item in pairs(set2) do - if not set1[item] then - return false; - end - end - - return true; -end -function set_mt.__tostring(set) - local s, items = { }, set._items; - for item in pairs(items) do - s[#s+1] = tostring(item); - end - return t_concat(s, ", "); -end local items_mt = {}; function items_mt.__call(items, _, k) return next(items, k); end -function new(list) +local function new(list) local items = setmetatable({}, items_mt); local set = { _items = items }; - + + -- We access the set through an upvalue in these methods, so ignore 'self' being unused + --luacheck: ignore 212/self + function set:add(item) items[item] = true; end - + function set:contains(item) return items[item]; end - + function set:items() - return items; + return next, items; end - + function set:remove(item) items[item] = nil; end - - function set:add_list(list) - if list then - for _, item in ipairs(list) do + + function set:add_list(item_list) + if item_list then + for _, item in ipairs(item_list) do items[item] = true; end end end - + function set:include(otherset) for item in otherset do items[item] = true; @@ -101,22 +64,22 @@ function new(list) items[item] = nil; end end - + function set:empty() return not next(items); end - + if list then set:add_list(list); end - + return setmetatable(set, set_mt); end -function union(set1, set2) +local function union(set1, set2) local set = new(); local items = set._items; - + for item in pairs(set1._items) do items[item] = true; end @@ -124,14 +87,14 @@ function union(set1, set2) for item in pairs(set2._items) do items[item] = true; end - + return set; end -function difference(set1, set2) +local function difference(set1, set2) local set = new(); local items = set._items; - + for item in pairs(set1._items) do items[item] = (not set2._items[item]) or nil; end @@ -139,21 +102,68 @@ function difference(set1, set2) return set; end -function intersection(set1, set2) +local function intersection(set1, set2) local set = new(); local items = set._items; - + set1, set2 = set1._items, set2._items; - + for item in pairs(set1) do items[item] = (not not set2[item]) or nil; end - + return set; end -function xor(set1, set2) +local function xor(set1, set2) return union(set1, set2) - intersection(set1, set2); end -return _M; +function set_mt.__add(set1, set2) + return union(set1, set2); +end +function set_mt.__sub(set1, set2) + return difference(set1, set2); +end +function set_mt.__div(set, func) + local new_set = new(); + local items, new_items = set._items, new_set._items; + for item in pairs(items) do + local new_item = func(item); + if new_item ~= nil then + new_items[new_item] = true; + end + end + return new_set; +end +function set_mt.__eq(set1, set2) + set1, set2 = set1._items, set2._items; + for item in pairs(set1) do + if not set2[item] then + return false; + end + end + + for item in pairs(set2) do + if not set1[item] then + return false; + end + end + + return true; +end +function set_mt.__tostring(set) + local s, items = { }, set._items; + for item in pairs(items) do + s[#s+1] = tostring(item); + end + return t_concat(s, ", "); +end + +return { + new = new; + union = union; + difference = difference; + intersection = intersection; + xor = xor; +}; diff --git a/util/sql.lua b/util/sql.lua index f360d6d0..2d5e1774 100644 --- a/util/sql.lua +++ b/util/sql.lua @@ -13,7 +13,7 @@ local DBI = require "DBI"; DBI.Drivers(); local build_url = require "socket.url".build; -module("sql") +local _ENV = nil; local column_mt = {}; local table_mt = {}; @@ -21,42 +21,17 @@ local query_mt = {}; --local op_mt = {}; local index_mt = {}; -function is_column(x) return getmetatable(x)==column_mt; end -function is_index(x) return getmetatable(x)==index_mt; end -function is_table(x) return getmetatable(x)==table_mt; end -function is_query(x) return getmetatable(x)==query_mt; end ---function is_op(x) return getmetatable(x)==op_mt; end ---function expr(...) return setmetatable({...}, op_mt); end -function Integer(n) return "Integer()" end -function String(n) return "String()" end +local function is_column(x) return getmetatable(x)==column_mt; end +local function is_index(x) return getmetatable(x)==index_mt; end +local function is_table(x) return getmetatable(x)==table_mt; end +local function is_query(x) return getmetatable(x)==query_mt; end +local function Integer(n) return "Integer()" end +local function String(n) return "String()" end ---[[local ops = { - __add = function(a, b) return "("..a.."+"..b..")" end; - __sub = function(a, b) return "("..a.."-"..b..")" end; - __mul = function(a, b) return "("..a.."*"..b..")" end; - __div = function(a, b) return "("..a.."/"..b..")" end; - __mod = function(a, b) return "("..a.."%"..b..")" end; - __pow = function(a, b) return "POW("..a..","..b..")" end; - __unm = function(a) return "NOT("..a..")" end; - __len = function(a) return "COUNT("..a..")" end; - __eq = function(a, b) return "("..a.."=="..b..")" end; - __lt = function(a, b) return "("..a.."<"..b..")" end; - __le = function(a, b) return "("..a.."<="..b..")" end; -}; - -local functions = { - -}; - -local cmap = { - [Integer] = Integer(); - [String] = String(); -};]] - -function Column(definition) +local function Column(definition) return setmetatable(definition, column_mt); end -function Table(definition) +local function Table(definition) local c = {} for i,col in ipairs(definition) do if is_column(col) then @@ -67,7 +42,7 @@ function Table(definition) end return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt); end -function Index(definition) +local function Index(definition) return setmetatable(definition, index_mt); end @@ -94,7 +69,6 @@ function index_mt:__tostring() return s..' }'; -- return 'Index{ name="'..self.name..'", type="'..self.type..'" }' end --- local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end local function parse_url(url) @@ -121,32 +95,13 @@ local function parse_url(url) }; end ---[[local session = {}; - -function session.query(...) - local rets = {...}; - local query = setmetatable({ __rets = rets, __filters }, query_mt); - return query; -end --- - -local function db2uri(params) - return build_url{ - scheme = params.driver, - user = params.username, - password = params.password, - host = params.host, - port = params.port, - path = params.database, - }; -end]] - local engine = {}; function engine:connect() if self.conn then return true; end local params = self.params; assert(params.driver, "no driver") + log("debug", "Connecting to [%s] %s...", params.driver, params.database); local dbh, err = DBI.Connect( params.driver, params.database, params.username, params.password, @@ -156,8 +111,19 @@ function engine:connect() dbh:autocommit(false); -- don't commit automatically self.conn = dbh; self.prepared = {}; + local ok, err = self:set_encoding(); + if not ok then + return ok, err; + end + local ok, err = self:onconnect(); + if ok == false then + return ok, err; + end return true; end +function engine:onconnect() + -- Override from create_engine() +end function engine:execute(sql, ...) local success, err = self:connect(); if not success then return success, err; end @@ -177,8 +143,8 @@ function engine:execute(sql, ...) end local result_mt = { __index = { - affected = function(self) return self.__affected; end; - rowcount = function(self) return self.__rowcount; end; + affected = function(self) return self.__stmt:affected(); end; + rowcount = function(self) return self.__stmt:rowcount(); end; } }; function engine:execute_query(sql, ...) @@ -200,7 +166,7 @@ function engine:execute_update(sql, ...) prepared[sql] = stmt; end assert(stmt:execute(...)); - return setmetatable({ __affected = stmt:affected(), __rowcount = stmt:rowcount() }, result_mt); + return setmetatable({ __stmt = stmt }, result_mt); end engine.insert = engine.execute_update; engine.select = engine.execute_query; @@ -208,12 +174,13 @@ engine.delete = engine.execute_update; engine.update = engine.execute_update; function engine:_transaction(func, ...) if not self.conn then - local a,b = self:connect(); - if not a then return a,b; end + local ok, err = self:connect(); + if not ok then return ok, err; end end --assert(not self.__transaction, "Recursive transactions not allowed"); local args, n_args = {...}, select("#", ...); local function f() return func(unpack(args, 1, n_args)); end + log("debug", "SQL transaction begin [%s]", tostring(func)); self.__transaction = true; local success, a, b, c = xpcall(f, debug_traceback); self.__transaction = nil; @@ -229,15 +196,15 @@ function engine:_transaction(func, ...) end end function engine:transaction(...) - local a,b = self:_transaction(...); - if not a then + local ok, ret = self:_transaction(...); + if not ok then local conn = self.conn; if not conn or not conn:ping() then self.conn = nil; - a,b = self:_transaction(...); + ok, ret = self:_transaction(...); end end - return a,b; + return ok, ret; end function engine:_create_index(index) local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` ("; @@ -251,19 +218,39 @@ function engine:_create_index(index) elseif self.params.driver == "MySQL" then sql = sql:gsub("`([,)])", "`(20)%1"); end + if index.unique then + sql = sql:gsub("^CREATE", "CREATE UNIQUE"); + end --print(sql); return self:execute(sql); end function engine:_create_table(table) local sql = "CREATE TABLE `"..table.name.."` ("; for i,col in ipairs(table.c) do - sql = sql.."`"..col.name.."` "..col.type; + local col_type = col.type; + if col_type == "MEDIUMTEXT" and self.params.driver ~= "MySQL" then + col_type = "TEXT"; -- MEDIUMTEXT is MySQL-specific + end + if col.auto_increment == true and self.params.driver == "PostgreSQL" then + col_type = "BIGSERIAL"; + end + sql = sql.."`"..col.name.."` "..col_type; if col.nullable == false then sql = sql.." NOT NULL"; end + if col.primary_key == true then sql = sql.." PRIMARY KEY"; end + if col.auto_increment == true then + if self.params.driver == "MySQL" then + sql = sql.." AUTO_INCREMENT"; + elseif self.params.driver == "SQLite3" then + sql = sql.." AUTOINCREMENT"; + end + end if i ~= #table.c then sql = sql..", "; end end sql = sql.. ");" if self.params.driver == "PostgreSQL" then sql = sql:gsub("`", "\""); + elseif self.params.driver == "MySQL" then + sql = sql:gsub(";$", (" CHARACTER SET '%s' COLLATE '%s_bin';"):format(self.charset, self.charset)); end local success,err = self:execute(sql); if not success then return success,err; end @@ -274,6 +261,46 @@ function engine:_create_table(table) end return success; end +function engine:set_encoding() -- to UTF-8 + local driver = self.params.driver; + if driver == "SQLite3" then + return self:transaction(function() + if self:select"PRAGMA encoding;"()[1] == "UTF-8" then + self.charset = "utf8"; + end + end); + end + local set_names_query = "SET NAMES '%s';" + local charset = "utf8"; + if driver == "MySQL" then + local ok, charsets = self:transaction(function() + return self:select"SELECT `CHARACTER_SET_NAME` FROM `information_schema`.`CHARACTER_SETS` WHERE `CHARACTER_SET_NAME` LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;"; + end); + local row = ok and charsets(); + charset = row and row[1] or charset; + set_names_query = set_names_query:gsub(";$", (" COLLATE '%s';"):format(charset.."_bin")); + end + self.charset = charset; + log("debug", "Using encoding '%s' for database connection", charset); + local ok, err = self:transaction(function() return self:execute(set_names_query:format(charset)); end); + if not ok then + return ok, err; + end + + if driver == "MySQL" then + local ok, actual_charset = self:transaction(function () + return self:select"SHOW SESSION VARIABLES LIKE 'character_set_client'"; + end); + for row in actual_charset do + if row[2] ~= charset then + log("error", "MySQL %s is actually %q (expected %q)", row[1], row[2], charset); + return false, "Failed to set connection encoding"; + end + end + end + + return true; +end local engine_mt = { __index = engine }; local function db2uri(params) @@ -286,55 +313,21 @@ local function db2uri(params) path = params.database, }; end -local engine_cache = {}; -- TODO make weak valued -function create_engine(self, params) - local url = db2uri(params); - if not engine_cache[url] then - local engine = setmetatable({ url = url, params = params }, engine_mt); - engine_cache[url] = engine; - end - return engine_cache[url]; -end - - ---[[Users = Table { - name="users"; - Column { name="user_id", type=String(), primary_key=true }; -}; -print(Users) -print(Users.c.user_id)]] ---local engine = create_engine('postgresql://scott:tiger@localhost:5432/mydatabase'); ---[[local engine = create_engine{ driver = "SQLite3", database = "./alchemy.sqlite" }; - -local i = 0; -for row in assert(engine:execute("select * from sqlite_master")):rows(true) do - i = i+1; - print(i); - for k,v in pairs(row) do - print("",k,v); - end +local function create_engine(self, params, onconnect) + return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt); end -print("---") -Prosody = Table { - name="prosody"; - Column { name="host", type="TEXT", nullable=false }; - Column { name="user", type="TEXT", nullable=false }; - Column { name="store", type="TEXT", nullable=false }; - Column { name="key", type="TEXT", nullable=false }; - Column { name="type", type="TEXT", nullable=false }; - Column { name="value", type="TEXT", nullable=false }; - Index { name="prosody_index", "host", "user", "store", "key" }; +return { + is_column = is_column; + is_index = is_index; + is_table = is_table; + is_query = is_query; + Integer = Integer; + String = String; + Column = Column; + Table = Table; + Index = Index; + create_engine = create_engine; + db2uri = db2uri; }; ---print(Prosody); -assert(engine:transaction(function() - assert(Prosody:create(engine)); -end)); - -for row in assert(engine:execute("select user from prosody")):rows(true) do - print("username:", row['username']) -end ---result.close();]] - -return _M; diff --git a/util/sslconfig.lua b/util/sslconfig.lua new file mode 100644 index 00000000..71f27c94 --- /dev/null +++ b/util/sslconfig.lua @@ -0,0 +1,95 @@ +local type = type; +local pairs = pairs; +local rawset = rawset; +local t_concat = table.concat; +local t_insert = table.insert; +local setmetatable = setmetatable; + +local _ENV = nil; + +local handlers = { }; +local finalisers = { }; +local id = function (v) return v end + +function handlers.options(a, k, b) + local o = a[k] or { }; + if type(b) ~= "table" then b = { b } end + for key, value in pairs(b) do + if value == true or value == false then + o[key] = value; + else + o[value] = true; + end + end + a[k] = o; +end + +handlers.verify = handlers.options; +handlers.verifyext = handlers.options; + +function finalisers.options(a) + local o = {}; + for opt, enable in pairs(a) do + if enable then + o[#o+1] = opt; + end + end + return o; +end + +finalisers.verify = finalisers.options; +finalisers.verifyext = finalisers.options; + +function finalisers.ciphers(a) + if type(a) == "table" then + return t_concat(a, ":"); + end + return a; +end + +local protocols = { "sslv2", "sslv3", "tlsv1", "tlsv1_1", "tlsv1_2" }; +for i = 1, #protocols do protocols[protocols[i] .. "+"] = i - 1; end + +local function protocol(a) + local min_protocol = protocols[a.protocol]; + if min_protocol then + a.protocol = "sslv23"; + for i = 1, min_protocol do + t_insert(a.options, "no_"..protocols[i]); + end + end +end + +local function apply(a, b) + if type(b) == "table" then + for k,v in pairs(b) do + (handlers[k] or rawset)(a, k, v); + end + end +end + +local function final(a) + local f = { }; + for k,v in pairs(a) do + f[k] = (finalisers[k] or id)(v); + end + protocol(f); + return f; +end + +local sslopts_mt = { + __index = { + apply = apply; + final = final; + }; +}; + +local function new() + return setmetatable({options={}}, sslopts_mt); +end + +return { + apply = apply; + final = final; + new = new; +}; diff --git a/util/stanza.lua b/util/stanza.lua index 7c214210..90422a06 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -35,13 +35,12 @@ end local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas"; -module "stanza" +local _ENV = nil; -stanza_mt = { __type = "stanza" }; +local stanza_mt = { __type = "stanza" }; stanza_mt.__index = stanza_mt; -local stanza_mt = stanza_mt; -function stanza(name, attr) +local function stanza(name, attr) local stanza = { name = name, attr = attr or {}, tags = {} }; return setmetatable(stanza, stanza_mt); end @@ -99,7 +98,7 @@ function stanza_mt:get_child(name, xmlns) if (not name or child.name == name) and ((not xmlns and self.attr.xmlns == child.attr.xmlns) or child.attr.xmlns == xmlns) then - + return child; end end @@ -152,7 +151,7 @@ end function stanza_mt:maptags(callback) local tags, curr_tag = self.tags, 1; local n_children, n_tags = #self, #tags; - + local i = 1; while curr_tag <= n_tags and n_tags > 0 do if self[i] == tags[curr_tag] then @@ -200,12 +199,8 @@ function stanza_mt:find(path) end -local xml_escape -do - local escape_table = { ["'"] = "'", ["\""] = """, ["<"] = "<", [">"] = ">", ["&"] = "&" }; - function xml_escape(str) return (s_gsub(str, "['&<>\"]", escape_table)); end - _M.xml_escape = xml_escape; -end +local escape_table = { ["'"] = "'", ["\""] = """, ["<"] = "<", [">"] = ">", ["&"] = "&" }; +local function xml_escape(str) return (s_gsub(str, "['&<>\"]", escape_table)); end local function _dostring(t, buf, self, xml_escape, parentns) local nsid = 0; @@ -258,13 +253,13 @@ end function stanza_mt.get_error(stanza) local type, condition, text; - + local error_tag = stanza:get_child("error"); if not error_tag then return nil, nil, nil; end type = error_tag.attr.type; - + for _, child in ipairs(error_tag.tags) do if child.attr.xmlns == xmlns_stanzas then if not text and child.name == "text" then @@ -280,15 +275,13 @@ function stanza_mt.get_error(stanza) return type, condition or "undefined-condition", text; end -do - local id = 0; - function new_id() - id = id + 1; - return "lx"..id; - end +local id = 0; +local function new_id() + id = id + 1; + return "lx"..id; end -function preserialize(stanza) +local function preserialize(stanza) local s = { name = stanza.name, attr = stanza.attr }; for _, child in ipairs(stanza) do if type(child) == "table" then @@ -300,7 +293,7 @@ function preserialize(stanza) return s; end -function deserialize(stanza) +local function deserialize(stanza) -- Set metatable if stanza then local attr = stanza.attr; @@ -333,55 +326,52 @@ function deserialize(stanza) stanza.tags = tags; end end - + return stanza; end -local function _clone(stanza) +local function clone(stanza) local attr, tags = {}, {}; for k,v in pairs(stanza.attr) do attr[k] = v; end local new = { name = stanza.name, attr = attr, tags = tags }; for i=1,#stanza do local child = stanza[i]; if child.name then - child = _clone(child); + child = clone(child); t_insert(tags, child); end t_insert(new, child); end return setmetatable(new, stanza_mt); end -clone = _clone; -function message(attr, body) +local function message(attr, body) if not body then return stanza("message", attr); else return stanza("message", attr):tag("body"):text(body):up(); end end -function iq(attr) +local function iq(attr) if attr and not attr.id then attr.id = new_id(); end return stanza("iq", attr or { id = new_id() }); end -function reply(orig) +local function reply(orig) return stanza(orig.name, orig.attr and { to = orig.attr.from, from = orig.attr.to, id = orig.attr.id, type = ((orig.name == "iq" and "result") or orig.attr.type) }); end -do - local xmpp_stanzas_attr = { xmlns = xmlns_stanzas }; - function error_reply(orig, type, condition, message) - local t = reply(orig); - t.attr.type = "error"; - t:tag("error", {type = type}) --COMPAT: Some day xmlns:stanzas goes here - :tag(condition, xmpp_stanzas_attr):up(); - if (message) then t:tag("text", xmpp_stanzas_attr):text(message):up(); end - return t; -- stanza ready for adding app-specific errors - end +local xmpp_stanzas_attr = { xmlns = xmlns_stanzas }; +local function error_reply(orig, type, condition, message) + local t = reply(orig); + t.attr.type = "error"; + t:tag("error", {type = type}) --COMPAT: Some day xmlns:stanzas goes here + :tag(condition, xmpp_stanzas_attr):up(); + if (message) then t:tag("text", xmpp_stanzas_attr):text(message):up(); end + return t; -- stanza ready for adding app-specific errors end -function presence(attr) +local function presence(attr) return stanza("presence", attr); end @@ -390,7 +380,7 @@ if do_pretty_printing then local style_attrv = getstyle("red"); local style_tagname = getstyle("red"); local style_punc = getstyle("magenta"); - + local attr_format = " "..getstring(style_attrk, "%s")..getstring(style_punc, "=")..getstring(style_attrv, "'%s'"); local top_tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">"); --local tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">").."%s"..getstring(style_punc, "</")..getstring(style_tagname, "%s")..getstring(style_punc, ">"); @@ -411,7 +401,7 @@ if do_pretty_printing then end return s_format(tag_format, t.name, attr_string, children_text, t.name); end - + function stanza_mt.pretty_top_tag(t) local attr_string = ""; if t.attr then @@ -425,4 +415,17 @@ else stanza_mt.pretty_top_tag = stanza_mt.top_tag; end -return _M; +return { + stanza_mt = stanza_mt; + stanza = stanza; + new_id = new_id; + preserialize = preserialize; + deserialize = deserialize; + clone = clone; + message = message; + iq = iq; + reply = reply; + error_reply = error_reply; + presence = presence; + xml_escape = xml_escape; +}; diff --git a/util/statistics.lua b/util/statistics.lua new file mode 100644 index 00000000..26355026 --- /dev/null +++ b/util/statistics.lua @@ -0,0 +1,160 @@ +local t_sort = table.sort +local m_floor = math.floor; +local time = require "socket".gettime; + +local function nop_function() end + +local function percentile(arr, length, pc) + local n = pc/100 * (length + 1); + local k, d = m_floor(n), n%1; + if k == 0 then + return arr[1] or 0; + elseif k >= length then + return arr[length]; + end + return arr[k] + d*(arr[k+1] - arr[k]); +end + +local function new_registry(config) + config = config or {}; + local duration_sample_interval = config.duration_sample_interval or 5; + local duration_max_samples = config.duration_max_stored_samples or 5000; + + local function get_distribution_stats(events, n_actual_events, since, new_time, units) + local n_stored_events = #events; + t_sort(events); + local sum = 0; + for i = 1, n_stored_events do + sum = sum + events[i]; + end + + return { + samples = events; + sample_count = n_stored_events; + count = n_actual_events, + rate = n_actual_events/(new_time-since); + average = n_stored_events > 0 and sum/n_stored_events or 0, + min = events[1] or 0, + max = events[n_stored_events] or 0, + units = units, + }; + end + + + local registry = {}; + local methods; + methods = { + amount = function (name, initial) + local v = initial or 0; + registry[name..":amount"] = function () return "amount", v; end + return function (new_v) v = new_v; end + end; + counter = function (name, initial) + local v = initial or 0; + registry[name..":amount"] = function () return "amount", v; end + return function (delta) + v = v + delta; + end; + end; + rate = function (name) + local since, n = time(), 0; + registry[name..":rate"] = function () + local t = time(); + local stats = { + rate = n/(t-since); + count = n; + }; + since, n = t, 0; + return "rate", stats.rate, stats; + end; + return function () + n = n + 1; + end; + end; + distribution = function (name, unit, type) + type = type or "distribution"; + local events, last_event = {}, 0; + local n_actual_events = 0; + local since = time(); + + registry[name..":"..type] = function () + local new_time = time(); + local stats = get_distribution_stats(events, n_actual_events, since, new_time, unit); + events, last_event = {}, 0; + n_actual_events = 0; + since = new_time; + return type, stats.average, stats; + end; + + return function (value) + n_actual_events = n_actual_events + 1; + if n_actual_events%duration_sample_interval == 1 then + last_event = (last_event%duration_max_samples) + 1; + events[last_event] = value; + end + end; + end; + sizes = function (name) + return methods.distribution(name, "bytes", "size"); + end; + times = function (name) + local events, last_event = {}, 0; + local n_actual_events = 0; + local since = time(); + + registry[name..":duration"] = function () + local new_time = time(); + local stats = get_distribution_stats(events, n_actual_events, since, new_time, "seconds"); + events, last_event = {}, 0; + n_actual_events = 0; + since = new_time; + return "duration", stats.average, stats; + end; + + return function () + n_actual_events = n_actual_events + 1; + if n_actual_events%duration_sample_interval ~= 1 then + return nop_function; + end + + local start_time = time(); + return function () + local end_time = time(); + local duration = end_time - start_time; + last_event = (last_event%duration_max_samples) + 1; + events[last_event] = duration; + end + end; + end; + + get_stats = function () + return registry; + end; + }; + return methods; +end + +return { + new = new_registry; + get_histogram = function (duration, n_buckets) + n_buckets = n_buckets or 100; + local events, n_events = duration.samples, duration.sample_count; + if not (events and n_events) then + return nil, "not a valid distribution stat"; + end + local histogram = {}; + + for i = 1, 100, 100/n_buckets do + histogram[i] = percentile(events, n_events, i); + end + return histogram; + end; + + get_percentile = function (duration, pc) + local events, n_events = duration.samples, duration.sample_count; + if not (events and n_events) then + return nil, "not a valid distribution stat"; + end + return percentile(events, n_events, pc); + end; +} diff --git a/util/template.lua b/util/template.lua index 66d4fca7..a26dd7ca 100644 --- a/util/template.lua +++ b/util/template.lua @@ -9,7 +9,7 @@ local debug = debug; local t_remove = table.remove; local parse_xml = require "util.xml".parse; -module("template") +local _ENV = nil; local function trim_xml(stanza) for i=#stanza,1,-1 do diff --git a/util/termcolours.lua b/util/termcolours.lua index 6ef3b689..a1c01aa5 100644 --- a/util/termcolours.lua +++ b/util/termcolours.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -19,7 +19,7 @@ if os.getenv("WINDIR") then end local orig_color = windows and windows.get_consolecolor and windows.get_consolecolor(); -module "termcolours" +local _ENV = nil; local stylemap = { reset = 0; bright = 1, dim = 2, underscore = 4, blink = 5, reverse = 7, hidden = 8; @@ -45,7 +45,7 @@ local cssmap = { }; local fmt_string = char(0x1B).."[%sm%s"..char(0x1B).."[0m"; -function getstring(style, text) +local function getstring(style, text) if style then return format(fmt_string, style, text); else @@ -53,7 +53,7 @@ function getstring(style, text) end end -function getstyle(...) +local function getstyle(...) local styles, result = { ... }, {}; for i, style in ipairs(styles) do style = stylemap[style]; @@ -65,7 +65,7 @@ function getstyle(...) end local last = "0"; -function setstyle(style) +local function setstyle(style) style = style or "0"; if style ~= last then io_write("\27["..style.."m"); @@ -95,8 +95,13 @@ local function ansi2css(ansi_codes) return "</span><span style='"..t_concat(css, ";").."'>"; end -function tohtml(input) +local function tohtml(input) return input:gsub("\027%[(.-)m", ansi2css); end -return _M; +return { + getstring = getstring; + getstyle = getstyle; + setstyle = setstyle; + tohtml = tohtml; +}; diff --git a/util/throttle.lua b/util/throttle.lua index 55e1d07b..3d3f5d2d 100644 --- a/util/throttle.lua +++ b/util/throttle.lua @@ -3,7 +3,7 @@ local gettime = require "socket".gettime; local setmetatable = setmetatable; local floor = math.floor; -module "throttle" +local _ENV = nil; local throttle = {}; local throttle_mt = { __index = throttle }; @@ -39,8 +39,10 @@ function throttle:poll(cost, split) end end -function create(max, period) +local function create(max, period) return setmetatable({ rate = max / period, max = max, t = 0, balance = max }, throttle_mt); end -return _M; +return { + create = create; +}; diff --git a/util/timer.lua b/util/timer.lua index af1e57b6..3713625d 100644 --- a/util/timer.lua +++ b/util/timer.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -17,7 +17,7 @@ local type = type; local data = {}; local new_data = {}; -module "timer" +local _ENV = nil; local _add_task; if not server.event then @@ -42,7 +42,7 @@ if not server.event then end new_data = {}; end - + local next_time = math_huge; for i, d in pairs(data) do local t, callback = d[1], d[2]; @@ -78,6 +78,6 @@ else end end -add_task = _add_task; - -return _M; +return { + add_task = _add_task; +}; diff --git a/util/uuid.lua b/util/uuid.lua index 796c8ee4..e10fc0f7 100644 --- a/util/uuid.lua +++ b/util/uuid.lua @@ -1,50 +1,32 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- +local random = require "util.random"; +local random_bytes = random.bytes; +local hex = require "util.hex".to; +local m_ceil = math.ceil; -local m_random = math.random; -local tostring = tostring; -local os_time = os.time; -local os_clock = os.clock; -local sha1 = require "util.hashes".sha1; - -module "uuid" - -local last_uniq_time = 0; -local function uniq_time() - local new_uniq_time = os_time(); - if last_uniq_time >= new_uniq_time then new_uniq_time = last_uniq_time + 1; end - last_uniq_time = new_uniq_time; - return new_uniq_time; -end - -local function new_random(x) - return sha1(x..os_clock()..tostring({}), true); -end - -local buffer = new_random(uniq_time()); -local function _seed(x) - buffer = new_random(buffer..x); -end local function get_nibbles(n) - if #buffer < n then _seed(uniq_time()); end - local r = buffer:sub(0, n); - buffer = buffer:sub(n+1); - return r; + return hex(random_bytes(m_ceil(n/2))):sub(1, n); end + local function get_twobits() return ("%x"):format(get_nibbles(1):byte() % 4 + 8); end -function generate() +local function generate() -- generate RFC 4122 complaint UUIDs (version 4 - random) return get_nibbles(8).."-"..get_nibbles(4).."-4"..get_nibbles(3).."-"..(get_twobits())..get_nibbles(3).."-"..get_nibbles(12); end -seed = _seed; -return _M; +return { + get_nibbles=get_nibbles; + generate = generate ; + -- COMPAT + seed = random.seed; +}; diff --git a/util/watchdog.lua b/util/watchdog.lua index bcb2e274..aa8c6486 100644 --- a/util/watchdog.lua +++ b/util/watchdog.lua @@ -2,12 +2,12 @@ local timer = require "util.timer"; local setmetatable = setmetatable; local os_time = os.time; -module "watchdog" +local _ENV = nil; local watchdog_methods = {}; local watchdog_mt = { __index = watchdog_methods }; -function new(timeout, callback) +local function new(timeout, callback) local watchdog = setmetatable({ timeout = timeout, last_reset = os_time(), callback = callback }, watchdog_mt); timer.add_task(timeout+1, function (current_time) local last_reset = watchdog.last_reset; @@ -31,4 +31,6 @@ function watchdog_methods:cancel() self.last_reset = nil; end -return _M; +return { + new = new; +}; diff --git a/util/x509.lua b/util/x509.lua index 19d4ec6d..f228b201 100644 --- a/util/x509.lua +++ b/util/x509.lua @@ -20,13 +20,11 @@ local nameprep = require "util.encodings".stringprep.nameprep; local idna_to_ascii = require "util.encodings".idna.to_ascii; +local base64 = require "util.encodings".base64; local log = require "util.logger".init("x509"); -local pairs, ipairs = pairs, ipairs; local s_format = string.format; -local t_insert = table.insert; -local t_concat = table.concat; -module "x509" +local _ENV = nil; local oid_commonname = "2.5.4.3"; -- [LDAP] 2.3 local oid_subjectaltname = "2.5.29.17"; -- [PKIX] 4.2.1.6 @@ -149,7 +147,10 @@ local function compare_srvname(host, service, asserted_names) return false end -function verify_identity(host, service, cert) +local function verify_identity(host, service, cert) + if cert.setencode then + cert:setencode("utf8"); + end local ext = cert:extensions() if ext[oid_subjectaltname] then local sans = ext[oid_subjectaltname]; @@ -161,7 +162,9 @@ function verify_identity(host, service, cert) if sans[oid_xmppaddr] then had_supported_altnames = true - if compare_xmppaddr(host, sans[oid_xmppaddr]) then return true end + if service == "_xmpp-client" or service == "_xmpp-server" then + if compare_xmppaddr(host, sans[oid_xmppaddr]) then return true end + end end if sans[oid_dnssrv] then @@ -212,4 +215,27 @@ function verify_identity(host, service, cert) return false end -return _M; +local pat = "%-%-%-%-%-BEGIN ([A-Z ]+)%-%-%-%-%-\r?\n".. +"([0-9A-Za-z+/=\r\n]*)\r?\n%-%-%-%-%-END %1%-%-%-%-%-"; + +local function pem2der(pem) + local typ, data = pem:match(pat); + if typ and data then + return base64.decode(data), typ; + end +end + +local wrap = ('.'):rep(64); +local envelope = "-----BEGIN %s-----\n%s\n-----END %s-----\n" + +local function der2pem(data, typ) + typ = typ and typ:upper() or "CERTIFICATE"; + data = base64.encode(data); + return s_format(envelope, typ, data:gsub(wrap, '%0\n', (#data-1)/64), typ); +end + +return { + verify_identity = verify_identity; + pem2der = pem2der; + der2pem = der2pem; +}; diff --git a/util/xml.lua b/util/xml.lua index 076490fa..733d821a 100644 --- a/util/xml.lua +++ b/util/xml.lua @@ -2,7 +2,7 @@ local st = require "util.stanza"; local lxp = require "lxp"; -module("xml") +local _ENV = nil; local parse_xml = (function() local ns_prefixes = { @@ -11,6 +11,7 @@ local parse_xml = (function() local ns_separator = "\1"; local ns_pattern = "^([^"..ns_separator.."]*)"..ns_separator.."?(.*)$"; return function(xml) + --luacheck: ignore 212/self local handler = {}; local stanza = st.stanza("root"); function handler:StartElement(tagname, attr) @@ -26,8 +27,8 @@ local parse_xml = (function() attr[i] = nil; local ns, nm = k:match(ns_pattern); if nm ~= "" then - ns = ns_prefixes[ns]; - if ns then + ns = ns_prefixes[ns]; + if ns then attr[ns..":"..nm] = attr[k]; attr[k] = nil; end @@ -38,7 +39,7 @@ local parse_xml = (function() function handler:CharacterData(data) stanza:text(data); end - function handler:EndElement(tagname) + function handler:EndElement() stanza:up(); end local parser = lxp.new(handler, "\1"); @@ -53,5 +54,6 @@ local parse_xml = (function() end; end)(); -parse = parse_xml; -return _M; +return { + parse = parse_xml; +}; diff --git a/util/xmppstream.lua b/util/xmppstream.lua index 138c86b7..7be63285 100644 --- a/util/xmppstream.lua +++ b/util/xmppstream.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -24,7 +24,7 @@ local lxp_supports_bytecount = not not lxp.new({}).getcurrentbytecount; local default_stanza_size_limit = 1024*1024*10; -- 10MB -module "xmppstream" +local _ENV = nil; local new_parser = lxp.new; @@ -40,29 +40,26 @@ local xmlns_streams = "http://etherx.jabber.org/streams"; local ns_separator = "\1"; local ns_pattern = "^([^"..ns_separator.."]*)"..ns_separator.."?(.*)$"; -_M.ns_separator = ns_separator; -_M.ns_pattern = ns_pattern; - local function dummy_cb() end -function new_sax_handlers(session, stream_callbacks, cb_handleprogress) +local function new_sax_handlers(session, stream_callbacks, cb_handleprogress) local xml_handlers = {}; - + local cb_streamopened = stream_callbacks.streamopened; local cb_streamclosed = stream_callbacks.streamclosed; local cb_error = stream_callbacks.error or function(session, e, stanza) error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); end; local cb_handlestanza = stream_callbacks.handlestanza; cb_handleprogress = cb_handleprogress or dummy_cb; - + local stream_ns = stream_callbacks.stream_ns or xmlns_streams; local stream_tag = stream_callbacks.stream_tag or "stream"; if stream_ns ~= "" then stream_tag = stream_ns..ns_separator..stream_tag; end local stream_error_tag = stream_ns..ns_separator..(stream_callbacks.error_tag or "error"); - + local stream_default_ns = stream_callbacks.default_ns; - + local stack = {}; local chardata, stanza = {}; local stanza_size = 0; @@ -82,7 +79,7 @@ function new_sax_handlers(session, stream_callbacks, cb_handleprogress) attr.xmlns = curr_ns; non_streamns_depth = non_streamns_depth + 1; end - + for i=1,#attr do local k = attr[i]; attr[i] = nil; @@ -92,7 +89,7 @@ function new_sax_handlers(session, stream_callbacks, cb_handleprogress) attr[k] = nil; end end - + if not stanza then --if we are not currently inside a stanza if lxp_supports_bytecount then stanza_size = self:getcurrentbytecount(); @@ -116,7 +113,7 @@ function new_sax_handlers(session, stream_callbacks, cb_handleprogress) if curr_ns == "jabber:client" and name ~= "iq" and name ~= "presence" and name ~= "message" then cb_error(session, "invalid-top-level-element"); end - + stanza = setmetatable({ name = name, attr = attr, tags = {} }, stanza_mt); else -- we are inside a stanza, so add a tag if lxp_supports_bytecount then @@ -205,26 +202,26 @@ function new_sax_handlers(session, stream_callbacks, cb_handleprogress) error("Failed to abort parsing"); end end - + if lxp_supports_doctype then xml_handlers.StartDoctypeDecl = restricted_handler; end xml_handlers.Comment = restricted_handler; xml_handlers.ProcessingInstruction = restricted_handler; - + local function reset() stanza, chardata, stanza_size = nil, {}, 0; stack = {}; end - + local function set_session(stream, new_session) session = new_session; end - + return xml_handlers, { reset = reset, set_session = set_session }; end -function new(session, stream_callbacks, stanza_size_limit) +local function new(session, stream_callbacks, stanza_size_limit) -- Used to track parser progress (e.g. to enforce size limits) local n_outstanding_bytes = 0; local handle_progress; @@ -241,6 +238,25 @@ function new(session, stream_callbacks, stanza_size_limit) local parser = new_parser(handlers, ns_separator, false); local parse = parser.parse; + function session.open_stream(session, from, to) + local send = session.sends2s or session.send; + + local attr = { + ["xmlns:stream"] = "http://etherx.jabber.org/streams", + ["xml:lang"] = "en", + xmlns = stream_callbacks.default_ns, + version = session.version and (session.version > 0 and "1.0" or nil), + id = session.streamid, + from = from or session.host, to = to, + }; + if session.stream_attrs then + session:stream_attrs(from, to, attr) + end + send("<?xml version='1.0'?>"); + send(st.stanza("stream:stream", attr):top_tag()); + return true; + end + return { reset = function () parser = new_parser(handlers, ns_separator, false); @@ -262,4 +278,9 @@ function new(session, stream_callbacks, stanza_size_limit) }; end -return _M; +return { + ns_separator = ns_separator; + ns_pattern = ns_pattern; + new_sax_handlers = new_sax_handlers; + new = new; +}; |