aboutsummaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/adhoc.lua31
-rw-r--r--util/array.lua163
-rw-r--r--util/caps.lua61
-rw-r--r--util/dataforms.lua244
-rw-r--r--util/datamanager.lua388
-rw-r--r--util/datetime.lua57
-rw-r--r--util/debug.lua193
-rw-r--r--util/dependencies.lua149
-rw-r--r--util/envload.lua34
-rw-r--r--util/events.lua83
-rw-r--r--util/filters.lua87
-rw-r--r--util/helpers.lua82
-rw-r--r--util/hmac.lua15
-rw-r--r--util/http.lua64
-rw-r--r--util/import.lua9
-rw-r--r--util/ip.lua189
-rw-r--r--util/iterators.lua159
-rw-r--r--util/jid.lua107
-rw-r--r--util/json.lua412
-rw-r--r--util/logger.lua83
-rw-r--r--util/multitable.lua177
-rw-r--r--util/openssl.lua172
-rw-r--r--util/pluginloader.lua60
-rw-r--r--util/prosodyctl.lua279
-rw-r--r--util/pubsub.lua388
-rw-r--r--util/rfc6724.lua142
-rw-r--r--util/sasl.lua119
-rw-r--r--util/sasl/anonymous.lua46
-rw-r--r--util/sasl/digest-md5.lua248
-rw-r--r--util/sasl/plain.lua89
-rw-r--r--util/sasl/scram.lua216
-rw-r--r--util/sasl_cyrus.lua166
-rw-r--r--util/serialization.lua95
-rw-r--r--util/set.lua159
-rw-r--r--util/sql.lua340
-rw-r--r--util/stanza.lua392
-rw-r--r--util/template.lua97
-rw-r--r--util/termcolours.lua102
-rw-r--r--util/throttle.lua46
-rw-r--r--util/timer.lua83
-rw-r--r--util/uuid.lua47
-rw-r--r--util/watchdog.lua34
-rw-r--r--util/x509.lua215
-rw-r--r--util/xml.lua57
-rw-r--r--util/xmppstream.lua191
45 files changed, 6410 insertions, 160 deletions
diff --git a/util/adhoc.lua b/util/adhoc.lua
new file mode 100644
index 00000000..671e85cf
--- /dev/null
+++ b/util/adhoc.lua
@@ -0,0 +1,31 @@
+local function new_simple_form(form, result_handler)
+ return function(self, data, state)
+ if state then
+ if data.action == "cancel" then
+ return { status = "canceled" };
+ end
+ local fields, err = form:data(data.form);
+ return result_handler(fields, err, data);
+ else
+ return { status = "executing", actions = {"next", "complete", default = "complete"}, form = form }, "executing";
+ end
+ end
+end
+
+local function new_initial_data_form(form, initial_data, result_handler)
+ return function(self, data, state)
+ if state then
+ if data.action == "cancel" then
+ return { status = "canceled" };
+ end
+ local fields, err = form:data(data.form);
+ return result_handler(fields, err, data);
+ else
+ return { status = "executing", actions = {"next", "complete", default = "complete"},
+ form = { layout = form, values = initial_data() } }, "executing";
+ end
+ end
+end
+
+return { new_simple_form = new_simple_form,
+ new_initial_data_form = new_initial_data_form };
diff --git a/util/array.lua b/util/array.lua
new file mode 100644
index 00000000..b78fb98f
--- /dev/null
+++ b/util/array.lua
@@ -0,0 +1,163 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local t_insert, t_sort, t_remove, t_concat
+ = table.insert, table.sort, table.remove, table.concat;
+
+local setmetatable = setmetatable;
+local math_random = math.random;
+local pairs, ipairs = pairs, ipairs;
+local tostring = tostring;
+
+local array = {};
+local array_base = {};
+local array_methods = {};
+local array_mt = { __index = array_methods, __tostring = function (array) return array:concat(", "); end };
+
+local function new_array(self, t, _s, _var)
+ if type(t) == "function" then -- Assume iterator
+ t = self.collect(t, _s, _var);
+ end
+ return setmetatable(t or {}, array_mt);
+end
+
+function array_mt.__add(a1, a2)
+ local res = new_array();
+ return res:append(a1):append(a2);
+end
+
+setmetatable(array, { __call = new_array });
+
+-- Read-only methods
+function array_methods:random()
+ return self[math_random(1,#self)];
+end
+
+-- These methods can be called two ways:
+-- array.method(existing_array, [params [, ...]]) -- Create new array for result
+-- existing_array:method([params, ...]) -- Transform existing array into result
+--
+function array_base.map(outa, ina, func)
+ for k,v in ipairs(ina) do
+ outa[k] = func(v);
+ end
+ return outa;
+end
+
+function array_base.filter(outa, ina, func)
+ local inplace, start_length = ina == outa, #ina;
+ local write = 1;
+ for read=1,start_length do
+ local v = ina[read];
+ if func(v) then
+ outa[write] = v;
+ write = write + 1;
+ end
+ end
+
+ if inplace and write <= start_length then
+ for i=write,start_length do
+ outa[i] = nil;
+ end
+ end
+
+ return outa;
+end
+
+function array_base.sort(outa, ina, ...)
+ if ina ~= outa then
+ outa:append(ina);
+ end
+ t_sort(outa, ...);
+ return outa;
+end
+
+function array_base.pluck(outa, ina, key)
+ for i=1,#ina do
+ outa[i] = ina[i][key];
+ end
+ return outa;
+end
+
+--- These methods only mutate the array
+function array_methods:shuffle(outa, ina)
+ local len = #self;
+ for i=1,#self do
+ local r = math_random(i,len);
+ self[i], self[r] = self[r], self[i];
+ end
+ return self;
+end
+
+function array_methods:reverse()
+ local len = #self-1;
+ for i=len,1,-1 do
+ self:push(self[i]);
+ self:pop(i);
+ end
+ return self;
+end
+
+function array_methods:append(array)
+ local len,len2 = #self, #array;
+ for i=1,len2 do
+ self[len+i] = array[i];
+ end
+ return self;
+end
+
+function array_methods:push(x)
+ t_insert(self, x);
+ return self;
+end
+
+function array_methods:pop(x)
+ local v = self[x];
+ t_remove(self, x);
+ return v;
+end
+
+function array_methods:concat(sep)
+ return t_concat(array.map(self, tostring), sep);
+end
+
+function array_methods:length()
+ return #self;
+end
+
+--- These methods always create a new array
+function array.collect(f, s, var)
+ local t = {};
+ while true do
+ var = f(s, var);
+ if var == nil then break; end
+ t_insert(t, var);
+ end
+ return setmetatable(t, array_mt);
+end
+
+---
+
+-- Setup methods from array_base
+for method, f in pairs(array_base) do
+ local base_method = f;
+ -- Setup global array method which makes new array
+ array[method] = function (old_a, ...)
+ local a = new_array();
+ return base_method(a, old_a, ...);
+ end
+ -- Setup per-array (mutating) method
+ array_methods[method] = function (self, ...)
+ return base_method(self, self, ...);
+ end
+end
+
+_G.array = array;
+module("array");
+
+return array;
diff --git a/util/caps.lua b/util/caps.lua
new file mode 100644
index 00000000..a61e7403
--- /dev/null
+++ b/util/caps.lua
@@ -0,0 +1,61 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local base64 = require "util.encodings".base64.encode;
+local sha1 = require "util.hashes".sha1;
+
+local t_insert, t_sort, t_concat = table.insert, table.sort, table.concat;
+local ipairs = ipairs;
+
+module "caps"
+
+function calculate_hash(disco_info)
+ local identities, features, extensions = {}, {}, {};
+ for _, tag in ipairs(disco_info) do
+ if tag.name == "identity" then
+ t_insert(identities, (tag.attr.category or "").."\0"..(tag.attr.type or "").."\0"..(tag.attr["xml:lang"] or "").."\0"..(tag.attr.name or ""));
+ elseif tag.name == "feature" then
+ t_insert(features, tag.attr.var or "");
+ elseif tag.name == "x" and tag.attr.xmlns == "jabber:x:data" then
+ local form = {};
+ local FORM_TYPE;
+ for _, field in ipairs(tag.tags) do
+ if field.name == "field" and field.attr.var then
+ local values = {};
+ for _, val in ipairs(field.tags) do
+ val = #val.tags == 0 and val:get_text();
+ if val then t_insert(values, val); end
+ end
+ t_sort(values);
+ if field.attr.var == "FORM_TYPE" then
+ FORM_TYPE = values[1];
+ elseif #values > 0 then
+ t_insert(form, field.attr.var.."\0"..t_concat(values, "<"));
+ else
+ t_insert(form, field.attr.var);
+ end
+ end
+ end
+ t_sort(form);
+ form = t_concat(form, "<");
+ if FORM_TYPE then form = FORM_TYPE.."\0"..form; end
+ t_insert(extensions, form);
+ end
+ end
+ t_sort(identities);
+ t_sort(features);
+ t_sort(extensions);
+ if #identities > 0 then identities = t_concat(identities, "<"):gsub("%z", "/").."<"; else identities = ""; end
+ if #features > 0 then features = t_concat(features, "<").."<"; else features = ""; end
+ if #extensions > 0 then extensions = t_concat(extensions, "<"):gsub("%z", "<").."<"; else extensions = ""; end
+ local S = identities..features..extensions;
+ local ver = base64(sha1(S));
+ return ver, S;
+end
+
+return _M;
diff --git a/util/dataforms.lua b/util/dataforms.lua
new file mode 100644
index 00000000..8634e337
--- /dev/null
+++ b/util/dataforms.lua
@@ -0,0 +1,244 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local setmetatable = setmetatable;
+local pairs, ipairs = pairs, ipairs;
+local tostring, type, next = tostring, type, next;
+local t_concat = table.concat;
+local st = require "util.stanza";
+local jid_prep = require "util.jid".prep;
+
+module "dataforms"
+
+local xmlns_forms = 'jabber:x:data';
+
+local form_t = {};
+local form_mt = { __index = form_t };
+
+function new(layout)
+ return setmetatable(layout, form_mt);
+end
+
+function form_t.form(layout, data, formtype)
+ local form = st.stanza("x", { xmlns = xmlns_forms, type = formtype or "form" });
+ if layout.title then
+ form:tag("title"):text(layout.title):up();
+ end
+ if layout.instructions then
+ form:tag("instructions"):text(layout.instructions):up();
+ end
+ for n, field in ipairs(layout) do
+ local field_type = field.type or "text-single";
+ -- Add field tag
+ form:tag("field", { type = field_type, var = field.name, label = field.label });
+
+ local value = (data and data[field.name]) or field.value;
+
+ if value then
+ -- Add value, depending on type
+ if field_type == "hidden" then
+ if type(value) == "table" then
+ -- Assume an XML snippet
+ form:tag("value")
+ :add_child(value)
+ :up();
+ else
+ form:tag("value"):text(tostring(value)):up();
+ end
+ elseif field_type == "boolean" then
+ form:tag("value"):text((value and "1") or "0"):up();
+ elseif field_type == "fixed" then
+ form:tag("value"):text(value):up();
+ elseif field_type == "jid-multi" then
+ for _, jid in ipairs(value) do
+ form:tag("value"):text(jid):up();
+ end
+ elseif field_type == "jid-single" then
+ form:tag("value"):text(value):up();
+ elseif field_type == "text-single" or field_type == "text-private" then
+ form:tag("value"):text(value):up();
+ elseif field_type == "text-multi" then
+ -- Split into multiple <value> tags, one for each line
+ for line in value:gmatch("([^\r\n]+)\r?\n*") do
+ form:tag("value"):text(line):up();
+ end
+ elseif field_type == "list-single" then
+ local has_default = false;
+ for _, val in ipairs(value) do
+ if type(val) == "table" then
+ form:tag("option", { label = val.label }):tag("value"):text(val.value):up():up();
+ if val.default and (not has_default) then
+ form:tag("value"):text(val.value):up();
+ has_default = true;
+ end
+ else
+ form:tag("option", { label= val }):tag("value"):text(tostring(val)):up():up();
+ end
+ end
+ elseif field_type == "list-multi" then
+ for _, val in ipairs(value) do
+ if type(val) == "table" then
+ form:tag("option", { label = val.label }):tag("value"):text(val.value):up():up();
+ if val.default then
+ form:tag("value"):text(val.value):up();
+ end
+ else
+ form:tag("option", { label= val }):tag("value"):text(tostring(val)):up():up();
+ end
+ end
+ end
+ end
+
+ if field.required then
+ form:tag("required"):up();
+ end
+
+ -- Jump back up to list of fields
+ form:up();
+ end
+ return form;
+end
+
+local field_readers = {};
+
+function form_t.data(layout, stanza)
+ local data = {};
+ local errors = {};
+
+ for _, field in ipairs(layout) do
+ local tag;
+ for field_tag in stanza:childtags() do
+ if field.name == field_tag.attr.var then
+ tag = field_tag;
+ break;
+ end
+ end
+
+ if not tag then
+ if field.required then
+ errors[field.name] = "Required value missing";
+ end
+ else
+ local reader = field_readers[field.type];
+ if reader then
+ data[field.name], errors[field.name] = reader(tag, field.required);
+ end
+ end
+ end
+ if next(errors) then
+ return data, errors;
+ end
+ return data;
+end
+
+field_readers["text-single"] =
+ function (field_tag, required)
+ local data = field_tag:get_child_text("value");
+ if data and #data > 0 then
+ return data
+ elseif required then
+ return nil, "Required value missing";
+ end
+ end
+
+field_readers["text-private"] =
+ field_readers["text-single"];
+
+field_readers["jid-single"] =
+ function (field_tag, required)
+ local raw_data = field_tag:get_child_text("value")
+ local data = jid_prep(raw_data);
+ if data and #data > 0 then
+ return data
+ elseif raw_data then
+ return nil, "Invalid JID: " .. raw_data;
+ elseif required then
+ return nil, "Required value missing";
+ end
+ end
+
+field_readers["jid-multi"] =
+ function (field_tag, required)
+ local result = {};
+ local err = {};
+ for value_tag in field_tag:childtags("value") do
+ local raw_value = value_tag:get_text();
+ local value = jid_prep(raw_value);
+ result[#result+1] = value;
+ if raw_value and not value then
+ err[#err+1] = ("Invalid JID: " .. raw_value);
+ end
+ end
+ if #result > 0 then
+ return result, (#err > 0 and t_concat(err, "\n") or nil);
+ elseif required then
+ return nil, "Required value missing";
+ end
+ end
+
+field_readers["list-multi"] =
+ function (field_tag, required)
+ local result = {};
+ for value in field_tag:childtags("value") do
+ result[#result+1] = value:get_text();
+ end
+ return result, (required and #result == 0 and "Required value missing" or nil);
+ end
+
+field_readers["text-multi"] =
+ function (field_tag, required)
+ local data, err = field_readers["list-multi"](field_tag, required);
+ if data then
+ data = t_concat(data, "\n");
+ end
+ return data, err;
+ end
+
+field_readers["list-single"] =
+ field_readers["text-single"];
+
+local boolean_values = {
+ ["1"] = true, ["true"] = true,
+ ["0"] = false, ["false"] = false,
+};
+
+field_readers["boolean"] =
+ function (field_tag, required)
+ local raw_value = field_tag:get_child_text("value");
+ local value = boolean_values[raw_value ~= nil and raw_value];
+ if value ~= nil then
+ return value;
+ elseif raw_value then
+ return nil, "Invalid boolean representation";
+ elseif required then
+ return nil, "Required value missing";
+ end
+ end
+
+field_readers["hidden"] =
+ function (field_tag)
+ return field_tag:get_child_text("value");
+ end
+
+return _M;
+
+
+--[=[
+
+Layout:
+{
+
+ title = "MUC Configuration",
+ instructions = [[Use this form to configure options for this MUC room.]],
+
+ { name = "FORM_TYPE", type = "hidden", required = true };
+ { name = "field-name", type = "field-type", required = false };
+}
+
+
+--]=]
diff --git a/util/datamanager.lua b/util/datamanager.lua
index be63673e..4a4d62b3 100644
--- a/util/datamanager.lua
+++ b/util/datamanager.lua
@@ -1,20 +1,53 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+
local format = string.format;
-local setmetatable, type = setmetatable, type;
-local pairs = pairs;
+local setmetatable = setmetatable;
+local ipairs = ipairs;
local char = string.char;
-local loadfile, setfenv, pcall = loadfile, setfenv, pcall;
-local log = log;
+local pcall = pcall;
+local log = require "util.logger".init("datamanager");
local io_open = io.open;
+local os_remove = os.remove;
+local os_rename = os.rename;
+local tonumber = tonumber;
+local next = next;
+local t_insert = table.insert;
+local t_concat = table.concat;
+local envloadfile = require"util.envload".envloadfile;
+local serialize = require "util.serialization".serialize;
+local path_separator = assert ( package.config:match ( "^([^\n]+)" ) , "package.config not in standard form" ) -- Extract directory seperator from package.config (an undocumented string that comes with lua)
+local lfs = require "lfs";
+local prosody = prosody;
-module "datamanager"
+local raw_mkdir = lfs.mkdir;
+local function fallocate(f, offset, len)
+ -- This assumes that current position == offset
+ local fake_data = (" "):rep(len);
+ local ok, msg = f:write(fake_data);
+ if not ok then
+ return ok, msg;
+ end
+ f:seek("set", offset);
+ return true;
+end;
+pcall(function()
+ local pposix = require "util.pposix";
+ raw_mkdir = pposix.mkdir or raw_mkdir; -- Doesn't trample on umask
+ fallocate = pposix.fallocate or fallocate;
+end);
+module "datamanager"
---- utils -----
local encode, decode;
-
-local log = function (type, msg) return log(type, "datamanager", msg); end
-
-do
+do
local urlcodes = setmetatable({}, { __index = function (t, k) t[k] = char(tonumber("0x"..k)); return t[k]; end });
decode = function (s)
@@ -22,58 +55,313 @@ do
end
encode = function (s)
- return s and (s:gsub("%W", function (c) return format("%%%x", c:byte()); end));
- end
-end
-
-local function basicSerialize (o)
- if type(o) == "number" or type(o) == "boolean" then
- return tostring(o)
- else -- assume it is a string
- return string.format("%q", tostring(o))
- end
-end
-
-
-local function simplesave (f, o)
- if type(o) == "number" then
- f:write(o)
- elseif type(o) == "string" then
- f:write(format("%q", o))
- elseif type(o) == "table" then
- f:write("{\n")
- for k,v in pairs(o) do
- f:write(" [", format("%q", k), "] = ")
- simplesave(f, v)
- f:write(",\n")
- end
- f:write("}\n")
- else
- error("cannot serialize a " .. type(o))
- end
- end
-
+ return s and (s:gsub("%W", function (c) return format("%%%02x", c:byte()); end));
+ end
+end
+
+local _mkdir = {};
+local function mkdir(path)
+ path = path:gsub("/", path_separator); -- TODO as an optimization, do this during path creation rather than here
+ if not _mkdir[path] then
+ raw_mkdir(path);
+ _mkdir[path] = true;
+ end
+ return path;
+end
+
+local data_path = (prosody and prosody.paths and prosody.paths.data) or ".";
+local callbacks = {};
+
------- API -------------
-function getpath(username, host, datastore)
- return format("data/%s/%s/%s.dat", encode(host), datastore, encode(username));
+function set_data_path(path)
+ log("debug", "Setting data path to: %s", path);
+ data_path = path;
+end
+
+local function callback(username, host, datastore, data)
+ for _, f in ipairs(callbacks) do
+ username, host, datastore, data = f(username, host, datastore, data);
+ if username == false then break; end
+ end
+
+ return username, host, datastore, data;
+end
+function add_callback(func)
+ if not callbacks[func] then -- Would you really want to set the same callback more than once?
+ callbacks[func] = true;
+ callbacks[#callbacks+1] = func;
+ return true;
+ end
+end
+function remove_callback(func)
+ if callbacks[func] then
+ for i, f in ipairs(callbacks) do
+ if f == func then
+ callbacks[i] = nil;
+ callbacks[f] = nil;
+ return true;
+ end
+ end
+ end
+end
+
+function getpath(username, host, datastore, ext, create)
+ ext = ext or "dat";
+ host = (host and encode(host)) or "_global";
+ username = username and encode(username);
+ if username then
+ if create then mkdir(mkdir(mkdir(data_path).."/"..host).."/"..datastore); end
+ return format("%s/%s/%s/%s.%s", data_path, host, datastore, username, ext);
+ else
+ if create then mkdir(mkdir(data_path).."/"..host); end
+ return format("%s/%s/%s.%s", data_path, host, datastore, ext);
+ end
end
function load(username, host, datastore)
- local data, ret = loadfile(getpath(username, host, datastore));
- if not data then log("warn", "Failed to load "..datastore.." storage ('"..ret.."') for user: "..username.."@"..host); return nil; end
- setfenv(data, {});
+ local data, ret = envloadfile(getpath(username, host, datastore), {});
+ if not data then
+ local mode = lfs.attributes(getpath(username, host, datastore), "mode");
+ if not mode then
+ log("debug", "Assuming empty %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil");
+ return nil;
+ else -- file exists, but can't be read
+ -- TODO more detailed error checking and logging?
+ log("error", "Failed to load %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil");
+ return nil, "Error reading storage";
+ end
+ end
+
local success, ret = pcall(data);
- if not success then log("error", "Unable to load "..datastore.." storage ('"..ret.."') for user: "..username.."@"..host); return nil; end
+ if not success then
+ log("error", "Unable to load %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil");
+ return nil, "Error reading storage";
+ end
return ret;
end
+local function atomic_store(filename, data)
+ local scratch = filename.."~";
+ local f, ok, msg;
+ repeat
+ f, msg = io_open(scratch, "w");
+ if not f then break end
+
+ ok, msg = f:write(data);
+ if not ok then break end
+
+ ok, msg = f:close();
+ if not ok then break end
+
+ return os_rename(scratch, filename);
+ until false;
+
+ -- Cleanup
+ if f then f:close(); end
+ os_remove(scratch);
+ return nil, msg;
+end
+
+if prosody.platform ~= "posix" then
+ -- os.rename does not overwrite existing files on Windows
+ -- TODO We could use Transactional NTFS on Vista and above
+ function atomic_store(filename, data)
+ local f, err = io_open(filename, "w");
+ if not f then return f, err; end
+ local ok, msg = f:write(data);
+ if not ok then f:close(); return ok, msg; end
+ return f:close();
+ end
+end
+
function store(username, host, datastore, data)
- local f, msg = io_open(getpath(username, host, datastore), "w+");
- if not f then log("error", "Unable to write to "..datastore.." storage ('"..msg.."') for user: "..username.."@"..host); return nil; end
- f:write("return ");
- simplesave(f, data);
+ if not data then
+ data = {};
+ end
+
+ username, host, datastore, data = callback(username, host, datastore, data);
+ if username == false then
+ return true; -- Don't save this data at all
+ end
+
+ -- save the datastore
+ local d = "return " .. serialize(data) .. ";\n";
+ local mkdir_cache_cleared;
+ repeat
+ local ok, msg = atomic_store(getpath(username, host, datastore, nil, true), d);
+ if not ok then
+ if not mkdir_cache_cleared then -- We may need to recreate a removed directory
+ _mkdir = {};
+ mkdir_cache_cleared = true;
+ else
+ log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil");
+ return nil, "Error saving to storage";
+ end
+ end
+ if next(data) == nil then -- try to delete empty datastore
+ log("debug", "Removing empty %s datastore for user %s@%s", datastore, username or "nil", host or "nil");
+ os_remove(getpath(username, host, datastore));
+ end
+ -- we write data even when we are deleting because lua doesn't have a
+ -- platform independent way of checking for non-exisitng files
+ until ok;
+ return true;
+end
+
+function list_append(username, host, datastore, data)
+ if not data then return; end
+ if callback(username, host, datastore) == false then return true; end
+ -- save the datastore
+ local f, msg = io_open(getpath(username, host, datastore, "list", true), "r+");
+ if not f then
+ f, msg = io_open(getpath(username, host, datastore, "list", true), "w");
+ end
+ if not f then
+ log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil");
+ return;
+ end
+ local data = "item(" .. serialize(data) .. ");\n";
+ local pos = f:seek("end");
+ local ok, msg = fallocate(f, pos, #data);
+ f:seek("set", pos);
+ if ok then
+ f:write(data);
+ else
+ log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil");
+ return ok, msg;
+ end
f:close();
return true;
end
+function list_store(username, host, datastore, data)
+ if not data then
+ data = {};
+ end
+ if callback(username, host, datastore) == false then return true; end
+ -- save the datastore
+ local d = {};
+ for _, item in ipairs(data) do
+ d[#d+1] = "item(" .. serialize(item) .. ");\n";
+ end
+ local ok, msg = atomic_store(getpath(username, host, datastore, "list", true), t_concat(d));
+ if not ok then
+ log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil");
+ return;
+ end
+ if next(data) == nil then -- try to delete empty datastore
+ log("debug", "Removing empty %s datastore for user %s@%s", datastore, username or "nil", host or "nil");
+ os_remove(getpath(username, host, datastore, "list"));
+ end
+ -- we write data even when we are deleting because lua doesn't have a
+ -- platform independent way of checking for non-exisitng files
+ return true;
+end
+
+function list_load(username, host, datastore)
+ local items = {};
+ local data, ret = envloadfile(getpath(username, host, datastore, "list"), {item = function(i) t_insert(items, i); end});
+ if not data then
+ local mode = lfs.attributes(getpath(username, host, datastore, "list"), "mode");
+ if not mode then
+ log("debug", "Assuming empty %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil");
+ return nil;
+ else -- file exists, but can't be read
+ -- TODO more detailed error checking and logging?
+ log("error", "Failed to load %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil");
+ return nil, "Error reading storage";
+ end
+ end
+
+ local success, ret = pcall(data);
+ if not success then
+ log("error", "Unable to load %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil");
+ return nil, "Error reading storage";
+ end
+ return items;
+end
+
+local type_map = {
+ keyval = "dat";
+ list = "list";
+}
+
+function users(host, store, typ)
+ typ = type_map[typ or "keyval"];
+ local store_dir = format("%s/%s/%s", data_path, encode(host), store);
+
+ local mode, err = lfs.attributes(store_dir, "mode");
+ if not mode then
+ return function() log("debug", err or (store_dir .. " does not exist")) end
+ end
+ local next, state = lfs.dir(store_dir);
+ return function(state)
+ for node in next, state do
+ local file, ext = node:match("^(.*)%.([dalist]+)$");
+ if file and ext == typ then
+ return decode(file);
+ end
+ end
+ end, state;
+end
+
+function stores(username, host, typ)
+ typ = type_map[typ or "keyval"];
+ local store_dir = format("%s/%s/", data_path, encode(host));
+
+ local mode, err = lfs.attributes(store_dir, "mode");
+ if not mode then
+ return function() log("debug", err or (store_dir .. " does not exist")) end
+ end
+ local next, state = lfs.dir(store_dir);
+ return function(state)
+ for node in next, state do
+ if not node:match"^%." then
+ if username == true then
+ if lfs.attributes(store_dir..node, "mode") == "directory" then
+ return decode(node);
+ end
+ elseif username then
+ local store = decode(node)
+ if lfs.attributes(getpath(username, host, store, typ), "mode") then
+ return store;
+ end
+ elseif lfs.attributes(node, "mode") == "file" then
+ local file, ext = node:match("^(.*)%.([dalist]+)$");
+ if ext == typ then
+ return decode(file)
+ end
+ end
+ end
+ end
+ end, state;
+end
+
+local function do_remove(path)
+ local ok, err = os_remove(path);
+ if not ok and lfs.attributes(path, "mode") then
+ return ok, err;
+ end
+ return true
+end
+
+function purge(username, host)
+ local host_dir = format("%s/%s/", data_path, encode(host));
+ local errs = {};
+ for file in lfs.dir(host_dir) do
+ if lfs.attributes(host_dir..file, "mode") == "directory" then
+ local store = decode(file);
+ local ok, err = do_remove(getpath(username, host, store));
+ if not ok then errs[#errs+1] = err; end
+
+ local ok, err = do_remove(getpath(username, host, store, "list"));
+ if not ok then errs[#errs+1] = err; end
+ end
+ end
+ return #errs == 0, t_concat(errs, ", ");
+end
+
+_M.path_decode = decode;
+_M.path_encode = encode;
+return _M;
diff --git a/util/datetime.lua b/util/datetime.lua
new file mode 100644
index 00000000..a1f62a48
--- /dev/null
+++ b/util/datetime.lua
@@ -0,0 +1,57 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+
+-- XEP-0082: XMPP Date and Time Profiles
+
+local os_date = os.date;
+local os_time = os.time;
+local os_difftime = os.difftime;
+local error = error;
+local tonumber = tonumber;
+
+module "datetime"
+
+function date(t)
+ return os_date("!%Y-%m-%d", t);
+end
+
+function datetime(t)
+ return os_date("!%Y-%m-%dT%H:%M:%SZ", t);
+end
+
+function time(t)
+ return os_date("!%H:%M:%S", t);
+end
+
+function legacy(t)
+ return os_date("!%Y%m%dT%H:%M:%S", t);
+end
+
+function parse(s)
+ if s then
+ local year, month, day, hour, min, sec, tzd;
+ year, month, day, hour, min, sec, tzd = s:match("^(%d%d%d%d)%-?(%d%d)%-?(%d%d)T(%d%d):(%d%d):(%d%d)%.?%d*([Z+%-]?.*)$");
+ if year then
+ local time_offset = os_difftime(os_time(os_date("*t")), os_time(os_date("!*t"))); -- to deal with local timezone
+ local tzd_offset = 0;
+ if tzd ~= "" and tzd ~= "Z" then
+ local sign, h, m = tzd:match("([+%-])(%d%d):?(%d*)");
+ if not sign then return; end
+ if #m ~= 2 then m = "0"; end
+ h, m = tonumber(h), tonumber(m);
+ tzd_offset = h * 60 * 60 + m * 60;
+ if sign == "-" then tzd_offset = -tzd_offset; end
+ end
+ sec = (sec + time_offset) - tzd_offset;
+ return os_time({year=year, month=month, day=day, hour=hour, min=min, sec=sec, isdst=false});
+ end
+ end
+end
+
+return _M;
diff --git a/util/debug.lua b/util/debug.lua
new file mode 100644
index 00000000..bff0e347
--- /dev/null
+++ b/util/debug.lua
@@ -0,0 +1,193 @@
+-- Variables ending with these names will not
+-- have their values printed ('password' includes
+-- 'new_password', etc.)
+local censored_names = {
+ password = true;
+ passwd = true;
+ pass = true;
+ pwd = true;
+};
+local optimal_line_length = 65;
+
+local termcolours = require "util.termcolours";
+local getstring = termcolours.getstring;
+local styles;
+do
+ _ = termcolours.getstyle;
+ styles = {
+ boundary_padding = _("bright");
+ filename = _("bright", "blue");
+ level_num = _("green");
+ funcname = _("yellow");
+ location = _("yellow");
+ };
+end
+module("debugx", package.seeall);
+
+function get_locals_table(level)
+ level = level + 1; -- Skip this function itself
+ local locals = {};
+ for local_num = 1, math.huge do
+ local name, value = debug.getlocal(level, local_num);
+ if not name then break; end
+ table.insert(locals, { name = name, value = value });
+ end
+ return locals;
+end
+
+function get_upvalues_table(func)
+ local upvalues = {};
+ if func then
+ for upvalue_num = 1, math.huge do
+ local name, value = debug.getupvalue(func, upvalue_num);
+ if not name then break; end
+ table.insert(upvalues, { name = name, value = value });
+ end
+ end
+ return upvalues;
+end
+
+function string_from_var_table(var_table, max_line_len, indent_str)
+ local var_string = {};
+ local col_pos = 0;
+ max_line_len = max_line_len or math.huge;
+ indent_str = "\n"..(indent_str or "");
+ for _, var in ipairs(var_table) do
+ local name, value = var.name, var.value;
+ if name:sub(1,1) ~= "(" then
+ if type(value) == "string" then
+ if censored_names[name:match("%a+$")] then
+ value = "<hidden>";
+ else
+ value = ("%q"):format(value);
+ end
+ else
+ value = tostring(value);
+ end
+ if #value > max_line_len then
+ value = value:sub(1, max_line_len-3).."…";
+ end
+ local str = ("%s = %s"):format(name, tostring(value));
+ col_pos = col_pos + #str;
+ if col_pos > max_line_len then
+ table.insert(var_string, indent_str);
+ col_pos = 0;
+ end
+ table.insert(var_string, str);
+ end
+ end
+ if #var_string == 0 then
+ return nil;
+ else
+ return "{ "..table.concat(var_string, ", "):gsub(indent_str..", ", indent_str).." }";
+ end
+end
+
+function get_traceback_table(thread, start_level)
+ local levels = {};
+ for level = start_level, math.huge do
+ local info;
+ if thread then
+ info = debug.getinfo(thread, level+1);
+ else
+ info = debug.getinfo(level+1);
+ end
+ if not info then break; end
+
+ levels[(level-start_level)+1] = {
+ level = level;
+ info = info;
+ locals = get_locals_table(level+1);
+ upvalues = get_upvalues_table(info.func);
+ };
+ end
+ return levels;
+end
+
+function traceback(...)
+ local ok, ret = pcall(_traceback, ...);
+ if not ok then
+ return "Error in error handling: "..ret;
+ end
+ return ret;
+end
+
+local function build_source_boundary_marker(last_source_desc)
+ local padding = string.rep("-", math.floor(((optimal_line_length - 6) - #last_source_desc)/2));
+ return getstring(styles.boundary_padding, "v"..padding).." "..getstring(styles.filename, last_source_desc).." "..getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v "));
+end
+
+function _traceback(thread, message, level)
+
+ -- Lua manual says: debug.traceback ([thread,] [message [, level]])
+ -- I fathom this to mean one of:
+ -- ()
+ -- (thread)
+ -- (message, level)
+ -- (thread, message, level)
+
+ if thread == nil then -- Defaults
+ thread, message, level = coroutine.running(), message, level;
+ elseif type(thread) == "string" then
+ thread, message, level = coroutine.running(), thread, message;
+ elseif type(thread) ~= "thread" then
+ return nil; -- debug.traceback() does this
+ end
+
+ level = level or 1;
+
+ message = message and (message.."\n") or "";
+
+ -- +3 counts for this function, and the pcall() and wrapper above us
+ local levels = get_traceback_table(thread, level+3);
+
+ local last_source_desc;
+
+ local lines = {};
+ for nlevel, level in ipairs(levels) do
+ local info = level.info;
+ local line = "...";
+ local func_type = info.namewhat.." ";
+ local source_desc = (info.short_src == "[C]" and "C code") or info.short_src or "Unknown";
+ if func_type == " " then func_type = ""; end;
+ if info.short_src == "[C]" then
+ line = "[ C ] "..func_type.."C function "..getstring(styles.location, (info.name and ("%q"):format(info.name) or "(unknown name)"));
+ elseif info.what == "main" then
+ line = "[Lua] "..getstring(styles.location, info.short_src.." line "..info.currentline);
+ else
+ local name = info.name or " ";
+ if name ~= " " then
+ name = ("%q"):format(name);
+ end
+ if func_type == "global " or func_type == "local " then
+ func_type = func_type.."function ";
+ end
+ line = "[Lua] "..getstring(styles.location, info.short_src.." line "..info.currentline).." in "..func_type..getstring(styles.funcname, name).." (defined on line "..info.linedefined..")";
+ end
+ if source_desc ~= last_source_desc then -- Venturing into a new source, add marker for previous
+ last_source_desc = source_desc;
+ table.insert(lines, "\t "..build_source_boundary_marker(last_source_desc));
+ end
+ nlevel = nlevel-1;
+ table.insert(lines, "\t"..(nlevel==0 and ">" or " ")..getstring(styles.level_num, "("..nlevel..") ")..line);
+ local npadding = (" "):rep(#tostring(nlevel));
+ local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding);
+ if locals_str then
+ table.insert(lines, "\t "..npadding.."Locals: "..locals_str);
+ end
+ local upvalues_str = string_from_var_table(level.upvalues, optimal_line_length, "\t "..npadding);
+ if upvalues_str then
+ table.insert(lines, "\t "..npadding.."Upvals: "..upvalues_str);
+ end
+ end
+
+-- table.insert(lines, "\t "..build_source_boundary_marker(last_source_desc));
+
+ return message.."stack traceback:\n"..table.concat(lines, "\n");
+end
+
+function use()
+ debug.traceback = traceback;
+end
+
+return _M;
diff --git a/util/dependencies.lua b/util/dependencies.lua
new file mode 100644
index 00000000..53d2719d
--- /dev/null
+++ b/util/dependencies.lua
@@ -0,0 +1,149 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+module("dependencies", package.seeall)
+
+function softreq(...) local ok, lib = pcall(require, ...); if ok then return lib; else return nil, lib; end end
+
+-- Required to be able to find packages installed with luarocks
+if not softreq "luarocks.loader" then -- LuaRocks 2.x
+ softreq "luarocks.require"; -- LuaRocks <1.x
+end
+
+function missingdep(name, sources, msg)
+ print("");
+ print("**************************");
+ print("Prosody was unable to find "..tostring(name));
+ print("This package can be obtained in the following ways:");
+ print("");
+ local longest_platform = 0;
+ for platform in pairs(sources) do
+ longest_platform = math.max(longest_platform, #platform);
+ end
+ for platform, source in pairs(sources) do
+ print("", platform..":"..(" "):rep(4+longest_platform-#platform)..source);
+ end
+ print("");
+ print(msg or (name.." is required for Prosody to run, so we will now exit."));
+ print("More help can be found on our website, at http://prosody.im/doc/depends");
+ print("**************************");
+ print("");
+end
+
+-- COMPAT w/pre-0.8 Debian: The Debian config file used to use
+-- util.ztact, which has been removed from Prosody in 0.8. This
+-- is to log an error for people who still use it, so they can
+-- update their configs.
+package.preload["util.ztact"] = function ()
+ if not package.loaded["core.loggingmanager"] then
+ error("util.ztact has been removed from Prosody and you need to fix your config "
+ .."file. More information can be found at http://prosody.im/doc/packagers#ztact", 0);
+ else
+ error("module 'util.ztact' has been deprecated in Prosody 0.8.");
+ end
+end;
+
+function check_dependencies()
+ local fatal;
+
+ local lxp = softreq "lxp"
+
+ if not lxp then
+ missingdep("luaexpat", {
+ ["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-expat0";
+ ["luarocks"] = "luarocks install luaexpat";
+ ["Source"] = "http://www.keplerproject.org/luaexpat/";
+ });
+ fatal = true;
+ end
+
+ local socket = softreq "socket"
+
+ if not socket then
+ missingdep("luasocket", {
+ ["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-socket2";
+ ["luarocks"] = "luarocks install luasocket";
+ ["Source"] = "http://www.tecgraf.puc-rio.br/~diego/professional/luasocket/";
+ });
+ fatal = true;
+ end
+
+ local lfs, err = softreq "lfs"
+ if not lfs then
+ missingdep("luafilesystem", {
+ ["luarocks"] = "luarocks install luafilesystem";
+ ["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-filesystem0";
+ ["Source"] = "http://www.keplerproject.org/luafilesystem/";
+ });
+ fatal = true;
+ end
+
+ local ssl = softreq "ssl"
+
+ if not ssl then
+ missingdep("LuaSec", {
+ ["Debian/Ubuntu"] = "http://prosody.im/download/start#debian_and_ubuntu";
+ ["luarocks"] = "luarocks install luasec";
+ ["Source"] = "http://www.inf.puc-rio.br/~brunoos/luasec/";
+ }, "SSL/TLS support will not be available");
+ end
+
+ local encodings, err = softreq "util.encodings"
+ if not encodings then
+ if err:match("not found") then
+ missingdep("util.encodings", { ["Windows"] = "Make sure you have encodings.dll from the Prosody distribution in util/";
+ ["GNU/Linux"] = "Run './configure' and 'make' in the Prosody source directory to build util/encodings.so";
+ });
+ else
+ print "***********************************"
+ print("util/encodings couldn't be loaded. Check that you have a recent version of libidn");
+ print ""
+ print("The full error was:");
+ print(err)
+ print "***********************************"
+ end
+ fatal = true;
+ end
+
+ local hashes, err = softreq "util.hashes"
+ if not hashes then
+ if err:match("not found") then
+ missingdep("util.hashes", { ["Windows"] = "Make sure you have hashes.dll from the Prosody distribution in util/";
+ ["GNU/Linux"] = "Run './configure' and 'make' in the Prosody source directory to build util/hashes.so";
+ });
+ else
+ print "***********************************"
+ print("util/hashes couldn't be loaded. Check that you have a recent version of OpenSSL (libcrypto in particular)");
+ print ""
+ print("The full error was:");
+ print(err)
+ print "***********************************"
+ end
+ fatal = true;
+ end
+ return not fatal;
+end
+
+function log_warnings()
+ if ssl then
+ local major, minor, veryminor, patched = ssl._VERSION:match("(%d+)%.(%d+)%.?(%d*)(M?)");
+ if not major or ((tonumber(major) == 0 and (tonumber(minor) or 0) <= 3 and (tonumber(veryminor) or 0) <= 2) and patched ~= "M") then
+ log("error", "This version of LuaSec contains a known bug that causes disconnects, see http://prosody.im/doc/depends");
+ end
+ end
+ if lxp then
+ if not pcall(lxp.new, { StartDoctypeDecl = false }) then
+ log("error", "The version of LuaExpat on your system leaves Prosody "
+ .."vulnerable to denial-of-service attacks. You should upgrade to "
+ .."LuaExpat 1.1.1 or higher as soon as possible. See "
+ .."http://prosody.im/doc/depends#luaexpat for more information.");
+ end
+ end
+end
+
+return _M;
diff --git a/util/envload.lua b/util/envload.lua
new file mode 100644
index 00000000..53e28348
--- /dev/null
+++ b/util/envload.lua
@@ -0,0 +1,34 @@
+-- Prosody IM
+-- Copyright (C) 2008-2011 Florian Zeitz
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local load, loadstring, loadfile, setfenv = load, loadstring, loadfile, setfenv;
+local envload;
+local envloadfile;
+
+if setfenv then
+ function envload(code, source, env)
+ local f, err = loadstring(code, source);
+ if f and env then setfenv(f, env); end
+ return f, err;
+ end
+
+ function envloadfile(file, env)
+ local f, err = loadfile(file);
+ if f and env then setfenv(f, env); end
+ return f, err;
+ end
+else
+ function envload(code, source, env)
+ return load(code, source, nil, env);
+ end
+
+ function envloadfile(file, env)
+ return loadfile(file, nil, env);
+ end
+end
+
+return { envload = envload, envloadfile = envloadfile };
diff --git a/util/events.lua b/util/events.lua
new file mode 100644
index 00000000..412acccd
--- /dev/null
+++ b/util/events.lua
@@ -0,0 +1,83 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+
+local pairs = pairs;
+local t_insert = table.insert;
+local t_sort = table.sort;
+local setmetatable = setmetatable;
+local next = next;
+
+module "events"
+
+function new()
+ local handlers = {};
+ local event_map = {};
+ local function _rebuild_index(handlers, event)
+ local _handlers = event_map[event];
+ if not _handlers or next(_handlers) == nil then return; end
+ local index = {};
+ for handler in pairs(_handlers) do
+ t_insert(index, handler);
+ end
+ t_sort(index, function(a, b) return _handlers[a] > _handlers[b]; end);
+ handlers[event] = index;
+ return index;
+ end;
+ setmetatable(handlers, { __index = _rebuild_index });
+ local function add_handler(event, handler, priority)
+ local map = event_map[event];
+ if map then
+ map[handler] = priority or 0;
+ else
+ map = {[handler] = priority or 0};
+ event_map[event] = map;
+ end
+ handlers[event] = nil;
+ end;
+ local function remove_handler(event, handler)
+ local map = event_map[event];
+ if map then
+ map[handler] = nil;
+ handlers[event] = nil;
+ if next(map) == nil then
+ event_map[event] = nil;
+ end
+ end
+ end;
+ local function add_handlers(handlers)
+ for event, handler in pairs(handlers) do
+ add_handler(event, handler);
+ end
+ end;
+ local function remove_handlers(handlers)
+ for event, handler in pairs(handlers) do
+ remove_handler(event, handler);
+ end
+ end;
+ local function fire_event(event, ...)
+ local h = handlers[event];
+ if h then
+ for i=1,#h do
+ local ret = h[i](...);
+ if ret ~= nil then return ret; end
+ end
+ end
+ end;
+ return {
+ add_handler = add_handler;
+ remove_handler = remove_handler;
+ add_handlers = add_handlers;
+ remove_handlers = remove_handlers;
+ fire_event = fire_event;
+ _handlers = handlers;
+ _event_map = event_map;
+ };
+end
+
+return _M;
diff --git a/util/filters.lua b/util/filters.lua
new file mode 100644
index 00000000..d143666b
--- /dev/null
+++ b/util/filters.lua
@@ -0,0 +1,87 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local t_insert, t_remove = table.insert, table.remove;
+
+module "filters"
+
+local new_filter_hooks = {};
+
+function initialize(session)
+ if not session.filters then
+ local filters = {};
+ session.filters = filters;
+
+ function session.filter(type, data)
+ local filter_list = filters[type];
+ if filter_list then
+ for i = 1, #filter_list do
+ data = filter_list[i](data, session);
+ if data == nil then break; end
+ end
+ end
+ return data;
+ end
+ end
+
+ for i=1,#new_filter_hooks do
+ new_filter_hooks[i](session);
+ end
+
+ return session.filter;
+end
+
+function add_filter(session, type, callback, priority)
+ if not session.filters then
+ initialize(session);
+ end
+
+ local filter_list = session.filters[type];
+ if not filter_list then
+ filter_list = {};
+ session.filters[type] = filter_list;
+ end
+
+ priority = priority or 0;
+
+ local i = 0;
+ repeat
+ i = i + 1;
+ until not filter_list[i] or filter_list[filter_list[i]] >= priority;
+
+ t_insert(filter_list, i, callback);
+ filter_list[callback] = priority;
+end
+
+function remove_filter(session, type, callback)
+ if not session.filters then return; end
+ local filter_list = session.filters[type];
+ if filter_list and filter_list[callback] then
+ for i=1, #filter_list do
+ if filter_list[i] == callback then
+ t_remove(filter_list, i);
+ filter_list[callback] = nil;
+ return true;
+ end
+ end
+ end
+end
+
+function add_filter_hook(callback)
+ t_insert(new_filter_hooks, callback);
+end
+
+function remove_filter_hook(callback)
+ for i=1,#new_filter_hooks do
+ if new_filter_hooks[i] == callback then
+ t_remove(new_filter_hooks, i);
+ end
+ end
+end
+
+return _M;
diff --git a/util/helpers.lua b/util/helpers.lua
new file mode 100644
index 00000000..08b86a7c
--- /dev/null
+++ b/util/helpers.lua
@@ -0,0 +1,82 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local debug = require "util.debug";
+
+module("helpers", package.seeall);
+
+-- Helper functions for debugging
+
+local log = require "util.logger".init("util.debug");
+
+function log_host_events(host)
+ return log_events(prosody.hosts[host].events, host);
+end
+
+function revert_log_host_events(host)
+ return revert_log_events(prosody.hosts[host].events);
+end
+
+function log_events(events, name, logger)
+ local f = events.fire_event;
+ if not f then
+ error("Object does not appear to be a util.events object");
+ end
+ logger = logger or log;
+ name = name or tostring(events);
+ function events.fire_event(event, ...)
+ logger("debug", "%s firing event: %s", name, event);
+ return f(event, ...);
+ end
+ events[events.fire_event] = f;
+ return events;
+end
+
+function revert_log_events(events)
+ events.fire_event, events[events.fire_event] = events[events.fire_event], nil; -- :))
+end
+
+function show_events(events, specific_event)
+ local event_handlers = events._handlers;
+ local events_array = {};
+ local event_handler_arrays = {};
+ for event in pairs(events._event_map) do
+ local handlers = event_handlers[event];
+ if handlers and (event == specific_event or not specific_event) then
+ table.insert(events_array, event);
+ local handler_strings = {};
+ for i, handler in ipairs(handlers) do
+ local upvals = debug.string_from_var_table(debug.get_upvalues_table(handler));
+ handler_strings[i] = " "..i..": "..tostring(handler)..(upvals and ("\n "..upvals) or "");
+ end
+ event_handler_arrays[event] = handler_strings;
+ end
+ end
+ table.sort(events_array);
+ local i = 1;
+ while i <= #events_array do
+ local handlers = event_handler_arrays[events_array[i]];
+ for j=#handlers, 1, -1 do
+ table.insert(events_array, i+1, handlers[j]);
+ end
+ if i > 1 then events_array[i] = "\n"..events_array[i]; end
+ i = i + #handlers + 1
+ end
+ return table.concat(events_array, "\n");
+end
+
+function get_upvalue(f, get_name)
+ local i, name, value = 0;
+ repeat
+ i = i + 1;
+ name, value = debug.getupvalue(f, i);
+ until name == get_name or name == nil;
+ return value;
+end
+
+return _M;
diff --git a/util/hmac.lua b/util/hmac.lua
new file mode 100644
index 00000000..51211c7a
--- /dev/null
+++ b/util/hmac.lua
@@ -0,0 +1,15 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+-- COMPAT: Only for external pre-0.9 modules
+
+local hashes = require "util.hashes"
+
+return { md5 = hashes.hmac_md5,
+ sha1 = hashes.hmac_sha1,
+ sha256 = hashes.hmac_sha256 };
diff --git a/util/http.lua b/util/http.lua
new file mode 100644
index 00000000..f7259920
--- /dev/null
+++ b/util/http.lua
@@ -0,0 +1,64 @@
+-- Prosody IM
+-- Copyright (C) 2013 Florian Zeitz
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local format, char = string.format, string.char;
+local pairs, ipairs, tonumber = pairs, ipairs, tonumber;
+local t_insert, t_concat = table.insert, table.concat;
+
+local function urlencode(s)
+ return s and (s:gsub("[^a-zA-Z0-9.~_-]", function (c) return format("%%%02x", c:byte()); end));
+end
+local function urldecode(s)
+ return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end));
+end
+
+local function _formencodepart(s)
+ return s and (s:gsub("%W", function (c)
+ if c ~= " " then
+ return format("%%%02x", c:byte());
+ else
+ return "+";
+ end
+ end));
+end
+
+local function formencode(form)
+ local result = {};
+ if form[1] then -- Array of ordered { name, value }
+ for _, field in ipairs(form) do
+ t_insert(result, _formencodepart(field.name).."=".._formencodepart(field.value));
+ end
+ else -- Unordered map of name -> value
+ for name, value in pairs(form) do
+ t_insert(result, _formencodepart(name).."=".._formencodepart(value));
+ end
+ end
+ return t_concat(result, "&");
+end
+
+local function formdecode(s)
+ if not s:match("=") then return urldecode(s); end
+ local r = {};
+ for k, v in s:gmatch("([^=&]*)=([^&]*)") do
+ k, v = k:gsub("%+", "%%20"), v:gsub("%+", "%%20");
+ k, v = urldecode(k), urldecode(v);
+ t_insert(r, { name = k, value = v });
+ r[k] = v;
+ end
+ return r;
+end
+
+local function contains_token(field, token)
+ field = ","..field:gsub("[ \t]", ""):lower()..",";
+ return field:find(","..token:lower()..",", 1, true) ~= nil;
+end
+
+return {
+ urlencode = urlencode, urldecode = urldecode;
+ formencode = formencode, formdecode = formdecode;
+ contains_token = contains_token;
+};
diff --git a/util/import.lua b/util/import.lua
index 6aab0d45..81401e8b 100644
--- a/util/import.lua
+++ b/util/import.lua
@@ -1,3 +1,12 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+
local t_insert = table.insert;
function import(module, ...)
diff --git a/util/ip.lua b/util/ip.lua
new file mode 100644
index 00000000..de287b16
--- /dev/null
+++ b/util/ip.lua
@@ -0,0 +1,189 @@
+-- Prosody IM
+-- Copyright (C) 2008-2011 Florian Zeitz
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local ip_methods = {};
+local ip_mt = { __index = function (ip, key) return (ip_methods[key])(ip); end,
+ __tostring = function (ip) return ip.addr; end,
+ __eq = function (ipA, ipB) return ipA.addr == ipB.addr; end};
+local hex2bits = { ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111", ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011", ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111" };
+
+local function new_ip(ipStr, proto)
+ if proto ~= "IPv4" and proto ~= "IPv6" then
+ return nil, "invalid protocol";
+ end
+
+ return setmetatable({ addr = ipStr, proto = proto }, ip_mt);
+end
+
+local function toBits(ip)
+ local result = "";
+ local fields = {};
+ if ip.proto == "IPv4" then
+ ip = ip.toV4mapped;
+ end
+ ip = (ip.addr):upper();
+ ip:gsub("([^:]*):?", function (c) fields[#fields + 1] = c end);
+ if not ip:match(":$") then fields[#fields] = nil; end
+ for i, field in ipairs(fields) do
+ if field:len() == 0 and i ~= 1 and i ~= #fields then
+ for i = 1, 16 * (9 - #fields) do
+ result = result .. "0";
+ end
+ else
+ for i = 1, 4 - field:len() do
+ result = result .. "0000";
+ end
+ for i = 1, field:len() do
+ result = result .. hex2bits[field:sub(i,i)];
+ end
+ end
+ end
+ return result;
+end
+
+local function commonPrefixLength(ipA, ipB)
+ ipA, ipB = toBits(ipA), toBits(ipB);
+ for i = 1, 128 do
+ if ipA:sub(i,i) ~= ipB:sub(i,i) then
+ return i-1;
+ end
+ end
+ return 128;
+end
+
+local function v4scope(ip)
+ local fields = {};
+ ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end);
+ -- Loopback:
+ if fields[1] == 127 then
+ return 0x2;
+ -- Link-local unicast:
+ elseif fields[1] == 169 and fields[2] == 254 then
+ return 0x2;
+ -- Global unicast:
+ else
+ return 0xE;
+ end
+end
+
+local function v6scope(ip)
+ -- Loopback:
+ if ip:match("^[0:]*1$") then
+ return 0x2;
+ -- Link-local unicast:
+ elseif ip:match("^[Ff][Ee][89ABab]") then
+ return 0x2;
+ -- Site-local unicast:
+ elseif ip:match("^[Ff][Ee][CcDdEeFf]") then
+ return 0x5;
+ -- Multicast:
+ elseif ip:match("^[Ff][Ff]") then
+ return tonumber("0x"..ip:sub(4,4));
+ -- Global unicast:
+ else
+ return 0xE;
+ end
+end
+
+local function label(ip)
+ if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then
+ return 0;
+ elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then
+ return 2;
+ elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then
+ return 5;
+ elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then
+ return 13;
+ elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then
+ return 11;
+ elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then
+ return 12;
+ elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then
+ return 3;
+ elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then
+ return 4;
+ else
+ return 1;
+ end
+end
+
+local function precedence(ip)
+ if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then
+ return 50;
+ elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then
+ return 30;
+ elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then
+ return 5;
+ elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then
+ return 3;
+ elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then
+ return 1;
+ elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then
+ return 1;
+ elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then
+ return 1;
+ elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then
+ return 35;
+ else
+ return 40;
+ end
+end
+
+local function toV4mapped(ip)
+ local fields = {};
+ local ret = "::ffff:";
+ ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end);
+ ret = ret .. ("%02x"):format(fields[1]);
+ ret = ret .. ("%02x"):format(fields[2]);
+ ret = ret .. ":"
+ ret = ret .. ("%02x"):format(fields[3]);
+ ret = ret .. ("%02x"):format(fields[4]);
+ return new_ip(ret, "IPv6");
+end
+
+function ip_methods:toV4mapped()
+ if self.proto ~= "IPv4" then return nil, "No IPv4 address" end
+ local value = toV4mapped(self.addr);
+ self.toV4mapped = value;
+ return value;
+end
+
+function ip_methods:label()
+ local value;
+ if self.proto == "IPv4" then
+ value = label(self.toV4mapped);
+ else
+ value = label(self);
+ end
+ self.label = value;
+ return value;
+end
+
+function ip_methods:precedence()
+ local value;
+ if self.proto == "IPv4" then
+ value = precedence(self.toV4mapped);
+ else
+ value = precedence(self);
+ end
+ self.precedence = value;
+ return value;
+end
+
+function ip_methods:scope()
+ local value;
+ if self.proto == "IPv4" then
+ value = v4scope(self.addr);
+ else
+ value = v6scope(self.addr);
+ end
+ self.scope = value;
+ return value;
+end
+
+return {new_ip = new_ip,
+ commonPrefixLength = commonPrefixLength};
diff --git a/util/iterators.lua b/util/iterators.lua
new file mode 100644
index 00000000..1f6aacb8
--- /dev/null
+++ b/util/iterators.lua
@@ -0,0 +1,159 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+--[[ Iterators ]]--
+
+local it = {};
+
+-- Reverse an iterator
+function it.reverse(f, s, var)
+ local results = {};
+
+ -- First call the normal iterator
+ while true do
+ local ret = { f(s, var) };
+ var = ret[1];
+ if var == nil then break; end
+ table.insert(results, 1, ret);
+ end
+
+ -- Then return our reverse one
+ local i,max = 0, #results;
+ return function (results)
+ if i<max then
+ i = i + 1;
+ return unpack(results[i]);
+ end
+ end, results;
+end
+
+-- Iterate only over keys in a table
+local function _keys_it(t, key)
+ return (next(t, key));
+end
+function it.keys(t)
+ return _keys_it, t;
+end
+
+-- Iterate only over values in a table
+function it.values(t)
+ local key, val;
+ return function (t)
+ key, val = next(t, key);
+ return val;
+ end, t;
+end
+
+-- Given an iterator, iterate only over unique items
+function it.unique(f, s, var)
+ local set = {};
+
+ return function ()
+ while true do
+ local ret = { f(s, var) };
+ var = ret[1];
+ if var == nil then break; end
+ if not set[var] then
+ set[var] = true;
+ return var;
+ end
+ end
+ end;
+end
+
+--[[ Return the number of items an iterator returns ]]--
+function it.count(f, s, var)
+ local x = 0;
+
+ while true do
+ local ret = { f(s, var) };
+ var = ret[1];
+ if var == nil then break; end
+ x = x + 1;
+ end
+
+ return x;
+end
+
+-- Return the first n items an iterator returns
+function it.head(n, f, s, var)
+ local c = 0;
+ return function (s, var)
+ if c >= n then
+ return nil;
+ end
+ c = c + 1;
+ return f(s, var);
+ end, s;
+end
+
+-- Skip the first n items an iterator returns
+function it.skip(n, f, s, var)
+ for i=1,n do
+ var = f(s, var);
+ end
+ return f, s, var;
+end
+
+-- Return the last n items an iterator returns
+function it.tail(n, f, s, var)
+ local results, count = {}, 0;
+ while true do
+ local ret = { f(s, var) };
+ var = ret[1];
+ if var == nil then break; end
+ results[(count%n)+1] = ret;
+ count = count + 1;
+ end
+
+ if n > count then n = count; end
+
+ local pos = 0;
+ return function ()
+ pos = pos + 1;
+ if pos > n then return nil; end
+ return unpack(results[((count-1+pos)%n)+1]);
+ end
+ --return reverse(head(n, reverse(f, s, var)));
+end
+
+local function _ripairs_iter(t, key) if key > 1 then return key-1, t[key-1]; end end
+function it.ripairs(t)
+ return _ripairs_iter, t, #t+1;
+end
+
+local function _range_iter(max, curr) if curr < max then return curr + 1; end end
+function it.range(x, y)
+ if not y then x, y = 1, x; end -- Default to 1..x if y not given
+ return _range_iter, y, x-1;
+end
+
+-- Convert the values returned by an iterator to an array
+function it.to_array(f, s, var)
+ local t, var = {};
+ while true do
+ var = f(s, var);
+ if var == nil then break; end
+ table.insert(t, var);
+ end
+ return t;
+end
+
+-- Treat the return of an iterator as key,value pairs,
+-- and build a table
+function it.to_table(f, s, var)
+ local t, var2 = {};
+ while true do
+ var, var2 = f(s, var);
+ if var == nil then break; end
+ t[var] = var2;
+ end
+ return t;
+end
+
+return it;
diff --git a/util/jid.lua b/util/jid.lua
index 28dd1a92..4c4371d8 100644
--- a/util/jid.lua
+++ b/util/jid.lua
@@ -1,12 +1,107 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+
local match = string.match;
+local nodeprep = require "util.encodings".stringprep.nodeprep;
+local nameprep = require "util.encodings".stringprep.nameprep;
+local resourceprep = require "util.encodings".stringprep.resourceprep;
+
+local escapes = {
+ [" "] = "\\20"; ['"'] = "\\22";
+ ["&"] = "\\26"; ["'"] = "\\27";
+ ["/"] = "\\2f"; [":"] = "\\3a";
+ ["<"] = "\\3c"; [">"] = "\\3e";
+ ["@"] = "\\40"; ["\\"] = "\\5c";
+};
+local unescapes = {};
+for k,v in pairs(escapes) do unescapes[v] = k; end
module "jid"
-function split(jid)
- if not jid then return nil; end
- local node = match(jid, "^([^@]+)@");
- local server = (node and match(jid, ".-@([^@/]+)")) or match(jid, "^([^@/]+)");
- local resource = match(jid, "/(.+)$");
- return node, server, resource;
+local function _split(jid)
+ if not jid then return; end
+ local node, nodepos = match(jid, "^([^@/]+)@()");
+ local host, hostpos = match(jid, "^([^@/]+)()", nodepos)
+ if node and not host then return nil, nil, nil; end
+ local resource = match(jid, "^/(.+)$", hostpos);
+ if (not host) or ((not resource) and #jid >= hostpos) then return nil, nil, nil; end
+ return node, host, resource;
+end
+split = _split;
+
+function bare(jid)
+ local node, host = _split(jid);
+ if node and host then
+ return node.."@"..host;
+ end
+ return host;
+end
+
+local function _prepped_split(jid)
+ local node, host, resource = _split(jid);
+ if host then
+ host = nameprep(host);
+ if not host then return; end
+ if node then
+ node = nodeprep(node);
+ if not node then return; end
+ end
+ if resource then
+ resource = resourceprep(resource);
+ if not resource then return; end
+ end
+ return node, host, resource;
+ end
end
+prepped_split = _prepped_split;
+
+function prep(jid)
+ local node, host, resource = _prepped_split(jid);
+ if host then
+ if node then
+ host = node .. "@" .. host;
+ end
+ if resource then
+ host = host .. "/" .. resource;
+ end
+ end
+ return host;
+end
+
+function join(node, host, resource)
+ if node and host and resource then
+ return node.."@"..host.."/"..resource;
+ elseif node and host then
+ return node.."@"..host;
+ elseif host and resource then
+ return host.."/"..resource;
+ elseif host then
+ return host;
+ end
+ return nil; -- Invalid JID
+end
+
+function compare(jid, acl)
+ -- compare jid to single acl rule
+ -- TODO compare to table of rules?
+ local jid_node, jid_host, jid_resource = _split(jid);
+ local acl_node, acl_host, acl_resource = _split(acl);
+ if ((acl_node ~= nil and acl_node == jid_node) or acl_node == nil) and
+ ((acl_host ~= nil and acl_host == jid_host) or acl_host == nil) and
+ ((acl_resource ~= nil and acl_resource == jid_resource) or acl_resource == nil) then
+ return true
+ end
+ return false
+end
+
+function escape(s) return s and (s:gsub(".", escapes)); end
+function unescape(s) return s and (s:gsub("\\%x%x", unescapes)); end
+
+return _M;
diff --git a/util/json.lua b/util/json.lua
new file mode 100644
index 00000000..9c2dd2c6
--- /dev/null
+++ b/util/json.lua
@@ -0,0 +1,412 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local type = type;
+local t_insert, t_concat, t_remove, t_sort = table.insert, table.concat, table.remove, table.sort;
+local s_char = string.char;
+local tostring, tonumber = tostring, tonumber;
+local pairs, ipairs = pairs, ipairs;
+local next = next;
+local error = error;
+local newproxy, getmetatable = newproxy, getmetatable;
+local print = print;
+
+local has_array, array = pcall(require, "util.array");
+local array_mt = hasarray and getmetatable(array()) or {};
+
+--module("json")
+local json = {};
+
+local null = newproxy and newproxy(true) or {};
+if getmetatable and getmetatable(null) then
+ getmetatable(null).__tostring = function() return "null"; end;
+end
+json.null = null;
+
+local escapes = {
+ ["\""] = "\\\"", ["\\"] = "\\\\", ["\b"] = "\\b",
+ ["\f"] = "\\f", ["\n"] = "\\n", ["\r"] = "\\r", ["\t"] = "\\t"};
+local unescapes = {
+ ["\""] = "\"", ["\\"] = "\\", ["/"] = "/",
+ b = "\b", f = "\f", n = "\n", r = "\r", t = "\t"};
+for i=0,31 do
+ local ch = s_char(i);
+ if not escapes[ch] then escapes[ch] = ("\\u%.4X"):format(i); end
+end
+
+local function codepoint_to_utf8(code)
+ if code < 0x80 then return s_char(code); end
+ local bits0_6 = code % 64;
+ if code < 0x800 then
+ local bits6_5 = (code - bits0_6) / 64;
+ return s_char(0x80 + 0x40 + bits6_5, 0x80 + bits0_6);
+ end
+ local bits0_12 = code % 4096;
+ local bits6_6 = (bits0_12 - bits0_6) / 64;
+ local bits12_4 = (code - bits0_12) / 4096;
+ return s_char(0x80 + 0x40 + 0x20 + bits12_4, 0x80 + bits6_6, 0x80 + bits0_6);
+end
+
+local valid_types = {
+ number = true,
+ string = true,
+ table = true,
+ boolean = true
+};
+local special_keys = {
+ __array = true;
+ __hash = true;
+};
+
+local simplesave, tablesave, arraysave, stringsave;
+
+function stringsave(o, buffer)
+ -- FIXME do proper utf-8 and binary data detection
+ t_insert(buffer, "\""..(o:gsub(".", escapes)).."\"");
+end
+
+function arraysave(o, buffer)
+ t_insert(buffer, "[");
+ if next(o) then
+ for i,v in ipairs(o) do
+ simplesave(v, buffer);
+ t_insert(buffer, ",");
+ end
+ t_remove(buffer);
+ end
+ t_insert(buffer, "]");
+end
+
+function tablesave(o, buffer)
+ local __array = {};
+ local __hash = {};
+ local hash = {};
+ for i,v in ipairs(o) do
+ __array[i] = v;
+ end
+ for k,v in pairs(o) do
+ local ktype, vtype = type(k), type(v);
+ if valid_types[vtype] or v == null then
+ if ktype == "string" and not special_keys[k] then
+ hash[k] = v;
+ elseif (valid_types[ktype] or k == null) and __array[k] == nil then
+ __hash[k] = v;
+ end
+ end
+ end
+ if next(__hash) ~= nil or next(hash) ~= nil or next(__array) == nil then
+ t_insert(buffer, "{");
+ local mark = #buffer;
+ if buffer.ordered then
+ local keys = {};
+ for k in pairs(hash) do
+ t_insert(keys, k);
+ end
+ t_sort(keys);
+ for _,k in ipairs(keys) do
+ stringsave(k, buffer);
+ t_insert(buffer, ":");
+ simplesave(hash[k], buffer);
+ t_insert(buffer, ",");
+ end
+ else
+ for k,v in pairs(hash) do
+ stringsave(k, buffer);
+ t_insert(buffer, ":");
+ simplesave(v, buffer);
+ t_insert(buffer, ",");
+ end
+ end
+ if next(__hash) ~= nil then
+ t_insert(buffer, "\"__hash\":[");
+ for k,v in pairs(__hash) do
+ simplesave(k, buffer);
+ t_insert(buffer, ",");
+ simplesave(v, buffer);
+ t_insert(buffer, ",");
+ end
+ t_remove(buffer);
+ t_insert(buffer, "]");
+ t_insert(buffer, ",");
+ end
+ if next(__array) then
+ t_insert(buffer, "\"__array\":");
+ arraysave(__array, buffer);
+ t_insert(buffer, ",");
+ end
+ if mark ~= #buffer then t_remove(buffer); end
+ t_insert(buffer, "}");
+ else
+ arraysave(__array, buffer);
+ end
+end
+
+function simplesave(o, buffer)
+ local t = type(o);
+ if t == "number" then
+ t_insert(buffer, tostring(o));
+ elseif t == "string" then
+ stringsave(o, buffer);
+ elseif t == "table" then
+ local mt = getmetatable(o);
+ if mt == array_mt then
+ arraysave(o, buffer);
+ else
+ tablesave(o, buffer);
+ end
+ elseif t == "boolean" then
+ t_insert(buffer, (o and "true" or "false"));
+ else
+ t_insert(buffer, "null");
+ end
+end
+
+function json.encode(obj)
+ local t = {};
+ simplesave(obj, t);
+ return t_concat(t);
+end
+function json.encode_ordered(obj)
+ local t = { ordered = true };
+ simplesave(obj, t);
+ return t_concat(t);
+end
+function json.encode_array(obj)
+ local t = {};
+ arraysave(obj, t);
+ return t_concat(t);
+end
+
+-----------------------------------
+
+
+function json.decode(json)
+ json = json.." "; -- appending a space ensures valid json wouldn't touch EOF
+ local pos = 1;
+ local current = {};
+ local stack = {};
+ local ch, peek;
+ local function next()
+ ch = json:sub(pos, pos);
+ if ch == "" then error("Unexpected EOF"); end
+ pos = pos+1;
+ peek = json:sub(pos, pos);
+ return ch;
+ end
+
+ local function skipwhitespace()
+ while ch and (ch == "\r" or ch == "\n" or ch == "\t" or ch == " ") do
+ next();
+ end
+ end
+ local function skiplinecomment()
+ repeat next(); until not(ch) or ch == "\r" or ch == "\n";
+ skipwhitespace();
+ end
+ local function skipstarcomment()
+ next(); next(); -- skip '/', '*'
+ while peek and ch ~= "*" and peek ~= "/" do next(); end
+ if not peek then error("eof in star comment") end
+ next(); next(); -- skip '*', '/'
+ skipwhitespace();
+ end
+ local function skipstuff()
+ while true do
+ skipwhitespace();
+ if ch == "/" and peek == "*" then
+ skipstarcomment();
+ elseif ch == "/" and peek == "/" then
+ skiplinecomment();
+ else
+ return;
+ end
+ end
+ end
+
+ local readvalue;
+ local function readarray()
+ local t = setmetatable({}, array_mt);
+ next(); -- skip '['
+ skipstuff();
+ if ch == "]" then next(); return t; end
+ t_insert(t, readvalue());
+ while true do
+ skipstuff();
+ if ch == "]" then next(); return t; end
+ if not ch then error("eof while reading array");
+ elseif ch == "," then next();
+ elseif ch then error("unexpected character in array, comma expected"); end
+ if not ch then error("eof while reading array"); end
+ t_insert(t, readvalue());
+ end
+ end
+
+ local function checkandskip(c)
+ local x = ch or "eof";
+ if x ~= c then error("unexpected "..x..", '"..c.."' expected"); end
+ next();
+ end
+ local function readliteral(lit, val)
+ for c in lit:gmatch(".") do
+ checkandskip(c);
+ end
+ return val;
+ end
+ local function readstring()
+ local s = "";
+ checkandskip("\"");
+ while ch do
+ while ch and ch ~= "\\" and ch ~= "\"" do
+ s = s..ch; next();
+ end
+ if ch == "\\" then
+ next();
+ if unescapes[ch] then
+ s = s..unescapes[ch];
+ next();
+ elseif ch == "u" then
+ local seq = "";
+ for i=1,4 do
+ next();
+ if not ch then error("unexpected eof in string"); end
+ if not ch:match("[0-9a-fA-F]") then error("invalid unicode escape sequence in string"); end
+ seq = seq..ch;
+ end
+ s = s..codepoint_to_utf8(tonumber(seq, 16));
+ next();
+ else error("invalid escape sequence in string"); end
+ end
+ if ch == "\"" then
+ next();
+ return s;
+ end
+ end
+ error("eof while reading string");
+ end
+ local function readnumber()
+ local s = "";
+ if ch == "-" then
+ s = s..ch; next();
+ if not ch:match("[0-9]") then error("number format error"); end
+ end
+ if ch == "0" then
+ s = s..ch; next();
+ if ch:match("[0-9]") then error("number format error"); end
+ else
+ while ch and ch:match("[0-9]") do
+ s = s..ch; next();
+ end
+ end
+ if ch == "." then
+ s = s..ch; next();
+ if not ch:match("[0-9]") then error("number format error"); end
+ while ch and ch:match("[0-9]") do
+ s = s..ch; next();
+ end
+ if ch == "e" or ch == "E" then
+ s = s..ch; next();
+ if ch == "+" or ch == "-" then
+ s = s..ch; next();
+ if not ch:match("[0-9]") then error("number format error"); end
+ while ch and ch:match("[0-9]") do
+ s = s..ch; next();
+ end
+ end
+ end
+ end
+ return tonumber(s);
+ end
+ local function readmember(t)
+ skipstuff();
+ local k = readstring();
+ skipstuff();
+ checkandskip(":");
+ t[k] = readvalue();
+ end
+ local function fixobject(obj)
+ local __array = obj.__array;
+ if __array then
+ obj.__array = nil;
+ for i,v in ipairs(__array) do
+ t_insert(obj, v);
+ end
+ end
+ local __hash = obj.__hash;
+ if __hash then
+ obj.__hash = nil;
+ local k;
+ for i,v in ipairs(__hash) do
+ if k ~= nil then
+ obj[k] = v; k = nil;
+ else
+ k = v;
+ end
+ end
+ end
+ return obj;
+ end
+ local function readobject()
+ local t = {};
+ next(); -- skip '{'
+ skipstuff();
+ if ch == "}" then next(); return t; end
+ if not ch then error("eof while reading object"); end
+ readmember(t);
+ while true do
+ skipstuff();
+ if ch == "}" then next(); return fixobject(t); end
+ if not ch then error("eof while reading object");
+ elseif ch == "," then next();
+ elseif ch then error("unexpected character in object, comma expected"); end
+ if not ch then error("eof while reading object"); end
+ readmember(t);
+ end
+ end
+
+ function readvalue()
+ skipstuff();
+ while ch do
+ if ch == "{" then
+ return readobject();
+ elseif ch == "[" then
+ return readarray();
+ elseif ch == "\"" then
+ return readstring();
+ elseif ch:match("[%-0-9%.]") then
+ return readnumber();
+ elseif ch == "n" then
+ return readliteral("null", null);
+ elseif ch == "t" then
+ return readliteral("true", true);
+ elseif ch == "f" then
+ return readliteral("false", false);
+ else
+ error("invalid character at value start: "..ch);
+ end
+ end
+ error("eof while reading value");
+ end
+ next();
+ return readvalue();
+end
+
+function json.test(object)
+ local encoded = json.encode(object);
+ local decoded = json.decode(encoded);
+ local recoded = json.encode(decoded);
+ if encoded ~= recoded then
+ print("FAILED");
+ print("encoded:", encoded);
+ print("recoded:", recoded);
+ else
+ print(encoded);
+ end
+ return encoded == recoded;
+end
+
+return json;
diff --git a/util/logger.lua b/util/logger.lua
index 623ceb67..26206d4d 100644
--- a/util/logger.lua
+++ b/util/logger.lua
@@ -1,23 +1,74 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local pcall = pcall;
+
+local find = string.find;
+local ipairs, pairs, setmetatable = ipairs, pairs, setmetatable;
-local format = string.format;
-local print = print;
-local debug = debug;
-local tostring = tostring;
module "logger"
+local level_sinks = {};
+
+local make_logger;
+
function init(name)
- name = nil; -- While this line is not commented, will automatically fill in file/line number info
- return function (level, message, ...)
- if not name then
- local inf = debug.getinfo(3, 'Snl');
- level = level .. ","..tostring(inf.short_src):match("[^/]*$")..":"..inf.currentline;
- end
- if ... then
- print(level, format(message, ...));
- else
- print(level, message);
- end
+ local log_debug = make_logger(name, "debug");
+ local log_info = make_logger(name, "info");
+ local log_warn = make_logger(name, "warn");
+ local log_error = make_logger(name, "error");
+
+ return function (level, message, ...)
+ if level == "debug" then
+ return log_debug(message, ...);
+ elseif level == "info" then
+ return log_info(message, ...);
+ elseif level == "warn" then
+ return log_warn(message, ...);
+ elseif level == "error" then
+ return log_error(message, ...);
end
+ end
+end
+
+function make_logger(source_name, level)
+ local level_handlers = level_sinks[level];
+ if not level_handlers then
+ level_handlers = {};
+ level_sinks[level] = level_handlers;
+ end
+
+ local logger = function (message, ...)
+ for i = 1,#level_handlers do
+ level_handlers[i](source_name, level, message, ...);
+ end
+ end
+
+ return logger;
+end
+
+function reset()
+ for level, handler_list in pairs(level_sinks) do
+ -- Clear all handlers for this level
+ for i = 1, #handler_list do
+ handler_list[i] = nil;
+ end
+ end
end
-return _M; \ No newline at end of file
+function add_level_sink(level, sink_function)
+ if not level_sinks[level] then
+ level_sinks[level] = { sink_function };
+ else
+ level_sinks[level][#level_sinks[level] + 1 ] = sink_function;
+ end
+end
+
+_M.new = make_logger;
+
+return _M;
diff --git a/util/multitable.lua b/util/multitable.lua
new file mode 100644
index 00000000..dbf34d28
--- /dev/null
+++ b/util/multitable.lua
@@ -0,0 +1,177 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local select = select;
+local t_insert = table.insert;
+local unpack, pairs, next, type = unpack, pairs, next, type;
+
+module "multitable"
+
+local function get(self, ...)
+ local t = self.data;
+ for n = 1,select('#', ...) do
+ t = t[select(n, ...)];
+ if not t then break; end
+ end
+ return t;
+end
+
+local function add(self, ...)
+ local t = self.data;
+ local count = select('#', ...);
+ for n = 1,count-1 do
+ local key = select(n, ...);
+ local tab = t[key];
+ if not tab then tab = {}; t[key] = tab; end
+ t = tab;
+ end
+ t_insert(t, (select(count, ...)));
+end
+
+local function set(self, ...)
+ local t = self.data;
+ local count = select('#', ...);
+ for n = 1,count-2 do
+ local key = select(n, ...);
+ local tab = t[key];
+ if not tab then tab = {}; t[key] = tab; end
+ t = tab;
+ end
+ t[(select(count-1, ...))] = (select(count, ...));
+end
+
+local function r(t, n, _end, ...)
+ if t == nil then return; end
+ local k = select(n, ...);
+ if n == _end then
+ t[k] = nil;
+ return;
+ end
+ if k then
+ local v = t[k];
+ if v then
+ r(v, n+1, _end, ...);
+ if not next(v) then
+ t[k] = nil;
+ end
+ end
+ else
+ for _,b in pairs(t) do
+ r(b, n+1, _end, ...);
+ if not next(b) then
+ t[_] = nil;
+ end
+ end
+ end
+end
+
+local function remove(self, ...)
+ local _end = select('#', ...);
+ for n = _end,1 do
+ if select(n, ...) then _end = n; break; end
+ end
+ r(self.data, 1, _end, ...);
+end
+
+
+local function s(t, n, results, _end, ...)
+ if t == nil then return; end
+ local k = select(n, ...);
+ if n == _end then
+ if k == nil then
+ for _, v in pairs(t) do
+ t_insert(results, v);
+ end
+ else
+ t_insert(results, t[k]);
+ end
+ return;
+ end
+ if k then
+ local v = t[k];
+ if v then
+ s(v, n+1, results, _end, ...);
+ end
+ else
+ for _,b in pairs(t) do
+ s(b, n+1, results, _end, ...);
+ end
+ end
+end
+
+-- Search for keys, nil == wildcard
+local function search(self, ...)
+ local _end = select('#', ...);
+ for n = _end,1 do
+ if select(n, ...) then _end = n; break; end
+ end
+ local results = {};
+ s(self.data, 1, results, _end, ...);
+ return results;
+end
+
+-- Append results to an existing list
+local function search_add(self, results, ...)
+ if not results then results = {}; end
+ local _end = select('#', ...);
+ for n = _end,1 do
+ if select(n, ...) then _end = n; break; end
+ end
+ s(self.data, 1, results, _end, ...);
+ return results;
+end
+
+function iter(self, ...)
+ local query = { ... };
+ local maxdepth = select("#", ...);
+ local stack = { self.data };
+ local keys = { };
+ local function it(self)
+ local depth = #stack;
+ local key = next(stack[depth], keys[depth]);
+ if key == nil then -- Go up the stack
+ stack[depth], keys[depth] = nil, nil;
+ if depth > 1 then
+ return it(self);
+ end
+ return; -- The end
+ else
+ keys[depth] = key;
+ end
+ local value = stack[depth][key];
+ if query[depth] == nil or key == query[depth] then
+ if depth == maxdepth then -- Result
+ local result = {}; -- Collect keys forming path to result
+ for i = 1, depth do
+ result[i] = keys[i];
+ end
+ result[depth+1] = value;
+ return unpack(result, 1, depth+1);
+ elseif type(value) == "table" then
+ t_insert(stack, value); -- Descend
+ end
+ end
+ return it(self);
+ end;
+ return it, self;
+end
+
+function new()
+ return {
+ data = {};
+ get = get;
+ add = add;
+ set = set;
+ remove = remove;
+ search = search;
+ search_add = search_add;
+ iter = iter;
+ };
+end
+
+return _M;
diff --git a/util/openssl.lua b/util/openssl.lua
new file mode 100644
index 00000000..ef3fba96
--- /dev/null
+++ b/util/openssl.lua
@@ -0,0 +1,172 @@
+local type, tostring, pairs, ipairs = type, tostring, pairs, ipairs;
+local t_insert, t_concat = table.insert, table.concat;
+local s_format = string.format;
+
+local oid_xmppaddr = "1.3.6.1.5.5.7.8.5"; -- [XMPP-CORE]
+local oid_dnssrv = "1.3.6.1.5.5.7.8.7"; -- [SRV-ID]
+
+local idna_to_ascii = require "util.encodings".idna.to_ascii;
+
+local _M = {};
+local config = {};
+_M.config = config;
+
+local ssl_config = {};
+local ssl_config_mt = {__index=ssl_config};
+
+function config.new()
+ return setmetatable({
+ req = {
+ distinguished_name = "distinguished_name",
+ req_extensions = "v3_extensions",
+ x509_extensions = "v3_extensions",
+ prompt = "no",
+ },
+ distinguished_name = {
+ countryName = "GB",
+ -- stateOrProvinceName = "",
+ localityName = "The Internet",
+ organizationName = "Your Organisation",
+ organizationalUnitName = "XMPP Department",
+ commonName = "example.com",
+ emailAddress = "xmpp@example.com",
+ },
+ v3_extensions = {
+ basicConstraints = "CA:FALSE",
+ keyUsage = "digitalSignature,keyEncipherment",
+ extendedKeyUsage = "serverAuth,clientAuth",
+ subjectAltName = "@subject_alternative_name",
+ },
+ subject_alternative_name = {
+ DNS = {},
+ otherName = {},
+ },
+ }, ssl_config_mt);
+end
+
+local DN_order = {
+ "countryName";
+ "stateOrProvinceName";
+ "localityName";
+ "streetAddress";
+ "organizationName";
+ "organizationalUnitName";
+ "commonName";
+ "emailAddress";
+}
+_M._DN_order = DN_order;
+function ssl_config:serialize()
+ local s = "";
+ for k, t in pairs(self) do
+ s = s .. ("[%s]\n"):format(k);
+ if k == "subject_alternative_name" then
+ for san, n in pairs(t) do
+ for i = 1,#n do
+ s = s .. s_format("%s.%d = %s\n", san, i -1, n[i]);
+ end
+ end
+ elseif k == "distinguished_name" then
+ for i=1,#DN_order do
+ local k = DN_order[i]
+ local v = t[k];
+ if v then
+ s = s .. ("%s = %s\n"):format(k, v);
+ end
+ end
+ else
+ for k, v in pairs(t) do
+ s = s .. ("%s = %s\n"):format(k, v);
+ end
+ end
+ s = s .. "\n";
+ end
+ return s;
+end
+
+local function utf8string(s)
+ -- This is how we tell openssl not to encode UTF-8 strings as fake Latin1
+ return s_format("FORMAT:UTF8,UTF8:%s", s);
+end
+
+local function ia5string(s)
+ return s_format("IA5STRING:%s", s);
+end
+
+_M.util = {
+ utf8string = utf8string,
+ ia5string = ia5string,
+};
+
+function ssl_config:add_dNSName(host)
+ t_insert(self.subject_alternative_name.DNS, idna_to_ascii(host));
+end
+
+function ssl_config:add_sRVName(host, service)
+ t_insert(self.subject_alternative_name.otherName,
+ s_format("%s;%s", oid_dnssrv, ia5string("_" .. service .."." .. idna_to_ascii(host))));
+end
+
+function ssl_config:add_xmppAddr(host)
+ t_insert(self.subject_alternative_name.otherName,
+ s_format("%s;%s", oid_xmppaddr, utf8string(host)));
+end
+
+function ssl_config:from_prosody(hosts, config, certhosts)
+ -- TODO Decide if this should go elsewhere
+ local found_matching_hosts = false;
+ for i = 1,#certhosts do
+ local certhost = certhosts[i];
+ for name in pairs(hosts) do
+ if name == certhost or name:sub(-1-#certhost) == "."..certhost then
+ found_matching_hosts = true;
+ self:add_dNSName(name);
+ --print(name .. "#component_module: " .. (config.get(name, "component_module") or "nil"));
+ if config.get(name, "component_module") == nil then
+ self:add_sRVName(name, "xmpp-client");
+ end
+ --print(name .. "#anonymous_login: " .. tostring(config.get(name, "anonymous_login")));
+ if not (config.get(name, "anonymous_login") or
+ config.get(name, "authentication") == "anonymous") then
+ self:add_sRVName(name, "xmpp-server");
+ end
+ self:add_xmppAddr(name);
+ end
+ end
+ end
+ if not found_matching_hosts then
+ return nil, "no-matching-hosts";
+ end
+end
+
+do -- Lua to shell calls.
+ local function shell_escape(s)
+ return s:gsub("'",[['\'']]);
+ end
+
+ local function serialize(f,o)
+ local r = {"openssl", f};
+ for k,v in pairs(o) do
+ if type(k) == "string" then
+ t_insert(r, ("-%s"):format(k));
+ if v ~= true then
+ t_insert(r, ("'%s'"):format(shell_escape(tostring(v))));
+ end
+ end
+ end
+ for _,v in ipairs(o) do
+ t_insert(r, ("'%s'"):format(shell_escape(tostring(v))));
+ end
+ return t_concat(r, " ");
+ end
+
+ local os_execute = os.execute;
+ setmetatable(_M, {
+ __index=function(_,f)
+ return function(opts)
+ return 0 == os_execute(serialize(f, type(opts) == "table" and opts or {}));
+ end;
+ end;
+ });
+end
+
+return _M;
diff --git a/util/pluginloader.lua b/util/pluginloader.lua
new file mode 100644
index 00000000..c10fdf65
--- /dev/null
+++ b/util/pluginloader.lua
@@ -0,0 +1,60 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local dir_sep, path_sep = package.config:match("^(%S+)%s(%S+)");
+local plugin_dir = {};
+for path in (CFG_PLUGINDIR or "./plugins/"):gsub("[/\\]", dir_sep):gmatch("[^"..path_sep.."]+") do
+ path = path..dir_sep; -- add path separator to path end
+ path = path:gsub(dir_sep..dir_sep.."+", dir_sep); -- coalesce multiple separaters
+ plugin_dir[#plugin_dir + 1] = path;
+end
+
+local io_open = io.open;
+local envload = require "util.envload".envload;
+
+module "pluginloader"
+
+function load_file(names)
+ local file, err, path;
+ for i=1,#plugin_dir do
+ for j=1,#names do
+ path = plugin_dir[i]..names[j];
+ file, err = io_open(path);
+ if file then
+ local content = file:read("*a");
+ file:close();
+ return content, path;
+ end
+ end
+ end
+ return file, err;
+end
+
+function load_resource(plugin, resource)
+ resource = resource or "mod_"..plugin..".lua";
+
+ local names = {
+ "mod_"..plugin.."/"..plugin.."/"..resource; -- mod_hello/hello/mod_hello.lua
+ "mod_"..plugin.."/"..resource; -- mod_hello/mod_hello.lua
+ plugin.."/"..resource; -- hello/mod_hello.lua
+ resource; -- mod_hello.lua
+ };
+
+ return load_file(names);
+end
+
+function load_code(plugin, resource, env)
+ local content, err = load_resource(plugin, resource);
+ if not content then return content, err; end
+ local path = err;
+ local f, err = envload(content, "@"..path, env);
+ if not f then return f, err; end
+ return f, path;
+end
+
+return _M;
diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua
new file mode 100644
index 00000000..b80a69f2
--- /dev/null
+++ b/util/prosodyctl.lua
@@ -0,0 +1,279 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+
+local config = require "core.configmanager";
+local encodings = require "util.encodings";
+local stringprep = encodings.stringprep;
+local storagemanager = require "core.storagemanager";
+local usermanager = require "core.usermanager";
+local signal = require "util.signal";
+local set = require "util.set";
+local lfs = require "lfs";
+local pcall = pcall;
+local type = type;
+
+local nodeprep, nameprep = stringprep.nodeprep, stringprep.nameprep;
+
+local io, os = io, os;
+local print = print;
+local tostring, tonumber = tostring, tonumber;
+
+local CFG_SOURCEDIR = _G.CFG_SOURCEDIR;
+
+local _G = _G;
+local prosody = prosody;
+
+module "prosodyctl"
+
+-- UI helpers
+function show_message(msg, ...)
+ print(msg:format(...));
+end
+
+function show_warning(msg, ...)
+ print(msg:format(...));
+end
+
+function show_usage(usage, desc)
+ print("Usage: ".._G.arg[0].." "..usage);
+ if desc then
+ print(" "..desc);
+ end
+end
+
+function getchar(n)
+ local stty_ret = os.execute("stty raw -echo 2>/dev/null");
+ local ok, char;
+ if stty_ret == 0 then
+ ok, char = pcall(io.read, n or 1);
+ os.execute("stty sane");
+ else
+ ok, char = pcall(io.read, "*l");
+ if ok then
+ char = char:sub(1, n or 1);
+ end
+ end
+ if ok then
+ return char;
+ end
+end
+
+function getline()
+ local ok, line = pcall(io.read, "*l");
+ if ok then
+ return line;
+ end
+end
+
+function getpass()
+ local stty_ret = os.execute("stty -echo 2>/dev/null");
+ if stty_ret ~= 0 then
+ io.write("\027[08m"); -- ANSI 'hidden' text attribute
+ end
+ local ok, pass = pcall(io.read, "*l");
+ if stty_ret == 0 then
+ os.execute("stty sane");
+ else
+ io.write("\027[00m");
+ end
+ io.write("\n");
+ if ok then
+ return pass;
+ end
+end
+
+function show_yesno(prompt)
+ io.write(prompt, " ");
+ local choice = getchar():lower();
+ io.write("\n");
+ if not choice:match("%a") then
+ choice = prompt:match("%[.-(%U).-%]$");
+ if not choice then return nil; end
+ end
+ return (choice == "y");
+end
+
+function read_password()
+ local password;
+ while true do
+ io.write("Enter new password: ");
+ password = getpass();
+ if not password then
+ show_message("No password - cancelled");
+ return;
+ end
+ io.write("Retype new password: ");
+ if getpass() ~= password then
+ if not show_yesno [=[Passwords did not match, try again? [Y/n]]=] then
+ return;
+ end
+ else
+ break;
+ end
+ end
+ return password;
+end
+
+function show_prompt(prompt)
+ io.write(prompt, " ");
+ local line = getline();
+ line = line and line:gsub("\n$","");
+ return (line and #line > 0) and line or nil;
+end
+
+-- Server control
+function adduser(params)
+ local user, host, password = nodeprep(params.user), nameprep(params.host), params.password;
+ if not user then
+ return false, "invalid-username";
+ elseif not host then
+ return false, "invalid-hostname";
+ end
+
+ local host_session = prosody.hosts[host];
+ if not host_session then
+ return false, "no-such-host";
+ end
+
+ storagemanager.initialize_host(host);
+ local provider = host_session.users;
+ if not(provider) or provider.name == "null" then
+ usermanager.initialize_host(host);
+ end
+
+ local ok, errmsg = usermanager.create_user(user, password, host);
+ if not ok then
+ return false, errmsg;
+ end
+ return true;
+end
+
+function user_exists(params)
+ local user, host, password = nodeprep(params.user), nameprep(params.host), params.password;
+
+ storagemanager.initialize_host(host);
+ local provider = prosody.hosts[host].users;
+ if not(provider) or provider.name == "null" then
+ usermanager.initialize_host(host);
+ end
+
+ return usermanager.user_exists(user, host);
+end
+
+function passwd(params)
+ if not _M.user_exists(params) then
+ return false, "no-such-user";
+ end
+
+ return _M.adduser(params);
+end
+
+function deluser(params)
+ if not _M.user_exists(params) then
+ return false, "no-such-user";
+ end
+ local user, host = nodeprep(params.user), nameprep(params.host);
+
+ return usermanager.delete_user(user, host);
+end
+
+function getpid()
+ local pidfile = config.get("*", "pidfile");
+ if not pidfile then
+ return false, "no-pidfile";
+ end
+
+ local modules_enabled = set.new(config.get("*", "modules_enabled"));
+ if not modules_enabled:contains("posix") then
+ return false, "no-posix";
+ end
+
+ local file, err = io.open(pidfile, "r+");
+ if not file then
+ return false, "pidfile-read-failed", err;
+ end
+
+ local locked, err = lfs.lock(file, "w");
+ if locked then
+ file:close();
+ return false, "pidfile-not-locked";
+ end
+
+ local pid = tonumber(file:read("*a"));
+ file:close();
+
+ if not pid then
+ return false, "invalid-pid";
+ end
+
+ return true, pid;
+end
+
+function isrunning()
+ local ok, pid, err = _M.getpid();
+ if not ok then
+ if pid == "pidfile-read-failed" or pid == "pidfile-not-locked" then
+ -- Report as not running, since we can't open the pidfile
+ -- (it probably doesn't exist)
+ return true, false;
+ end
+ return ok, pid;
+ end
+ return true, signal.kill(pid, 0) == 0;
+end
+
+function start()
+ local ok, ret = _M.isrunning();
+ if not ok then
+ return ok, ret;
+ end
+ if ret then
+ return false, "already-running";
+ end
+ if not CFG_SOURCEDIR then
+ os.execute("./prosody");
+ else
+ os.execute(CFG_SOURCEDIR.."/../../bin/prosody");
+ end
+ return true;
+end
+
+function stop()
+ local ok, ret = _M.isrunning();
+ if not ok then
+ return ok, ret;
+ end
+ if not ret then
+ return false, "not-running";
+ end
+
+ local ok, pid = _M.getpid()
+ if not ok then return false, pid; end
+
+ signal.kill(pid, signal.SIGTERM);
+ return true;
+end
+
+function reload()
+ local ok, ret = _M.isrunning();
+ if not ok then
+ return ok, ret;
+ end
+ if not ret then
+ return false, "not-running";
+ end
+
+ local ok, pid = _M.getpid()
+ if not ok then return false, pid; end
+
+ signal.kill(pid, signal.SIGHUP);
+ return true;
+end
+
+return _M;
diff --git a/util/pubsub.lua b/util/pubsub.lua
new file mode 100644
index 00000000..e7fc86b1
--- /dev/null
+++ b/util/pubsub.lua
@@ -0,0 +1,388 @@
+local events = require "util.events";
+
+module("pubsub", package.seeall);
+
+local service = {};
+local service_mt = { __index = service };
+
+local default_config = {
+ broadcaster = function () end;
+ get_affiliation = function () end;
+ capabilities = {};
+};
+
+function new(config)
+ config = config or {};
+ return setmetatable({
+ config = setmetatable(config, { __index = default_config });
+ affiliations = {};
+ subscriptions = {};
+ nodes = {};
+ events = events.new();
+ }, service_mt);
+end
+
+function service:jids_equal(jid1, jid2)
+ local normalize = self.config.normalize_jid;
+ return normalize(jid1) == normalize(jid2);
+end
+
+function service:may(node, actor, action)
+ if actor == true then return true; end
+
+ local node_obj = self.nodes[node];
+ local node_aff = node_obj and node_obj.affiliations[actor];
+ local service_aff = self.affiliations[actor]
+ or self.config.get_affiliation(actor, node, action)
+ or "none";
+
+ -- Check if node allows/forbids it
+ local node_capabilities = node_obj and node_obj.capabilities;
+ if node_capabilities then
+ local caps = node_capabilities[node_aff or service_aff];
+ if caps then
+ local can = caps[action];
+ if can ~= nil then
+ return can;
+ end
+ end
+ end
+
+ -- Check service-wide capabilities instead
+ local service_capabilities = self.config.capabilities;
+ local caps = service_capabilities[node_aff or service_aff];
+ if caps then
+ local can = caps[action];
+ if can ~= nil then
+ return can;
+ end
+ end
+
+ return false;
+end
+
+function service:set_affiliation(node, actor, jid, affiliation)
+ -- Access checking
+ if not self:may(node, actor, "set_affiliation") then
+ return false, "forbidden";
+ end
+ --
+ local node_obj = self.nodes[node];
+ if not node_obj then
+ return false, "item-not-found";
+ end
+ node_obj.affiliations[jid] = affiliation;
+ local _, jid_sub = self:get_subscription(node, true, jid);
+ if not jid_sub and not self:may(node, jid, "be_unsubscribed") then
+ local ok, err = self:add_subscription(node, true, jid);
+ if not ok then
+ return ok, err;
+ end
+ elseif jid_sub and not self:may(node, jid, "be_subscribed") then
+ local ok, err = self:add_subscription(node, true, jid);
+ if not ok then
+ return ok, err;
+ end
+ end
+ return true;
+end
+
+function service:add_subscription(node, actor, jid, options)
+ -- Access checking
+ local cap;
+ if actor == true or jid == actor or self:jids_equal(actor, jid) then
+ cap = "subscribe";
+ else
+ cap = "subscribe_other";
+ end
+ if not self:may(node, actor, cap) then
+ return false, "forbidden";
+ end
+ if not self:may(node, jid, "be_subscribed") then
+ return false, "forbidden";
+ end
+ --
+ local node_obj = self.nodes[node];
+ if not node_obj then
+ if not self.config.autocreate_on_subscribe then
+ return false, "item-not-found";
+ else
+ local ok, err = self:create(node, true);
+ if not ok then
+ return ok, err;
+ end
+ node_obj = self.nodes[node];
+ end
+ end
+ node_obj.subscribers[jid] = options or true;
+ local normal_jid = self.config.normalize_jid(jid);
+ local subs = self.subscriptions[normal_jid];
+ if subs then
+ if not subs[jid] then
+ subs[jid] = { [node] = true };
+ else
+ subs[jid][node] = true;
+ end
+ else
+ self.subscriptions[normal_jid] = { [jid] = { [node] = true } };
+ end
+ self.events.fire_event("subscription-added", { node = node, jid = jid, normalized_jid = normal_jid, options = options });
+ return true;
+end
+
+function service:remove_subscription(node, actor, jid)
+ -- Access checking
+ local cap;
+ if actor == true or jid == actor or self:jids_equal(actor, jid) then
+ cap = "unsubscribe";
+ else
+ cap = "unsubscribe_other";
+ end
+ if not self:may(node, actor, cap) then
+ return false, "forbidden";
+ end
+ if not self:may(node, jid, "be_unsubscribed") then
+ return false, "forbidden";
+ end
+ --
+ local node_obj = self.nodes[node];
+ if not node_obj then
+ return false, "item-not-found";
+ end
+ if not node_obj.subscribers[jid] then
+ return false, "not-subscribed";
+ end
+ node_obj.subscribers[jid] = nil;
+ local normal_jid = self.config.normalize_jid(jid);
+ local subs = self.subscriptions[normal_jid];
+ if subs then
+ local jid_subs = subs[jid];
+ if jid_subs then
+ jid_subs[node] = nil;
+ if next(jid_subs) == nil then
+ subs[jid] = nil;
+ end
+ end
+ if next(subs) == nil then
+ self.subscriptions[normal_jid] = nil;
+ end
+ end
+ self.events.fire_event("subscription-removed", { node = node, jid = jid, normalized_jid = normal_jid });
+ return true;
+end
+
+function service:remove_all_subscriptions(actor, jid)
+ local normal_jid = self.config.normalize_jid(jid);
+ local subs = self.subscriptions[normal_jid]
+ subs = subs and subs[jid];
+ if subs then
+ for node in pairs(subs) do
+ self:remove_subscription(node, true, jid);
+ end
+ end
+ return true;
+end
+
+function service:get_subscription(node, actor, jid)
+ -- Access checking
+ local cap;
+ if actor == true or jid == actor or self:jids_equal(actor, jid) then
+ cap = "get_subscription";
+ else
+ cap = "get_subscription_other";
+ end
+ if not self:may(node, actor, cap) then
+ return false, "forbidden";
+ end
+ --
+ local node_obj = self.nodes[node];
+ if not node_obj then
+ return false, "item-not-found";
+ end
+ return true, node_obj.subscribers[jid];
+end
+
+function service:create(node, actor)
+ -- Access checking
+ if not self:may(node, actor, "create") then
+ return false, "forbidden";
+ end
+ --
+ if self.nodes[node] then
+ return false, "conflict";
+ end
+
+ self.nodes[node] = {
+ name = node;
+ subscribers = {};
+ config = {};
+ data = {};
+ affiliations = {};
+ };
+ local ok, err = self:set_affiliation(node, true, actor, "owner");
+ if not ok then
+ self.nodes[node] = nil;
+ end
+ return ok, err;
+end
+
+function service:delete(node, actor)
+ -- Access checking
+ if not self:may(node, actor, "delete") then
+ return false, "forbidden";
+ end
+ --
+ local node_obj = self.nodes[node];
+ self.nodes[node] = nil;
+ self.config.broadcaster("delete", node, node_obj.subscribers);
+ return true;
+end
+
+function service:publish(node, actor, id, item)
+ -- Access checking
+ if not self:may(node, actor, "publish") then
+ return false, "forbidden";
+ end
+ --
+ local node_obj = self.nodes[node];
+ if not node_obj then
+ if not self.config.autocreate_on_publish then
+ return false, "item-not-found";
+ end
+ local ok, err = self:create(node, true);
+ if not ok then
+ return ok, err;
+ end
+ node_obj = self.nodes[node];
+ end
+ node_obj.data[id] = item;
+ self.events.fire_event("item-published", { node = node, actor = actor, id = id, item = item });
+ self.config.broadcaster("items", node, node_obj.subscribers, item);
+ return true;
+end
+
+function service:retract(node, actor, id, retract)
+ -- Access checking
+ if not self:may(node, actor, "retract") then
+ return false, "forbidden";
+ end
+ --
+ local node_obj = self.nodes[node];
+ if (not node_obj) or (not node_obj.data[id]) then
+ return false, "item-not-found";
+ end
+ node_obj.data[id] = nil;
+ if retract then
+ self.config.broadcaster("items", node, node_obj.subscribers, retract);
+ end
+ return true
+end
+
+function service:purge(node, actor, notify)
+ -- Access checking
+ if not self:may(node, actor, "retract") then
+ return false, "forbidden";
+ end
+ --
+ local node_obj = self.nodes[node];
+ if not node_obj then
+ return false, "item-not-found";
+ end
+ node_obj.data = {}; -- Purge
+ if notify then
+ self.config.broadcaster("purge", node, node_obj.subscribers);
+ end
+ return true
+end
+
+function service:get_items(node, actor, id)
+ -- Access checking
+ if not self:may(node, actor, "get_items") then
+ return false, "forbidden";
+ end
+ --
+ local node_obj = self.nodes[node];
+ if not node_obj then
+ return false, "item-not-found";
+ end
+ if id then -- Restrict results to a single specific item
+ return true, { [id] = node_obj.data[id] };
+ else
+ return true, node_obj.data;
+ end
+end
+
+function service:get_nodes(actor)
+ -- Access checking
+ if not self:may(nil, actor, "get_nodes") then
+ return false, "forbidden";
+ end
+ --
+ return true, self.nodes;
+end
+
+function service:get_subscriptions(node, actor, jid)
+ -- Access checking
+ local cap;
+ if actor == true or jid == actor or self:jids_equal(actor, jid) then
+ cap = "get_subscriptions";
+ else
+ cap = "get_subscriptions_other";
+ end
+ if not self:may(node, actor, cap) then
+ return false, "forbidden";
+ end
+ --
+ local node_obj;
+ if node then
+ node_obj = self.nodes[node];
+ if not node_obj then
+ return false, "item-not-found";
+ end
+ end
+ local normal_jid = self.config.normalize_jid(jid);
+ local subs = self.subscriptions[normal_jid];
+ -- We return the subscription object from the node to save
+ -- a get_subscription() call for each node.
+ local ret = {};
+ if subs then
+ for jid, subscribed_nodes in pairs(subs) do
+ if node then -- Return only subscriptions to this node
+ if subscribed_nodes[node] then
+ ret[#ret+1] = {
+ node = subscribed_nodes[node];
+ jid = jid;
+ subscription = node_obj.subscribers[jid];
+ };
+ end
+ else -- Return subscriptions to all nodes
+ local nodes = self.nodes;
+ for subscribed_node in pairs(subscribed_nodes) do
+ ret[#ret+1] = {
+ node = subscribed_node;
+ jid = jid;
+ subscription = nodes[subscribed_node].subscribers[jid];
+ };
+ end
+ end
+ end
+ end
+ return true, ret;
+end
+
+-- Access models only affect 'none' affiliation caps, service/default access level...
+function service:set_node_capabilities(node, actor, capabilities)
+ -- Access checking
+ if not self:may(node, actor, "configure") then
+ return false, "forbidden";
+ end
+ --
+ local node_obj = self.nodes[node];
+ if not node_obj then
+ return false, "item-not-found";
+ end
+ node_obj.capabilities = capabilities;
+ return true;
+end
+
+return _M;
diff --git a/util/rfc6724.lua b/util/rfc6724.lua
new file mode 100644
index 00000000..c8aec631
--- /dev/null
+++ b/util/rfc6724.lua
@@ -0,0 +1,142 @@
+-- Prosody IM
+-- Copyright (C) 2011-2013 Florian Zeitz
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+-- This is used to sort destination addresses by preference
+-- during S2S connections.
+-- We can't hand this off to getaddrinfo, since it blocks
+
+local ip_commonPrefixLength = require"util.ip".commonPrefixLength
+local new_ip = require"util.ip".new_ip;
+
+local function commonPrefixLength(ipA, ipB)
+ local len = ip_commonPrefixLength(ipA, ipB);
+ return len < 64 and len or 64;
+end
+
+local function t_sort(t, comp)
+ for i = 1, (#t - 1) do
+ for j = (i + 1), #t do
+ local a, b = t[i], t[j];
+ if not comp(a,b) then
+ t[i], t[j] = b, a;
+ end
+ end
+ end
+end
+
+local function source(dest, candidates)
+ local function comp(ipA, ipB)
+ -- Rule 1: Prefer same address
+ if dest == ipA then
+ return true;
+ elseif dest == ipB then
+ return false;
+ end
+
+ -- Rule 2: Prefer appropriate scope
+ if ipA.scope < ipB.scope then
+ if ipA.scope < dest.scope then
+ return false;
+ else
+ return true;
+ end
+ elseif ipA.scope > ipB.scope then
+ if ipB.scope < dest.scope then
+ return true;
+ else
+ return false;
+ end
+ end
+
+ -- Rule 3: Avoid deprecated addresses
+ -- XXX: No way to determine this
+ -- Rule 4: Prefer home addresses
+ -- XXX: Mobility Address related, no way to determine this
+ -- Rule 5: Prefer outgoing interface
+ -- XXX: Interface to address relation. No way to determine this
+ -- Rule 6: Prefer matching label
+ if ipA.label == dest.label and ipB.label ~= dest.label then
+ return true;
+ elseif ipB.label == dest.label and ipA.label ~= dest.label then
+ return false;
+ end
+
+ -- Rule 7: Prefer temporary addresses (over public ones)
+ -- XXX: No way to determine this
+ -- Rule 8: Use longest matching prefix
+ if commonPrefixLength(ipA, dest) > commonPrefixLength(ipB, dest) then
+ return true;
+ else
+ return false;
+ end
+ end
+
+ t_sort(candidates, comp);
+ return candidates[1];
+end
+
+local function destination(candidates, sources)
+ local sourceAddrs = {};
+ local function comp(ipA, ipB)
+ local ipAsource = sourceAddrs[ipA];
+ local ipBsource = sourceAddrs[ipB];
+ -- Rule 1: Avoid unusable destinations
+ -- XXX: No such information
+ -- Rule 2: Prefer matching scope
+ if ipA.scope == ipAsource.scope and ipB.scope ~= ipBsource.scope then
+ return true;
+ elseif ipA.scope ~= ipAsource.scope and ipB.scope == ipBsource.scope then
+ return false;
+ end
+
+ -- Rule 3: Avoid deprecated addresses
+ -- XXX: No way to determine this
+ -- Rule 4: Prefer home addresses
+ -- XXX: Mobility Address related, no way to determine this
+ -- Rule 5: Prefer matching label
+ if ipAsource.label == ipA.label and ipBsource.label ~= ipB.label then
+ return true;
+ elseif ipBsource.label == ipB.label and ipAsource.label ~= ipA.label then
+ return false;
+ end
+
+ -- Rule 6: Prefer higher precedence
+ if ipA.precedence > ipB.precedence then
+ return true;
+ elseif ipA.precedence < ipB.precedence then
+ return false;
+ end
+
+ -- Rule 7: Prefer native transport
+ -- XXX: No way to determine this
+ -- Rule 8: Prefer smaller scope
+ if ipA.scope < ipB.scope then
+ return true;
+ elseif ipA.scope > ipB.scope then
+ return false;
+ end
+
+ -- Rule 9: Use longest matching prefix
+ if commonPrefixLength(ipA, ipAsource) > commonPrefixLength(ipB, ipBsource) then
+ return true;
+ elseif commonPrefixLength(ipA, ipAsource) < commonPrefixLength(ipB, ipBsource) then
+ return false;
+ end
+
+ -- Rule 10: Otherwise, leave order unchanged
+ return true;
+ end
+ for _, ip in ipairs(candidates) do
+ sourceAddrs[ip] = source(ip, sources);
+ end
+
+ t_sort(candidates, comp);
+ return candidates;
+end
+
+return {source = source,
+ destination = destination};
diff --git a/util/sasl.lua b/util/sasl.lua
index dbd6326a..afb3861b 100644
--- a/util/sasl.lua
+++ b/util/sasl.lua
@@ -1,43 +1,96 @@
+-- sasl.lua v0.4
+-- Copyright (C) 2008-2010 Tobias Markmann
+--
+-- All rights reserved.
+--
+-- Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+--
+-- * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+-- * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+-- * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
+--
+-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+local pairs, ipairs = pairs, ipairs;
+local t_insert = table.insert;
+local type = type
+local setmetatable = setmetatable;
+local assert = assert;
+local require = require;
-local base64 = require "base64"
-local log = require "util.logger".init("sasl");
-local tostring = tostring;
-local st = require "util.stanza";
-local s_match = string.match;
module "sasl"
+--[[
+Authentication Backend Prototypes:
+
+state = false : disabled
+state = true : enabled
+state = nil : non-existant
+]]
+
+local method = {};
+method.__index = method;
+local mechanisms = {};
+local backend_mechanism = {};
+
+-- register a new SASL mechanims
+function registerMechanism(name, backends, f)
+ assert(type(name) == "string", "Parameter name MUST be a string.");
+ assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table.");
+ assert(type(f) == "function", "Parameter f MUST be a function.");
+ mechanisms[name] = f
+ for _, backend_name in ipairs(backends) do
+ if backend_mechanism[backend_name] == nil then backend_mechanism[backend_name] = {}; end
+ t_insert(backend_mechanism[backend_name], name);
+ end
+end
+
+-- create a new SASL object which can be used to authenticate clients
+function new(realm, profile)
+ local mechanisms = profile.mechanisms;
+ if not mechanisms then
+ mechanisms = {};
+ for backend, f in pairs(profile) do
+ if backend_mechanism[backend] then
+ for _, mechanism in ipairs(backend_mechanism[backend]) do
+ mechanisms[mechanism] = true;
+ end
+ end
+ end
+ profile.mechanisms = mechanisms;
+ end
+ return setmetatable({ profile = profile, realm = realm, mechs = mechanisms }, method);
+end
-local function new_plain(onAuth, onSuccess, onFail, onWrite)
- local object = { mechanism = "PLAIN", onAuth = onAuth, onSuccess = onSuccess, onFail = onFail,
- onWrite = onWrite}
- --local challenge = base64.encode("");
- --onWrite(st.stanza("challenge", {xmlns = "urn:ietf:params:xml:ns:xmpp-sasl"}):text(challenge))
- object.feed = function(self, stanza)
- if stanza.name ~= "response" and stanza.name ~= "auth" then self.onFail("invalid-stanza-tag") end
- if stanza.attr.xmlns ~= "urn:ietf:params:xml:ns:xmpp-sasl" then self.onFail("invalid-stanza-namespace") end
- local response = base64.decode(stanza[1])
- local authorization = s_match(response, "([^&%z]+)")
- local authentication = s_match(response, "%z([^&%z]+)%z")
- local password = s_match(response, "%z[^&%z]+%z([^&%z]+)")
- if self.onAuth(authentication, password) == true then
- self.onWrite(st.stanza("success", {xmlns = "urn:ietf:params:xml:ns:xmpp-sasl"}))
- self.onSuccess(authentication)
- else
- self.onWrite(st.stanza("failure", {xmlns = "urn:ietf:params:xml:ns:xmpp-sasl"}):tag("temporary-auth-failure"));
- end
- end
- return object
+-- get a fresh clone with the same realm and profile
+function method:clean_clone()
+ return new(self.realm, self.profile)
end
+-- get a list of possible SASL mechanims to use
+function method:mechanisms()
+ return self.mechs;
+end
-function new(mechanism, onAuth, onSuccess, onFail, onWrite)
- local object
- if mechanism == "PLAIN" then object = new_plain(onAuth, onSuccess, onFail, onWrite)
- else
- log("debug", "Unsupported SASL mechanism: "..tostring(mechanism));
- onFail("unsupported-mechanism")
+-- select a mechanism to use
+function method:select(mechanism)
+ if not self.selected and self.mechs[mechanism] then
+ self.selected = mechanism;
+ return true;
end
- return object
end
-return _M; \ No newline at end of file
+-- feed new messages to process into the library
+function method:process(message)
+ --if message == "" or message == nil then return "failure", "malformed-request" end
+ return mechanisms[self.selected](self, message);
+end
+
+-- load the mechanisms
+require "util.sasl.plain" .init(registerMechanism);
+require "util.sasl.digest-md5".init(registerMechanism);
+require "util.sasl.anonymous" .init(registerMechanism);
+require "util.sasl.scram" .init(registerMechanism);
+
+return _M;
diff --git a/util/sasl/anonymous.lua b/util/sasl/anonymous.lua
new file mode 100644
index 00000000..ca5fe404
--- /dev/null
+++ b/util/sasl/anonymous.lua
@@ -0,0 +1,46 @@
+-- sasl.lua v0.4
+-- Copyright (C) 2008-2010 Tobias Markmann
+--
+-- All rights reserved.
+--
+-- Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+--
+-- * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+-- * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+-- * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
+--
+-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+local s_match = string.match;
+
+local log = require "util.logger".init("sasl");
+local generate_uuid = require "util.uuid".generate;
+
+module "sasl.anonymous"
+
+--=========================
+--SASL ANONYMOUS according to RFC 4505
+
+--[[
+Supported Authentication Backends
+
+anonymous:
+ function(username, realm)
+ return true; --for normal usage just return true; if you don't like the supplied username you can return false.
+ end
+]]
+
+local function anonymous(self, message)
+ local username;
+ repeat
+ username = generate_uuid();
+ until self.profile.anonymous(self, username, self.realm);
+ self.username = username;
+ return "success"
+end
+
+function init(registerMechanism)
+ registerMechanism("ANONYMOUS", {"anonymous"}, anonymous);
+end
+
+return _M;
diff --git a/util/sasl/digest-md5.lua b/util/sasl/digest-md5.lua
new file mode 100644
index 00000000..591d8537
--- /dev/null
+++ b/util/sasl/digest-md5.lua
@@ -0,0 +1,248 @@
+-- sasl.lua v0.4
+-- Copyright (C) 2008-2010 Tobias Markmann
+--
+-- All rights reserved.
+--
+-- Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+--
+-- * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+-- * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+-- * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
+--
+-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+local tostring = tostring;
+local type = type;
+
+local s_gmatch = string.gmatch;
+local s_match = string.match;
+local t_concat = table.concat;
+local t_insert = table.insert;
+local to_byte, to_char = string.byte, string.char;
+
+local md5 = require "util.hashes".md5;
+local log = require "util.logger".init("sasl");
+local generate_uuid = require "util.uuid".generate;
+local nodeprep = require "util.encodings".stringprep.nodeprep;
+
+module "sasl.digest-md5"
+
+--=========================
+--SASL DIGEST-MD5 according to RFC 2831
+
+--[[
+Supported Authentication Backends
+
+digest_md5:
+ function(username, domain, realm, encoding) -- domain and realm are usually the same; for some broken
+ -- implementations it's not
+ return digesthash, state;
+ end
+
+digest_md5_test:
+ function(username, domain, realm, encoding, digesthash)
+ return true or false, state;
+ end
+]]
+
+local function digest(self, message)
+ --TODO complete support for authzid
+
+ local function serialize(message)
+ local data = ""
+
+ -- testing all possible values
+ if message["realm"] then data = data..[[realm="]]..message.realm..[[",]] end
+ if message["nonce"] then data = data..[[nonce="]]..message.nonce..[[",]] end
+ if message["qop"] then data = data..[[qop="]]..message.qop..[[",]] end
+ if message["charset"] then data = data..[[charset=]]..message.charset.."," end
+ if message["algorithm"] then data = data..[[algorithm=]]..message.algorithm.."," end
+ if message["rspauth"] then data = data..[[rspauth=]]..message.rspauth.."," end
+ data = data:gsub(",$", "")
+ return data
+ end
+
+ local function utf8tolatin1ifpossible(passwd)
+ local i = 1;
+ while i <= #passwd do
+ local passwd_i = to_byte(passwd:sub(i, i));
+ if passwd_i > 0x7F then
+ if passwd_i < 0xC0 or passwd_i > 0xC3 then
+ return passwd;
+ end
+ i = i + 1;
+ passwd_i = to_byte(passwd:sub(i, i));
+ if passwd_i < 0x80 or passwd_i > 0xBF then
+ return passwd;
+ end
+ end
+ i = i + 1;
+ end
+
+ local p = {};
+ local j = 0;
+ i = 1;
+ while (i <= #passwd) do
+ local passwd_i = to_byte(passwd:sub(i, i));
+ if passwd_i > 0x7F then
+ i = i + 1;
+ local passwd_i_1 = to_byte(passwd:sub(i, i));
+ t_insert(p, to_char(passwd_i%4*64 + passwd_i_1%64)); -- I'm so clever
+ else
+ t_insert(p, to_char(passwd_i));
+ end
+ i = i + 1;
+ end
+ return t_concat(p);
+ end
+ local function latin1toutf8(str)
+ local p = {};
+ for ch in s_gmatch(str, ".") do
+ ch = to_byte(ch);
+ if (ch < 0x80) then
+ t_insert(p, to_char(ch));
+ elseif (ch < 0xC0) then
+ t_insert(p, to_char(0xC2, ch));
+ else
+ t_insert(p, to_char(0xC3, ch - 64));
+ end
+ end
+ return t_concat(p);
+ end
+ local function parse(data)
+ local message = {}
+ -- COMPAT: %z in the pattern to work around jwchat bug (sends "charset=utf-8\0")
+ for k, v in s_gmatch(data, [[([%w%-]+)="?([^",%z]*)"?,?]]) do -- FIXME The hacky regex makes me shudder
+ message[k] = v;
+ end
+ return message;
+ end
+
+ if not self.nonce then
+ self.nonce = generate_uuid();
+ self.step = 0;
+ self.nonce_count = {};
+ end
+
+ self.step = self.step + 1;
+ if (self.step == 1) then
+ local challenge = serialize({ nonce = self.nonce,
+ qop = "auth",
+ charset = "utf-8",
+ algorithm = "md5-sess",
+ realm = self.realm});
+ return "challenge", challenge;
+ elseif (self.step == 2) then
+ local response = parse(message);
+ -- check for replay attack
+ if response["nc"] then
+ if self.nonce_count[response["nc"]] then return "failure", "not-authorized" end
+ end
+
+ -- check for username, it's REQUIRED by RFC 2831
+ local username = response["username"];
+ local _nodeprep = self.profile.nodeprep;
+ if username and _nodeprep ~= false then
+ username = (_nodeprep or nodeprep)(username); -- FIXME charset
+ end
+ if not username or username == "" then
+ return "failure", "malformed-request";
+ end
+ self.username = username;
+
+ -- check for nonce, ...
+ if not response["nonce"] then
+ return "failure", "malformed-request";
+ else
+ -- check if it's the right nonce
+ if response["nonce"] ~= tostring(self.nonce) then return "failure", "malformed-request" end
+ end
+
+ if not response["cnonce"] then return "failure", "malformed-request", "Missing entry for cnonce in SASL message." end
+ if not response["qop"] then response["qop"] = "auth" end
+
+ if response["realm"] == nil or response["realm"] == "" then
+ response["realm"] = "";
+ elseif response["realm"] ~= self.realm then
+ return "failure", "not-authorized", "Incorrect realm value";
+ end
+
+ local decoder;
+ if response["charset"] == nil then
+ decoder = utf8tolatin1ifpossible;
+ elseif response["charset"] ~= "utf-8" then
+ return "failure", "incorrect-encoding", "The client's response uses "..response["charset"].." for encoding with isn't supported by sasl.lua. Supported encodings are latin or utf-8.";
+ end
+
+ local domain = "";
+ local protocol = "";
+ if response["digest-uri"] then
+ protocol, domain = response["digest-uri"]:match("(%w+)/(.*)$");
+ if protocol == nil or domain == nil then return "failure", "malformed-request" end
+ else
+ return "failure", "malformed-request", "Missing entry for digest-uri in SASL message."
+ end
+
+ --TODO maybe realm support
+ local Y, state;
+ if self.profile.plain then
+ local password, state = self.profile.plain(self, response["username"], self.realm)
+ if state == nil then return "failure", "not-authorized"
+ elseif state == false then return "failure", "account-disabled" end
+ Y = md5(response["username"]..":"..response["realm"]..":"..password);
+ elseif self.profile["digest-md5"] then
+ Y, state = self.profile["digest-md5"](self, response["username"], self.realm, response["realm"], response["charset"])
+ if state == nil then return "failure", "not-authorized"
+ elseif state == false then return "failure", "account-disabled" end
+ elseif self.profile["digest-md5-test"] then
+ -- TODO
+ end
+ --local password_encoding, Y = self.credentials_handler("DIGEST-MD5", response["username"], self.realm, response["realm"], decoder);
+ --if Y == nil then return "failure", "not-authorized"
+ --elseif Y == false then return "failure", "account-disabled" end
+ local A1 = "";
+ if response.authzid then
+ if response.authzid == self.username or response.authzid == self.username.."@"..self.realm then
+ -- COMPAT
+ log("warn", "Client is violating RFC 3920 (section 6.1, point 7).");
+ A1 = Y..":"..response["nonce"]..":"..response["cnonce"]..":"..response.authzid;
+ else
+ return "failure", "invalid-authzid";
+ end
+ else
+ A1 = Y..":"..response["nonce"]..":"..response["cnonce"];
+ end
+ local A2 = "AUTHENTICATE:"..protocol.."/"..domain;
+
+ local HA1 = md5(A1, true);
+ local HA2 = md5(A2, true);
+
+ local KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2;
+ local response_value = md5(KD, true);
+
+ if response_value == response["response"] then
+ -- calculate rspauth
+ A2 = ":"..protocol.."/"..domain;
+
+ HA1 = md5(A1, true);
+ HA2 = md5(A2, true);
+
+ KD = HA1..":"..response["nonce"]..":"..response["nc"]..":"..response["cnonce"]..":"..response["qop"]..":"..HA2
+ local rspauth = md5(KD, true);
+ self.authenticated = true;
+ --TODO: considering sending the rspauth in a success node for saving one roundtrip; allowed according to http://tools.ietf.org/html/draft-saintandre-rfc3920bis-09#section-7.3.6
+ return "challenge", serialize({rspauth = rspauth});
+ else
+ return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated."
+ end
+ elseif self.step == 3 then
+ if self.authenticated ~= nil then return "success"
+ else return "failure", "malformed-request" end
+ end
+end
+
+function init(registerMechanism)
+ registerMechanism("DIGEST-MD5", {"plain"}, digest);
+end
+
+return _M;
diff --git a/util/sasl/plain.lua b/util/sasl/plain.lua
new file mode 100644
index 00000000..c9ec2911
--- /dev/null
+++ b/util/sasl/plain.lua
@@ -0,0 +1,89 @@
+-- sasl.lua v0.4
+-- Copyright (C) 2008-2010 Tobias Markmann
+--
+-- All rights reserved.
+--
+-- Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+--
+-- * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+-- * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+-- * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
+--
+-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+local s_match = string.match;
+local saslprep = require "util.encodings".stringprep.saslprep;
+local nodeprep = require "util.encodings".stringprep.nodeprep;
+local log = require "util.logger".init("sasl");
+
+module "sasl.plain"
+
+-- ================================
+-- SASL PLAIN according to RFC 4616
+
+--[[
+Supported Authentication Backends
+
+plain:
+ function(username, realm)
+ return password, state;
+ end
+
+plain_test:
+ function(username, password, realm)
+ return true or false, state;
+ end
+]]
+
+local function plain(self, message)
+ if not message then
+ return "failure", "malformed-request";
+ end
+
+ local authorization, authentication, password = s_match(message, "^([^%z]*)%z([^%z]+)%z([^%z]+)");
+
+ if not authorization then
+ return "failure", "malformed-request";
+ end
+
+ -- SASLprep password and authentication
+ authentication = saslprep(authentication);
+ password = saslprep(password);
+
+ if (not password) or (password == "") or (not authentication) or (authentication == "") then
+ log("debug", "Username or password violates SASLprep.");
+ return "failure", "malformed-request", "Invalid username or password.";
+ end
+
+ local _nodeprep = self.profile.nodeprep;
+ if _nodeprep ~= false then
+ authentication = (_nodeprep or nodeprep)(authentication);
+ if not authentication or authentication == "" then
+ return "failure", "malformed-request", "Invalid username or password."
+ end
+ end
+
+ local correct, state = false, false;
+ if self.profile.plain then
+ local correct_password;
+ correct_password, state = self.profile.plain(self, authentication, self.realm);
+ correct = (correct_password == password);
+ elseif self.profile.plain_test then
+ correct, state = self.profile.plain_test(self, authentication, password, self.realm);
+ end
+
+ self.username = authentication
+ if state == false then
+ return "failure", "account-disabled";
+ elseif state == nil or not correct then
+ return "failure", "not-authorized", "Unable to authorize you with the authentication credentials you've sent.";
+ end
+
+ return "success";
+end
+
+function init(registerMechanism)
+ registerMechanism("PLAIN", {"plain", "plain_test"}, plain);
+end
+
+return _M;
diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua
new file mode 100644
index 00000000..cf2f0ede
--- /dev/null
+++ b/util/sasl/scram.lua
@@ -0,0 +1,216 @@
+-- sasl.lua v0.4
+-- Copyright (C) 2008-2010 Tobias Markmann
+--
+-- All rights reserved.
+--
+-- Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+--
+-- * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+-- * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+-- * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
+--
+-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+local s_match = string.match;
+local type = type
+local string = string
+local base64 = require "util.encodings".base64;
+local hmac_sha1 = require "util.hashes".hmac_sha1;
+local sha1 = require "util.hashes".sha1;
+local Hi = require "util.hashes".scram_Hi_sha1;
+local generate_uuid = require "util.uuid".generate;
+local saslprep = require "util.encodings".stringprep.saslprep;
+local nodeprep = require "util.encodings".stringprep.nodeprep;
+local log = require "util.logger".init("sasl");
+local t_concat = table.concat;
+local char = string.char;
+local byte = string.byte;
+
+module "sasl.scram"
+
+--=========================
+--SASL SCRAM-SHA-1 according to RFC 5802
+
+--[[
+Supported Authentication Backends
+
+scram_{MECH}:
+ -- MECH being a standard hash name (like those at IANA's hash registry) with '-' replaced with '_'
+ function(username, realm)
+ return stored_key, server_key, iteration_count, salt, state;
+ end
+]]
+
+local default_i = 4096
+
+local function bp( b )
+ local result = ""
+ for i=1, b:len() do
+ result = result.."\\"..b:byte(i)
+ end
+ return result
+end
+
+local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;};
+
+local result = {};
+local function binaryXOR( a, b )
+ for i=1, #a do
+ local x, y = byte(a, i), byte(b, i);
+ local lowx, lowy = x % 16, y % 16;
+ local hix, hiy = (x - lowx) / 16, (y - lowy) / 16;
+ local lowr, hir = xor_map[lowx * 16 + lowy + 1], xor_map[hix * 16 + hiy + 1];
+ local r = hir * 16 + lowr;
+ result[i] = char(r)
+ end
+ return t_concat(result);
+end
+
+local function validate_username(username, _nodeprep)
+ -- check for forbidden char sequences
+ for eq in username:gmatch("=(.?.?)") do
+ if eq ~= "2C" and eq ~= "3D" then
+ return false
+ end
+ end
+
+ -- replace =2C with , and =3D with =
+ username = username:gsub("=2C", ",");
+ username = username:gsub("=3D", "=");
+
+ -- apply SASLprep
+ username = saslprep(username);
+
+ if username and _nodeprep ~= false then
+ username = (_nodeprep or nodeprep)(username);
+ end
+
+ return username and #username>0 and username;
+end
+
+local function hashprep(hashname)
+ return hashname:lower():gsub("-", "_");
+end
+
+function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
+ if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
+ return false, "inappropriate argument types"
+ end
+ if iteration_count < 4096 then
+ log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
+ end
+ local salted_password = Hi(password, salt, iteration_count);
+ local stored_key = sha1(hmac_sha1(salted_password, "Client Key"))
+ local server_key = hmac_sha1(salted_password, "Server Key");
+ return true, stored_key, server_key
+end
+
+local function scram_gen(hash_name, H_f, HMAC_f)
+ local function scram_hash(self, message)
+ if not self.state then self["state"] = {} end
+
+ if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end
+ if not self.state.name then
+ -- we are processing client_first_message
+ local client_first_message = message;
+
+ -- TODO: fail if authzid is provided, since we don't support them yet
+ self.state["client_first_message"] = client_first_message;
+ self.state["gs2_cbind_flag"], self.state["authzid"], self.state["name"], self.state["clientnonce"]
+ = client_first_message:match("^(%a),(.*),n=(.*),r=([^,]*).*");
+
+ -- we don't do any channel binding yet
+ if self.state.gs2_cbind_flag ~= "n" and self.state.gs2_cbind_flag ~= "y" then
+ return "failure", "malformed-request";
+ end
+
+ if not self.state.name or not self.state.clientnonce then
+ return "failure", "malformed-request", "Channel binding isn't support at this time.";
+ end
+
+ self.state.name = validate_username(self.state.name, self.profile.nodeprep);
+ if not self.state.name then
+ log("debug", "Username violates either SASLprep or contains forbidden character sequences.")
+ return "failure", "malformed-request", "Invalid username.";
+ end
+
+ self.state["servernonce"] = generate_uuid();
+
+ -- retreive credentials
+ if self.profile.plain then
+ local password, state = self.profile.plain(self, self.state.name, self.realm)
+ if state == nil then return "failure", "not-authorized"
+ elseif state == false then return "failure", "account-disabled" end
+
+ password = saslprep(password);
+ if not password then
+ log("debug", "Password violates SASLprep.");
+ return "failure", "not-authorized", "Invalid password."
+ end
+
+ self.state.salt = generate_uuid();
+ self.state.iteration_count = default_i;
+
+ local succ = false;
+ succ, self.state.stored_key, self.state.server_key = getAuthenticationDatabaseSHA1(password, self.state.salt, default_i, self.state.iteration_count);
+ if not succ then
+ log("error", "Generating authentication database failed. Reason: %s", self.state.stored_key);
+ return "failure", "temporary-auth-failure";
+ end
+ elseif self.profile["scram_"..hashprep(hash_name)] then
+ local stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, self.state.name, self.realm);
+ if state == nil then return "failure", "not-authorized"
+ elseif state == false then return "failure", "account-disabled" end
+
+ self.state.stored_key = stored_key;
+ self.state.server_key = server_key;
+ self.state.iteration_count = iteration_count;
+ self.state.salt = salt
+ end
+
+ local server_first_message = "r="..self.state.clientnonce..self.state.servernonce..",s="..base64.encode(self.state.salt)..",i="..self.state.iteration_count;
+ self.state["server_first_message"] = server_first_message;
+ return "challenge", server_first_message
+ else
+ -- we are processing client_final_message
+ local client_final_message = message;
+
+ self.state["channelbinding"], self.state["nonce"], self.state["proof"] = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)");
+
+ if not self.state.proof or not self.state.nonce or not self.state.channelbinding then
+ return "failure", "malformed-request", "Missing an attribute(p, r or c) in SASL message.";
+ end
+
+ if self.state.nonce ~= self.state.clientnonce..self.state.servernonce then
+ return "failure", "malformed-request", "Wrong nonce in client-final-message.";
+ end
+
+ local ServerKey = self.state.server_key;
+ local StoredKey = self.state.stored_key;
+
+ local AuthMessage = "n=" .. s_match(self.state.client_first_message,"n=(.+)") .. "," .. self.state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+")
+ local ClientSignature = HMAC_f(StoredKey, AuthMessage)
+ local ClientKey = binaryXOR(ClientSignature, base64.decode(self.state.proof))
+ local ServerSignature = HMAC_f(ServerKey, AuthMessage)
+
+ if StoredKey == H_f(ClientKey) then
+ local server_final_message = "v="..base64.encode(ServerSignature);
+ self["username"] = self.state.name;
+ return "success", server_final_message;
+ else
+ return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated.";
+ end
+ end
+ end
+ return scram_hash;
+end
+
+function init(registerMechanism)
+ local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
+ registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash));
+ end
+
+ registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
+end
+
+return _M;
diff --git a/util/sasl_cyrus.lua b/util/sasl_cyrus.lua
new file mode 100644
index 00000000..19684587
--- /dev/null
+++ b/util/sasl_cyrus.lua
@@ -0,0 +1,166 @@
+-- sasl.lua v0.4
+-- Copyright (C) 2008-2009 Tobias Markmann
+--
+-- All rights reserved.
+--
+-- Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+--
+-- * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+-- * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+-- * Neither the name of Tobias Markmann nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
+--
+-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+local cyrussasl = require "cyrussasl";
+local log = require "util.logger".init("sasl_cyrus");
+
+local setmetatable = setmetatable
+
+local pcall = pcall
+local s_match, s_gmatch = string.match, string.gmatch
+
+local sasl_errstring = {
+ -- SASL result codes --
+ [1] = "another step is needed in authentication";
+ [0] = "successful result";
+ [-1] = "generic failure";
+ [-2] = "memory shortage failure";
+ [-3] = "overflowed buffer";
+ [-4] = "mechanism not supported";
+ [-5] = "bad protocol / cancel";
+ [-6] = "can't request info until later in exchange";
+ [-7] = "invalid parameter supplied";
+ [-8] = "transient failure (e.g., weak key)";
+ [-9] = "integrity check failed";
+ [-12] = "SASL library not initialized";
+
+ -- client only codes --
+ [2] = "needs user interaction";
+ [-10] = "server failed mutual authentication step";
+ [-11] = "mechanism doesn't support requested feature";
+
+ -- server only codes --
+ [-13] = "authentication failure";
+ [-14] = "authorization failure";
+ [-15] = "mechanism too weak for this user";
+ [-16] = "encryption needed to use mechanism";
+ [-17] = "One time use of a plaintext password will enable requested mechanism for user";
+ [-18] = "passphrase expired, has to be reset";
+ [-19] = "account disabled";
+ [-20] = "user not found";
+ [-23] = "version mismatch with plug-in";
+ [-24] = "remote authentication server unavailable";
+ [-26] = "user exists, but no verifier for user";
+
+ -- codes for password setting --
+ [-21] = "passphrase locked";
+ [-22] = "requested change was not needed";
+ [-27] = "passphrase is too weak for security policy";
+ [-28] = "user supplied passwords not permitted";
+};
+setmetatable(sasl_errstring, { __index = function() return "undefined error!" end });
+
+module "sasl_cyrus"
+
+local method = {};
+method.__index = method;
+local initialized = false;
+
+local function init(service_name)
+ if not initialized then
+ local st, errmsg = pcall(cyrussasl.server_init, service_name);
+ if st then
+ initialized = true;
+ else
+ log("error", "Failed to initialize Cyrus SASL: %s", errmsg);
+ end
+ end
+end
+
+-- create a new SASL object which can be used to authenticate clients
+-- host_fqdn may be nil in which case gethostname() gives the value.
+-- For GSSAPI, this determines the hostname in the service ticket (after
+-- reverse DNS canonicalization, only if [libdefaults] rdns = true which
+-- is the default).
+function new(realm, service_name, app_name, host_fqdn)
+
+ init(app_name or service_name);
+
+ local st, ret = pcall(cyrussasl.server_new, service_name, host_fqdn, realm, nil, nil)
+ if not st then
+ log("error", "Creating SASL server connection failed: %s", ret);
+ return nil;
+ end
+
+ local sasl_i = { realm = realm, service_name = service_name, cyrus = ret };
+
+ if cyrussasl.set_canon_cb then
+ local c14n_cb = function (user)
+ local node = s_match(user, "^([^@]+)");
+ log("debug", "Canonicalizing username %s to %s", user, node)
+ return node
+ end
+ cyrussasl.set_canon_cb(sasl_i.cyrus, c14n_cb);
+ end
+
+ cyrussasl.setssf(sasl_i.cyrus, 0, 0xffffffff)
+ local mechanisms = {};
+ local cyrus_mechs = cyrussasl.listmech(sasl_i.cyrus, nil, "", " ", "");
+ for w in s_gmatch(cyrus_mechs, "[^ ]+") do
+ mechanisms[w] = true;
+ end
+ sasl_i.mechs = mechanisms;
+ return setmetatable(sasl_i, method);
+end
+
+-- get a fresh clone with the same realm and service name
+function method:clean_clone()
+ return new(self.realm, self.service_name)
+end
+
+-- get a list of possible SASL mechanims to use
+function method:mechanisms()
+ return self.mechs;
+end
+
+-- select a mechanism to use
+function method:select(mechanism)
+ if not self.selected and self.mechs[mechanism] then
+ self.selected = mechanism;
+ return true;
+ end
+end
+
+-- feed new messages to process into the library
+function method:process(message)
+ local err;
+ local data;
+
+ if not self.first_step_done then
+ err, data = cyrussasl.server_start(self.cyrus, self.selected, message or "")
+ self.first_step_done = true;
+ else
+ err, data = cyrussasl.server_step(self.cyrus, message or "")
+ end
+
+ self.username = cyrussasl.get_username(self.cyrus)
+
+ if (err == 0) then -- SASL_OK
+ if self.require_provisioning and not self.require_provisioning(self.username) then
+ return "failure", "not-authorized", "User authenticated successfully, but not provisioned for XMPP";
+ end
+ return "success", data
+ elseif (err == 1) then -- SASL_CONTINUE
+ return "challenge", data
+ elseif (err == -4) then -- SASL_NOMECH
+ log("debug", "SASL mechanism not available from remote end")
+ return "failure", "invalid-mechanism", "SASL mechanism not available"
+ elseif (err == -13) then -- SASL_BADAUTH
+ return "failure", "not-authorized", sasl_errstring[err];
+ else
+ log("debug", "Got SASL error condition %d: %s", err, sasl_errstring[err]);
+ return "failure", "undefined-condition", sasl_errstring[err];
+ end
+end
+
+return _M;
diff --git a/util/serialization.lua b/util/serialization.lua
new file mode 100644
index 00000000..8a259184
--- /dev/null
+++ b/util/serialization.lua
@@ -0,0 +1,95 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local string_rep = string.rep;
+local type = type;
+local tostring = tostring;
+local t_insert = table.insert;
+local t_concat = table.concat;
+local error = error;
+local pairs = pairs;
+local next = next;
+
+local loadstring = loadstring;
+local pcall = pcall;
+
+local debug_traceback = debug.traceback;
+local log = require "util.logger".init("serialization");
+local envload = require"util.envload".envload;
+
+module "serialization"
+
+local indent = function(i)
+ return string_rep("\t", i);
+end
+local function basicSerialize (o)
+ if type(o) == "number" or type(o) == "boolean" then
+ -- no need to check for NaN, as that's not a valid table index
+ if o == 1/0 then return "(1/0)";
+ elseif o == -1/0 then return "(-1/0)";
+ else return tostring(o); end
+ else -- assume it is a string -- FIXME make sure it's a string. throw an error otherwise.
+ return (("%q"):format(tostring(o)):gsub("\\\n", "\\n"));
+ end
+end
+local function _simplesave(o, ind, t, func)
+ if type(o) == "number" then
+ if o ~= o then func(t, "(0/0)");
+ elseif o == 1/0 then func(t, "(1/0)");
+ elseif o == -1/0 then func(t, "(-1/0)");
+ else func(t, tostring(o)); end
+ elseif type(o) == "string" then
+ func(t, (("%q"):format(o):gsub("\\\n", "\\n")));
+ elseif type(o) == "table" then
+ if next(o) ~= nil then
+ func(t, "{\n");
+ for k,v in pairs(o) do
+ func(t, indent(ind));
+ func(t, "[");
+ func(t, basicSerialize(k));
+ func(t, "] = ");
+ if ind == 0 then
+ _simplesave(v, 0, t, func);
+ else
+ _simplesave(v, ind+1, t, func);
+ end
+ func(t, ";\n");
+ end
+ func(t, indent(ind-1));
+ func(t, "}");
+ else
+ func(t, "{}");
+ end
+ elseif type(o) == "boolean" then
+ func(t, (o and "true" or "false"));
+ else
+ log("error", "cannot serialize a %s: %s", type(o), debug_traceback())
+ func(t, "nil");
+ end
+end
+
+function append(t, o)
+ _simplesave(o, 1, t, t.write or t_insert);
+ return t;
+end
+
+function serialize(o)
+ return t_concat(append({}, o));
+end
+
+function deserialize(str)
+ if type(str) ~= "string" then return nil; end
+ str = "return "..str;
+ local f, err = envload(str, "@data", {});
+ if not f then return nil, err; end
+ local success, ret = pcall(f);
+ if not success then return nil, ret; end
+ return ret;
+end
+
+return _M;
diff --git a/util/set.lua b/util/set.lua
new file mode 100644
index 00000000..7f45526e
--- /dev/null
+++ b/util/set.lua
@@ -0,0 +1,159 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local ipairs, pairs, setmetatable, next, tostring =
+ ipairs, pairs, setmetatable, next, tostring;
+local t_concat = table.concat;
+
+module "set"
+
+local set_mt = {};
+function set_mt.__call(set, _, k)
+ return next(set._items, k);
+end
+function set_mt.__add(set1, set2)
+ return _M.union(set1, set2);
+end
+function set_mt.__sub(set1, set2)
+ return _M.difference(set1, set2);
+end
+function set_mt.__div(set, func)
+ local new_set, new_items = _M.new();
+ local items, new_items = set._items, new_set._items;
+ for item in pairs(items) do
+ local new_item = func(item);
+ if new_item ~= nil then
+ new_items[new_item] = true;
+ end
+ end
+ return new_set;
+end
+function set_mt.__eq(set1, set2)
+ local set1, set2 = set1._items, set2._items;
+ for item in pairs(set1) do
+ if not set2[item] then
+ return false;
+ end
+ end
+
+ for item in pairs(set2) do
+ if not set1[item] then
+ return false;
+ end
+ end
+
+ return true;
+end
+function set_mt.__tostring(set)
+ local s, items = { }, set._items;
+ for item in pairs(items) do
+ s[#s+1] = tostring(item);
+ end
+ return t_concat(s, ", ");
+end
+
+local items_mt = {};
+function items_mt.__call(items, _, k)
+ return next(items, k);
+end
+
+function new(list)
+ local items = setmetatable({}, items_mt);
+ local set = { _items = items };
+
+ function set:add(item)
+ items[item] = true;
+ end
+
+ function set:contains(item)
+ return items[item];
+ end
+
+ function set:items()
+ return items;
+ end
+
+ function set:remove(item)
+ items[item] = nil;
+ end
+
+ function set:add_list(list)
+ if list then
+ for _, item in ipairs(list) do
+ items[item] = true;
+ end
+ end
+ end
+
+ function set:include(otherset)
+ for item in pairs(otherset) do
+ items[item] = true;
+ end
+ end
+
+ function set:exclude(otherset)
+ for item in pairs(otherset) do
+ items[item] = nil;
+ end
+ end
+
+ function set:empty()
+ return not next(items);
+ end
+
+ if list then
+ set:add_list(list);
+ end
+
+ return setmetatable(set, set_mt);
+end
+
+function union(set1, set2)
+ local set = new();
+ local items = set._items;
+
+ for item in pairs(set1._items) do
+ items[item] = true;
+ end
+
+ for item in pairs(set2._items) do
+ items[item] = true;
+ end
+
+ return set;
+end
+
+function difference(set1, set2)
+ local set = new();
+ local items = set._items;
+
+ for item in pairs(set1._items) do
+ items[item] = (not set2._items[item]) or nil;
+ end
+
+ return set;
+end
+
+function intersection(set1, set2)
+ local set = new();
+ local items = set._items;
+
+ set1, set2 = set1._items, set2._items;
+
+ for item in pairs(set1) do
+ items[item] = (not not set2[item]) or nil;
+ end
+
+ return set;
+end
+
+function xor(set1, set2)
+ return union(set1, set2) - intersection(set1, set2);
+end
+
+return _M;
diff --git a/util/sql.lua b/util/sql.lua
new file mode 100644
index 00000000..f360d6d0
--- /dev/null
+++ b/util/sql.lua
@@ -0,0 +1,340 @@
+
+local setmetatable, getmetatable = setmetatable, getmetatable;
+local ipairs, unpack, select = ipairs, unpack, select;
+local tonumber, tostring = tonumber, tostring;
+local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback;
+local t_concat = table.concat;
+local s_char = string.char;
+local log = require "util.logger".init("sql");
+
+local DBI = require "DBI";
+-- This loads all available drivers while globals are unlocked
+-- LuaDBI should be fixed to not set globals.
+DBI.Drivers();
+local build_url = require "socket.url".build;
+
+module("sql")
+
+local column_mt = {};
+local table_mt = {};
+local query_mt = {};
+--local op_mt = {};
+local index_mt = {};
+
+function is_column(x) return getmetatable(x)==column_mt; end
+function is_index(x) return getmetatable(x)==index_mt; end
+function is_table(x) return getmetatable(x)==table_mt; end
+function is_query(x) return getmetatable(x)==query_mt; end
+--function is_op(x) return getmetatable(x)==op_mt; end
+--function expr(...) return setmetatable({...}, op_mt); end
+function Integer(n) return "Integer()" end
+function String(n) return "String()" end
+
+--[[local ops = {
+ __add = function(a, b) return "("..a.."+"..b..")" end;
+ __sub = function(a, b) return "("..a.."-"..b..")" end;
+ __mul = function(a, b) return "("..a.."*"..b..")" end;
+ __div = function(a, b) return "("..a.."/"..b..")" end;
+ __mod = function(a, b) return "("..a.."%"..b..")" end;
+ __pow = function(a, b) return "POW("..a..","..b..")" end;
+ __unm = function(a) return "NOT("..a..")" end;
+ __len = function(a) return "COUNT("..a..")" end;
+ __eq = function(a, b) return "("..a.."=="..b..")" end;
+ __lt = function(a, b) return "("..a.."<"..b..")" end;
+ __le = function(a, b) return "("..a.."<="..b..")" end;
+};
+
+local functions = {
+
+};
+
+local cmap = {
+ [Integer] = Integer();
+ [String] = String();
+};]]
+
+function Column(definition)
+ return setmetatable(definition, column_mt);
+end
+function Table(definition)
+ local c = {}
+ for i,col in ipairs(definition) do
+ if is_column(col) then
+ c[i], c[col.name] = col, col;
+ elseif is_index(col) then
+ col.table = definition.name;
+ end
+ end
+ return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
+end
+function Index(definition)
+ return setmetatable(definition, index_mt);
+end
+
+function table_mt:__tostring()
+ local s = { 'name="'..self.__table__.name..'"' }
+ for i,col in ipairs(self.__table__) do
+ s[#s+1] = tostring(col);
+ end
+ return 'Table{ '..t_concat(s, ", ")..' }'
+end
+table_mt.__index = {};
+function table_mt.__index:create(engine)
+ return engine:_create_table(self);
+end
+function table_mt:__call(...)
+ -- TODO
+end
+function column_mt:__tostring()
+ return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
+end
+function index_mt:__tostring()
+ local s = 'Index{ name="'..self.name..'"';
+ for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end
+ return s..' }';
+-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
+end
+--
+
+local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
+local function parse_url(url)
+ local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)");
+ assert(scheme, "Invalid URL format");
+ local username, password, host, port;
+ local authpart, hostpart = secondpart:match("([^@]+)@([^@+])");
+ if not authpart then hostpart = secondpart; end
+ if authpart then
+ username, password = authpart:match("([^:]*):(.*)");
+ username = username or authpart;
+ password = password and urldecode(password);
+ end
+ if hostpart then
+ host, port = hostpart:match("([^:]*):(.*)");
+ host = host or hostpart;
+ port = port and assert(tonumber(port), "Invalid URL format");
+ end
+ return {
+ scheme = scheme:lower();
+ username = username; password = password;
+ host = host; port = port;
+ database = #database > 0 and database or nil;
+ };
+end
+
+--[[local session = {};
+
+function session.query(...)
+ local rets = {...};
+ local query = setmetatable({ __rets = rets, __filters }, query_mt);
+ return query;
+end
+--
+
+local function db2uri(params)
+ return build_url{
+ scheme = params.driver,
+ user = params.username,
+ password = params.password,
+ host = params.host,
+ port = params.port,
+ path = params.database,
+ };
+end]]
+
+local engine = {};
+function engine:connect()
+ if self.conn then return true; end
+
+ local params = self.params;
+ assert(params.driver, "no driver")
+ local dbh, err = DBI.Connect(
+ params.driver, params.database,
+ params.username, params.password,
+ params.host, params.port
+ );
+ if not dbh then return nil, err; end
+ dbh:autocommit(false); -- don't commit automatically
+ self.conn = dbh;
+ self.prepared = {};
+ return true;
+end
+function engine:execute(sql, ...)
+ local success, err = self:connect();
+ if not success then return success, err; end
+ local prepared = self.prepared;
+
+ local stmt = prepared[sql];
+ if not stmt then
+ local err;
+ stmt, err = self.conn:prepare(sql);
+ if not stmt then return stmt, err; end
+ prepared[sql] = stmt;
+ end
+
+ local success, err = stmt:execute(...);
+ if not success then return success, err; end
+ return stmt;
+end
+
+local result_mt = { __index = {
+ affected = function(self) return self.__affected; end;
+ rowcount = function(self) return self.__rowcount; end;
+} };
+
+function engine:execute_query(sql, ...)
+ if self.params.driver == "PostgreSQL" then
+ sql = sql:gsub("`", "\"");
+ end
+ local stmt = assert(self.conn:prepare(sql));
+ assert(stmt:execute(...));
+ return stmt:rows();
+end
+function engine:execute_update(sql, ...)
+ if self.params.driver == "PostgreSQL" then
+ sql = sql:gsub("`", "\"");
+ end
+ local prepared = self.prepared;
+ local stmt = prepared[sql];
+ if not stmt then
+ stmt = assert(self.conn:prepare(sql));
+ prepared[sql] = stmt;
+ end
+ assert(stmt:execute(...));
+ return setmetatable({ __affected = stmt:affected(), __rowcount = stmt:rowcount() }, result_mt);
+end
+engine.insert = engine.execute_update;
+engine.select = engine.execute_query;
+engine.delete = engine.execute_update;
+engine.update = engine.execute_update;
+function engine:_transaction(func, ...)
+ if not self.conn then
+ local a,b = self:connect();
+ if not a then return a,b; end
+ end
+ --assert(not self.__transaction, "Recursive transactions not allowed");
+ local args, n_args = {...}, select("#", ...);
+ local function f() return func(unpack(args, 1, n_args)); end
+ self.__transaction = true;
+ local success, a, b, c = xpcall(f, debug_traceback);
+ self.__transaction = nil;
+ if success then
+ log("debug", "SQL transaction success [%s]", tostring(func));
+ local ok, err = self.conn:commit();
+ if not ok then return ok, err; end -- commit failed
+ return success, a, b, c;
+ else
+ log("debug", "SQL transaction failure [%s]: %s", tostring(func), a);
+ if self.conn then self.conn:rollback(); end
+ return success, a;
+ end
+end
+function engine:transaction(...)
+ local a,b = self:_transaction(...);
+ if not a then
+ local conn = self.conn;
+ if not conn or not conn:ping() then
+ self.conn = nil;
+ a,b = self:_transaction(...);
+ end
+ end
+ return a,b;
+end
+function engine:_create_index(index)
+ local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` (";
+ for i=1,#index do
+ sql = sql.."`"..index[i].."`";
+ if i ~= #index then sql = sql..", "; end
+ end
+ sql = sql..");"
+ if self.params.driver == "PostgreSQL" then
+ sql = sql:gsub("`", "\"");
+ elseif self.params.driver == "MySQL" then
+ sql = sql:gsub("`([,)])", "`(20)%1");
+ end
+ --print(sql);
+ return self:execute(sql);
+end
+function engine:_create_table(table)
+ local sql = "CREATE TABLE `"..table.name.."` (";
+ for i,col in ipairs(table.c) do
+ sql = sql.."`"..col.name.."` "..col.type;
+ if col.nullable == false then sql = sql.." NOT NULL"; end
+ if i ~= #table.c then sql = sql..", "; end
+ end
+ sql = sql.. ");"
+ if self.params.driver == "PostgreSQL" then
+ sql = sql:gsub("`", "\"");
+ end
+ local success,err = self:execute(sql);
+ if not success then return success,err; end
+ for i,v in ipairs(table.__table__) do
+ if is_index(v) then
+ self:_create_index(v);
+ end
+ end
+ return success;
+end
+local engine_mt = { __index = engine };
+
+local function db2uri(params)
+ return build_url{
+ scheme = params.driver,
+ user = params.username,
+ password = params.password,
+ host = params.host,
+ port = params.port,
+ path = params.database,
+ };
+end
+local engine_cache = {}; -- TODO make weak valued
+function create_engine(self, params)
+ local url = db2uri(params);
+ if not engine_cache[url] then
+ local engine = setmetatable({ url = url, params = params }, engine_mt);
+ engine_cache[url] = engine;
+ end
+ return engine_cache[url];
+end
+
+
+--[[Users = Table {
+ name="users";
+ Column { name="user_id", type=String(), primary_key=true };
+};
+print(Users)
+print(Users.c.user_id)]]
+
+--local engine = create_engine('postgresql://scott:tiger@localhost:5432/mydatabase');
+--[[local engine = create_engine{ driver = "SQLite3", database = "./alchemy.sqlite" };
+
+local i = 0;
+for row in assert(engine:execute("select * from sqlite_master")):rows(true) do
+ i = i+1;
+ print(i);
+ for k,v in pairs(row) do
+ print("",k,v);
+ end
+end
+print("---")
+
+Prosody = Table {
+ name="prosody";
+ Column { name="host", type="TEXT", nullable=false };
+ Column { name="user", type="TEXT", nullable=false };
+ Column { name="store", type="TEXT", nullable=false };
+ Column { name="key", type="TEXT", nullable=false };
+ Column { name="type", type="TEXT", nullable=false };
+ Column { name="value", type="TEXT", nullable=false };
+ Index { name="prosody_index", "host", "user", "store", "key" };
+};
+--print(Prosody);
+assert(engine:transaction(function()
+ assert(Prosody:create(engine));
+end));
+
+for row in assert(engine:execute("select user from prosody")):rows(true) do
+ print("username:", row['username'])
+end
+--result.close();]]
+
+return _M;
diff --git a/util/stanza.lua b/util/stanza.lua
index 85595664..7c214210 100644
--- a/util/stanza.lua
+++ b/util/stanza.lua
@@ -1,117 +1,363 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+
local t_insert = table.insert;
local t_remove = table.remove;
+local t_concat = table.concat;
local s_format = string.format;
+local s_match = string.match;
local tostring = tostring;
local setmetatable = setmetatable;
local pairs = pairs;
local ipairs = ipairs;
local type = type;
local s_gsub = string.gsub;
+local s_sub = string.sub;
+local s_find = string.find;
+local os = os;
+
+local do_pretty_printing = not os.getenv("WINDIR");
+local getstyle, getstring;
+if do_pretty_printing then
+ local ok, termcolours = pcall(require, "util.termcolours");
+ if ok then
+ getstyle, getstring = termcolours.getstyle, termcolours.getstring;
+ else
+ do_pretty_printing = nil;
+ end
+end
+
+local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas";
+
module "stanza"
-stanza_mt = {};
+stanza_mt = { __type = "stanza" };
stanza_mt.__index = stanza_mt;
+local stanza_mt = stanza_mt;
function stanza(name, attr)
- local stanza = { name = name, attr = attr or {}, tags = {}, last_add = {}};
+ local stanza = { name = name, attr = attr or {}, tags = {} };
return setmetatable(stanza, stanza_mt);
end
+local stanza = stanza;
function stanza_mt:query(xmlns)
return self:tag("query", { xmlns = xmlns });
end
+
+function stanza_mt:body(text, attr)
+ return self:tag("body", attr):text(text);
+end
+
function stanza_mt:tag(name, attrs)
local s = stanza(name, attrs);
- (self.last_add[#self.last_add] or self):add_child(s);
- t_insert(self.last_add, s);
+ local last_add = self.last_add;
+ if not last_add then last_add = {}; self.last_add = last_add; end
+ (last_add[#last_add] or self):add_direct_child(s);
+ t_insert(last_add, s);
return self;
end
function stanza_mt:text(text)
- (self.last_add[#self.last_add] or self):add_child(text);
- return self;
+ local last_add = self.last_add;
+ (last_add and last_add[#last_add] or self):add_direct_child(text);
+ return self;
end
function stanza_mt:up()
- t_remove(self.last_add);
+ local last_add = self.last_add;
+ if last_add then t_remove(last_add); end
return self;
end
-function stanza_mt:add_child(child)
+function stanza_mt:reset()
+ self.last_add = nil;
+ return self;
+end
+
+function stanza_mt:add_direct_child(child)
if type(child) == "table" then
t_insert(self.tags, child);
end
t_insert(self, child);
end
+function stanza_mt:add_child(child)
+ local last_add = self.last_add;
+ (last_add and last_add[#last_add] or self):add_direct_child(child);
+ return self;
+end
+
+function stanza_mt:get_child(name, xmlns)
+ for _, child in ipairs(self.tags) do
+ if (not name or child.name == name)
+ and ((not xmlns and self.attr.xmlns == child.attr.xmlns)
+ or child.attr.xmlns == xmlns) then
+
+ return child;
+ end
+ end
+end
+
+function stanza_mt:get_child_text(name, xmlns)
+ local tag = self:get_child(name, xmlns);
+ if tag then
+ return tag:get_text();
+ end
+ return nil;
+end
+
function stanza_mt:child_with_name(name)
- for _, child in ipairs(self) do
+ for _, child in ipairs(self.tags) do
if child.name == name then return child; end
end
end
+function stanza_mt:child_with_ns(ns)
+ for _, child in ipairs(self.tags) do
+ if child.attr.xmlns == ns then return child; end
+ end
+end
+
function stanza_mt:children()
local i = 0;
return function (a)
i = i + 1
- local v = a[i]
- if v then return v; end
+ return a[i];
end, self, i;
-
end
-function stanza_mt:childtags()
- local i = 0;
- return function (a)
- i = i + 1
- local v = self.tags[i]
- if v then return v; end
- end, self.tags[1], i;
-
+
+function stanza_mt:childtags(name, xmlns)
+ local tags = self.tags;
+ local start_i, max_i = 1, #tags;
+ return function ()
+ for i = start_i, max_i do
+ local v = tags[i];
+ if (not name or v.name == name)
+ and ((not xmlns and self.attr.xmlns == v.attr.xmlns)
+ or v.attr.xmlns == xmlns) then
+ start_i = i+1;
+ return v;
+ end
+ end
+ end;
end
-do
- local xml_entities = { ["'"] = "&apos;", ["\""] = "&quot;", ["<"] = "&lt;", [">"] = "&gt;", ["&"] = "&amp;" };
- function xml_escape(s) return s_gsub(s, "['&<>\"]", xml_entities); end
+function stanza_mt:maptags(callback)
+ local tags, curr_tag = self.tags, 1;
+ local n_children, n_tags = #self, #tags;
+
+ local i = 1;
+ while curr_tag <= n_tags and n_tags > 0 do
+ if self[i] == tags[curr_tag] then
+ local ret = callback(self[i]);
+ if ret == nil then
+ t_remove(self, i);
+ t_remove(tags, curr_tag);
+ n_children = n_children - 1;
+ n_tags = n_tags - 1;
+ i = i - 1;
+ curr_tag = curr_tag - 1;
+ else
+ self[i] = ret;
+ tags[curr_tag] = ret;
+ end
+ curr_tag = curr_tag + 1;
+ end
+ i = i + 1;
+ end
+ return self;
end
-local xml_escape = xml_escape;
+function stanza_mt:find(path)
+ local pos = 1;
+ local len = #path + 1;
-function stanza_mt.__tostring(t)
- local children_text = "";
- for n, child in ipairs(t) do
- if type(child) == "string" then
- children_text = children_text .. xml_escape(child);
- else
- children_text = children_text .. tostring(child);
+ repeat
+ local xmlns, name, text;
+ local char = s_sub(path, pos, pos);
+ if char == "@" then
+ return self.attr[s_sub(path, pos + 1)];
+ elseif char == "{" then
+ xmlns, pos = s_match(path, "^([^}]+)}()", pos + 1);
+ end
+ name, text, pos = s_match(path, "^([^@/#]*)([/#]?)()", pos);
+ name = name ~= "" and name or nil;
+ if pos == len then
+ if text == "#" then
+ return self:get_child_text(name, xmlns);
+ end
+ return self:get_child(name, xmlns);
+ end
+ self = self:get_child(name, xmlns);
+ until not self
+end
+
+
+local xml_escape
+do
+ local escape_table = { ["'"] = "&apos;", ["\""] = "&quot;", ["<"] = "&lt;", [">"] = "&gt;", ["&"] = "&amp;" };
+ function xml_escape(str) return (s_gsub(str, "['&<>\"]", escape_table)); end
+ _M.xml_escape = xml_escape;
+end
+
+local function _dostring(t, buf, self, xml_escape, parentns)
+ local nsid = 0;
+ local name = t.name
+ t_insert(buf, "<"..name);
+ for k, v in pairs(t.attr) do
+ if s_find(k, "\1", 1, true) then
+ local ns, attrk = s_match(k, "^([^\1]*)\1?(.*)$");
+ nsid = nsid + 1;
+ t_insert(buf, " xmlns:ns"..nsid.."='"..xml_escape(ns).."' ".."ns"..nsid..":"..attrk.."='"..xml_escape(v).."'");
+ elseif not(k == "xmlns" and v == parentns) then
+ t_insert(buf, " "..k.."='"..xml_escape(v).."'");
end
end
+ local len = #t;
+ if len == 0 then
+ t_insert(buf, "/>");
+ else
+ t_insert(buf, ">");
+ for n=1,len do
+ local child = t[n];
+ if child.name then
+ self(child, buf, self, xml_escape, t.attr.xmlns);
+ else
+ t_insert(buf, xml_escape(child));
+ end
+ end
+ t_insert(buf, "</"..name..">");
+ end
+end
+function stanza_mt.__tostring(t)
+ local buf = {};
+ _dostring(t, buf, _dostring, xml_escape, nil);
+ return t_concat(buf);
+end
+function stanza_mt.top_tag(t)
local attr_string = "";
if t.attr then
- for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(" %s='%s'", k, tostring(v)); end end
+ for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(" %s='%s'", k, xml_escape(tostring(v))); end end
end
-
- return s_format("<%s%s>%s</%s>", t.name, attr_string, children_text, t.name);
+ return s_format("<%s%s>", t.name, attr_string);
end
-function stanza_mt.__add(s1, s2)
- return s1:add_child(s2);
+function stanza_mt.get_text(t)
+ if #t.tags == 0 then
+ return t_concat(t);
+ end
end
+function stanza_mt.get_error(stanza)
+ local type, condition, text;
+
+ local error_tag = stanza:get_child("error");
+ if not error_tag then
+ return nil, nil, nil;
+ end
+ type = error_tag.attr.type;
+
+ for _, child in ipairs(error_tag.tags) do
+ if child.attr.xmlns == xmlns_stanzas then
+ if not text and child.name == "text" then
+ text = child:get_text();
+ elseif not condition then
+ condition = child.name;
+ end
+ if condition and text then
+ break;
+ end
+ end
+ end
+ return type, condition or "undefined-condition", text;
+end
do
- local id = 0;
- function new_id()
- id = id + 1;
- return "lx"..id;
- end
+ local id = 0;
+ function new_id()
+ id = id + 1;
+ return "lx"..id;
+ end
+end
+
+function preserialize(stanza)
+ local s = { name = stanza.name, attr = stanza.attr };
+ for _, child in ipairs(stanza) do
+ if type(child) == "table" then
+ t_insert(s, preserialize(child));
+ else
+ t_insert(s, child);
+ end
+ end
+ return s;
+end
+
+function deserialize(stanza)
+ -- Set metatable
+ if stanza then
+ local attr = stanza.attr;
+ for i=1,#attr do attr[i] = nil; end
+ local attrx = {};
+ for att in pairs(attr) do
+ if s_find(att, "|", 1, true) and not s_find(att, "\1", 1, true) then
+ local ns,na = s_match(att, "^([^|]+)|(.+)$");
+ attrx[ns.."\1"..na] = attr[att];
+ attr[att] = nil;
+ end
+ end
+ for a,v in pairs(attrx) do
+ attr[a] = v;
+ end
+ setmetatable(stanza, stanza_mt);
+ for _, child in ipairs(stanza) do
+ if type(child) == "table" then
+ deserialize(child);
+ end
+ end
+ if not stanza.tags then
+ -- Rebuild tags
+ local tags = {};
+ for _, child in ipairs(stanza) do
+ if type(child) == "table" then
+ t_insert(tags, child);
+ end
+ end
+ stanza.tags = tags;
+ end
+ end
+
+ return stanza;
+end
+
+local function _clone(stanza)
+ local attr, tags = {}, {};
+ for k,v in pairs(stanza.attr) do attr[k] = v; end
+ local new = { name = stanza.name, attr = attr, tags = tags };
+ for i=1,#stanza do
+ local child = stanza[i];
+ if child.name then
+ child = _clone(child);
+ t_insert(tags, child);
+ end
+ t_insert(new, child);
+ end
+ return setmetatable(new, stanza_mt);
end
+clone = _clone;
function message(attr, body)
if not body then
return stanza("message", attr);
else
- return stanza("message", attr):tag("body"):text(body);
+ return stanza("message", attr):tag("body"):text(body):up();
end
end
function iq(attr)
@@ -120,21 +366,63 @@ function iq(attr)
end
function reply(orig)
- return stanza(orig.name, orig.attr and { to = orig.attr.from, from = orig.attr.to, id = orig.attr.id, type = ((orig.name == "iq" and "result") or nil) });
+ return stanza(orig.name, orig.attr and { to = orig.attr.from, from = orig.attr.to, id = orig.attr.id, type = ((orig.name == "iq" and "result") or orig.attr.type) });
end
-function error_reply(orig, type, condition, message, clone)
- local r = reply(orig);
- t.attr.type = "error";
- -- TODO use clone
- t:tag("error", {type = type})
- :tag(condition, {xmlns = "urn:ietf:params:xml:ns:xmpp-stanzas"}):up();
- if (message) then t:tag("text"):text(message):up(); end
- return t; -- stanza ready for adding app-specific errors
+do
+ local xmpp_stanzas_attr = { xmlns = xmlns_stanzas };
+ function error_reply(orig, type, condition, message)
+ local t = reply(orig);
+ t.attr.type = "error";
+ t:tag("error", {type = type}) --COMPAT: Some day xmlns:stanzas goes here
+ :tag(condition, xmpp_stanzas_attr):up();
+ if (message) then t:tag("text", xmpp_stanzas_attr):text(message):up(); end
+ return t; -- stanza ready for adding app-specific errors
+ end
end
function presence(attr)
return stanza("presence", attr);
end
-return _M; \ No newline at end of file
+if do_pretty_printing then
+ local style_attrk = getstyle("yellow");
+ local style_attrv = getstyle("red");
+ local style_tagname = getstyle("red");
+ local style_punc = getstyle("magenta");
+
+ local attr_format = " "..getstring(style_attrk, "%s")..getstring(style_punc, "=")..getstring(style_attrv, "'%s'");
+ local top_tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">");
+ --local tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">").."%s"..getstring(style_punc, "</")..getstring(style_tagname, "%s")..getstring(style_punc, ">");
+ local tag_format = top_tag_format.."%s"..getstring(style_punc, "</")..getstring(style_tagname, "%s")..getstring(style_punc, ">");
+ function stanza_mt.pretty_print(t)
+ local children_text = "";
+ for n, child in ipairs(t) do
+ if type(child) == "string" then
+ children_text = children_text .. xml_escape(child);
+ else
+ children_text = children_text .. child:pretty_print();
+ end
+ end
+
+ local attr_string = "";
+ if t.attr then
+ for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(attr_format, k, tostring(v)); end end
+ end
+ return s_format(tag_format, t.name, attr_string, children_text, t.name);
+ end
+
+ function stanza_mt.pretty_top_tag(t)
+ local attr_string = "";
+ if t.attr then
+ for k, v in pairs(t.attr) do if type(k) == "string" then attr_string = attr_string .. s_format(attr_format, k, tostring(v)); end end
+ end
+ return s_format(top_tag_format, t.name, attr_string);
+ end
+else
+ -- Sorry, fresh out of colours for you guys ;)
+ stanza_mt.pretty_print = stanza_mt.__tostring;
+ stanza_mt.pretty_top_tag = stanza_mt.top_tag;
+end
+
+return _M;
diff --git a/util/template.lua b/util/template.lua
new file mode 100644
index 00000000..66d4fca7
--- /dev/null
+++ b/util/template.lua
@@ -0,0 +1,97 @@
+
+local stanza_mt = require "util.stanza".stanza_mt;
+local setmetatable = setmetatable;
+local pairs = pairs;
+local ipairs = ipairs;
+local error = error;
+local loadstring = loadstring;
+local debug = debug;
+local t_remove = table.remove;
+local parse_xml = require "util.xml".parse;
+
+module("template")
+
+local function trim_xml(stanza)
+ for i=#stanza,1,-1 do
+ local child = stanza[i];
+ if child.name then
+ trim_xml(child);
+ else
+ child = child:gsub("^%s*", ""):gsub("%s*$", "");
+ stanza[i] = child;
+ if child == "" then t_remove(stanza, i); end
+ end
+ end
+end
+
+local function create_string_string(str)
+ str = ("%q"):format(str);
+ str = str:gsub("{([^}]*)}", function(s)
+ return '"..(data["'..s..'"]or"").."';
+ end);
+ return str;
+end
+local function create_attr_string(attr, xmlns)
+ local str = '{';
+ for name,value in pairs(attr) do
+ if name ~= "xmlns" or value ~= xmlns then
+ str = str..("[%q]=%s;"):format(name, create_string_string(value));
+ end
+ end
+ return str..'}';
+end
+local function create_clone_string(stanza, lookup, xmlns)
+ if not lookup[stanza] then
+ local s = ('setmetatable({name=%q,attr=%s,tags={'):format(stanza.name, create_attr_string(stanza.attr, xmlns));
+ -- add tags
+ for i,tag in ipairs(stanza.tags) do
+ s = s..create_clone_string(tag, lookup, stanza.attr.xmlns)..";";
+ end
+ s = s..'};';
+ -- add children
+ for i,child in ipairs(stanza) do
+ if child.name then
+ s = s..create_clone_string(child, lookup, stanza.attr.xmlns)..";";
+ else
+ s = s..create_string_string(child)..";"
+ end
+ end
+ s = s..'}, stanza_mt)';
+ s = s:gsub('%.%.""', ""):gsub('([=;])""%.%.', "%1"):gsub(';"";', ";"); -- strip empty strings
+ local n = #lookup + 1;
+ lookup[n] = s;
+ lookup[stanza] = "_"..n;
+ end
+ return lookup[stanza];
+end
+local function create_cloner(stanza, chunkname)
+ local lookup = {};
+ local name = create_clone_string(stanza, lookup, "");
+ local f = "local setmetatable,stanza_mt=...;return function(data)";
+ for i=1,#lookup do
+ f = f.."local _"..i.."="..lookup[i]..";";
+ end
+ f = f.."return "..name..";end";
+ local f,err = loadstring(f, chunkname);
+ if not f then error(err); end
+ return f(setmetatable, stanza_mt);
+end
+
+local template_mt = { __tostring = function(t) return t.name end };
+local function create_template(templates, text)
+ local stanza, err = parse_xml(text);
+ if not stanza then error(err); end
+ trim_xml(stanza);
+
+ local info = debug.getinfo(3, "Sl");
+ info = info and ("template(%s:%d)"):format(info.short_src:match("[^\\/]*$"), info.currentline) or "template(unknown)";
+
+ local template = setmetatable({ apply = create_cloner(stanza, info), name = info, text = text }, template_mt);
+ templates[text] = template;
+ return template;
+end
+
+local templates = setmetatable({}, { __mode = 'k', __index = create_template });
+return function(text)
+ return templates[text];
+end;
diff --git a/util/termcolours.lua b/util/termcolours.lua
new file mode 100644
index 00000000..6ef3b689
--- /dev/null
+++ b/util/termcolours.lua
@@ -0,0 +1,102 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+
+local t_concat, t_insert = table.concat, table.insert;
+local char, format = string.char, string.format;
+local tonumber = tonumber;
+local ipairs = ipairs;
+local io_write = io.write;
+
+local windows;
+if os.getenv("WINDIR") then
+ windows = require "util.windows";
+end
+local orig_color = windows and windows.get_consolecolor and windows.get_consolecolor();
+
+module "termcolours"
+
+local stylemap = {
+ reset = 0; bright = 1, dim = 2, underscore = 4, blink = 5, reverse = 7, hidden = 8;
+ black = 30; red = 31; green = 32; yellow = 33; blue = 34; magenta = 35; cyan = 36; white = 37;
+ ["black background"] = 40; ["red background"] = 41; ["green background"] = 42; ["yellow background"] = 43; ["blue background"] = 44; ["magenta background"] = 45; ["cyan background"] = 46; ["white background"] = 47;
+ bold = 1, dark = 2, underline = 4, underlined = 4, normal = 0;
+ }
+
+local winstylemap = {
+ ["0"] = orig_color, -- reset
+ ["1"] = 7+8, -- bold
+ ["1;33"] = 2+4+8, -- bold yellow
+ ["1;31"] = 4+8 -- bold red
+}
+
+local cssmap = {
+ [1] = "font-weight: bold", [2] = "opacity: 0.5", [4] = "text-decoration: underline", [8] = "visibility: hidden",
+ [30] = "color:black", [31] = "color:red", [32]="color:green", [33]="color:#FFD700",
+ [34] = "color:blue", [35] = "color: magenta", [36] = "color:cyan", [37] = "color: white",
+ [40] = "background-color:black", [41] = "background-color:red", [42]="background-color:green",
+ [43]="background-color:yellow", [44] = "background-color:blue", [45] = "background-color: magenta",
+ [46] = "background-color:cyan", [47] = "background-color: white";
+};
+
+local fmt_string = char(0x1B).."[%sm%s"..char(0x1B).."[0m";
+function getstring(style, text)
+ if style then
+ return format(fmt_string, style, text);
+ else
+ return text;
+ end
+end
+
+function getstyle(...)
+ local styles, result = { ... }, {};
+ for i, style in ipairs(styles) do
+ style = stylemap[style];
+ if style then
+ t_insert(result, style);
+ end
+ end
+ return t_concat(result, ";");
+end
+
+local last = "0";
+function setstyle(style)
+ style = style or "0";
+ if style ~= last then
+ io_write("\27["..style.."m");
+ last = style;
+ end
+end
+
+if windows then
+ function setstyle(style)
+ style = style or "0";
+ if style ~= last then
+ windows.set_consolecolor(winstylemap[style] or orig_color);
+ last = style;
+ end
+ end
+ if not orig_color then
+ function setstyle(style) end
+ end
+end
+
+local function ansi2css(ansi_codes)
+ if ansi_codes == "0" then return "</span>"; end
+ local css = {};
+ for code in ansi_codes:gmatch("[^;]+") do
+ t_insert(css, cssmap[tonumber(code)]);
+ end
+ return "</span><span style='"..t_concat(css, ";").."'>";
+end
+
+function tohtml(input)
+ return input:gsub("\027%[(.-)m", ansi2css);
+end
+
+return _M;
diff --git a/util/throttle.lua b/util/throttle.lua
new file mode 100644
index 00000000..55e1d07b
--- /dev/null
+++ b/util/throttle.lua
@@ -0,0 +1,46 @@
+
+local gettime = require "socket".gettime;
+local setmetatable = setmetatable;
+local floor = math.floor;
+
+module "throttle"
+
+local throttle = {};
+local throttle_mt = { __index = throttle };
+
+function throttle:update()
+ local newt = gettime();
+ local elapsed = newt - self.t;
+ self.t = newt;
+ local balance = floor(self.rate * elapsed) + self.balance;
+ if balance > self.max then
+ self.balance = self.max;
+ else
+ self.balance = balance;
+ end
+ return self.balance;
+end
+
+function throttle:peek(cost)
+ cost = cost or 1;
+ return self.balance >= cost or self:update() >= cost;
+end
+
+function throttle:poll(cost, split)
+ if self:peek(cost) then
+ self.balance = self.balance - cost;
+ return true;
+ else
+ local balance = self.balance;
+ if split then
+ self.balance = 0;
+ end
+ return false, balance, (cost-balance);
+ end
+end
+
+function create(max, period)
+ return setmetatable({ rate = max / period, max = max, t = 0, balance = max }, throttle_mt);
+end
+
+return _M;
diff --git a/util/timer.lua b/util/timer.lua
new file mode 100644
index 00000000..af1e57b6
--- /dev/null
+++ b/util/timer.lua
@@ -0,0 +1,83 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local server = require "net.server";
+local math_min = math.min
+local math_huge = math.huge
+local get_time = require "socket".gettime;
+local t_insert = table.insert;
+local pairs = pairs;
+local type = type;
+
+local data = {};
+local new_data = {};
+
+module "timer"
+
+local _add_task;
+if not server.event then
+ function _add_task(delay, callback)
+ local current_time = get_time();
+ delay = delay + current_time;
+ if delay >= current_time then
+ t_insert(new_data, {delay, callback});
+ else
+ local r = callback(current_time);
+ if r and type(r) == "number" then
+ return _add_task(r, callback);
+ end
+ end
+ end
+
+ server._addtimer(function()
+ local current_time = get_time();
+ if #new_data > 0 then
+ for _, d in pairs(new_data) do
+ t_insert(data, d);
+ end
+ new_data = {};
+ end
+
+ local next_time = math_huge;
+ for i, d in pairs(data) do
+ local t, callback = d[1], d[2];
+ if t <= current_time then
+ data[i] = nil;
+ local r = callback(current_time);
+ if type(r) == "number" then
+ _add_task(r, callback);
+ next_time = math_min(next_time, r);
+ end
+ else
+ next_time = math_min(next_time, t - current_time);
+ end
+ end
+ return next_time;
+ end);
+else
+ local event = server.event;
+ local event_base = server.event_base;
+ local EVENT_LEAVE = (event.core and event.core.LEAVE) or -1;
+
+ function _add_task(delay, callback)
+ local event_handle;
+ event_handle = event_base:addevent(nil, 0, function ()
+ local ret = callback(get_time());
+ if ret then
+ return 0, ret;
+ elseif event_handle then
+ return EVENT_LEAVE;
+ end
+ end
+ , delay);
+ end
+end
+
+add_task = _add_task;
+
+return _M;
diff --git a/util/uuid.lua b/util/uuid.lua
index 489522aa..796c8ee4 100644
--- a/util/uuid.lua
+++ b/util/uuid.lua
@@ -1,9 +1,50 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
local m_random = math.random;
+local tostring = tostring;
+local os_time = os.time;
+local os_clock = os.clock;
+local sha1 = require "util.hashes".sha1;
+
module "uuid"
-function uuid_generate()
- return m_random(0, 99999999);
+local last_uniq_time = 0;
+local function uniq_time()
+ local new_uniq_time = os_time();
+ if last_uniq_time >= new_uniq_time then new_uniq_time = last_uniq_time + 1; end
+ last_uniq_time = new_uniq_time;
+ return new_uniq_time;
+end
+
+local function new_random(x)
+ return sha1(x..os_clock()..tostring({}), true);
+end
+
+local buffer = new_random(uniq_time());
+local function _seed(x)
+ buffer = new_random(buffer..x);
+end
+local function get_nibbles(n)
+ if #buffer < n then _seed(uniq_time()); end
+ local r = buffer:sub(0, n);
+ buffer = buffer:sub(n+1);
+ return r;
+end
+local function get_twobits()
+ return ("%x"):format(get_nibbles(1):byte() % 4 + 8);
+end
+
+function generate()
+ -- generate RFC 4122 complaint UUIDs (version 4 - random)
+ return get_nibbles(8).."-"..get_nibbles(4).."-4"..get_nibbles(3).."-"..(get_twobits())..get_nibbles(3).."-"..get_nibbles(12);
end
+seed = _seed;
-return _M; \ No newline at end of file
+return _M;
diff --git a/util/watchdog.lua b/util/watchdog.lua
new file mode 100644
index 00000000..bcb2e274
--- /dev/null
+++ b/util/watchdog.lua
@@ -0,0 +1,34 @@
+local timer = require "util.timer";
+local setmetatable = setmetatable;
+local os_time = os.time;
+
+module "watchdog"
+
+local watchdog_methods = {};
+local watchdog_mt = { __index = watchdog_methods };
+
+function new(timeout, callback)
+ local watchdog = setmetatable({ timeout = timeout, last_reset = os_time(), callback = callback }, watchdog_mt);
+ timer.add_task(timeout+1, function (current_time)
+ local last_reset = watchdog.last_reset;
+ if not last_reset then
+ return;
+ end
+ local time_left = (last_reset + timeout) - current_time;
+ if time_left < 0 then
+ return watchdog:callback();
+ end
+ return time_left + 1;
+ end);
+ return watchdog;
+end
+
+function watchdog_methods:reset()
+ self.last_reset = os_time();
+end
+
+function watchdog_methods:cancel()
+ self.last_reset = nil;
+end
+
+return _M;
diff --git a/util/x509.lua b/util/x509.lua
new file mode 100644
index 00000000..19d4ec6d
--- /dev/null
+++ b/util/x509.lua
@@ -0,0 +1,215 @@
+-- Prosody IM
+-- Copyright (C) 2010 Matthew Wild
+-- Copyright (C) 2010 Paul Aurich
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+-- TODO: I feel a fair amount of this logic should be integrated into Luasec,
+-- so that everyone isn't re-inventing the wheel. Dependencies on
+-- IDN libraries complicate that.
+
+
+-- [TLS-CERTS] - http://tools.ietf.org/html/rfc6125
+-- [XMPP-CORE] - http://tools.ietf.org/html/rfc6120
+-- [SRV-ID] - http://tools.ietf.org/html/rfc4985
+-- [IDNA] - http://tools.ietf.org/html/rfc5890
+-- [LDAP] - http://tools.ietf.org/html/rfc4519
+-- [PKIX] - http://tools.ietf.org/html/rfc5280
+
+local nameprep = require "util.encodings".stringprep.nameprep;
+local idna_to_ascii = require "util.encodings".idna.to_ascii;
+local log = require "util.logger".init("x509");
+local pairs, ipairs = pairs, ipairs;
+local s_format = string.format;
+local t_insert = table.insert;
+local t_concat = table.concat;
+
+module "x509"
+
+local oid_commonname = "2.5.4.3"; -- [LDAP] 2.3
+local oid_subjectaltname = "2.5.29.17"; -- [PKIX] 4.2.1.6
+local oid_xmppaddr = "1.3.6.1.5.5.7.8.5"; -- [XMPP-CORE]
+local oid_dnssrv = "1.3.6.1.5.5.7.8.7"; -- [SRV-ID]
+
+-- Compare a hostname (possibly international) with asserted names
+-- extracted from a certificate.
+-- This function follows the rules laid out in
+-- sections 6.4.1 and 6.4.2 of [TLS-CERTS]
+--
+-- A wildcard ("*") all by itself is allowed only as the left-most label
+local function compare_dnsname(host, asserted_names)
+ -- TODO: Sufficient normalization? Review relevant specs.
+ local norm_host = idna_to_ascii(host)
+ if norm_host == nil then
+ log("info", "Host %s failed IDNA ToASCII operation", host)
+ return false
+ end
+
+ norm_host = norm_host:lower()
+
+ local host_chopped = norm_host:gsub("^[^.]+%.", "") -- everything after the first label
+
+ for i=1,#asserted_names do
+ local name = asserted_names[i]
+ if norm_host == name:lower() then
+ log("debug", "Cert dNSName %s matched hostname", name);
+ return true
+ end
+
+ -- Allow the left most label to be a "*"
+ if name:match("^%*%.") then
+ local rest_name = name:gsub("^[^.]+%.", "")
+ if host_chopped == rest_name:lower() then
+ log("debug", "Cert dNSName %s matched hostname", name);
+ return true
+ end
+ end
+ end
+
+ return false
+end
+
+-- Compare an XMPP domain name with the asserted id-on-xmppAddr
+-- identities extracted from a certificate. Both are UTF8 strings.
+--
+-- Per [XMPP-CORE], matches against asserted identities don't include
+-- wildcards, so we just do a normalize on both and then a string comparison
+--
+-- TODO: Support for full JIDs?
+local function compare_xmppaddr(host, asserted_names)
+ local norm_host = nameprep(host)
+
+ for i=1,#asserted_names do
+ local name = asserted_names[i]
+
+ -- We only want to match against bare domains right now, not
+ -- those crazy full-er JIDs.
+ if name:match("[@/]") then
+ log("debug", "Ignoring xmppAddr %s because it's not a bare domain", name)
+ else
+ local norm_name = nameprep(name)
+ if norm_name == nil then
+ log("info", "Ignoring xmppAddr %s, failed nameprep!", name)
+ else
+ if norm_host == norm_name then
+ log("debug", "Cert xmppAddr %s matched hostname", name)
+ return true
+ end
+ end
+ end
+ end
+
+ return false
+end
+
+-- Compare a host + service against the asserted id-on-dnsSRV (SRV-ID)
+-- identities extracted from a certificate.
+--
+-- Per [SRV-ID], the asserted identities will be encoded in ASCII via ToASCII.
+-- Comparison is done case-insensitively, and a wildcard ("*") all by itself
+-- is allowed only as the left-most non-service label.
+local function compare_srvname(host, service, asserted_names)
+ local norm_host = idna_to_ascii(host)
+ if norm_host == nil then
+ log("info", "Host %s failed IDNA ToASCII operation", host);
+ return false
+ end
+
+ -- Service names start with a "_"
+ if service:match("^_") == nil then service = "_"..service end
+
+ norm_host = norm_host:lower();
+ local host_chopped = norm_host:gsub("^[^.]+%.", "") -- everything after the first label
+
+ for i=1,#asserted_names do
+ local asserted_service, name = asserted_names[i]:match("^(_[^.]+)%.(.*)");
+ if service == asserted_service then
+ if norm_host == name:lower() then
+ log("debug", "Cert SRVName %s matched hostname", name);
+ return true;
+ end
+
+ -- Allow the left most label to be a "*"
+ if name:match("^%*%.") then
+ local rest_name = name:gsub("^[^.]+%.", "")
+ if host_chopped == rest_name:lower() then
+ log("debug", "Cert SRVName %s matched hostname", name)
+ return true
+ end
+ end
+ if norm_host == name:lower() then
+ log("debug", "Cert SRVName %s matched hostname", name);
+ return true
+ end
+ end
+ end
+
+ return false
+end
+
+function verify_identity(host, service, cert)
+ local ext = cert:extensions()
+ if ext[oid_subjectaltname] then
+ local sans = ext[oid_subjectaltname];
+
+ -- Per [TLS-CERTS] 6.3, 6.4.4, "a client MUST NOT seek a match for a
+ -- reference identifier if the presented identifiers include a DNS-ID
+ -- SRV-ID, URI-ID, or any application-specific identifier types"
+ local had_supported_altnames = false
+
+ if sans[oid_xmppaddr] then
+ had_supported_altnames = true
+ if compare_xmppaddr(host, sans[oid_xmppaddr]) then return true end
+ end
+
+ if sans[oid_dnssrv] then
+ had_supported_altnames = true
+ -- Only check srvNames if the caller specified a service
+ if service and compare_srvname(host, service, sans[oid_dnssrv]) then return true end
+ end
+
+ if sans["dNSName"] then
+ had_supported_altnames = true
+ if compare_dnsname(host, sans["dNSName"]) then return true end
+ end
+
+ -- We don't need URIs, but [TLS-CERTS] is clear.
+ if sans["uniformResourceIdentifier"] then
+ had_supported_altnames = true
+ end
+
+ if had_supported_altnames then return false end
+ end
+
+ -- Extract a common name from the certificate, and check it as if it were
+ -- a dNSName subjectAltName (wildcards may apply for, and receive,
+ -- cat treats)
+ --
+ -- Per [TLS-CERTS] 1.8, a CN-ID is the Common Name from a cert subject
+ -- which has one and only one Common Name
+ local subject = cert:subject()
+ local cn = nil
+ for i=1,#subject do
+ local dn = subject[i]
+ if dn["oid"] == oid_commonname then
+ if cn then
+ log("info", "Certificate has multiple common names")
+ return false
+ end
+
+ cn = dn["value"];
+ end
+ end
+
+ if cn then
+ -- Per [TLS-CERTS] 6.4.4, follow the comparison rules for dNSName SANs.
+ return compare_dnsname(host, { cn })
+ end
+
+ -- If all else fails, well, why should we be any different?
+ return false
+end
+
+return _M;
diff --git a/util/xml.lua b/util/xml.lua
new file mode 100644
index 00000000..076490fa
--- /dev/null
+++ b/util/xml.lua
@@ -0,0 +1,57 @@
+
+local st = require "util.stanza";
+local lxp = require "lxp";
+
+module("xml")
+
+local parse_xml = (function()
+ local ns_prefixes = {
+ ["http://www.w3.org/XML/1998/namespace"] = "xml";
+ };
+ local ns_separator = "\1";
+ local ns_pattern = "^([^"..ns_separator.."]*)"..ns_separator.."?(.*)$";
+ return function(xml)
+ local handler = {};
+ local stanza = st.stanza("root");
+ function handler:StartElement(tagname, attr)
+ local curr_ns,name = tagname:match(ns_pattern);
+ if name == "" then
+ curr_ns, name = "", curr_ns;
+ end
+ if curr_ns ~= "" then
+ attr.xmlns = curr_ns;
+ end
+ for i=1,#attr do
+ local k = attr[i];
+ attr[i] = nil;
+ local ns, nm = k:match(ns_pattern);
+ if nm ~= "" then
+ ns = ns_prefixes[ns];
+ if ns then
+ attr[ns..":"..nm] = attr[k];
+ attr[k] = nil;
+ end
+ end
+ end
+ stanza:tag(name, attr);
+ end
+ function handler:CharacterData(data)
+ stanza:text(data);
+ end
+ function handler:EndElement(tagname)
+ stanza:up();
+ end
+ local parser = lxp.new(handler, "\1");
+ local ok, err, line, col = parser:parse(xml);
+ if ok then ok, err, line, col = parser:parse(); end
+ --parser:close();
+ if ok then
+ return stanza.tags[1];
+ else
+ return ok, err.." (line "..line..", col "..col..")";
+ end
+ end;
+end)();
+
+parse = parse_xml;
+return _M;
diff --git a/util/xmppstream.lua b/util/xmppstream.lua
new file mode 100644
index 00000000..4909678c
--- /dev/null
+++ b/util/xmppstream.lua
@@ -0,0 +1,191 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+
+local lxp = require "lxp";
+local st = require "util.stanza";
+local stanza_mt = st.stanza_mt;
+
+local error = error;
+local tostring = tostring;
+local t_insert = table.insert;
+local t_concat = table.concat;
+local t_remove = table.remove;
+local setmetatable = setmetatable;
+
+-- COMPAT: w/LuaExpat 1.1.0
+local lxp_supports_doctype = pcall(lxp.new, { StartDoctypeDecl = false });
+
+module "xmppstream"
+
+local new_parser = lxp.new;
+
+local xml_namespace = {
+ ["http://www.w3.org/XML/1998/namespace\1lang"] = "xml:lang";
+ ["http://www.w3.org/XML/1998/namespace\1space"] = "xml:space";
+ ["http://www.w3.org/XML/1998/namespace\1base"] = "xml:base";
+ ["http://www.w3.org/XML/1998/namespace\1id"] = "xml:id";
+};
+
+local xmlns_streams = "http://etherx.jabber.org/streams";
+
+local ns_separator = "\1";
+local ns_pattern = "^([^"..ns_separator.."]*)"..ns_separator.."?(.*)$";
+
+_M.ns_separator = ns_separator;
+_M.ns_pattern = ns_pattern;
+
+function new_sax_handlers(session, stream_callbacks)
+ local xml_handlers = {};
+
+ local cb_streamopened = stream_callbacks.streamopened;
+ local cb_streamclosed = stream_callbacks.streamclosed;
+ local cb_error = stream_callbacks.error or function(session, e, stanza) error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); end;
+ local cb_handlestanza = stream_callbacks.handlestanza;
+
+ local stream_ns = stream_callbacks.stream_ns or xmlns_streams;
+ local stream_tag = stream_callbacks.stream_tag or "stream";
+ if stream_ns ~= "" then
+ stream_tag = stream_ns..ns_separator..stream_tag;
+ end
+ local stream_error_tag = stream_ns..ns_separator..(stream_callbacks.error_tag or "error");
+
+ local stream_default_ns = stream_callbacks.default_ns;
+
+ local stack = {};
+ local chardata, stanza = {};
+ local non_streamns_depth = 0;
+ function xml_handlers:StartElement(tagname, attr)
+ if stanza and #chardata > 0 then
+ -- We have some character data in the buffer
+ t_insert(stanza, t_concat(chardata));
+ chardata = {};
+ end
+ local curr_ns,name = tagname:match(ns_pattern);
+ if name == "" then
+ curr_ns, name = "", curr_ns;
+ end
+
+ if curr_ns ~= stream_default_ns or non_streamns_depth > 0 then
+ attr.xmlns = curr_ns;
+ non_streamns_depth = non_streamns_depth + 1;
+ end
+
+ for i=1,#attr do
+ local k = attr[i];
+ attr[i] = nil;
+ local xmlk = xml_namespace[k];
+ if xmlk then
+ attr[xmlk] = attr[k];
+ attr[k] = nil;
+ end
+ end
+
+ if not stanza then --if we are not currently inside a stanza
+ if session.notopen then
+ if tagname == stream_tag then
+ non_streamns_depth = 0;
+ if cb_streamopened then
+ cb_streamopened(session, attr);
+ end
+ else
+ -- Garbage before stream?
+ cb_error(session, "no-stream");
+ end
+ return;
+ end
+ if curr_ns == "jabber:client" and name ~= "iq" and name ~= "presence" and name ~= "message" then
+ cb_error(session, "invalid-top-level-element");
+ end
+
+ stanza = setmetatable({ name = name, attr = attr, tags = {} }, stanza_mt);
+ else -- we are inside a stanza, so add a tag
+ t_insert(stack, stanza);
+ local oldstanza = stanza;
+ stanza = setmetatable({ name = name, attr = attr, tags = {} }, stanza_mt);
+ t_insert(oldstanza, stanza);
+ t_insert(oldstanza.tags, stanza);
+ end
+ end
+ function xml_handlers:CharacterData(data)
+ if stanza then
+ t_insert(chardata, data);
+ end
+ end
+ function xml_handlers:EndElement(tagname)
+ if non_streamns_depth > 0 then
+ non_streamns_depth = non_streamns_depth - 1;
+ end
+ if stanza then
+ if #chardata > 0 then
+ -- We have some character data in the buffer
+ t_insert(stanza, t_concat(chardata));
+ chardata = {};
+ end
+ -- Complete stanza
+ if #stack == 0 then
+ if tagname ~= stream_error_tag then
+ cb_handlestanza(session, stanza);
+ else
+ cb_error(session, "stream-error", stanza);
+ end
+ stanza = nil;
+ else
+ stanza = t_remove(stack);
+ end
+ else
+ if cb_streamclosed then
+ cb_streamclosed(session);
+ end
+ end
+ end
+
+ local function restricted_handler(parser)
+ cb_error(session, "parse-error", "restricted-xml", "Restricted XML, see RFC 6120 section 11.1.");
+ if not parser.stop or not parser:stop() then
+ error("Failed to abort parsing");
+ end
+ end
+
+ if lxp_supports_doctype then
+ xml_handlers.StartDoctypeDecl = restricted_handler;
+ end
+ xml_handlers.Comment = restricted_handler;
+ xml_handlers.ProcessingInstruction = restricted_handler;
+
+ local function reset()
+ stanza, chardata = nil, {};
+ stack = {};
+ end
+
+ local function set_session(stream, new_session)
+ session = new_session;
+ end
+
+ return xml_handlers, { reset = reset, set_session = set_session };
+end
+
+function new(session, stream_callbacks)
+ local handlers, meta = new_sax_handlers(session, stream_callbacks);
+ local parser = new_parser(handlers, ns_separator);
+ local parse = parser.parse;
+
+ return {
+ reset = function ()
+ parser = new_parser(handlers, ns_separator);
+ parse = parser.parse;
+ meta.reset();
+ end,
+ feed = function (self, data)
+ return parse(parser, data);
+ end,
+ set_session = meta.set_session;
+ };
+end
+
+return _M;