aboutsummaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/adhoc.lua31
-rw-r--r--util/array.lua64
-rw-r--r--util/async.lua158
-rw-r--r--util/broadcast.lua68
-rw-r--r--util/caps.lua2
-rw-r--r--util/dataforms.lua145
-rw-r--r--util/datamanager.lua258
-rw-r--r--util/datetime.lua6
-rw-r--r--util/debug.lua199
-rw-r--r--util/dependencies.lua39
-rw-r--r--util/envload.lua34
-rw-r--r--util/events.lua8
-rw-r--r--util/filters.lua16
-rw-r--r--util/helpers.lua43
-rw-r--r--util/hmac.lua66
-rw-r--r--util/http.lua64
-rw-r--r--util/httpstream.lua137
-rw-r--r--util/import.lua2
-rw-r--r--util/ip.lua244
-rw-r--r--util/iterators.lua83
-rw-r--r--util/jid.lua15
-rw-r--r--util/json.lua411
-rw-r--r--util/logger.lua35
-rw-r--r--util/multitable.lua43
-rw-r--r--util/openssl.lua172
-rw-r--r--util/pluginloader.lua76
-rw-r--r--util/prosodyctl.lua165
-rw-r--r--util/pubsub.lua89
-rw-r--r--util/rfc6724.lua142
-rw-r--r--util/sasl.lua15
-rw-r--r--util/sasl/anonymous.lua4
-rw-r--r--util/sasl/digest-md5.lua15
-rw-r--r--util/sasl/external.lua25
-rw-r--r--util/sasl/plain.lua21
-rw-r--r--util/sasl/scram.lua63
-rw-r--r--util/sasl_cyrus.lua8
-rw-r--r--util/serialization.lua8
-rw-r--r--util/set.lua53
-rw-r--r--util/sql.lua342
-rw-r--r--util/stanza.lua129
-rw-r--r--util/template.lua64
-rw-r--r--util/termcolours.lua25
-rw-r--r--util/throttle.lua46
-rw-r--r--util/timer.lua39
-rw-r--r--util/uuid.lua2
-rw-r--r--util/watchdog.lua34
-rw-r--r--util/x509.lua16
-rw-r--r--util/xml.lua57
-rw-r--r--util/xmlrpc.lua182
-rw-r--r--util/xmppstream.lua100
50 files changed, 2857 insertions, 1206 deletions
diff --git a/util/adhoc.lua b/util/adhoc.lua
new file mode 100644
index 00000000..671e85cf
--- /dev/null
+++ b/util/adhoc.lua
@@ -0,0 +1,31 @@
+local function new_simple_form(form, result_handler)
+ return function(self, data, state)
+ if state then
+ if data.action == "cancel" then
+ return { status = "canceled" };
+ end
+ local fields, err = form:data(data.form);
+ return result_handler(fields, err, data);
+ else
+ return { status = "executing", actions = {"next", "complete", default = "complete"}, form = form }, "executing";
+ end
+ end
+end
+
+local function new_initial_data_form(form, initial_data, result_handler)
+ return function(self, data, state)
+ if state then
+ if data.action == "cancel" then
+ return { status = "canceled" };
+ end
+ local fields, err = form:data(data.form);
+ return result_handler(fields, err, data);
+ else
+ return { status = "executing", actions = {"next", "complete", default = "complete"},
+ form = { layout = form, values = initial_data() } }, "executing";
+ end
+ end
+end
+
+return { new_simple_form = new_simple_form,
+ new_initial_data_form = new_initial_data_form };
diff --git a/util/array.lua b/util/array.lua
index 6c1f0460..9bf215af 100644
--- a/util/array.lua
+++ b/util/array.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -9,12 +9,20 @@
local t_insert, t_sort, t_remove, t_concat
= table.insert, table.sort, table.remove, table.concat;
+local setmetatable = setmetatable;
+local math_random = math.random;
+local pairs, ipairs = pairs, ipairs;
+local tostring = tostring;
+
local array = {};
local array_base = {};
local array_methods = {};
-local array_mt = { __index = array_methods, __tostring = function (array) return array:concat(", "); end };
+local array_mt = { __index = array_methods, __tostring = function (array) return "{"..array:concat(", ").."}"; end };
-local function new_array(_, t)
+local function new_array(self, t, _s, _var)
+ if type(t) == "function" then -- Assume iterator
+ t = self.collect(t, _s, _var);
+ end
return setmetatable(t or {}, array_mt);
end
@@ -25,6 +33,15 @@ end
setmetatable(array, { __call = new_array });
+-- Read-only methods
+function array_methods:random()
+ return self[math_random(1,#self)];
+end
+
+-- These methods can be called two ways:
+-- array.method(existing_array, [params [, ...]]) -- Create new array for result
+-- existing_array:method([params, ...]) -- Transform existing array into result
+--
function array_base.map(outa, ina, func)
for k,v in ipairs(ina) do
outa[k] = func(v);
@@ -42,13 +59,13 @@ function array_base.filter(outa, ina, func)
write = write + 1;
end
end
-
+
if inplace and write <= start_length then
for i=write,start_length do
outa[i] = nil;
end
end
-
+
return outa;
end
@@ -60,15 +77,18 @@ function array_base.sort(outa, ina, ...)
return outa;
end
---- These methods only mutate
-function array_methods:random()
- return self[math.random(1,#self)];
+function array_base.pluck(outa, ina, key)
+ for i=1,#ina do
+ outa[i] = ina[i][key];
+ end
+ return outa;
end
+--- These methods only mutate the array
function array_methods:shuffle(outa, ina)
local len = #self;
for i=1,#self do
- local r = math.random(i,len);
+ local r = math_random(i,len);
self[i], self[r] = self[r], self[i];
end
return self;
@@ -91,18 +111,32 @@ function array_methods:append(array)
return self;
end
-array_methods.push = table.insert;
-array_methods.pop = table.remove;
-array_methods.concat = table.concat;
-array_methods.length = function (t) return #t; end
+function array_methods:push(x)
+ t_insert(self, x);
+ return self;
+end
+
+function array_methods:pop(x)
+ local v = self[x];
+ t_remove(self, x);
+ return v;
+end
+
+function array_methods:concat(sep)
+ return t_concat(array.map(self, tostring), sep);
+end
+
+function array_methods:length()
+ return #self;
+end
--- These methods always create a new array
function array.collect(f, s, var)
- local t, var = {};
+ local t = {};
while true do
var = f(s, var);
if var == nil then break; end
- table.insert(t, var);
+ t_insert(t, var);
end
return setmetatable(t, array_mt);
end
diff --git a/util/async.lua b/util/async.lua
new file mode 100644
index 00000000..968ec804
--- /dev/null
+++ b/util/async.lua
@@ -0,0 +1,158 @@
+local log = require "util.logger".init("util.async");
+
+local function runner_continue(thread)
+ -- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure)
+ if coroutine.status(thread) ~= "suspended" then -- This should suffice
+ return false;
+ end
+ local ok, state, runner = coroutine.resume(thread);
+ if not ok then
+ local level = 0;
+ while debug.getinfo(thread, level, "") do level = level + 1; end
+ ok, runner = debug.getlocal(thread, level-1, 1);
+ local error_handler = runner.watchers.error;
+ if error_handler then error_handler(runner, debug.traceback(thread, state)); end
+ elseif state == "ready" then
+ -- If state is 'ready', it is our responsibility to update runner.state from 'waiting'.
+ -- We also have to :run(), because the queue might have further items that will not be
+ -- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer).
+ runner.state = "ready";
+ runner:run();
+ end
+ return true;
+end
+
+local function waiter(num)
+ local thread = coroutine.running();
+ if not thread then
+ error("Not running in an async context, see http://prosody.im/doc/developers/async");
+ end
+ num = num or 1;
+ local waiting;
+ return function ()
+ if num == 0 then return; end -- already done
+ waiting = true;
+ coroutine.yield("wait");
+ end, function ()
+ num = num - 1;
+ if num == 0 and waiting then
+ runner_continue(thread);
+ elseif num < 0 then
+ error("done() called too many times");
+ end
+ end;
+end
+
+local function guarder()
+ local guards = {};
+ return function (id, func)
+ local thread = coroutine.running();
+ if not thread then
+ error("Not running in an async context, see http://prosody.im/doc/developers/async");
+ end
+ local guard = guards[id];
+ if not guard then
+ guard = {};
+ guards[id] = guard;
+ log("debug", "New guard!");
+ else
+ table.insert(guard, thread);
+ log("debug", "Guarded. %d threads waiting.", #guard)
+ coroutine.yield("wait");
+ end
+ local function exit()
+ local next_waiting = table.remove(guard, 1);
+ if next_waiting then
+ log("debug", "guard: Executing next waiting thread (%d left)", #guard)
+ runner_continue(next_waiting);
+ else
+ log("debug", "Guard off duty.")
+ guards[id] = nil;
+ end
+ end
+ if func then
+ func();
+ exit();
+ return;
+ end
+ return exit;
+ end;
+end
+
+local runner_mt = {};
+runner_mt.__index = runner_mt;
+
+local function runner_create_thread(func, self)
+ local thread = coroutine.create(function (self)
+ while true do
+ func(coroutine.yield("ready", self));
+ end
+ end);
+ assert(coroutine.resume(thread, self)); -- Start it up, it will return instantly to wait for the first input
+ return thread;
+end
+
+local empty_watchers = {};
+local function runner(func, watchers, data)
+ return setmetatable({ func = func, thread = false, state = "ready", notified_state = "ready",
+ queue = {}, watchers = watchers or empty_watchers, data = data }
+ , runner_mt);
+end
+
+function runner_mt:run(input)
+ if input ~= nil then
+ table.insert(self.queue, input);
+ end
+ if self.state ~= "ready" then
+ return true, self.state, #self.queue;
+ end
+
+ local q, thread = self.queue, self.thread;
+ if not thread or coroutine.status(thread) == "dead" then
+ thread = runner_create_thread(self.func, self);
+ self.thread = thread;
+ end
+
+ local n, state, err = #q, self.state, nil;
+ self.state = "running";
+ while n > 0 and state == "ready" do
+ local consumed;
+ for i = 1,n do
+ local input = q[i];
+ local ok, new_state = coroutine.resume(thread, input);
+ if not ok then
+ consumed, state, err = i, "ready", debug.traceback(thread, new_state);
+ self.thread = nil;
+ break;
+ elseif new_state == "wait" then
+ consumed, state = i, "waiting";
+ break;
+ end
+ end
+ if not consumed then consumed = n; end
+ if q[n+1] ~= nil then
+ n = #q;
+ end
+ for i = 1, n do
+ q[i] = q[consumed+i];
+ end
+ n = #q;
+ end
+ self.state = state;
+ if err or state ~= self.notified_state then
+ if err then
+ state = "error"
+ else
+ self.notified_state = state;
+ end
+ local handler = self.watchers[state];
+ if handler then handler(self, err); end
+ end
+ return true, state, n;
+end
+
+function runner_mt:enqueue(input)
+ table.insert(self.queue, input);
+end
+
+return { waiter = waiter, guarder = guarder, runner = runner };
diff --git a/util/broadcast.lua b/util/broadcast.lua
deleted file mode 100644
index be17461d..00000000
--- a/util/broadcast.lua
+++ /dev/null
@@ -1,68 +0,0 @@
--- Prosody IM
--- Copyright (C) 2008-2010 Matthew Wild
--- Copyright (C) 2008-2010 Waqas Hussain
---
--- This project is MIT/X11 licensed. Please see the
--- COPYING file in the source package for more information.
---
-
-
-local ipairs, pairs, setmetatable, type =
- ipairs, pairs, setmetatable, type;
-
-module "pubsub"
-
-local pubsub_node_mt = { __index = _M };
-
-function new_node(name)
- return setmetatable({ name = name, subscribers = {} }, pubsub_node_mt);
-end
-
-function set_subscribers(node, subscribers_list, list_type)
- local subscribers = node.subscribers;
-
- if list_type == "array" then
- for _, jid in ipairs(subscribers_list) do
- if not subscribers[jid] then
- node:add_subscriber(jid);
- end
- end
- elseif (not list_type) or list_type == "set" then
- for jid in pairs(subscribers_list) do
- if type(jid) == "string" then
- node:add_subscriber(jid);
- end
- end
- end
-end
-
-function get_subscribers(node)
- return node.subscribers;
-end
-
-function publish(node, item, dispatcher, data)
- local subscribers = node.subscribers;
- for i = 1,#subscribers do
- item.attr.to = subscribers[i];
- dispatcher(data, item);
- end
-end
-
-function add_subscriber(node, jid)
- local subscribers = node.subscribers;
- if not subscribers[jid] then
- local space = #subscribers;
- subscribers[space] = jid;
- subscribers[jid] = space;
- end
-end
-
-function remove_subscriber(node, jid)
- local subscribers = node.subscribers;
- if subscribers[jid] then
- subscribers[subscribers[jid]] = nil;
- subscribers[jid] = nil;
- end
-end
-
-return _M;
diff --git a/util/caps.lua b/util/caps.lua
index a61e7403..4723b912 100644
--- a/util/caps.lua
+++ b/util/caps.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
diff --git a/util/dataforms.lua b/util/dataforms.lua
index ae745e03..b38d0e27 100644
--- a/util/dataforms.lua
+++ b/util/dataforms.lua
@@ -1,16 +1,17 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
local setmetatable = setmetatable;
local pairs, ipairs = pairs, ipairs;
-local tostring, type = tostring, type;
+local tostring, type, next = tostring, type, next;
local t_concat = table.concat;
local st = require "util.stanza";
+local jid_prep = require "util.jid".prep;
module "dataforms"
@@ -37,7 +38,7 @@ function form_t.form(layout, data, formtype)
form:tag("field", { type = field_type, var = field.name, label = field.label });
local value = (data and data[field.name]) or field.value;
-
+
if value then
-- Add value, depending on type
if field_type == "hidden" then
@@ -52,7 +53,7 @@ function form_t.form(layout, data, formtype)
elseif field_type == "boolean" then
form:tag("value"):text((value and "1") or "0"):up();
elseif field_type == "fixed" then
-
+ form:tag("value"):text(value):up();
elseif field_type == "jid-multi" then
for _, jid in ipairs(value) do
form:tag("value"):text(jid):up();
@@ -92,11 +93,11 @@ function form_t.form(layout, data, formtype)
end
end
end
-
+
if field.required then
form:tag("required"):up();
end
-
+
-- Jump back up to list of fields
form:up();
end
@@ -107,30 +108,41 @@ local field_readers = {};
function form_t.data(layout, stanza)
local data = {};
-
- for field_tag in stanza:childtags() do
- local field_type;
- for n, field in ipairs(layout) do
+ local errors = {};
+
+ for _, field in ipairs(layout) do
+ local tag;
+ for field_tag in stanza:childtags() do
if field.name == field_tag.attr.var then
- field_type = field.type;
+ tag = field_tag;
break;
end
end
-
- local reader = field_readers[field_type];
- if reader then
- data[field_tag.attr.var] = reader(field_tag);
+
+ if not tag then
+ if field.required then
+ errors[field.name] = "Required value missing";
+ end
+ else
+ local reader = field_readers[field.type];
+ if reader then
+ data[field.name], errors[field.name] = reader(tag, field.required);
+ end
end
-
+ end
+ if next(errors) then
+ return data, errors;
end
return data;
end
field_readers["text-single"] =
- function (field_tag)
- local value = field_tag:child_with_name("value");
- if value then
- return value[1];
+ function (field_tag, required)
+ local data = field_tag:get_child_text("value");
+ if data and #data > 0 then
+ return data
+ elseif required then
+ return nil, "Required value missing";
end
end
@@ -138,64 +150,85 @@ field_readers["text-private"] =
field_readers["text-single"];
field_readers["jid-single"] =
- field_readers["text-single"];
+ function (field_tag, required)
+ local raw_data = field_tag:get_child_text("value")
+ local data = jid_prep(raw_data);
+ if data and #data > 0 then
+ return data
+ elseif raw_data then
+ return nil, "Invalid JID: " .. raw_data;
+ elseif required then
+ return nil, "Required value missing";
+ end
+ end
field_readers["jid-multi"] =
- function (field_tag)
+ function (field_tag, required)
local result = {};
- for value_tag in field_tag:childtags() do
- if value_tag.name == "value" then
- result[#result+1] = value_tag[1];
+ local err = {};
+ for value_tag in field_tag:childtags("value") do
+ local raw_value = value_tag:get_text();
+ local value = jid_prep(raw_value);
+ result[#result+1] = value;
+ if raw_value and not value then
+ err[#err+1] = ("Invalid JID: " .. raw_value);
end
end
- return result;
+ if #result > 0 then
+ return result, (#err > 0 and t_concat(err, "\n") or nil);
+ elseif required then
+ return nil, "Required value missing";
+ end
end
-field_readers["text-multi"] =
- function (field_tag)
+field_readers["list-multi"] =
+ function (field_tag, required)
local result = {};
- for value_tag in field_tag:childtags() do
- if value_tag.name == "value" then
- result[#result+1] = value_tag[1];
- end
+ for value in field_tag:childtags("value") do
+ result[#result+1] = value:get_text();
+ end
+ if #result > 0 then
+ return result;
+ elseif required then
+ return nil, "Required value missing";
+ end
+ end
+
+field_readers["text-multi"] =
+ function (field_tag, required)
+ local data, err = field_readers["list-multi"](field_tag, required);
+ if data then
+ data = t_concat(data, "\n");
end
- return t_concat(result, "\n");
+ return data, err;
end
field_readers["list-single"] =
field_readers["text-single"];
-field_readers["list-multi"] =
- function (field_tag)
- local result = {};
- for value_tag in field_tag:childtags() do
- if value_tag.name == "value" then
- result[#result+1] = value_tag[1];
- end
- end
- return result;
- end
+local boolean_values = {
+ ["1"] = true, ["true"] = true,
+ ["0"] = false, ["false"] = false,
+};
field_readers["boolean"] =
- function (field_tag)
- local value = field_tag:child_with_name("value");
- if value then
- if value[1] == "1" or value[1] == "true" then
- return true;
- else
- return false;
- end
+ function (field_tag, required)
+ local raw_value = field_tag:get_child_text("value");
+ local value = boolean_values[raw_value ~= nil and raw_value];
+ if value ~= nil then
+ return value;
+ elseif raw_value then
+ return nil, "Invalid boolean representation";
+ elseif required then
+ return nil, "Required value missing";
end
end
field_readers["hidden"] =
function (field_tag)
- local value = field_tag:child_with_name("value");
- if value then
- return value[1];
- end
+ return field_tag:get_child_text("value");
end
-
+
return _M;
diff --git a/util/datamanager.lua b/util/datamanager.lua
index fbdfb581..4a4d62b3 100644
--- a/util/datamanager.lua
+++ b/util/datamanager.lua
@@ -1,34 +1,47 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
local format = string.format;
-local setmetatable, type = setmetatable, type;
-local pairs, ipairs = pairs, ipairs;
+local setmetatable = setmetatable;
+local ipairs = ipairs;
local char = string.char;
-local loadfile, setfenv, pcall = loadfile, setfenv, pcall;
+local pcall = pcall;
local log = require "util.logger".init("datamanager");
local io_open = io.open;
local os_remove = os.remove;
-local tostring, tonumber = tostring, tonumber;
-local error = error;
+local os_rename = os.rename;
+local tonumber = tonumber;
local next = next;
local t_insert = table.insert;
-local append = require "util.serialization".append;
-local path_separator = "/"; if os.getenv("WINDIR") then path_separator = "\\" end
+local t_concat = table.concat;
+local envloadfile = require"util.envload".envloadfile;
+local serialize = require "util.serialization".serialize;
+local path_separator = assert ( package.config:match ( "^([^\n]+)" ) , "package.config not in standard form" ) -- Extract directory seperator from package.config (an undocumented string that comes with lua)
local lfs = require "lfs";
-local raw_mkdir;
+local prosody = prosody;
-if prosody.platform == "posix" then
- raw_mkdir = require "util.pposix".mkdir; -- Doesn't trample on umask
-else
- raw_mkdir = lfs.mkdir;
-end
+local raw_mkdir = lfs.mkdir;
+local function fallocate(f, offset, len)
+ -- This assumes that current position == offset
+ local fake_data = (" "):rep(len);
+ local ok, msg = f:write(fake_data);
+ if not ok then
+ return ok, msg;
+ end
+ f:seek("set", offset);
+ return true;
+end;
+pcall(function()
+ local pposix = require "util.pposix";
+ raw_mkdir = pposix.mkdir or raw_mkdir; -- Doesn't trample on umask
+ fallocate = pposix.fallocate or fallocate;
+end);
module "datamanager"
@@ -56,7 +69,7 @@ local function mkdir(path)
return path;
end
-local data_path = "data";
+local data_path = (prosody and prosody.paths and prosody.paths.data) or ".";
local callbacks = {};
------- API -------------
@@ -71,7 +84,7 @@ local function callback(username, host, datastore, data)
username, host, datastore, data = f(username, host, datastore, data);
if username == false then break; end
end
-
+
return username, host, datastore, data;
end
function add_callback(func)
@@ -100,37 +113,68 @@ function getpath(username, host, datastore, ext, create)
if username then
if create then mkdir(mkdir(mkdir(data_path).."/"..host).."/"..datastore); end
return format("%s/%s/%s/%s.%s", data_path, host, datastore, username, ext);
- elseif host then
+ else
if create then mkdir(mkdir(data_path).."/"..host); end
return format("%s/%s/%s.%s", data_path, host, datastore, ext);
- else
- if create then mkdir(data_path); end
- return format("%s/%s.%s", data_path, datastore, ext);
end
end
function load(username, host, datastore)
- local data, ret = loadfile(getpath(username, host, datastore));
+ local data, ret = envloadfile(getpath(username, host, datastore), {});
if not data then
local mode = lfs.attributes(getpath(username, host, datastore), "mode");
if not mode then
- log("debug", "Failed to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil"));
+ log("debug", "Assuming empty %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil");
return nil;
else -- file exists, but can't be read
-- TODO more detailed error checking and logging?
- log("error", "Failed to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil"));
+ log("error", "Failed to load %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil");
return nil, "Error reading storage";
end
end
- setfenv(data, {});
+
local success, ret = pcall(data);
if not success then
- log("error", "Unable to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil"));
+ log("error", "Unable to load %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil");
return nil, "Error reading storage";
end
return ret;
end
+local function atomic_store(filename, data)
+ local scratch = filename.."~";
+ local f, ok, msg;
+ repeat
+ f, msg = io_open(scratch, "w");
+ if not f then break end
+
+ ok, msg = f:write(data);
+ if not ok then break end
+
+ ok, msg = f:close();
+ if not ok then break end
+
+ return os_rename(scratch, filename);
+ until false;
+
+ -- Cleanup
+ if f then f:close(); end
+ os_remove(scratch);
+ return nil, msg;
+end
+
+if prosody.platform ~= "posix" then
+ -- os.rename does not overwrite existing files on Windows
+ -- TODO We could use Transactional NTFS on Vista and above
+ function atomic_store(filename, data)
+ local f, err = io_open(filename, "w");
+ if not f then return f, err; end
+ local ok, msg = f:write(data);
+ if not ok then f:close(); return ok, msg; end
+ return f:close();
+ end
+end
+
function store(username, host, datastore, data)
if not data then
data = {};
@@ -142,20 +186,26 @@ function store(username, host, datastore, data)
end
-- save the datastore
- local f, msg = io_open(getpath(username, host, datastore, nil, true), "w+");
- if not f then
- log("error", "Unable to write to "..datastore.." storage ('"..msg.."') for user: "..(username or "nil").."@"..(host or "nil"));
- return nil, "Error saving to storage";
- end
- f:write("return ");
- append(f, data);
- f:close();
- if next(data) == nil then -- try to delete empty datastore
- log("debug", "Removing empty %s datastore for user %s@%s", datastore, username or "nil", host or "nil");
- os_remove(getpath(username, host, datastore));
- end
- -- we write data even when we are deleting because lua doesn't have a
- -- platform independent way of checking for non-exisitng files
+ local d = "return " .. serialize(data) .. ";\n";
+ local mkdir_cache_cleared;
+ repeat
+ local ok, msg = atomic_store(getpath(username, host, datastore, nil, true), d);
+ if not ok then
+ if not mkdir_cache_cleared then -- We may need to recreate a removed directory
+ _mkdir = {};
+ mkdir_cache_cleared = true;
+ else
+ log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil");
+ return nil, "Error saving to storage";
+ end
+ end
+ if next(data) == nil then -- try to delete empty datastore
+ log("debug", "Removing empty %s datastore for user %s@%s", datastore, username or "nil", host or "nil");
+ os_remove(getpath(username, host, datastore));
+ end
+ -- we write data even when we are deleting because lua doesn't have a
+ -- platform independent way of checking for non-exisitng files
+ until ok;
return true;
end
@@ -163,14 +213,24 @@ function list_append(username, host, datastore, data)
if not data then return; end
if callback(username, host, datastore) == false then return true; end
-- save the datastore
- local f, msg = io_open(getpath(username, host, datastore, "list", true), "a+");
+ local f, msg = io_open(getpath(username, host, datastore, "list", true), "r+");
+ if not f then
+ f, msg = io_open(getpath(username, host, datastore, "list", true), "w");
+ end
if not f then
- log("error", "Unable to write to "..datastore.." storage ('"..msg.."') for user: "..(username or "nil").."@"..(host or "nil"));
+ log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil");
return;
end
- f:write("item(");
- append(f, data);
- f:write(");\n");
+ local data = "item(" .. serialize(data) .. ");\n";
+ local pos = f:seek("end");
+ local ok, msg = fallocate(f, pos, #data);
+ f:seek("set", pos);
+ if ok then
+ f:write(data);
+ else
+ log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil");
+ return ok, msg;
+ end
f:close();
return true;
end
@@ -181,17 +241,15 @@ function list_store(username, host, datastore, data)
end
if callback(username, host, datastore) == false then return true; end
-- save the datastore
- local f, msg = io_open(getpath(username, host, datastore, "list", true), "w+");
- if not f then
- log("error", "Unable to write to "..datastore.." storage ('"..msg.."') for user: "..(username or "nil").."@"..(host or "nil"));
- return;
+ local d = {};
+ for _, item in ipairs(data) do
+ d[#d+1] = "item(" .. serialize(item) .. ");\n";
end
- for _, d in ipairs(data) do
- f:write("item(");
- append(f, d);
- f:write(");\n");
+ local ok, msg = atomic_store(getpath(username, host, datastore, "list", true), t_concat(d));
+ if not ok then
+ log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil");
+ return;
end
- f:close();
if next(data) == nil then -- try to delete empty datastore
log("debug", "Removing empty %s datastore for user %s@%s", datastore, username or "nil", host or "nil");
os_remove(getpath(username, host, datastore, "list"));
@@ -202,26 +260,108 @@ function list_store(username, host, datastore, data)
end
function list_load(username, host, datastore)
- local data, ret = loadfile(getpath(username, host, datastore, "list"));
+ local items = {};
+ local data, ret = envloadfile(getpath(username, host, datastore, "list"), {item = function(i) t_insert(items, i); end});
if not data then
local mode = lfs.attributes(getpath(username, host, datastore, "list"), "mode");
if not mode then
- log("debug", "Failed to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil"));
+ log("debug", "Assuming empty %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil");
return nil;
else -- file exists, but can't be read
-- TODO more detailed error checking and logging?
- log("error", "Failed to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil"));
+ log("error", "Failed to load %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil");
return nil, "Error reading storage";
end
end
- local items = {};
- setfenv(data, {item = function(i) t_insert(items, i); end});
+
local success, ret = pcall(data);
if not success then
- log("error", "Unable to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil"));
+ log("error", "Unable to load %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil");
return nil, "Error reading storage";
end
return items;
end
+local type_map = {
+ keyval = "dat";
+ list = "list";
+}
+
+function users(host, store, typ)
+ typ = type_map[typ or "keyval"];
+ local store_dir = format("%s/%s/%s", data_path, encode(host), store);
+
+ local mode, err = lfs.attributes(store_dir, "mode");
+ if not mode then
+ return function() log("debug", err or (store_dir .. " does not exist")) end
+ end
+ local next, state = lfs.dir(store_dir);
+ return function(state)
+ for node in next, state do
+ local file, ext = node:match("^(.*)%.([dalist]+)$");
+ if file and ext == typ then
+ return decode(file);
+ end
+ end
+ end, state;
+end
+
+function stores(username, host, typ)
+ typ = type_map[typ or "keyval"];
+ local store_dir = format("%s/%s/", data_path, encode(host));
+
+ local mode, err = lfs.attributes(store_dir, "mode");
+ if not mode then
+ return function() log("debug", err or (store_dir .. " does not exist")) end
+ end
+ local next, state = lfs.dir(store_dir);
+ return function(state)
+ for node in next, state do
+ if not node:match"^%." then
+ if username == true then
+ if lfs.attributes(store_dir..node, "mode") == "directory" then
+ return decode(node);
+ end
+ elseif username then
+ local store = decode(node)
+ if lfs.attributes(getpath(username, host, store, typ), "mode") then
+ return store;
+ end
+ elseif lfs.attributes(node, "mode") == "file" then
+ local file, ext = node:match("^(.*)%.([dalist]+)$");
+ if ext == typ then
+ return decode(file)
+ end
+ end
+ end
+ end
+ end, state;
+end
+
+local function do_remove(path)
+ local ok, err = os_remove(path);
+ if not ok and lfs.attributes(path, "mode") then
+ return ok, err;
+ end
+ return true
+end
+
+function purge(username, host)
+ local host_dir = format("%s/%s/", data_path, encode(host));
+ local errs = {};
+ for file in lfs.dir(host_dir) do
+ if lfs.attributes(host_dir..file, "mode") == "directory" then
+ local store = decode(file);
+ local ok, err = do_remove(getpath(username, host, store));
+ if not ok then errs[#errs+1] = err; end
+
+ local ok, err = do_remove(getpath(username, host, store, "list"));
+ if not ok then errs[#errs+1] = err; end
+ end
+ end
+ return #errs == 0, t_concat(errs, ", ");
+end
+
+_M.path_decode = decode;
+_M.path_encode = encode;
return _M;
diff --git a/util/datetime.lua b/util/datetime.lua
index 301a49a5..dd596527 100644
--- a/util/datetime.lua
+++ b/util/datetime.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -36,7 +36,7 @@ end
function parse(s)
if s then
local year, month, day, hour, min, sec, tzd;
- year, month, day, hour, min, sec, tzd = s:match("^(%d%d%d%d)-?(%d%d)-?(%d%d)T(%d%d):(%d%d):(%d%d)%.?%d*([Z+%-].*)$");
+ year, month, day, hour, min, sec, tzd = s:match("^(%d%d%d%d)%-?(%d%d)%-?(%d%d)T(%d%d):(%d%d):(%d%d)%.?%d*([Z+%-]?.*)$");
if year then
local time_offset = os_difftime(os_time(os_date("*t")), os_time(os_date("!*t"))); -- to deal with local timezone
local tzd_offset = 0;
@@ -49,7 +49,7 @@ function parse(s)
if sign == "-" then tzd_offset = -tzd_offset; end
end
sec = (sec + time_offset) - tzd_offset;
- return os_time({year=year, month=month, day=day, hour=hour, min=min, sec=sec});
+ return os_time({year=year, month=month, day=day, hour=hour, min=min, sec=sec, isdst=false});
end
end
end
diff --git a/util/debug.lua b/util/debug.lua
new file mode 100644
index 00000000..91f691e1
--- /dev/null
+++ b/util/debug.lua
@@ -0,0 +1,199 @@
+-- Variables ending with these names will not
+-- have their values printed ('password' includes
+-- 'new_password', etc.)
+local censored_names = {
+ password = true;
+ passwd = true;
+ pass = true;
+ pwd = true;
+};
+local optimal_line_length = 65;
+
+local termcolours = require "util.termcolours";
+local getstring = termcolours.getstring;
+local styles;
+do
+ _ = termcolours.getstyle;
+ styles = {
+ boundary_padding = _("bright");
+ filename = _("bright", "blue");
+ level_num = _("green");
+ funcname = _("yellow");
+ location = _("yellow");
+ };
+end
+module("debugx", package.seeall);
+
+function get_locals_table(thread, level)
+ local locals = {};
+ for local_num = 1, math.huge do
+ local name, value;
+ if thread then
+ name, value = debug.getlocal(thread, level, local_num);
+ else
+ name, value = debug.getlocal(level+1, local_num);
+ end
+ if not name then break; end
+ table.insert(locals, { name = name, value = value });
+ end
+ return locals;
+end
+
+function get_upvalues_table(func)
+ local upvalues = {};
+ if func then
+ for upvalue_num = 1, math.huge do
+ local name, value = debug.getupvalue(func, upvalue_num);
+ if not name then break; end
+ table.insert(upvalues, { name = name, value = value });
+ end
+ end
+ return upvalues;
+end
+
+function string_from_var_table(var_table, max_line_len, indent_str)
+ local var_string = {};
+ local col_pos = 0;
+ max_line_len = max_line_len or math.huge;
+ indent_str = "\n"..(indent_str or "");
+ for _, var in ipairs(var_table) do
+ local name, value = var.name, var.value;
+ if name:sub(1,1) ~= "(" then
+ if type(value) == "string" then
+ if censored_names[name:match("%a+$")] then
+ value = "<hidden>";
+ else
+ value = ("%q"):format(value);
+ end
+ else
+ value = tostring(value);
+ end
+ if #value > max_line_len then
+ value = value:sub(1, max_line_len-3).."…";
+ end
+ local str = ("%s = %s"):format(name, tostring(value));
+ col_pos = col_pos + #str;
+ if col_pos > max_line_len then
+ table.insert(var_string, indent_str);
+ col_pos = 0;
+ end
+ table.insert(var_string, str);
+ end
+ end
+ if #var_string == 0 then
+ return nil;
+ else
+ return "{ "..table.concat(var_string, ", "):gsub(indent_str..", ", indent_str).." }";
+ end
+end
+
+function get_traceback_table(thread, start_level)
+ local levels = {};
+ for level = start_level, math.huge do
+ local info;
+ if thread then
+ info = debug.getinfo(thread, level);
+ else
+ info = debug.getinfo(level+1);
+ end
+ if not info then break; end
+
+ levels[(level-start_level)+1] = {
+ level = level;
+ info = info;
+ locals = get_locals_table(thread, level+(thread and 0 or 1));
+ upvalues = get_upvalues_table(info.func);
+ };
+ end
+ return levels;
+end
+
+function traceback(...)
+ local ok, ret = pcall(_traceback, ...);
+ if not ok then
+ return "Error in error handling: "..ret;
+ end
+ return ret;
+end
+
+local function build_source_boundary_marker(last_source_desc)
+ local padding = string.rep("-", math.floor(((optimal_line_length - 6) - #last_source_desc)/2));
+ return getstring(styles.boundary_padding, "v"..padding).." "..getstring(styles.filename, last_source_desc).." "..getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v "));
+end
+
+function _traceback(thread, message, level)
+
+ -- Lua manual says: debug.traceback ([thread,] [message [, level]])
+ -- I fathom this to mean one of:
+ -- ()
+ -- (thread)
+ -- (message, level)
+ -- (thread, message, level)
+
+ if thread == nil then -- Defaults
+ thread, message, level = coroutine.running(), message, level;
+ elseif type(thread) == "string" then
+ thread, message, level = coroutine.running(), thread, message;
+ elseif type(thread) ~= "thread" then
+ return nil; -- debug.traceback() does this
+ end
+
+ level = level or 0;
+
+ message = message and (message.."\n") or "";
+
+ -- +3 counts for this function, and the pcall() and wrapper above us, the +1... I don't know.
+ local levels = get_traceback_table(thread, level+(thread == nil and 4 or 0));
+
+ local last_source_desc;
+
+ local lines = {};
+ for nlevel, level in ipairs(levels) do
+ local info = level.info;
+ local line = "...";
+ local func_type = info.namewhat.." ";
+ local source_desc = (info.short_src == "[C]" and "C code") or info.short_src or "Unknown";
+ if func_type == " " then func_type = ""; end;
+ if info.short_src == "[C]" then
+ line = "[ C ] "..func_type.."C function "..getstring(styles.location, (info.name and ("%q"):format(info.name) or "(unknown name)"));
+ elseif info.what == "main" then
+ line = "[Lua] "..getstring(styles.location, info.short_src.." line "..info.currentline);
+ else
+ local name = info.name or " ";
+ if name ~= " " then
+ name = ("%q"):format(name);
+ end
+ if func_type == "global " or func_type == "local " then
+ func_type = func_type.."function ";
+ end
+ line = "[Lua] "..getstring(styles.location, info.short_src.." line "..info.currentline).." in "..func_type..getstring(styles.funcname, name).." (defined on line "..info.linedefined..")";
+ end
+ if source_desc ~= last_source_desc then -- Venturing into a new source, add marker for previous
+ last_source_desc = source_desc;
+ table.insert(lines, "\t "..build_source_boundary_marker(last_source_desc));
+ end
+ nlevel = nlevel-1;
+ table.insert(lines, "\t"..(nlevel==0 and ">" or " ")..getstring(styles.level_num, "("..nlevel..") ")..line);
+ local npadding = (" "):rep(#tostring(nlevel));
+ if level.locals then
+ local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding);
+ if locals_str then
+ table.insert(lines, "\t "..npadding.."Locals: "..locals_str);
+ end
+ end
+ local upvalues_str = string_from_var_table(level.upvalues, optimal_line_length, "\t "..npadding);
+ if upvalues_str then
+ table.insert(lines, "\t "..npadding.."Upvals: "..upvalues_str);
+ end
+ end
+
+-- table.insert(lines, "\t "..build_source_boundary_marker(last_source_desc));
+
+ return message.."stack traceback:\n"..table.concat(lines, "\n");
+end
+
+function use()
+ debug.traceback = traceback;
+end
+
+return _M;
diff --git a/util/dependencies.lua b/util/dependencies.lua
index 9371521c..109a3332 100644
--- a/util/dependencies.lua
+++ b/util/dependencies.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -35,11 +35,24 @@ function missingdep(name, sources, msg)
print("");
end
+-- COMPAT w/pre-0.8 Debian: The Debian config file used to use
+-- util.ztact, which has been removed from Prosody in 0.8. This
+-- is to log an error for people who still use it, so they can
+-- update their configs.
+package.preload["util.ztact"] = function ()
+ if not package.loaded["core.loggingmanager"] then
+ error("util.ztact has been removed from Prosody and you need to fix your config "
+ .."file. More information can be found at http://prosody.im/doc/packagers#ztact", 0);
+ else
+ error("module 'util.ztact' has been deprecated in Prosody 0.8.");
+ end
+end;
+
function check_dependencies()
local fatal;
-
+
local lxp = softreq "lxp"
-
+
if not lxp then
missingdep("luaexpat", {
["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-expat0";
@@ -48,9 +61,9 @@ function check_dependencies()
});
fatal = true;
end
-
+
local socket = softreq "socket"
-
+
if not socket then
missingdep("luasocket", {
["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-socket2";
@@ -59,7 +72,7 @@ function check_dependencies()
});
fatal = true;
end
-
+
local lfs, err = softreq "lfs"
if not lfs then
missingdep("luafilesystem", {
@@ -69,9 +82,9 @@ function check_dependencies()
});
fatal = true;
end
-
+
local ssl = softreq "ssl"
-
+
if not ssl then
missingdep("LuaSec", {
["Debian/Ubuntu"] = "http://prosody.im/download/start#debian_and_ubuntu";
@@ -79,7 +92,7 @@ function check_dependencies()
["Source"] = "http://www.inf.puc-rio.br/~brunoos/luasec/";
}, "SSL/TLS support will not be available");
end
-
+
local encodings, err = softreq "util.encodings"
if not encodings then
if err:match("not found") then
@@ -123,6 +136,14 @@ function log_warnings()
log("error", "This version of LuaSec contains a known bug that causes disconnects, see http://prosody.im/doc/depends");
end
end
+ if lxp then
+ if not pcall(lxp.new, { StartDoctypeDecl = false }) then
+ log("error", "The version of LuaExpat on your system leaves Prosody "
+ .."vulnerable to denial-of-service attacks. You should upgrade to "
+ .."LuaExpat 1.1.1 or higher as soon as possible. See "
+ .."http://prosody.im/doc/depends#luaexpat for more information.");
+ end
+ end
end
return _M;
diff --git a/util/envload.lua b/util/envload.lua
new file mode 100644
index 00000000..53e28348
--- /dev/null
+++ b/util/envload.lua
@@ -0,0 +1,34 @@
+-- Prosody IM
+-- Copyright (C) 2008-2011 Florian Zeitz
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local load, loadstring, loadfile, setfenv = load, loadstring, loadfile, setfenv;
+local envload;
+local envloadfile;
+
+if setfenv then
+ function envload(code, source, env)
+ local f, err = loadstring(code, source);
+ if f and env then setfenv(f, env); end
+ return f, err;
+ end
+
+ function envloadfile(file, env)
+ local f, err = loadfile(file);
+ if f and env then setfenv(f, env); end
+ return f, err;
+ end
+else
+ function envload(code, source, env)
+ return load(code, source, nil, env);
+ end
+
+ function envloadfile(file, env)
+ return loadfile(file, nil, env);
+ end
+end
+
+return { envload = envload, envloadfile = envloadfile };
diff --git a/util/events.lua b/util/events.lua
index 412acccd..40ca3913 100644
--- a/util/events.lua
+++ b/util/events.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -60,11 +60,11 @@ function new()
remove_handler(event, handler);
end
end;
- local function fire_event(event, ...)
- local h = handlers[event];
+ local function fire_event(event_name, event_data)
+ local h = handlers[event_name];
if h then
for i=1,#h do
- local ret = h[i](...);
+ local ret = h[i](event_data);
if ret ~= nil then return ret; end
end
end
diff --git a/util/filters.lua b/util/filters.lua
index d143666b..8a470011 100644
--- a/util/filters.lua
+++ b/util/filters.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -16,7 +16,7 @@ function initialize(session)
if not session.filters then
local filters = {};
session.filters = filters;
-
+
function session.filter(type, data)
local filter_list = filters[type];
if filter_list then
@@ -28,11 +28,11 @@ function initialize(session)
return data;
end
end
-
+
for i=1,#new_filter_hooks do
new_filter_hooks[i](session);
end
-
+
return session.filter;
end
@@ -40,20 +40,20 @@ function add_filter(session, type, callback, priority)
if not session.filters then
initialize(session);
end
-
+
local filter_list = session.filters[type];
if not filter_list then
filter_list = {};
session.filters[type] = filter_list;
end
-
+
priority = priority or 0;
-
+
local i = 0;
repeat
i = i + 1;
until not filter_list[i] or filter_list[filter_list[i]] >= priority;
-
+
t_insert(filter_list, i, callback);
filter_list[callback] = priority;
end
diff --git a/util/helpers.lua b/util/helpers.lua
index 11356176..437a920c 100644
--- a/util/helpers.lua
+++ b/util/helpers.lua
@@ -1,17 +1,27 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
+local debug = require "util.debug";
+
module("helpers", package.seeall);
-- Helper functions for debugging
local log = require "util.logger".init("util.debug");
+function log_host_events(host)
+ return log_events(prosody.hosts[host].events, host);
+end
+
+function revert_log_host_events(host)
+ return revert_log_events(prosody.hosts[host].events);
+end
+
function log_events(events, name, logger)
local f = events.fire_event;
if not f then
@@ -28,7 +38,36 @@ function log_events(events, name, logger)
end
function revert_log_events(events)
- events.fire_event, events[events.fire_event] = events[events.fire_event], nil; -- :)
+ events.fire_event, events[events.fire_event] = events[events.fire_event], nil; -- :))
+end
+
+function show_events(events, specific_event)
+ local event_handlers = events._handlers;
+ local events_array = {};
+ local event_handler_arrays = {};
+ for event in pairs(events._event_map) do
+ local handlers = event_handlers[event];
+ if handlers and (event == specific_event or not specific_event) then
+ table.insert(events_array, event);
+ local handler_strings = {};
+ for i, handler in ipairs(handlers) do
+ local upvals = debug.string_from_var_table(debug.get_upvalues_table(handler));
+ handler_strings[i] = " "..i..": "..tostring(handler)..(upvals and ("\n "..upvals) or "");
+ end
+ event_handler_arrays[event] = handler_strings;
+ end
+ end
+ table.sort(events_array);
+ local i = 1;
+ while i <= #events_array do
+ local handlers = event_handler_arrays[events_array[i]];
+ for j=#handlers, 1, -1 do
+ table.insert(events_array, i+1, handlers[j]);
+ end
+ if i > 1 then events_array[i] = "\n"..events_array[i]; end
+ i = i + #handlers + 1
+ end
+ return table.concat(events_array, "\n");
end
function get_upvalue(f, get_name)
diff --git a/util/hmac.lua b/util/hmac.lua
index 6df6986e..2c4cc6ef 100644
--- a/util/hmac.lua
+++ b/util/hmac.lua
@@ -1,69 +1,15 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
-local hashes = require "util.hashes"
-
-local s_char = string.char;
-local s_gsub = string.gsub;
-local s_rep = string.rep;
-
-module "hmac"
-
-local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;};
-local function xor(x, y)
- local lowx, lowy = x % 16, y % 16;
- local hix, hiy = (x - lowx) / 16, (y - lowy) / 16;
- local lowr, hir = xor_map[lowx * 16 + lowy + 1], xor_map[hix * 16 + hiy + 1];
- local r = hir * 16 + lowr;
- return r;
-end
-local opadc, ipadc = s_char(0x5c), s_char(0x36);
-local ipad_map = {};
-local opad_map = {};
-for i=0,255 do
- ipad_map[s_char(i)] = s_char(xor(0x36, i));
- opad_map[s_char(i)] = s_char(xor(0x5c, i));
-end
-
---[[
-key
- the key to use in the hash
-message
- the message to hash
-hash
- the hash function
-blocksize
- the blocksize for the hash function in bytes
-hex
- return raw hash or hexadecimal string
---]]
-function hmac(key, message, hash, blocksize, hex)
- if #key > blocksize then
- key = hash(key)
- end
+-- COMPAT: Only for external pre-0.9 modules
- local padding = blocksize - #key;
- local ipad = s_gsub(key, ".", ipad_map)..s_rep(ipadc, padding);
- local opad = s_gsub(key, ".", opad_map)..s_rep(opadc, padding);
-
- return hash(opad..hash(ipad..message), hex)
-end
-
-function md5(key, message, hex)
- return hmac(key, message, hashes.md5, 64, hex)
-end
-
-function sha1(key, message, hex)
- return hmac(key, message, hashes.sha1, 64, hex)
-end
-
-function sha256(key, message, hex)
- return hmac(key, message, hashes.sha256, 64, hex)
-end
+local hashes = require "util.hashes"
-return _M
+return { md5 = hashes.hmac_md5,
+ sha1 = hashes.hmac_sha1,
+ sha256 = hashes.hmac_sha256 };
diff --git a/util/http.lua b/util/http.lua
new file mode 100644
index 00000000..f7259920
--- /dev/null
+++ b/util/http.lua
@@ -0,0 +1,64 @@
+-- Prosody IM
+-- Copyright (C) 2013 Florian Zeitz
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local format, char = string.format, string.char;
+local pairs, ipairs, tonumber = pairs, ipairs, tonumber;
+local t_insert, t_concat = table.insert, table.concat;
+
+local function urlencode(s)
+ return s and (s:gsub("[^a-zA-Z0-9.~_-]", function (c) return format("%%%02x", c:byte()); end));
+end
+local function urldecode(s)
+ return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end));
+end
+
+local function _formencodepart(s)
+ return s and (s:gsub("%W", function (c)
+ if c ~= " " then
+ return format("%%%02x", c:byte());
+ else
+ return "+";
+ end
+ end));
+end
+
+local function formencode(form)
+ local result = {};
+ if form[1] then -- Array of ordered { name, value }
+ for _, field in ipairs(form) do
+ t_insert(result, _formencodepart(field.name).."=".._formencodepart(field.value));
+ end
+ else -- Unordered map of name -> value
+ for name, value in pairs(form) do
+ t_insert(result, _formencodepart(name).."=".._formencodepart(value));
+ end
+ end
+ return t_concat(result, "&");
+end
+
+local function formdecode(s)
+ if not s:match("=") then return urldecode(s); end
+ local r = {};
+ for k, v in s:gmatch("([^=&]*)=([^&]*)") do
+ k, v = k:gsub("%+", "%%20"), v:gsub("%+", "%%20");
+ k, v = urldecode(k), urldecode(v);
+ t_insert(r, { name = k, value = v });
+ r[k] = v;
+ end
+ return r;
+end
+
+local function contains_token(field, token)
+ field = ","..field:gsub("[ \t]", ""):lower()..",";
+ return field:find(","..token:lower()..",", 1, true) ~= nil;
+end
+
+return {
+ urlencode = urlencode, urldecode = urldecode;
+ formencode = formencode, formdecode = formdecode;
+ contains_token = contains_token;
+};
diff --git a/util/httpstream.lua b/util/httpstream.lua
deleted file mode 100644
index bdc3fce7..00000000
--- a/util/httpstream.lua
+++ /dev/null
@@ -1,137 +0,0 @@
-
-local coroutine = coroutine;
-local tonumber = tonumber;
-
-local deadroutine = coroutine.create(function() end);
-coroutine.resume(deadroutine);
-
-module("httpstream")
-
-local function parser(success_cb, parser_type, options_cb)
- local data = coroutine.yield();
- local function readline()
- local pos = data:find("\r\n", nil, true);
- while not pos do
- data = data..coroutine.yield();
- pos = data:find("\r\n", nil, true);
- end
- local r = data:sub(1, pos-1);
- data = data:sub(pos+2);
- return r;
- end
- local function readlength(n)
- while #data < n do
- data = data..coroutine.yield();
- end
- local r = data:sub(1, n);
- data = data:sub(n + 1);
- return r;
- end
- local function readheaders()
- local headers = {}; -- read headers
- while true do
- local line = readline();
- if line == "" then break; end -- headers done
- local key, val = line:match("^([^%s:]+): *(.*)$");
- if not key then coroutine.yield("invalid-header-line"); end -- TODO handle multi-line and invalid headers
- key = key:lower();
- headers[key] = headers[key] and headers[key]..","..val or val;
- end
- return headers;
- end
-
- if not parser_type or parser_type == "server" then
- while true do
- -- read status line
- local status_line = readline();
- local method, path, httpversion = status_line:match("^(%S+)%s+(%S+)%s+HTTP/(%S+)$");
- if not method then coroutine.yield("invalid-status-line"); end
- path = path:gsub("^//+", "/"); -- TODO parse url more
- local headers = readheaders();
-
- -- read body
- local len = tonumber(headers["content-length"]);
- len = len or 0; -- TODO check for invalid len
- local body = readlength(len);
-
- success_cb({
- method = method;
- path = path;
- httpversion = httpversion;
- headers = headers;
- body = body;
- });
- end
- elseif parser_type == "client" then
- while true do
- -- read status line
- local status_line = readline();
- local httpversion, status_code, reason_phrase = status_line:match("^HTTP/(%S+)%s+(%d%d%d)%s+(.*)$");
- status_code = tonumber(status_code);
- if not status_code then coroutine.yield("invalid-status-line"); end
- local headers = readheaders();
-
- -- read body
- local have_body = not
- ( (options_cb and options_cb().method == "HEAD")
- or (status_code == 204 or status_code == 304 or status_code == 301)
- or (status_code >= 100 and status_code < 200) );
-
- local body;
- if have_body then
- local len = tonumber(headers["content-length"]);
- if headers["transfer-encoding"] == "chunked" then
- body = "";
- while true do
- local chunk_size = readline():match("^%x+");
- if not chunk_size then coroutine.yield("invalid-chunk-size"); end
- chunk_size = tonumber(chunk_size, 16)
- if chunk_size == 0 then break; end
- body = body..readlength(chunk_size);
- if readline() ~= "" then coroutine.yield("invalid-chunk-ending"); end
- end
- local trailers = readheaders();
- elseif len then -- TODO check for invalid len
- body = readlength(len);
- else -- read to end
- repeat
- local newdata = coroutine.yield();
- data = data..newdata;
- until newdata == "";
- body, data = data, "";
- end
- end
-
- success_cb({
- code = status_code;
- httpversion = httpversion;
- headers = headers;
- body = body;
- -- COMPAT the properties below are deprecated
- responseversion = httpversion;
- responseheaders = headers;
- });
- end
- else coroutine.yield("unknown-parser-type"); end
-end
-
-function new(success_cb, error_cb, parser_type, options_cb)
- local co = coroutine.create(parser);
- coroutine.resume(co, success_cb, parser_type, options_cb)
- return {
- feed = function(self, data)
- if not data then
- if parser_type == "client" then coroutine.resume(co, ""); end
- co = deadroutine;
- return error_cb();
- end
- local success, result = coroutine.resume(co, data);
- if result then
- co = deadroutine;
- return error_cb(result);
- end
- end;
- };
-end
-
-return _M;
diff --git a/util/import.lua b/util/import.lua
index 81401e8b..174da0ca 100644
--- a/util/import.lua
+++ b/util/import.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
diff --git a/util/ip.lua b/util/ip.lua
new file mode 100644
index 00000000..d0ae07eb
--- /dev/null
+++ b/util/ip.lua
@@ -0,0 +1,244 @@
+-- Prosody IM
+-- Copyright (C) 2008-2011 Florian Zeitz
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local ip_methods = {};
+local ip_mt = { __index = function (ip, key) return (ip_methods[key])(ip); end,
+ __tostring = function (ip) return ip.addr; end,
+ __eq = function (ipA, ipB) return ipA.addr == ipB.addr; end};
+local hex2bits = { ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111", ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011", ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111" };
+
+local function new_ip(ipStr, proto)
+ if not proto then
+ local sep = ipStr:match("^%x+(.)");
+ if sep == ":" or (not(sep) and ipStr:sub(1,1) == ":") then
+ proto = "IPv6"
+ elseif sep == "." then
+ proto = "IPv4"
+ end
+ if not proto then
+ return nil, "invalid address";
+ end
+ elseif proto ~= "IPv4" and proto ~= "IPv6" then
+ return nil, "invalid protocol";
+ end
+ if proto == "IPv6" and ipStr:find('.', 1, true) then
+ local changed;
+ ipStr, changed = ipStr:gsub(":(%d+)%.(%d+)%.(%d+)%.(%d+)$", function(a,b,c,d)
+ return (":%04X:%04X"):format(a*256+b,c*256+d);
+ end);
+ if changed ~= 1 then return nil, "invalid-address"; end
+ end
+
+ return setmetatable({ addr = ipStr, proto = proto }, ip_mt);
+end
+
+local function toBits(ip)
+ local result = "";
+ local fields = {};
+ if ip.proto == "IPv4" then
+ ip = ip.toV4mapped;
+ end
+ ip = (ip.addr):upper();
+ ip:gsub("([^:]*):?", function (c) fields[#fields + 1] = c end);
+ if not ip:match(":$") then fields[#fields] = nil; end
+ for i, field in ipairs(fields) do
+ if field:len() == 0 and i ~= 1 and i ~= #fields then
+ for i = 1, 16 * (9 - #fields) do
+ result = result .. "0";
+ end
+ else
+ for i = 1, 4 - field:len() do
+ result = result .. "0000";
+ end
+ for i = 1, field:len() do
+ result = result .. hex2bits[field:sub(i,i)];
+ end
+ end
+ end
+ return result;
+end
+
+local function commonPrefixLength(ipA, ipB)
+ ipA, ipB = toBits(ipA), toBits(ipB);
+ for i = 1, 128 do
+ if ipA:sub(i,i) ~= ipB:sub(i,i) then
+ return i-1;
+ end
+ end
+ return 128;
+end
+
+local function v4scope(ip)
+ local fields = {};
+ ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end);
+ -- Loopback:
+ if fields[1] == 127 then
+ return 0x2;
+ -- Link-local unicast:
+ elseif fields[1] == 169 and fields[2] == 254 then
+ return 0x2;
+ -- Global unicast:
+ else
+ return 0xE;
+ end
+end
+
+local function v6scope(ip)
+ -- Loopback:
+ if ip:match("^[0:]*1$") then
+ return 0x2;
+ -- Link-local unicast:
+ elseif ip:match("^[Ff][Ee][89ABab]") then
+ return 0x2;
+ -- Site-local unicast:
+ elseif ip:match("^[Ff][Ee][CcDdEeFf]") then
+ return 0x5;
+ -- Multicast:
+ elseif ip:match("^[Ff][Ff]") then
+ return tonumber("0x"..ip:sub(4,4));
+ -- Global unicast:
+ else
+ return 0xE;
+ end
+end
+
+local function label(ip)
+ if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then
+ return 0;
+ elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then
+ return 2;
+ elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then
+ return 5;
+ elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then
+ return 13;
+ elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then
+ return 11;
+ elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then
+ return 12;
+ elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then
+ return 3;
+ elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then
+ return 4;
+ else
+ return 1;
+ end
+end
+
+local function precedence(ip)
+ if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then
+ return 50;
+ elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then
+ return 30;
+ elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then
+ return 5;
+ elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then
+ return 3;
+ elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then
+ return 1;
+ elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then
+ return 1;
+ elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then
+ return 1;
+ elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then
+ return 35;
+ else
+ return 40;
+ end
+end
+
+local function toV4mapped(ip)
+ local fields = {};
+ local ret = "::ffff:";
+ ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end);
+ ret = ret .. ("%02x"):format(fields[1]);
+ ret = ret .. ("%02x"):format(fields[2]);
+ ret = ret .. ":"
+ ret = ret .. ("%02x"):format(fields[3]);
+ ret = ret .. ("%02x"):format(fields[4]);
+ return new_ip(ret, "IPv6");
+end
+
+function ip_methods:toV4mapped()
+ if self.proto ~= "IPv4" then return nil, "No IPv4 address" end
+ local value = toV4mapped(self.addr);
+ self.toV4mapped = value;
+ return value;
+end
+
+function ip_methods:label()
+ local value;
+ if self.proto == "IPv4" then
+ value = label(self.toV4mapped);
+ else
+ value = label(self);
+ end
+ self.label = value;
+ return value;
+end
+
+function ip_methods:precedence()
+ local value;
+ if self.proto == "IPv4" then
+ value = precedence(self.toV4mapped);
+ else
+ value = precedence(self);
+ end
+ self.precedence = value;
+ return value;
+end
+
+function ip_methods:scope()
+ local value;
+ if self.proto == "IPv4" then
+ value = v4scope(self.addr);
+ else
+ value = v6scope(self.addr);
+ end
+ self.scope = value;
+ return value;
+end
+
+function ip_methods:private()
+ local private = self.scope ~= 0xE;
+ if not private and self.proto == "IPv4" then
+ local ip = self.addr;
+ local fields = {};
+ ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end);
+ if fields[1] == 127 or fields[1] == 10 or (fields[1] == 192 and fields[2] == 168)
+ or (fields[1] == 172 and (fields[2] >= 16 or fields[2] <= 32)) then
+ private = true;
+ end
+ end
+ self.private = private;
+ return private;
+end
+
+local function parse_cidr(cidr)
+ local bits;
+ local ip_len = cidr:find("/", 1, true);
+ if ip_len then
+ bits = tonumber(cidr:sub(ip_len+1, -1));
+ cidr = cidr:sub(1, ip_len-1);
+ end
+ return new_ip(cidr), bits;
+end
+
+local function match(ipA, ipB, bits)
+ local common_bits = commonPrefixLength(ipA, ipB);
+ if not bits then
+ return ipA == ipB;
+ end
+ if bits and ipB.proto == "IPv4" then
+ common_bits = common_bits - 96; -- v6 mapped addresses always share these bits
+ end
+ return common_bits >= bits;
+end
+
+return {new_ip = new_ip,
+ commonPrefixLength = commonPrefixLength,
+ parse_cidr = parse_cidr,
+ match=match};
diff --git a/util/iterators.lua b/util/iterators.lua
index dc692d64..aa9c3ec0 100644
--- a/util/iterators.lua
+++ b/util/iterators.lua
@@ -1,15 +1,21 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
--[[ Iterators ]]--
+local it = {};
+
+local t_insert = table.insert;
+local select, unpack, next = select, unpack, next;
+local function pack(...) return { n = select("#", ...), ... }; end
+
-- Reverse an iterator
-function reverse(f, s, var)
+function it.reverse(f, s, var)
local results = {};
-- First call the normal iterator
@@ -17,9 +23,9 @@ function reverse(f, s, var)
local ret = { f(s, var) };
var = ret[1];
if var == nil then break; end
- table.insert(results, 1, ret);
+ t_insert(results, 1, ret);
end
-
+
-- Then return our reverse one
local i,max = 0, #results;
return function (results)
@@ -34,12 +40,12 @@ end
local function _keys_it(t, key)
return (next(t, key));
end
-function keys(t)
+function it.keys(t)
return _keys_it, t;
end
-- Iterate only over values in a table
-function values(t)
+function it.values(t)
local key, val;
return function (t)
key, val = next(t, key);
@@ -48,38 +54,37 @@ function values(t)
end
-- Given an iterator, iterate only over unique items
-function unique(f, s, var)
+function it.unique(f, s, var)
local set = {};
-
+
return function ()
while true do
- local ret = { f(s, var) };
+ local ret = pack(f(s, var));
var = ret[1];
if var == nil then break; end
if not set[var] then
set[var] = true;
- return var;
+ return unpack(ret, 1, ret.n);
end
end
end;
end
--[[ Return the number of items an iterator returns ]]--
-function count(f, s, var)
+function it.count(f, s, var)
local x = 0;
-
+
while true do
- local ret = { f(s, var) };
- var = ret[1];
+ var = f(s, var);
if var == nil then break; end
x = x + 1;
end
-
+
return x;
end
-- Return the first n items an iterator returns
-function head(n, f, s, var)
+function it.head(n, f, s, var)
local c = 0;
return function (s, var)
if c >= n then
@@ -91,7 +96,7 @@ function head(n, f, s, var)
end
-- Skip the first n items an iterator returns
-function skip(n, f, s, var)
+function it.skip(n, f, s, var)
for i=1,n do
var = f(s, var);
end
@@ -99,10 +104,10 @@ function skip(n, f, s, var)
end
-- Return the last n items an iterator returns
-function tail(n, f, s, var)
+function it.tail(n, f, s, var)
local results, count = {}, 0;
while true do
- local ret = { f(s, var) };
+ local ret = pack(f(s, var));
var = ret[1];
if var == nil then break; end
results[(count%n)+1] = ret;
@@ -115,26 +120,52 @@ function tail(n, f, s, var)
return function ()
pos = pos + 1;
if pos > n then return nil; end
- return unpack(results[((count-1+pos)%n)+1]);
+ local ret = results[((count-1+pos)%n)+1];
+ return unpack(ret, 1, ret.n);
end
- --return reverse(head(n, reverse(f, s, var)));
+ --return reverse(head(n, reverse(f, s, var))); -- !
+end
+
+function it.filter(filter, f, s, var)
+ if type(filter) ~= "function" then
+ local filter_value = filter;
+ function filter(x) return x ~= filter_value; end
+ end
+ return function (s, var)
+ local ret;
+ repeat ret = pack(f(s, var));
+ var = ret[1];
+ until var == nil or filter(unpack(ret, 1, ret.n));
+ return unpack(ret, 1, ret.n);
+ end, s, var;
+end
+
+local function _ripairs_iter(t, key) if key > 1 then return key-1, t[key-1]; end end
+function it.ripairs(t)
+ return _ripairs_iter, t, #t+1;
+end
+
+local function _range_iter(max, curr) if curr < max then return curr + 1; end end
+function it.range(x, y)
+ if not y then x, y = 1, x; end -- Default to 1..x if y not given
+ return _range_iter, y, x-1;
end
-- Convert the values returned by an iterator to an array
-function it2array(f, s, var)
+function it.to_array(f, s, var)
local t, var = {};
while true do
var = f(s, var);
if var == nil then break; end
- table.insert(t, var);
+ t_insert(t, var);
end
return t;
end
-- Treat the return of an iterator as key,value pairs,
-- and build a table
-function it2table(f, s, var)
- local t, var = {};
+function it.to_table(f, s, var)
+ local t, var2 = {};
while true do
var, var2 = f(s, var);
if var == nil then break; end
@@ -142,3 +173,5 @@ function it2table(f, s, var)
end
return t;
end
+
+return it;
diff --git a/util/jid.lua b/util/jid.lua
index 069817c6..0d9a864f 100644
--- a/util/jid.lua
+++ b/util/jid.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -13,6 +13,16 @@ local nodeprep = require "util.encodings".stringprep.nodeprep;
local nameprep = require "util.encodings".stringprep.nameprep;
local resourceprep = require "util.encodings".stringprep.resourceprep;
+local escapes = {
+ [" "] = "\\20"; ['"'] = "\\22";
+ ["&"] = "\\26"; ["'"] = "\\27";
+ ["/"] = "\\2f"; [":"] = "\\3a";
+ ["<"] = "\\3c"; [">"] = "\\3e";
+ ["@"] = "\\40"; ["\\"] = "\\5c";
+};
+local unescapes = {};
+for k,v in pairs(escapes) do unescapes[v] = k; end
+
module "jid"
local function _split(jid)
@@ -91,4 +101,7 @@ function compare(jid, acl)
return false
end
+function escape(s) return s and (s:gsub(".", escapes)); end
+function unescape(s) return s and (s:gsub("\\%x%x", unescapes)); end
+
return _M;
diff --git a/util/json.lua b/util/json.lua
index 40939bb4..a8a58afc 100644
--- a/util/json.lua
+++ b/util/json.lua
@@ -1,14 +1,24 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
local type = type;
-local t_insert, t_concat, t_remove = table.insert, table.concat, table.remove;
+local t_insert, t_concat, t_remove, t_sort = table.insert, table.concat, table.remove, table.sort;
local s_char = string.char;
local tostring, tonumber = tostring, tonumber;
local pairs, ipairs = pairs, ipairs;
local next = next;
local error = error;
-local newproxy, getmetatable = newproxy, getmetatable;
+local newproxy, getmetatable, setmetatable = newproxy, getmetatable, setmetatable;
local print = print;
+local has_array, array = pcall(require, "util.array");
+local array_mt = has_array and getmetatable(array()) or {};
+
--module("json")
local json = {};
@@ -29,6 +39,19 @@ for i=0,31 do
if not escapes[ch] then escapes[ch] = ("\\u%.4X"):format(i); end
end
+local function codepoint_to_utf8(code)
+ if code < 0x80 then return s_char(code); end
+ local bits0_6 = code % 64;
+ if code < 0x800 then
+ local bits6_5 = (code - bits0_6) / 64;
+ return s_char(0x80 + 0x40 + bits6_5, 0x80 + bits0_6);
+ end
+ local bits0_12 = code % 4096;
+ local bits6_6 = (bits0_12 - bits0_6) / 64;
+ local bits12_4 = (code - bits0_12) / 4096;
+ return s_char(0x80 + 0x40 + 0x20 + bits12_4, 0x80 + bits6_6, 0x80 + bits0_6);
+end
+
local valid_types = {
number = true,
string = true,
@@ -79,11 +102,25 @@ function tablesave(o, buffer)
if next(__hash) ~= nil or next(hash) ~= nil or next(__array) == nil then
t_insert(buffer, "{");
local mark = #buffer;
- for k,v in pairs(hash) do
- stringsave(k, buffer);
- t_insert(buffer, ":");
- simplesave(v, buffer);
- t_insert(buffer, ",");
+ if buffer.ordered then
+ local keys = {};
+ for k in pairs(hash) do
+ t_insert(keys, k);
+ end
+ t_sort(keys);
+ for _,k in ipairs(keys) do
+ stringsave(k, buffer);
+ t_insert(buffer, ":");
+ simplesave(hash[k], buffer);
+ t_insert(buffer, ",");
+ end
+ else
+ for k,v in pairs(hash) do
+ stringsave(k, buffer);
+ t_insert(buffer, ":");
+ simplesave(v, buffer);
+ t_insert(buffer, ",");
+ end
end
if next(__hash) ~= nil then
t_insert(buffer, "\"__hash\":[");
@@ -116,7 +153,12 @@ function simplesave(o, buffer)
elseif t == "string" then
stringsave(o, buffer);
elseif t == "table" then
- tablesave(o, buffer);
+ local mt = getmetatable(o);
+ if mt == array_mt then
+ arraysave(o, buffer);
+ else
+ tablesave(o, buffer);
+ end
elseif t == "boolean" then
t_insert(buffer, (o and "true" or "false"));
else
@@ -129,214 +171,191 @@ function json.encode(obj)
simplesave(obj, t);
return t_concat(t);
end
+function json.encode_ordered(obj)
+ local t = { ordered = true };
+ simplesave(obj, t);
+ return t_concat(t);
+end
+function json.encode_array(obj)
+ local t = {};
+ arraysave(obj, t);
+ return t_concat(t);
+end
-----------------------------------
-function json.decode(json)
- local pos = 1;
- local current = {};
- local stack = {};
- local ch, peek;
- local function next()
- ch = json:sub(pos, pos);
- pos = pos+1;
- peek = json:sub(pos, pos);
- return ch;
- end
-
- local function skipwhitespace()
- while ch and (ch == "\r" or ch == "\n" or ch == "\t" or ch == " ") do
- next();
+local function _skip_whitespace(json, index)
+ return json:find("[^ \t\r\n]", index) or index; -- no need to check \r\n, we converted those to \t
+end
+local function _fixobject(obj)
+ local __array = obj.__array;
+ if __array then
+ obj.__array = nil;
+ for i,v in ipairs(__array) do
+ t_insert(obj, v);
end
end
- local function skiplinecomment()
- repeat next(); until not(ch) or ch == "\r" or ch == "\n";
- skipwhitespace();
- end
- local function skipstarcomment()
- next(); next(); -- skip '/', '*'
- while peek and ch ~= "*" and peek ~= "/" do next(); end
- if not peek then error("eof in star comment") end
- next(); next(); -- skip '*', '/'
- skipwhitespace();
- end
- local function skipstuff()
- while true do
- skipwhitespace();
- if ch == "/" and peek == "*" then
- skipstarcomment();
- elseif ch == "/" and peek == "*" then
- skiplinecomment();
+ local __hash = obj.__hash;
+ if __hash then
+ obj.__hash = nil;
+ local k;
+ for i,v in ipairs(__hash) do
+ if k ~= nil then
+ obj[k] = v; k = nil;
else
- return;
+ k = v;
end
end
end
-
- local readvalue;
- local function readarray()
- local t = {};
- next(); -- skip '['
- skipstuff();
- if ch == "]" then next(); return t; end
- t_insert(t, readvalue());
- while true do
- skipstuff();
- if ch == "]" then next(); return t; end
- if not ch then error("eof while reading array");
- elseif ch == "," then next();
- elseif ch then error("unexpected character in array, comma expected"); end
- if not ch then error("eof while reading array"); end
- t_insert(t, readvalue());
+ return obj;
+end
+local _readvalue, _readstring;
+local function _readobject(json, index)
+ local o = {};
+ while true do
+ local key, val;
+ index = _skip_whitespace(json, index + 1);
+ if json:byte(index) ~= 0x22 then -- "\""
+ if json:byte(index) == 0x7d then return o, index + 1; end -- "}"
+ return nil, "key expected";
end
+ key, index = _readstring(json, index);
+ if key == nil then return nil, index; end
+ index = _skip_whitespace(json, index);
+ if json:byte(index) ~= 0x3a then return nil, "colon expected"; end -- ":"
+ val, index = _readvalue(json, index + 1);
+ if val == nil then return nil, index; end
+ o[key] = val;
+ index = _skip_whitespace(json, index);
+ local b = json:byte(index);
+ if b == 0x7d then return _fixobject(o), index + 1; end -- "}"
+ if b ~= 0x2c then return nil, "object eof"; end -- ","
end
-
- local function checkandskip(c)
- local x = ch or "eof";
- if x ~= c then error("unexpected "..x..", '"..c.."' expected"); end
- next();
- end
- local function readliteral(lit, val)
- for c in lit:gmatch(".") do
- checkandskip(c);
+end
+local function _readarray(json, index)
+ local a = {};
+ local oindex = index;
+ while true do
+ local val;
+ val, index = _readvalue(json, index + 1);
+ if val == nil then
+ if json:byte(oindex + 1) == 0x5d then return setmetatable(a, array_mt), oindex + 2; end -- "]"
+ return val, index;
end
- return val;
+ t_insert(a, val);
+ index = _skip_whitespace(json, index);
+ local b = json:byte(index);
+ if b == 0x5d then return setmetatable(a, array_mt), index + 1; end -- "]"
+ if b ~= 0x2c then return nil, "array eof"; end -- ","
end
- local function readstring()
- local s = "";
- checkandskip("\"");
- while ch do
- while ch and ch ~= "\\" and ch ~= "\"" do
- s = s..ch; next();
- end
- if ch == "\\" then
- next();
- if unescapes[ch] then
- s = s..unescapes[ch];
- next();
- elseif ch == "u" then
- local seq = "";
- for i=1,4 do
- next();
- if not ch then error("unexpected eof in string"); end
- if not ch:match("[0-9a-fA-F]") then error("invalid unicode escape sequence in string"); end
- seq = seq..ch;
- end
- s = s..s.char(tonumber(seq, 16)); -- FIXME do proper utf-8
- next();
- else error("invalid escape sequence in string"); end
- end
- if ch == "\"" then
- next();
- return s;
- end
- end
- error("eof while reading string");
+end
+local _unescape_error;
+local function _unescape_surrogate_func(x)
+ local lead, trail = tonumber(x:sub(3, 6), 16), tonumber(x:sub(9, 12), 16);
+ local codepoint = lead * 0x400 + trail - 0x35FDC00;
+ local a = codepoint % 64;
+ codepoint = (codepoint - a) / 64;
+ local b = codepoint % 64;
+ codepoint = (codepoint - b) / 64;
+ local c = codepoint % 64;
+ codepoint = (codepoint - c) / 64;
+ return s_char(0xF0 + codepoint, 0x80 + c, 0x80 + b, 0x80 + a);
+end
+local function _unescape_func(x)
+ x = x:match("%x%x%x%x", 3);
+ if x then
+ --if x >= 0xD800 and x <= 0xDFFF then _unescape_error = true; end -- bad surrogate pair
+ return codepoint_to_utf8(tonumber(x, 16));
end
- local function readnumber()
- local s = "";
- if ch == "-" then
- s = s..ch; next();
- if not ch:match("[0-9]") then error("number format error"); end
- end
- if ch == "0" then
- s = s..ch; next();
- if ch:match("[0-9]") then error("number format error"); end
- else
- while ch and ch:match("[0-9]") do
- s = s..ch; next();
- end
- end
- if ch == "." then
- s = s..ch; next();
- if not ch:match("[0-9]") then error("number format error"); end
- while ch and ch:match("[0-9]") do
- s = s..ch; next();
- end
- if ch == "e" or ch == "E" then
- s = s..ch; next();
- if ch == "+" or ch == "-" then
- s = s..ch; next();
- if not ch:match("[0-9]") then error("number format error"); end
- while ch and ch:match("[0-9]") do
- s = s..ch; next();
- end
- end
- end
- end
- return tonumber(s);
+ _unescape_error = true;
+end
+function _readstring(json, index)
+ index = index + 1;
+ local endindex = json:find("\"", index, true);
+ if endindex then
+ local s = json:sub(index, endindex - 1);
+ --if s:find("[%z-\31]") then return nil, "control char in string"; end
+ -- FIXME handle control characters
+ _unescape_error = nil;
+ --s = s:gsub("\\u[dD][89abAB]%x%x\\u[dD][cdefCDEF]%x%x", _unescape_surrogate_func);
+ -- FIXME handle escapes beyond BMP
+ s = s:gsub("\\u.?.?.?.?", _unescape_func);
+ if _unescape_error then return nil, "invalid escape"; end
+ return s, endindex + 1;
end
- local function readmember(t)
- local k = readstring();
- checkandskip(":");
- t[k] = readvalue();
+ return nil, "string eof";
+end
+local function _readnumber(json, index)
+ local m = json:match("[0-9%.%-eE%+]+", index); -- FIXME do strict checking
+ return tonumber(m), index + #m;
+end
+local function _readnull(json, index)
+ local a, b, c = json:byte(index + 1, index + 3);
+ if a == 0x75 and b == 0x6c and c == 0x6c then
+ return null, index + 4;
end
- local function fixobject(obj)
- local __array = obj.__array;
- if __array then
- obj.__array = nil;
- for i,v in ipairs(__array) do
- t_insert(obj, v);
- end
- end
- local __hash = obj.__hash;
- if __hash then
- obj.__hash = nil;
- local k;
- for i,v in ipairs(__hash) do
- if k ~= nil then
- obj[k] = v; k = nil;
- else
- k = v;
- end
- end
- end
- return obj;
+ return nil, "null parse failed";
+end
+local function _readtrue(json, index)
+ local a, b, c = json:byte(index + 1, index + 3);
+ if a == 0x72 and b == 0x75 and c == 0x65 then
+ return true, index + 4;
end
- local function readobject()
- local t = {};
- next(); -- skip '{'
- skipstuff();
- if ch == "}" then next(); return t; end
- if not ch then error("eof while reading object"); end
- readmember(t);
- while true do
- skipstuff();
- if ch == "}" then next(); return fixobject(t); end
- if not ch then error("eof while reading object");
- elseif ch == "," then next();
- elseif ch then error("unexpected character in object, comma expected"); end
- if not ch then error("eof while reading object"); end
- readmember(t);
- end
+ return nil, "true parse failed";
+end
+local function _readfalse(json, index)
+ local a, b, c, d = json:byte(index + 1, index + 4);
+ if a == 0x61 and b == 0x6c and c == 0x73 and d == 0x65 then
+ return false, index + 5;
end
-
- function readvalue()
- skipstuff();
- while ch do
- if ch == "{" then
- return readobject();
- elseif ch == "[" then
- return readarray();
- elseif ch == "\"" then
- return readstring();
- elseif ch:match("[%-0-9%.]") then
- return readnumber();
- elseif ch == "n" then
- return readliteral("null", null);
- elseif ch == "t" then
- return readliteral("true", true);
- elseif ch == "f" then
- return readliteral("false", false);
- else
- error("invalid character at value start: "..ch);
- end
- end
- error("eof while reading value");
+ return nil, "false parse failed";
+end
+function _readvalue(json, index)
+ index = _skip_whitespace(json, index);
+ local b = json:byte(index);
+ -- TODO try table lookup instead of if-else?
+ if b == 0x7B then -- "{"
+ return _readobject(json, index);
+ elseif b == 0x5B then -- "["
+ return _readarray(json, index);
+ elseif b == 0x22 then -- "\""
+ return _readstring(json, index);
+ elseif b ~= nil and b >= 0x30 and b <= 0x39 or b == 0x2d then -- "0"-"9" or "-"
+ return _readnumber(json, index);
+ elseif b == 0x6e then -- "n"
+ return _readnull(json, index);
+ elseif b == 0x74 then -- "t"
+ return _readtrue(json, index);
+ elseif b == 0x66 then -- "f"
+ return _readfalse(json, index);
+ else
+ return nil, "value expected";
end
- next();
- return readvalue();
+end
+local first_escape = {
+ ["\\\""] = "\\u0022";
+ ["\\\\"] = "\\u005c";
+ ["\\/" ] = "\\u002f";
+ ["\\b" ] = "\\u0008";
+ ["\\f" ] = "\\u000C";
+ ["\\n" ] = "\\u000A";
+ ["\\r" ] = "\\u000D";
+ ["\\t" ] = "\\u0009";
+ ["\\u" ] = "\\u";
+};
+
+function json.decode(json)
+ json = json:gsub("\\.", first_escape) -- get rid of all escapes except \uXXXX, making string parsing much simpler
+ --:gsub("[\r\n]", "\t"); -- \r\n\t are equivalent, we care about none of them, and none of them can be in strings
+
+ -- TODO do encoding verification
+
+ local val, index = _readvalue(json, 1);
+ if val == nil then return val, index; end
+ if json:find("[^ \t\r\n]", index) then return nil, "garbage at eof"; end
+
+ return val;
end
function json.test(object)
diff --git a/util/logger.lua b/util/logger.lua
index c3bf3992..cd0769f9 100644
--- a/util/logger.lua
+++ b/util/logger.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -13,8 +13,7 @@ local ipairs, pairs, setmetatable = ipairs, pairs, setmetatable;
module "logger"
-local name_sinks, level_sinks = {}, {};
-local name_patterns = {};
+local level_sinks = {};
local make_logger;
@@ -24,8 +23,6 @@ function init(name)
local log_warn = make_logger(name, "warn");
local log_error = make_logger(name, "error");
- --name = nil; -- While this line is not commented, will automatically fill in file/line number info
- local namelen = #name;
return function (level, message, ...)
if level == "debug" then
return log_debug(message, ...);
@@ -46,17 +43,7 @@ function make_logger(source_name, level)
level_sinks[level] = level_handlers;
end
- local source_handlers = name_sinks[source_name];
-
local logger = function (message, ...)
- if source_handlers then
- for i = 1,#source_handlers do
- if source_handlers[i](source_name, level, message, ...) == false then
- return;
- end
- end
- end
-
for i = 1,#level_handlers do
level_handlers[i](source_name, level, message, ...);
end
@@ -66,14 +53,12 @@ function make_logger(source_name, level)
end
function reset()
- for k in pairs(name_sinks) do name_sinks[k] = nil; end
for level, handler_list in pairs(level_sinks) do
-- Clear all handlers for this level
for i = 1, #handler_list do
handler_list[i] = nil;
end
end
- for k in pairs(name_patterns) do name_patterns[k] = nil; end
end
function add_level_sink(level, sink_function)
@@ -84,22 +69,6 @@ function add_level_sink(level, sink_function)
end
end
-function add_name_sink(name, sink_function, exclusive)
- if not name_sinks[name] then
- name_sinks[name] = { sink_function };
- else
- name_sinks[name][#name_sinks[name] + 1] = sink_function;
- end
-end
-
-function add_name_pattern_sink(name_pattern, sink_function, exclusive)
- if not name_patterns[name_pattern] then
- name_patterns[name_pattern] = { sink_function };
- else
- name_patterns[name_pattern][#name_patterns[name_pattern] + 1] = sink_function;
- end
-end
-
_M.new = make_logger;
return _M;
diff --git a/util/multitable.lua b/util/multitable.lua
index 66b9bd8a..caf25118 100644
--- a/util/multitable.lua
+++ b/util/multitable.lua
@@ -1,17 +1,14 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
-
-
local select = select;
local t_insert = table.insert;
-local pairs = pairs;
-local next = next;
+local unpack, pairs, next, type = unpack, pairs, next, type;
module "multitable"
@@ -129,6 +126,41 @@ local function search_add(self, results, ...)
return results;
end
+function iter(self, ...)
+ local query = { ... };
+ local maxdepth = select("#", ...);
+ local stack = { self.data };
+ local keys = { };
+ local function it(self)
+ local depth = #stack;
+ local key = next(stack[depth], keys[depth]);
+ if key == nil then -- Go up the stack
+ stack[depth], keys[depth] = nil, nil;
+ if depth > 1 then
+ return it(self);
+ end
+ return; -- The end
+ else
+ keys[depth] = key;
+ end
+ local value = stack[depth][key];
+ if query[depth] == nil or key == query[depth] then
+ if depth == maxdepth then -- Result
+ local result = {}; -- Collect keys forming path to result
+ for i = 1, depth do
+ result[i] = keys[i];
+ end
+ result[depth+1] = value;
+ return unpack(result, 1, depth+1);
+ elseif type(value) == "table" then
+ t_insert(stack, value); -- Descend
+ end
+ end
+ return it(self);
+ end;
+ return it, self;
+end
+
function new()
return {
data = {};
@@ -138,6 +170,7 @@ function new()
remove = remove;
search = search;
search_add = search_add;
+ iter = iter;
};
end
diff --git a/util/openssl.lua b/util/openssl.lua
new file mode 100644
index 00000000..ef3fba96
--- /dev/null
+++ b/util/openssl.lua
@@ -0,0 +1,172 @@
+local type, tostring, pairs, ipairs = type, tostring, pairs, ipairs;
+local t_insert, t_concat = table.insert, table.concat;
+local s_format = string.format;
+
+local oid_xmppaddr = "1.3.6.1.5.5.7.8.5"; -- [XMPP-CORE]
+local oid_dnssrv = "1.3.6.1.5.5.7.8.7"; -- [SRV-ID]
+
+local idna_to_ascii = require "util.encodings".idna.to_ascii;
+
+local _M = {};
+local config = {};
+_M.config = config;
+
+local ssl_config = {};
+local ssl_config_mt = {__index=ssl_config};
+
+function config.new()
+ return setmetatable({
+ req = {
+ distinguished_name = "distinguished_name",
+ req_extensions = "v3_extensions",
+ x509_extensions = "v3_extensions",
+ prompt = "no",
+ },
+ distinguished_name = {
+ countryName = "GB",
+ -- stateOrProvinceName = "",
+ localityName = "The Internet",
+ organizationName = "Your Organisation",
+ organizationalUnitName = "XMPP Department",
+ commonName = "example.com",
+ emailAddress = "xmpp@example.com",
+ },
+ v3_extensions = {
+ basicConstraints = "CA:FALSE",
+ keyUsage = "digitalSignature,keyEncipherment",
+ extendedKeyUsage = "serverAuth,clientAuth",
+ subjectAltName = "@subject_alternative_name",
+ },
+ subject_alternative_name = {
+ DNS = {},
+ otherName = {},
+ },
+ }, ssl_config_mt);
+end
+
+local DN_order = {
+ "countryName";
+ "stateOrProvinceName";
+ "localityName";
+ "streetAddress";
+ "organizationName";
+ "organizationalUnitName";
+ "commonName";
+ "emailAddress";
+}
+_M._DN_order = DN_order;
+function ssl_config:serialize()
+ local s = "";
+ for k, t in pairs(self) do
+ s = s .. ("[%s]\n"):format(k);
+ if k == "subject_alternative_name" then
+ for san, n in pairs(t) do
+ for i = 1,#n do
+ s = s .. s_format("%s.%d = %s\n", san, i -1, n[i]);
+ end
+ end
+ elseif k == "distinguished_name" then
+ for i=1,#DN_order do
+ local k = DN_order[i]
+ local v = t[k];
+ if v then
+ s = s .. ("%s = %s\n"):format(k, v);
+ end
+ end
+ else
+ for k, v in pairs(t) do
+ s = s .. ("%s = %s\n"):format(k, v);
+ end
+ end
+ s = s .. "\n";
+ end
+ return s;
+end
+
+local function utf8string(s)
+ -- This is how we tell openssl not to encode UTF-8 strings as fake Latin1
+ return s_format("FORMAT:UTF8,UTF8:%s", s);
+end
+
+local function ia5string(s)
+ return s_format("IA5STRING:%s", s);
+end
+
+_M.util = {
+ utf8string = utf8string,
+ ia5string = ia5string,
+};
+
+function ssl_config:add_dNSName(host)
+ t_insert(self.subject_alternative_name.DNS, idna_to_ascii(host));
+end
+
+function ssl_config:add_sRVName(host, service)
+ t_insert(self.subject_alternative_name.otherName,
+ s_format("%s;%s", oid_dnssrv, ia5string("_" .. service .."." .. idna_to_ascii(host))));
+end
+
+function ssl_config:add_xmppAddr(host)
+ t_insert(self.subject_alternative_name.otherName,
+ s_format("%s;%s", oid_xmppaddr, utf8string(host)));
+end
+
+function ssl_config:from_prosody(hosts, config, certhosts)
+ -- TODO Decide if this should go elsewhere
+ local found_matching_hosts = false;
+ for i = 1,#certhosts do
+ local certhost = certhosts[i];
+ for name in pairs(hosts) do
+ if name == certhost or name:sub(-1-#certhost) == "."..certhost then
+ found_matching_hosts = true;
+ self:add_dNSName(name);
+ --print(name .. "#component_module: " .. (config.get(name, "component_module") or "nil"));
+ if config.get(name, "component_module") == nil then
+ self:add_sRVName(name, "xmpp-client");
+ end
+ --print(name .. "#anonymous_login: " .. tostring(config.get(name, "anonymous_login")));
+ if not (config.get(name, "anonymous_login") or
+ config.get(name, "authentication") == "anonymous") then
+ self:add_sRVName(name, "xmpp-server");
+ end
+ self:add_xmppAddr(name);
+ end
+ end
+ end
+ if not found_matching_hosts then
+ return nil, "no-matching-hosts";
+ end
+end
+
+do -- Lua to shell calls.
+ local function shell_escape(s)
+ return s:gsub("'",[['\'']]);
+ end
+
+ local function serialize(f,o)
+ local r = {"openssl", f};
+ for k,v in pairs(o) do
+ if type(k) == "string" then
+ t_insert(r, ("-%s"):format(k));
+ if v ~= true then
+ t_insert(r, ("'%s'"):format(shell_escape(tostring(v))));
+ end
+ end
+ end
+ for _,v in ipairs(o) do
+ t_insert(r, ("'%s'"):format(shell_escape(tostring(v))));
+ end
+ return t_concat(r, " ");
+ end
+
+ local os_execute = os.execute;
+ setmetatable(_M, {
+ __index=function(_,f)
+ return function(opts)
+ return 0 == os_execute(serialize(f, type(opts) == "table" and opts or {}));
+ end;
+ end;
+ });
+end
+
+return _M;
diff --git a/util/pluginloader.lua b/util/pluginloader.lua
index 31ab1e88..b894f527 100644
--- a/util/pluginloader.lua
+++ b/util/pluginloader.lua
@@ -1,58 +1,60 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
+local dir_sep, path_sep = package.config:match("^(%S+)%s(%S+)");
+local plugin_dir = {};
+for path in (CFG_PLUGINDIR or "./plugins/"):gsub("[/\\]", dir_sep):gmatch("[^"..path_sep.."]+") do
+ path = path..dir_sep; -- add path separator to path end
+ path = path:gsub(dir_sep..dir_sep.."+", dir_sep); -- coalesce multiple separaters
+ plugin_dir[#plugin_dir + 1] = path;
+end
-local plugin_dir = CFG_PLUGINDIR or "./plugins/";
-
-local io_open, os_time = io.open, os.time;
-local loadstring, pairs = loadstring, pairs;
+local io_open = io.open;
+local envload = require "util.envload".envload;
module "pluginloader"
-local function load_file(name)
- local file, err = io_open(plugin_dir..name);
- if not file then return file, err; end
- local content = file:read("*a");
- file:close();
- return content, name;
+function load_file(names)
+ local file, err, path;
+ for i=1,#plugin_dir do
+ for j=1,#names do
+ path = plugin_dir[i]..names[j];
+ file, err = io_open(path);
+ if file then
+ local content = file:read("*a");
+ file:close();
+ return content, path;
+ end
+ end
+ end
+ return file, err;
end
-function load_resource(plugin, resource, loader)
- local path, name = plugin:match("([^/]*)/?(.*)");
- if name == "" then
- if not resource then
- resource = "mod_"..plugin..".lua";
- end
- loader = loader or load_file;
-
- local content, err = loader(plugin.."/"..resource);
- if not content then content, err = loader(resource); end
- -- TODO add support for packed plugins
-
- return content, err;
- else
- if not resource then
- resource = "mod_"..name..".lua";
- end
- loader = loader or load_file;
+function load_resource(plugin, resource)
+ resource = resource or "mod_"..plugin..".lua";
- local content, err = loader(plugin.."/"..resource);
- if not content then content, err = loader(path.."/"..resource); end
- -- TODO add support for packed plugins
-
- return content, err;
- end
+ local names = {
+ "mod_"..plugin.."/"..plugin.."/"..resource; -- mod_hello/hello/mod_hello.lua
+ "mod_"..plugin.."/"..resource; -- mod_hello/mod_hello.lua
+ plugin.."/"..resource; -- hello/mod_hello.lua
+ resource; -- mod_hello.lua
+ };
+
+ return load_file(names);
end
-function load_code(plugin, resource)
+function load_code(plugin, resource, env)
local content, err = load_resource(plugin, resource);
if not content then return content, err; end
- return loadstring(content, "@"..err);
+ local path = err;
+ local f, err = envload(content, "@"..path, env);
+ if not f then return f, err; end
+ return f, path;
end
return _M;
diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua
index 40d21be8..fe862114 100644
--- a/util/prosodyctl.lua
+++ b/util/prosodyctl.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -15,18 +15,119 @@ local usermanager = require "core.usermanager";
local signal = require "util.signal";
local set = require "util.set";
local lfs = require "lfs";
+local pcall = pcall;
+local type = type;
local nodeprep, nameprep = stringprep.nodeprep, stringprep.nameprep;
local io, os = io, os;
+local print = print;
local tostring, tonumber = tostring, tonumber;
local CFG_SOURCEDIR = _G.CFG_SOURCEDIR;
+local _G = _G;
local prosody = prosody;
module "prosodyctl"
+-- UI helpers
+function show_message(msg, ...)
+ print(msg:format(...));
+end
+
+function show_warning(msg, ...)
+ print(msg:format(...));
+end
+
+function show_usage(usage, desc)
+ print("Usage: ".._G.arg[0].." "..usage);
+ if desc then
+ print(" "..desc);
+ end
+end
+
+function getchar(n)
+ local stty_ret = os.execute("stty raw -echo 2>/dev/null");
+ local ok, char;
+ if stty_ret == 0 then
+ ok, char = pcall(io.read, n or 1);
+ os.execute("stty sane");
+ else
+ ok, char = pcall(io.read, "*l");
+ if ok then
+ char = char:sub(1, n or 1);
+ end
+ end
+ if ok then
+ return char;
+ end
+end
+
+function getline()
+ local ok, line = pcall(io.read, "*l");
+ if ok then
+ return line;
+ end
+end
+
+function getpass()
+ local stty_ret = os.execute("stty -echo 2>/dev/null");
+ if stty_ret ~= 0 then
+ io.write("\027[08m"); -- ANSI 'hidden' text attribute
+ end
+ local ok, pass = pcall(io.read, "*l");
+ if stty_ret == 0 then
+ os.execute("stty sane");
+ else
+ io.write("\027[00m");
+ end
+ io.write("\n");
+ if ok then
+ return pass;
+ end
+end
+
+function show_yesno(prompt)
+ io.write(prompt, " ");
+ local choice = getchar():lower();
+ io.write("\n");
+ if not choice:match("%a") then
+ choice = prompt:match("%[.-(%U).-%]$");
+ if not choice then return nil; end
+ end
+ return (choice == "y");
+end
+
+function read_password()
+ local password;
+ while true do
+ io.write("Enter new password: ");
+ password = getpass();
+ if not password then
+ show_message("No password - cancelled");
+ return;
+ end
+ io.write("Retype new password: ");
+ if getpass() ~= password then
+ if not show_yesno [=[Passwords did not match, try again? [Y/n]]=] then
+ return;
+ end
+ else
+ break;
+ end
+ end
+ return password;
+end
+
+function show_prompt(prompt)
+ io.write(prompt, " ");
+ local line = getline();
+ line = line and line:gsub("\n$","");
+ return (line and #line > 0) and line or nil;
+end
+
+-- Server control
function adduser(params)
local user, host, password = nodeprep(params.user), nameprep(params.host), params.password;
if not user then
@@ -35,12 +136,17 @@ function adduser(params)
return false, "invalid-hostname";
end
- local provider = prosody.hosts[host].users;
+ local host_session = prosody.hosts[host];
+ if not host_session then
+ return false, "no-such-host";
+ end
+
+ storagemanager.initialize_host(host);
+ local provider = host_session.users;
if not(provider) or provider.name == "null" then
usermanager.initialize_host(host);
end
- storagemanager.initialize_host(host);
-
+
local ok, errmsg = usermanager.create_user(user, password, host);
if not ok then
return false, errmsg;
@@ -50,12 +156,13 @@ end
function user_exists(params)
local user, host, password = nodeprep(params.user), nameprep(params.host), params.password;
+
+ storagemanager.initialize_host(host);
local provider = prosody.hosts[host].users;
if not(provider) or provider.name == "null" then
usermanager.initialize_host(host);
end
- storagemanager.initialize_host(host);
-
+
return usermanager.user_exists(user, host);
end
@@ -63,7 +170,7 @@ function passwd(params)
if not _M.user_exists(params) then
return false, "no-such-user";
end
-
+
return _M.adduser(params);
end
@@ -71,40 +178,40 @@ function deluser(params)
if not _M.user_exists(params) then
return false, "no-such-user";
end
- params.password = nil;
-
- return _M.adduser(params);
+ local user, host = nodeprep(params.user), nameprep(params.host);
+
+ return usermanager.delete_user(user, host);
end
function getpid()
- local pidfile = config.get("*", "core", "pidfile");
+ local pidfile = config.get("*", "pidfile");
if not pidfile then
return false, "no-pidfile";
end
-
- local modules_enabled = set.new(config.get("*", "core", "modules_enabled"));
+
+ local modules_enabled = set.new(config.get("*", "modules_enabled"));
if not modules_enabled:contains("posix") then
return false, "no-posix";
end
-
+
local file, err = io.open(pidfile, "r+");
if not file then
return false, "pidfile-read-failed", err;
end
-
+
local locked, err = lfs.lock(file, "w");
if locked then
file:close();
return false, "pidfile-not-locked";
end
-
+
local pid = tonumber(file:read("*a"));
file:close();
-
+
if not pid then
return false, "invalid-pid";
end
-
+
return true, pid;
end
@@ -145,10 +252,28 @@ function stop()
if not ret then
return false, "not-running";
end
-
+
local ok, pid = _M.getpid()
if not ok then return false, pid; end
-
+
signal.kill(pid, signal.SIGTERM);
return true;
end
+
+function reload()
+ local ok, ret = _M.isrunning();
+ if not ok then
+ return ok, ret;
+ end
+ if not ret then
+ return false, "not-running";
+ end
+
+ local ok, pid = _M.getpid()
+ if not ok then return false, pid; end
+
+ signal.kill(pid, signal.SIGHUP);
+ return true;
+end
+
+return _M;
diff --git a/util/pubsub.lua b/util/pubsub.lua
index 3beafab5..0dfd196b 100644
--- a/util/pubsub.lua
+++ b/util/pubsub.lua
@@ -1,3 +1,5 @@
+local events = require "util.events";
+
module("pubsub", package.seeall);
local service = {};
@@ -16,6 +18,7 @@ function new(config)
affiliations = {};
subscriptions = {};
nodes = {};
+ events = events.new();
}, service_mt);
end
@@ -26,18 +29,15 @@ end
function service:may(node, actor, action)
if actor == true then return true; end
-
-
+
local node_obj = self.nodes[node];
local node_aff = node_obj and node_obj.affiliations[actor];
local service_aff = self.affiliations[actor]
or self.config.get_affiliation(actor, node, action)
or "none";
-
+
+ -- Check if node allows/forbids it
local node_capabilities = node_obj and node_obj.capabilities;
- local service_capabilities = self.config.capabilities;
-
- -- Check if node allows/forbids it
if node_capabilities then
local caps = node_capabilities[node_aff or service_aff];
if caps then
@@ -47,7 +47,9 @@ function service:may(node, actor, action)
end
end
end
+
-- Check service-wide capabilities instead
+ local service_capabilities = self.config.capabilities;
local caps = service_capabilities[node_aff or service_aff];
if caps then
local can = caps[action];
@@ -55,7 +57,7 @@ function service:may(node, actor, action)
return can;
end
end
-
+
return false;
end
@@ -70,14 +72,14 @@ function service:set_affiliation(node, actor, jid, affiliation)
return false, "item-not-found";
end
node_obj.affiliations[jid] = affiliation;
- local _, jid_sub = self:get_subscription(node, nil, jid);
+ local _, jid_sub = self:get_subscription(node, true, jid);
if not jid_sub and not self:may(node, jid, "be_unsubscribed") then
- local ok, err = self:add_subscription(node, nil, jid);
+ local ok, err = self:add_subscription(node, true, jid);
if not ok then
return ok, err;
end
elseif jid_sub and not self:may(node, jid, "be_subscribed") then
- local ok, err = self:add_subscription(node, nil, jid);
+ local ok, err = self:add_subscription(node, true, jid);
if not ok then
return ok, err;
end
@@ -88,7 +90,7 @@ end
function service:add_subscription(node, actor, jid, options)
-- Access checking
local cap;
- if jid == actor or self:jids_equal(actor, jid) then
+ if actor == true or jid == actor or self:jids_equal(actor, jid) then
cap = "subscribe";
else
cap = "subscribe_other";
@@ -105,7 +107,7 @@ function service:add_subscription(node, actor, jid, options)
if not self.config.autocreate_on_subscribe then
return false, "item-not-found";
else
- local ok, err = self:create(node, actor);
+ local ok, err = self:create(node, true);
if not ok then
return ok, err;
end
@@ -124,13 +126,14 @@ function service:add_subscription(node, actor, jid, options)
else
self.subscriptions[normal_jid] = { [jid] = { [node] = true } };
end
+ self.events.fire_event("subscription-added", { node = node, jid = jid, normalized_jid = normal_jid, options = options });
return true;
end
function service:remove_subscription(node, actor, jid)
-- Access checking
local cap;
- if jid == actor or self:jids_equal(actor, jid) then
+ if actor == true or jid == actor or self:jids_equal(actor, jid) then
cap = "unsubscribe";
else
cap = "unsubscribe_other";
@@ -164,13 +167,26 @@ function service:remove_subscription(node, actor, jid)
self.subscriptions[normal_jid] = nil;
end
end
+ self.events.fire_event("subscription-removed", { node = node, jid = jid, normalized_jid = normal_jid });
+ return true;
+end
+
+function service:remove_all_subscriptions(actor, jid)
+ local normal_jid = self.config.normalize_jid(jid);
+ local subs = self.subscriptions[normal_jid]
+ subs = subs and subs[jid];
+ if subs then
+ for node in pairs(subs) do
+ self:remove_subscription(node, true, jid);
+ end
+ end
return true;
end
function service:get_subscription(node, actor, jid)
-- Access checking
local cap;
- if jid == actor or self:jids_equal(actor, jid) then
+ if actor == true or jid == actor or self:jids_equal(actor, jid) then
cap = "get_subscription";
else
cap = "get_subscription_other";
@@ -195,7 +211,7 @@ function service:create(node, actor)
if self.nodes[node] then
return false, "conflict";
end
-
+
self.nodes[node] = {
name = node;
subscribers = {};
@@ -210,6 +226,21 @@ function service:create(node, actor)
return ok, err;
end
+function service:delete(node, actor)
+ -- Access checking
+ if not self:may(node, actor, "delete") then
+ return false, "forbidden";
+ end
+ --
+ local node_obj = self.nodes[node];
+ if not node_obj then
+ return false, "item-not-found";
+ end
+ self.nodes[node] = nil;
+ self.config.broadcaster("delete", node, node_obj.subscribers);
+ return true;
+end
+
function service:publish(node, actor, id, item)
-- Access checking
if not self:may(node, actor, "publish") then
@@ -221,14 +252,15 @@ function service:publish(node, actor, id, item)
if not self.config.autocreate_on_publish then
return false, "item-not-found";
end
- local ok, err = self:create(node, actor);
+ local ok, err = self:create(node, true);
if not ok then
return ok, err;
end
node_obj = self.nodes[node];
end
node_obj.data[id] = item;
- self.config.broadcaster(node, node_obj.subscribers, item);
+ self.events.fire_event("item-published", { node = node, actor = actor, id = id, item = item });
+ self.config.broadcaster("items", node, node_obj.subscribers, item);
return true;
end
@@ -244,7 +276,24 @@ function service:retract(node, actor, id, retract)
end
node_obj.data[id] = nil;
if retract then
- self.config.broadcaster(node, node_obj.subscribers, retract);
+ self.config.broadcaster("items", node, node_obj.subscribers, retract);
+ end
+ return true
+end
+
+function service:purge(node, actor, notify)
+ -- Access checking
+ if not self:may(node, actor, "retract") then
+ return false, "forbidden";
+ end
+ --
+ local node_obj = self.nodes[node];
+ if not node_obj then
+ return false, "item-not-found";
+ end
+ node_obj.data = {}; -- Purge
+ if notify then
+ self.config.broadcaster("purge", node, node_obj.subscribers);
end
return true
end
@@ -278,7 +327,7 @@ end
function service:get_subscriptions(node, actor, jid)
-- Access checking
local cap;
- if jid == actor or self:jids_equal(actor, jid) then
+ if actor == true or jid == actor or self:jids_equal(actor, jid) then
cap = "get_subscriptions";
else
cap = "get_subscriptions_other";
@@ -304,7 +353,7 @@ function service:get_subscriptions(node, actor, jid)
if node then -- Return only subscriptions to this node
if subscribed_nodes[node] then
ret[#ret+1] = {
- node = subscribed_node;
+ node = node;
jid = jid;
subscription = node_obj.subscribers[jid];
};
diff --git a/util/rfc6724.lua b/util/rfc6724.lua
new file mode 100644
index 00000000..c8aec631
--- /dev/null
+++ b/util/rfc6724.lua
@@ -0,0 +1,142 @@
+-- Prosody IM
+-- Copyright (C) 2011-2013 Florian Zeitz
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+-- This is used to sort destination addresses by preference
+-- during S2S connections.
+-- We can't hand this off to getaddrinfo, since it blocks
+
+local ip_commonPrefixLength = require"util.ip".commonPrefixLength
+local new_ip = require"util.ip".new_ip;
+
+local function commonPrefixLength(ipA, ipB)
+ local len = ip_commonPrefixLength(ipA, ipB);
+ return len < 64 and len or 64;
+end
+
+local function t_sort(t, comp)
+ for i = 1, (#t - 1) do
+ for j = (i + 1), #t do
+ local a, b = t[i], t[j];
+ if not comp(a,b) then
+ t[i], t[j] = b, a;
+ end
+ end
+ end
+end
+
+local function source(dest, candidates)
+ local function comp(ipA, ipB)
+ -- Rule 1: Prefer same address
+ if dest == ipA then
+ return true;
+ elseif dest == ipB then
+ return false;
+ end
+
+ -- Rule 2: Prefer appropriate scope
+ if ipA.scope < ipB.scope then
+ if ipA.scope < dest.scope then
+ return false;
+ else
+ return true;
+ end
+ elseif ipA.scope > ipB.scope then
+ if ipB.scope < dest.scope then
+ return true;
+ else
+ return false;
+ end
+ end
+
+ -- Rule 3: Avoid deprecated addresses
+ -- XXX: No way to determine this
+ -- Rule 4: Prefer home addresses
+ -- XXX: Mobility Address related, no way to determine this
+ -- Rule 5: Prefer outgoing interface
+ -- XXX: Interface to address relation. No way to determine this
+ -- Rule 6: Prefer matching label
+ if ipA.label == dest.label and ipB.label ~= dest.label then
+ return true;
+ elseif ipB.label == dest.label and ipA.label ~= dest.label then
+ return false;
+ end
+
+ -- Rule 7: Prefer temporary addresses (over public ones)
+ -- XXX: No way to determine this
+ -- Rule 8: Use longest matching prefix
+ if commonPrefixLength(ipA, dest) > commonPrefixLength(ipB, dest) then
+ return true;
+ else
+ return false;
+ end
+ end
+
+ t_sort(candidates, comp);
+ return candidates[1];
+end
+
+local function destination(candidates, sources)
+ local sourceAddrs = {};
+ local function comp(ipA, ipB)
+ local ipAsource = sourceAddrs[ipA];
+ local ipBsource = sourceAddrs[ipB];
+ -- Rule 1: Avoid unusable destinations
+ -- XXX: No such information
+ -- Rule 2: Prefer matching scope
+ if ipA.scope == ipAsource.scope and ipB.scope ~= ipBsource.scope then
+ return true;
+ elseif ipA.scope ~= ipAsource.scope and ipB.scope == ipBsource.scope then
+ return false;
+ end
+
+ -- Rule 3: Avoid deprecated addresses
+ -- XXX: No way to determine this
+ -- Rule 4: Prefer home addresses
+ -- XXX: Mobility Address related, no way to determine this
+ -- Rule 5: Prefer matching label
+ if ipAsource.label == ipA.label and ipBsource.label ~= ipB.label then
+ return true;
+ elseif ipBsource.label == ipB.label and ipAsource.label ~= ipA.label then
+ return false;
+ end
+
+ -- Rule 6: Prefer higher precedence
+ if ipA.precedence > ipB.precedence then
+ return true;
+ elseif ipA.precedence < ipB.precedence then
+ return false;
+ end
+
+ -- Rule 7: Prefer native transport
+ -- XXX: No way to determine this
+ -- Rule 8: Prefer smaller scope
+ if ipA.scope < ipB.scope then
+ return true;
+ elseif ipA.scope > ipB.scope then
+ return false;
+ end
+
+ -- Rule 9: Use longest matching prefix
+ if commonPrefixLength(ipA, ipAsource) > commonPrefixLength(ipB, ipBsource) then
+ return true;
+ elseif commonPrefixLength(ipA, ipAsource) < commonPrefixLength(ipB, ipBsource) then
+ return false;
+ end
+
+ -- Rule 10: Otherwise, leave order unchanged
+ return true;
+ end
+ for _, ip in ipairs(candidates) do
+ sourceAddrs[ip] = source(ip, sources);
+ end
+
+ t_sort(candidates, comp);
+ return candidates;
+end
+
+return {source = source,
+ destination = destination};
diff --git a/util/sasl.lua b/util/sasl.lua
index 393a0919..0d90880d 100644
--- a/util/sasl.lua
+++ b/util/sasl.lua
@@ -68,13 +68,17 @@ end
-- create a new SASL object which can be used to authenticate clients
function new(realm, profile)
- local mechanisms = {};
- for backend, f in pairs(profile) do
- if backend_mechanism[backend] then
- for _, mechanism in ipairs(backend_mechanism[backend]) do
- mechanisms[mechanism] = true;
+ local mechanisms = profile.mechanisms;
+ if not mechanisms then
+ mechanisms = {};
+ for backend, f in pairs(profile) do
+ if backend_mechanism[backend] then
+ for _, mechanism in ipairs(backend_mechanism[backend]) do
+ mechanisms[mechanism] = true;
+ end
end
end
+ profile.mechanisms = mechanisms;
end
return setmetatable({ profile = profile, realm = realm, mechs = mechanisms }, method);
end
@@ -131,5 +135,6 @@ require "util.sasl.plain" .init(registerMechanism);
require "util.sasl.digest-md5".init(registerMechanism);
require "util.sasl.anonymous" .init(registerMechanism);
require "util.sasl.scram" .init(registerMechanism);
+require "util.sasl.external" .init(registerMechanism);
return _M;
diff --git a/util/sasl/anonymous.lua b/util/sasl/anonymous.lua
index b9af17fe..ca5fe404 100644
--- a/util/sasl/anonymous.lua
+++ b/util/sasl/anonymous.lua
@@ -16,7 +16,7 @@ local s_match = string.match;
local log = require "util.logger".init("sasl");
local generate_uuid = require "util.uuid".generate;
-module "anonymous"
+module "sasl.anonymous"
--=========================
--SASL ANONYMOUS according to RFC 4505
@@ -43,4 +43,4 @@ function init(registerMechanism)
registerMechanism("ANONYMOUS", {"anonymous"}, anonymous);
end
-return _M; \ No newline at end of file
+return _M;
diff --git a/util/sasl/digest-md5.lua b/util/sasl/digest-md5.lua
index 6f2c765e..591d8537 100644
--- a/util/sasl/digest-md5.lua
+++ b/util/sasl/digest-md5.lua
@@ -23,8 +23,9 @@ local to_byte, to_char = string.byte, string.char;
local md5 = require "util.hashes".md5;
local log = require "util.logger".init("sasl");
local generate_uuid = require "util.uuid".generate;
+local nodeprep = require "util.encodings".stringprep.nodeprep;
-module "digest-md5"
+module "sasl.digest-md5"
--=========================
--SASL DIGEST-MD5 according to RFC 2831
@@ -139,10 +140,15 @@ local function digest(self, message)
end
-- check for username, it's REQUIRED by RFC 2831
- if not response["username"] then
+ local username = response["username"];
+ local _nodeprep = self.profile.nodeprep;
+ if username and _nodeprep ~= false then
+ username = (_nodeprep or nodeprep)(username); -- FIXME charset
+ end
+ if not username or username == "" then
return "failure", "malformed-request";
end
- self["username"] = response["username"];
+ self.username = username;
-- check for nonce, ...
if not response["nonce"] then
@@ -178,7 +184,6 @@ local function digest(self, message)
end
--TODO maybe realm support
- self.username = response["username"];
local Y, state;
if self.profile.plain then
local password, state = self.profile.plain(self, response["username"], self.realm)
@@ -240,4 +245,4 @@ function init(registerMechanism)
registerMechanism("DIGEST-MD5", {"plain"}, digest);
end
-return _M; \ No newline at end of file
+return _M;
diff --git a/util/sasl/external.lua b/util/sasl/external.lua
new file mode 100644
index 00000000..4c5c4343
--- /dev/null
+++ b/util/sasl/external.lua
@@ -0,0 +1,25 @@
+local saslprep = require "util.encodings".stringprep.saslprep;
+
+module "sasl.external"
+
+local function external(self, message)
+ message = saslprep(message);
+ local state
+ self.username, state = self.profile.external(message);
+
+ if state == false then
+ return "failure", "account-disabled";
+ elseif state == nil then
+ return "failure", "not-authorized";
+ elseif state == "expired" then
+ return "false", "credentials-expired";
+ end
+
+ return "success";
+end
+
+function init(registerMechanism)
+ registerMechanism("EXTERNAL", {"external"}, external);
+end
+
+return _M;
diff --git a/util/sasl/plain.lua b/util/sasl/plain.lua
index d6ebe304..c9ec2911 100644
--- a/util/sasl/plain.lua
+++ b/util/sasl/plain.lua
@@ -13,9 +13,10 @@
local s_match = string.match;
local saslprep = require "util.encodings".stringprep.saslprep;
+local nodeprep = require "util.encodings".stringprep.nodeprep;
local log = require "util.logger".init("sasl");
-module "plain"
+module "sasl.plain"
-- ================================
-- SASL PLAIN according to RFC 4616
@@ -54,6 +55,14 @@ local function plain(self, message)
return "failure", "malformed-request", "Invalid username or password.";
end
+ local _nodeprep = self.profile.nodeprep;
+ if _nodeprep ~= false then
+ authentication = (_nodeprep or nodeprep)(authentication);
+ if not authentication or authentication == "" then
+ return "failure", "malformed-request", "Invalid username or password."
+ end
+ end
+
local correct, state = false, false;
if self.profile.plain then
local correct_password;
@@ -64,15 +73,13 @@ local function plain(self, message)
end
self.username = authentication
- if not state then
+ if state == false then
return "failure", "account-disabled";
- end
-
- if correct then
- return "success";
- else
+ elseif state == nil or not correct then
return "failure", "not-authorized", "Unable to authorize you with the authentication credentials you've sent.";
end
+
+ return "success";
end
function init(registerMechanism)
diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua
index 071de505..cf938dba 100644
--- a/util/sasl/scram.lua
+++ b/util/sasl/scram.lua
@@ -16,16 +16,18 @@ local type = type
local string = string
local tostring = tostring;
local base64 = require "util.encodings".base64;
-local hmac_sha1 = require "util.hmac".sha1;
+local hmac_sha1 = require "util.hashes".hmac_sha1;
local sha1 = require "util.hashes".sha1;
+local Hi = require "util.hashes".scram_Hi_sha1;
local generate_uuid = require "util.uuid".generate;
local saslprep = require "util.encodings".stringprep.saslprep;
+local nodeprep = require "util.encodings".stringprep.nodeprep;
local log = require "util.logger".init("sasl");
local t_concat = table.concat;
local char = string.char;
local byte = string.byte;
-module "scram"
+module "sasl.scram"
--=========================
--SASL SCRAM-SHA-1 according to RFC 5802
@@ -69,33 +71,26 @@ local function binaryXOR( a, b )
return t_concat(result);
end
--- hash algorithm independent Hi(PBKDF2) implementation
-function Hi(hmac, str, salt, i)
- local Ust = hmac(str, salt.."\0\0\0\1");
- local res = Ust;
- for n=1,i-1 do
- local Und = hmac(str, Ust)
- res = binaryXOR(res, Und)
- Ust = Und
- end
- return res
-end
-
-local function validate_username(username)
+local function validate_username(username, _nodeprep)
-- check for forbidden char sequences
for eq in username:gmatch("=(.?.?)") do
- if eq ~= "2D" and eq ~= "3D" then
+ if eq ~= "2C" and eq ~= "3D" then
return false
end
end
-
- -- replace =2D with , and =3D with =
- username = username:gsub("=2D", ",");
+
+ -- replace =2C with , and =3D with =
+ username = username:gsub("=2C", ",");
username = username:gsub("=3D", "=");
-
+
-- apply SASLprep
username = saslprep(username);
- return username;
+
+ if username and _nodeprep ~= false then
+ username = (_nodeprep or nodeprep)(username);
+ end
+
+ return username and #username>0 and username;
end
local function hashprep(hashname)
@@ -109,7 +104,7 @@ function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
if iteration_count < 4096 then
log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
end
- local salted_password = Hi(hmac_sha1, password, salt, iteration_count);
+ local salted_password = Hi(password, salt, iteration_count);
local stored_key = sha1(hmac_sha1(salted_password, "Client Key"))
local server_key = hmac_sha1(salted_password, "Server Key");
return true, stored_key, server_key
@@ -120,12 +115,12 @@ local function scram_gen(hash_name, H_f, HMAC_f)
if not self.state then self["state"] = {} end
local support_channel_binding = false;
if self.profile.cb then support_channel_binding = true; end
-
+
if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end
if not self.state.name then
-- we are processing client_first_message
local client_first_message = message;
- log("debug", client_first_message);
+
-- TODO: fail if authzid is provided, since we don't support them yet
self.state["client_first_message"] = client_first_message;
self.state["gs2_cbind_flag"], self.state["gs2_cbind_name"], self.state["authzid"], self.state["name"], self.state["clientnonce"]
@@ -156,21 +151,21 @@ local function scram_gen(hash_name, H_f, HMAC_f)
if not self.state.name or not self.state.clientnonce then
return "failure", "malformed-request", "Channel binding isn't support at this time.";
end
-
- self.state.name = validate_username(self.state.name);
+
+ self.state.name = validate_username(self.state.name, self.profile.nodeprep);
if not self.state.name then
log("debug", "Username violates either SASLprep or contains forbidden character sequences.")
return "failure", "malformed-request", "Invalid username.";
end
-
+
self.state["servernonce"] = generate_uuid();
-
+
-- retreive credentials
if self.profile.plain then
local password, state = self.profile.plain(self, self.state.name, self.realm)
if state == nil then return "failure", "not-authorized"
elseif state == false then return "failure", "account-disabled" end
-
+
password = saslprep(password);
if not password then
log("debug", "Password violates SASLprep.");
@@ -190,20 +185,20 @@ local function scram_gen(hash_name, H_f, HMAC_f)
local stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, self.state.name, self.realm);
if state == nil then return "failure", "not-authorized"
elseif state == false then return "failure", "account-disabled" end
-
+
self.state.stored_key = stored_key;
self.state.server_key = server_key;
self.state.iteration_count = iteration_count;
self.state.salt = salt
end
-
+
local server_first_message = "r="..self.state.clientnonce..self.state.servernonce..",s="..base64.encode(self.state.salt)..",i="..self.state.iteration_count;
self.state["server_first_message"] = server_first_message;
return "challenge", server_first_message
else
-- we are processing client_final_message
local client_final_message = message;
- log("debug", "client_final_message: %s", client_final_message);
+
self.state["channelbinding"], self.state["nonce"], self.state["proof"] = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)");
if not self.state.proof or not self.state.nonce or not self.state.channelbinding then
@@ -223,10 +218,10 @@ local function scram_gen(hash_name, H_f, HMAC_f)
if self.state.nonce ~= self.state.clientnonce..self.state.servernonce then
return "failure", "malformed-request", "Wrong nonce in client-final-message.";
end
-
+
local ServerKey = self.state.server_key;
local StoredKey = self.state.stored_key;
-
+
local AuthMessage = "n=" .. s_match(self.state.client_first_message,"n=(.+)") .. "," .. self.state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+")
local ClientSignature = HMAC_f(StoredKey, AuthMessage)
local ClientKey = binaryXOR(ClientSignature, base64.decode(self.state.proof))
diff --git a/util/sasl_cyrus.lua b/util/sasl_cyrus.lua
index 002118fd..a0e8bd69 100644
--- a/util/sasl_cyrus.lua
+++ b/util/sasl_cyrus.lua
@@ -78,11 +78,15 @@ local function init(service_name)
end
-- create a new SASL object which can be used to authenticate clients
-function new(realm, service_name, app_name)
+-- host_fqdn may be nil in which case gethostname() gives the value.
+-- For GSSAPI, this determines the hostname in the service ticket (after
+-- reverse DNS canonicalization, only if [libdefaults] rdns = true which
+-- is the default).
+function new(realm, service_name, app_name, host_fqdn)
init(app_name or service_name);
- local st, ret = pcall(cyrussasl.server_new, service_name, nil, realm, nil, nil)
+ local st, ret = pcall(cyrussasl.server_new, service_name, host_fqdn, realm, nil, nil)
if not st then
log("error", "Creating SASL server connection failed: %s", ret);
return nil;
diff --git a/util/serialization.lua b/util/serialization.lua
index e193b64f..06e45054 100644
--- a/util/serialization.lua
+++ b/util/serialization.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -16,11 +16,12 @@ local pairs = pairs;
local next = next;
local loadstring = loadstring;
-local setfenv = setfenv;
local pcall = pcall;
local debug_traceback = debug.traceback;
local log = require "util.logger".init("serialization");
+local envload = require"util.envload".envload;
+
module "serialization"
local indent = function(i)
@@ -84,9 +85,8 @@ end
function deserialize(str)
if type(str) ~= "string" then return nil; end
str = "return "..str;
- local f, err = loadstring(str, "@data");
+ local f, err = envload(str, "@data", {});
if not f then return nil, err; end
- setfenv(f, {});
local success, ret = pcall(f);
if not success then return nil, ret; end
return ret;
diff --git a/util/set.lua b/util/set.lua
index e4cc2dff..89cd7cf3 100644
--- a/util/set.lua
+++ b/util/set.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -26,8 +26,9 @@ function set_mt.__div(set, func)
local new_set, new_items = _M.new();
local items, new_items = set._items, new_set._items;
for item in pairs(items) do
- if func(item) then
- new_items[item] = true;
+ local new_item = func(item);
+ if new_item ~= nil then
+ new_items[new_item] = true;
end
end
return new_set;
@@ -39,13 +40,13 @@ function set_mt.__eq(set1, set2)
return false;
end
end
-
+
for item in pairs(set2) do
if not set1[item] then
return false;
end
end
-
+
return true;
end
function set_mt.__tostring(set)
@@ -64,56 +65,58 @@ end
function new(list)
local items = setmetatable({}, items_mt);
local set = { _items = items };
-
+
function set:add(item)
items[item] = true;
end
-
+
function set:contains(item)
return items[item];
end
-
+
function set:items()
- return items;
+ return next, items;
end
-
+
function set:remove(item)
items[item] = nil;
end
-
+
function set:add_list(list)
- for _, item in ipairs(list) do
- items[item] = true;
+ if list then
+ for _, item in ipairs(list) do
+ items[item] = true;
+ end
end
end
-
+
function set:include(otherset)
- for item in pairs(otherset) do
+ for item in otherset do
items[item] = true;
end
end
function set:exclude(otherset)
- for item in pairs(otherset) do
+ for item in otherset do
items[item] = nil;
end
end
-
+
function set:empty()
return not next(items);
end
-
+
if list then
set:add_list(list);
end
-
+
return setmetatable(set, set_mt);
end
function union(set1, set2)
local set = new();
local items = set._items;
-
+
for item in pairs(set1._items) do
items[item] = true;
end
@@ -121,14 +124,14 @@ function union(set1, set2)
for item in pairs(set2._items) do
items[item] = true;
end
-
+
return set;
end
function difference(set1, set2)
local set = new();
local items = set._items;
-
+
for item in pairs(set1._items) do
items[item] = (not set2._items[item]) or nil;
end
@@ -139,13 +142,13 @@ end
function intersection(set1, set2)
local set = new();
local items = set._items;
-
+
set1, set2 = set1._items, set2._items;
-
+
for item in pairs(set1) do
items[item] = (not not set2[item]) or nil;
end
-
+
return set;
end
diff --git a/util/sql.lua b/util/sql.lua
new file mode 100644
index 00000000..b8c16e27
--- /dev/null
+++ b/util/sql.lua
@@ -0,0 +1,342 @@
+
+local setmetatable, getmetatable = setmetatable, getmetatable;
+local ipairs, unpack, select = ipairs, unpack, select;
+local tonumber, tostring = tonumber, tostring;
+local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback;
+local t_concat = table.concat;
+local s_char = string.char;
+local log = require "util.logger".init("sql");
+
+local DBI = require "DBI";
+-- This loads all available drivers while globals are unlocked
+-- LuaDBI should be fixed to not set globals.
+DBI.Drivers();
+local build_url = require "socket.url".build;
+
+module("sql")
+
+local column_mt = {};
+local table_mt = {};
+local query_mt = {};
+--local op_mt = {};
+local index_mt = {};
+
+function is_column(x) return getmetatable(x)==column_mt; end
+function is_index(x) return getmetatable(x)==index_mt; end
+function is_table(x) return getmetatable(x)==table_mt; end
+function is_query(x) return getmetatable(x)==query_mt; end
+--function is_op(x) return getmetatable(x)==op_mt; end
+--function expr(...) return setmetatable({...}, op_mt); end
+function Integer(n) return "Integer()" end
+function String(n) return "String()" end
+
+--[[local ops = {
+ __add = function(a, b) return "("..a.."+"..b..")" end;
+ __sub = function(a, b) return "("..a.."-"..b..")" end;
+ __mul = function(a, b) return "("..a.."*"..b..")" end;
+ __div = function(a, b) return "("..a.."/"..b..")" end;
+ __mod = function(a, b) return "("..a.."%"..b..")" end;
+ __pow = function(a, b) return "POW("..a..","..b..")" end;
+ __unm = function(a) return "NOT("..a..")" end;
+ __len = function(a) return "COUNT("..a..")" end;
+ __eq = function(a, b) return "("..a.."=="..b..")" end;
+ __lt = function(a, b) return "("..a.."<"..b..")" end;
+ __le = function(a, b) return "("..a.."<="..b..")" end;
+};
+
+local functions = {
+
+};
+
+local cmap = {
+ [Integer] = Integer();
+ [String] = String();
+};]]
+
+function Column(definition)
+ return setmetatable(definition, column_mt);
+end
+function Table(definition)
+ local c = {}
+ for i,col in ipairs(definition) do
+ if is_column(col) then
+ c[i], c[col.name] = col, col;
+ elseif is_index(col) then
+ col.table = definition.name;
+ end
+ end
+ return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
+end
+function Index(definition)
+ return setmetatable(definition, index_mt);
+end
+
+function table_mt:__tostring()
+ local s = { 'name="'..self.__table__.name..'"' }
+ for i,col in ipairs(self.__table__) do
+ s[#s+1] = tostring(col);
+ end
+ return 'Table{ '..t_concat(s, ", ")..' }'
+end
+table_mt.__index = {};
+function table_mt.__index:create(engine)
+ return engine:_create_table(self);
+end
+function table_mt:__call(...)
+ -- TODO
+end
+function column_mt:__tostring()
+ return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
+end
+function index_mt:__tostring()
+ local s = 'Index{ name="'..self.name..'"';
+ for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end
+ return s..' }';
+-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
+end
+--
+
+local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
+local function parse_url(url)
+ local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)");
+ assert(scheme, "Invalid URL format");
+ local username, password, host, port;
+ local authpart, hostpart = secondpart:match("([^@]+)@([^@+])");
+ if not authpart then hostpart = secondpart; end
+ if authpart then
+ username, password = authpart:match("([^:]*):(.*)");
+ username = username or authpart;
+ password = password and urldecode(password);
+ end
+ if hostpart then
+ host, port = hostpart:match("([^:]*):(.*)");
+ host = host or hostpart;
+ port = port and assert(tonumber(port), "Invalid URL format");
+ end
+ return {
+ scheme = scheme:lower();
+ username = username; password = password;
+ host = host; port = port;
+ database = #database > 0 and database or nil;
+ };
+end
+
+--[[local session = {};
+
+function session.query(...)
+ local rets = {...};
+ local query = setmetatable({ __rets = rets, __filters }, query_mt);
+ return query;
+end
+--
+
+local function db2uri(params)
+ return build_url{
+ scheme = params.driver,
+ user = params.username,
+ password = params.password,
+ host = params.host,
+ port = params.port,
+ path = params.database,
+ };
+end]]
+
+local engine = {};
+function engine:connect()
+ if self.conn then return true; end
+
+ local params = self.params;
+ assert(params.driver, "no driver")
+ local dbh, err = DBI.Connect(
+ params.driver, params.database,
+ params.username, params.password,
+ params.host, params.port
+ );
+ if not dbh then return nil, err; end
+ dbh:autocommit(false); -- don't commit automatically
+ self.conn = dbh;
+ self.prepared = {};
+ return true;
+end
+function engine:execute(sql, ...)
+ local success, err = self:connect();
+ if not success then return success, err; end
+ local prepared = self.prepared;
+
+ local stmt = prepared[sql];
+ if not stmt then
+ local err;
+ stmt, err = self.conn:prepare(sql);
+ if not stmt then return stmt, err; end
+ prepared[sql] = stmt;
+ end
+
+ local success, err = stmt:execute(...);
+ if not success then return success, err; end
+ return stmt;
+end
+
+local result_mt = { __index = {
+ affected = function(self) return self.__stmt:affected(); end;
+ rowcount = function(self) return self.__stmt:rowcount(); end;
+} };
+
+function engine:execute_query(sql, ...)
+ if self.params.driver == "PostgreSQL" then
+ sql = sql:gsub("`", "\"");
+ end
+ local stmt = assert(self.conn:prepare(sql));
+ assert(stmt:execute(...));
+ return stmt:rows();
+end
+function engine:execute_update(sql, ...)
+ if self.params.driver == "PostgreSQL" then
+ sql = sql:gsub("`", "\"");
+ end
+ local prepared = self.prepared;
+ local stmt = prepared[sql];
+ if not stmt then
+ stmt = assert(self.conn:prepare(sql));
+ prepared[sql] = stmt;
+ end
+ assert(stmt:execute(...));
+ return setmetatable({ __stmt = stmt }, result_mt);
+end
+engine.insert = engine.execute_update;
+engine.select = engine.execute_query;
+engine.delete = engine.execute_update;
+engine.update = engine.execute_update;
+function engine:_transaction(func, ...)
+ if not self.conn then
+ local a,b = self:connect();
+ if not a then return a,b; end
+ end
+ --assert(not self.__transaction, "Recursive transactions not allowed");
+ local args, n_args = {...}, select("#", ...);
+ local function f() return func(unpack(args, 1, n_args)); end
+ self.__transaction = true;
+ local success, a, b, c = xpcall(f, debug_traceback);
+ self.__transaction = nil;
+ if success then
+ log("debug", "SQL transaction success [%s]", tostring(func));
+ local ok, err = self.conn:commit();
+ if not ok then return ok, err; end -- commit failed
+ return success, a, b, c;
+ else
+ log("debug", "SQL transaction failure [%s]: %s", tostring(func), a);
+ if self.conn then self.conn:rollback(); end
+ return success, a;
+ end
+end
+function engine:transaction(...)
+ local a,b = self:_transaction(...);
+ if not a then
+ local conn = self.conn;
+ if not conn or not conn:ping() then
+ self.conn = nil;
+ a,b = self:_transaction(...);
+ end
+ end
+ return a,b;
+end
+function engine:_create_index(index)
+ local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` (";
+ for i=1,#index do
+ sql = sql.."`"..index[i].."`";
+ if i ~= #index then sql = sql..", "; end
+ end
+ sql = sql..");"
+ if self.params.driver == "PostgreSQL" then
+ sql = sql:gsub("`", "\"");
+ elseif self.params.driver == "MySQL" then
+ sql = sql:gsub("`([,)])", "`(20)%1");
+ end
+ --print(sql);
+ return self:execute(sql);
+end
+function engine:_create_table(table)
+ local sql = "CREATE TABLE `"..table.name.."` (";
+ for i,col in ipairs(table.c) do
+ sql = sql.."`"..col.name.."` "..col.type;
+ if col.nullable == false then sql = sql.." NOT NULL"; end
+ if i ~= #table.c then sql = sql..", "; end
+ end
+ sql = sql.. ");"
+ if self.params.driver == "PostgreSQL" then
+ sql = sql:gsub("`", "\"");
+ elseif self.params.driver == "MySQL" then
+ sql = sql:gsub(";$", " CHARACTER SET 'utf8' COLLATE 'utf8_bin';");
+ end
+ local success,err = self:execute(sql);
+ if not success then return success,err; end
+ for i,v in ipairs(table.__table__) do
+ if is_index(v) then
+ self:_create_index(v);
+ end
+ end
+ return success;
+end
+local engine_mt = { __index = engine };
+
+local function db2uri(params)
+ return build_url{
+ scheme = params.driver,
+ user = params.username,
+ password = params.password,
+ host = params.host,
+ port = params.port,
+ path = params.database,
+ };
+end
+local engine_cache = {}; -- TODO make weak valued
+function create_engine(self, params)
+ local url = db2uri(params);
+ if not engine_cache[url] then
+ local engine = setmetatable({ url = url, params = params }, engine_mt);
+ engine_cache[url] = engine;
+ end
+ return engine_cache[url];
+end
+
+
+--[[Users = Table {
+ name="users";
+ Column { name="user_id", type=String(), primary_key=true };
+};
+print(Users)
+print(Users.c.user_id)]]
+
+--local engine = create_engine('postgresql://scott:tiger@localhost:5432/mydatabase');
+--[[local engine = create_engine{ driver = "SQLite3", database = "./alchemy.sqlite" };
+
+local i = 0;
+for row in assert(engine:execute("select * from sqlite_master")):rows(true) do
+ i = i+1;
+ print(i);
+ for k,v in pairs(row) do
+ print("",k,v);
+ end
+end
+print("---")
+
+Prosody = Table {
+ name="prosody";
+ Column { name="host", type="TEXT", nullable=false };
+ Column { name="user", type="TEXT", nullable=false };
+ Column { name="store", type="TEXT", nullable=false };
+ Column { name="key", type="TEXT", nullable=false };
+ Column { name="type", type="TEXT", nullable=false };
+ Column { name="value", type="TEXT", nullable=false };
+ Index { name="prosody_index", "host", "user", "store", "key" };
+};
+--print(Prosody);
+assert(engine:transaction(function()
+ assert(Prosody:create(engine));
+end));
+
+for row in assert(engine:execute("select user from prosody")):rows(true) do
+ print("username:", row['username'])
+end
+--result.close();]]
+
+return _M;
diff --git a/util/stanza.lua b/util/stanza.lua
index afaf9ce9..82601e63 100644
--- a/util/stanza.lua
+++ b/util/stanza.lua
@@ -1,29 +1,24 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
local t_insert = table.insert;
-local t_concat = table.concat;
local t_remove = table.remove;
local t_concat = table.concat;
local s_format = string.format;
local s_match = string.match;
local tostring = tostring;
local setmetatable = setmetatable;
-local getmetatable = getmetatable;
local pairs = pairs;
local ipairs = ipairs;
local type = type;
-local next = next;
-local print = print;
-local unpack = unpack;
local s_gsub = string.gsub;
-local s_char = string.char;
+local s_sub = string.sub;
local s_find = string.find;
local os = os;
@@ -44,11 +39,13 @@ module "stanza"
stanza_mt = { __type = "stanza" };
stanza_mt.__index = stanza_mt;
+local stanza_mt = stanza_mt;
function stanza(name, attr)
local stanza = { name = name, attr = attr or {}, tags = {} };
return setmetatable(stanza, stanza_mt);
end
+local stanza = stanza;
function stanza_mt:query(xmlns)
return self:tag("query", { xmlns = xmlns });
@@ -102,12 +99,20 @@ function stanza_mt:get_child(name, xmlns)
if (not name or child.name == name)
and ((not xmlns and self.attr.xmlns == child.attr.xmlns)
or child.attr.xmlns == xmlns) then
-
+
return child;
end
end
end
+function stanza_mt:get_child_text(name, xmlns)
+ local tag = self:get_child(name, xmlns);
+ if tag then
+ return tag:get_text();
+ end
+ return nil;
+end
+
function stanza_mt:child_with_name(name)
for _, child in ipairs(self.tags) do
if child.name == name then return child; end
@@ -128,37 +133,28 @@ function stanza_mt:children()
end, self, i;
end
-function stanza_mt:matching_tags(name, xmlns)
- xmlns = xmlns or self.attr.xmlns;
+function stanza_mt:childtags(name, xmlns)
local tags = self.tags;
local start_i, max_i = 1, #tags;
return function ()
- for i=start_i,max_i do
- v = tags[i];
- if (not name or v.name == name)
- and (not xmlns or xmlns == v.attr.xmlns) then
- start_i = i+1;
- return v;
- end
+ for i = start_i, max_i do
+ local v = tags[i];
+ if (not name or v.name == name)
+ and ((not xmlns and self.attr.xmlns == v.attr.xmlns)
+ or v.attr.xmlns == xmlns) then
+ start_i = i+1;
+ return v;
end
- end, tags, i;
-end
-
-function stanza_mt:childtags()
- local i = 0;
- return function (a)
- i = i + 1
- local v = self.tags[i]
- if v then return v; end
- end, self.tags[1], i;
+ end
+ end;
end
function stanza_mt:maptags(callback)
local tags, curr_tag = self.tags, 1;
local n_children, n_tags = #self, #tags;
-
+
local i = 1;
- while curr_tag <= n_tags do
+ while curr_tag <= n_tags and n_tags > 0 do
if self[i] == tags[curr_tag] then
local ret = callback(self[i]);
if ret == nil then
@@ -166,17 +162,44 @@ function stanza_mt:maptags(callback)
t_remove(tags, curr_tag);
n_children = n_children - 1;
n_tags = n_tags - 1;
+ i = i - 1;
+ curr_tag = curr_tag - 1;
else
self[i] = ret;
- tags[i] = ret;
+ tags[curr_tag] = ret;
end
- i = i + 1;
curr_tag = curr_tag + 1;
end
+ i = i + 1;
end
return self;
end
+function stanza_mt:find(path)
+ local pos = 1;
+ local len = #path + 1;
+
+ repeat
+ local xmlns, name, text;
+ local char = s_sub(path, pos, pos);
+ if char == "@" then
+ return self.attr[s_sub(path, pos + 1)];
+ elseif char == "{" then
+ xmlns, pos = s_match(path, "^([^}]+)}()", pos + 1);
+ end
+ name, text, pos = s_match(path, "^([^@/#]*)([/#]?)()", pos);
+ name = name ~= "" and name or nil;
+ if pos == len then
+ if text == "#" then
+ return self:get_child_text(name, xmlns);
+ end
+ return self:get_child(name, xmlns);
+ end
+ self = self:get_child(name, xmlns);
+ until not self
+end
+
+
local xml_escape
do
local escape_table = { ["'"] = "&apos;", ["\""] = "&quot;", ["<"] = "&lt;", [">"] = "&gt;", ["&"] = "&amp;" };
@@ -235,14 +258,14 @@ end
function stanza_mt.get_error(stanza)
local type, condition, text;
-
+
local error_tag = stanza:get_child("error");
if not error_tag then
return nil, nil, nil;
end
type = error_tag.attr.type;
-
- for child in error_tag:childtags() do
+
+ for _, child in ipairs(error_tag.tags) do
if child.attr.xmlns == xmlns_stanzas then
if not text and child.name == "text" then
text = child:get_text();
@@ -257,11 +280,6 @@ function stanza_mt.get_error(stanza)
return type, condition or "undefined-condition", text;
end
-function stanza_mt.__add(s1, s2)
- return s1:add_direct_child(s2);
-end
-
-
do
local id = 0;
function new_id()
@@ -315,28 +333,25 @@ function deserialize(stanza)
stanza.tags = tags;
end
end
-
+
return stanza;
end
-function clone(stanza)
- local lookup_table = {};
- local function _copy(object)
- if type(object) ~= "table" then
- return object;
- elseif lookup_table[object] then
- return lookup_table[object];
+local function _clone(stanza)
+ local attr, tags = {}, {};
+ for k,v in pairs(stanza.attr) do attr[k] = v; end
+ local new = { name = stanza.name, attr = attr, tags = tags };
+ for i=1,#stanza do
+ local child = stanza[i];
+ if child.name then
+ child = _clone(child);
+ t_insert(tags, child);
end
- local new_table = {};
- lookup_table[object] = new_table;
- for index, value in pairs(object) do
- new_table[_copy(index)] = _copy(value);
- end
- return setmetatable(new_table, getmetatable(object));
+ t_insert(new, child);
end
-
- return _copy(stanza)
+ return setmetatable(new, stanza_mt);
end
+clone = _clone;
function message(attr, body)
if not body then
@@ -375,7 +390,7 @@ if do_pretty_printing then
local style_attrv = getstyle("red");
local style_tagname = getstyle("red");
local style_punc = getstyle("magenta");
-
+
local attr_format = " "..getstring(style_attrk, "%s")..getstring(style_punc, "=")..getstring(style_attrv, "'%s'");
local top_tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">");
--local tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">").."%s"..getstring(style_punc, "</")..getstring(style_tagname, "%s")..getstring(style_punc, ">");
@@ -396,7 +411,7 @@ if do_pretty_printing then
end
return s_format(tag_format, t.name, attr_string, children_text, t.name);
end
-
+
function stanza_mt.pretty_top_tag(t)
local attr_string = "";
if t.attr then
diff --git a/util/template.lua b/util/template.lua
index ebd8be14..66d4fca7 100644
--- a/util/template.lua
+++ b/util/template.lua
@@ -1,64 +1,28 @@
-local st = require "util.stanza";
-local lxp = require "lxp";
+local stanza_mt = require "util.stanza".stanza_mt;
local setmetatable = setmetatable;
local pairs = pairs;
local ipairs = ipairs;
local error = error;
local loadstring = loadstring;
local debug = debug;
+local t_remove = table.remove;
+local parse_xml = require "util.xml".parse;
module("template")
-local parse_xml = (function()
- local ns_prefixes = {
- ["http://www.w3.org/XML/1998/namespace"] = "xml";
- };
- local ns_separator = "\1";
- local ns_pattern = "^([^"..ns_separator.."]*)"..ns_separator.."?(.*)$";
- return function(xml)
- local handler = {};
- local stanza = st.stanza("root");
- function handler:StartElement(tagname, attr)
- local curr_ns,name = tagname:match(ns_pattern);
- if name == "" then
- curr_ns, name = "", curr_ns;
- end
- if curr_ns ~= "" then
- attr.xmlns = curr_ns;
- end
- for i=1,#attr do
- local k = attr[i];
- attr[i] = nil;
- local ns, nm = k:match(ns_pattern);
- if nm ~= "" then
- ns = ns_prefixes[ns];
- if ns then
- attr[ns..":"..nm] = attr[k];
- attr[k] = nil;
- end
- end
- end
- stanza:tag(name, attr);
- end
- function handler:CharacterData(data)
- data = data:gsub("^%s*", ""):gsub("%s*$", "");
- stanza:text(data);
- end
- function handler:EndElement(tagname)
- stanza:up();
- end
- local parser = lxp.new(handler, "\1");
- local ok, err, line, col = parser:parse(xml);
- if ok then ok, err, line, col = parser:parse(); end
- --parser:close();
- if ok then
- return stanza.tags[1];
+local function trim_xml(stanza)
+ for i=#stanza,1,-1 do
+ local child = stanza[i];
+ if child.name then
+ trim_xml(child);
else
- return ok, err.." (line "..line..", col "..col..")";
+ child = child:gsub("^%s*", ""):gsub("%s*$", "");
+ stanza[i] = child;
+ if child == "" then t_remove(stanza, i); end
end
- end;
-end)();
+ end
+end
local function create_string_string(str)
str = ("%q"):format(str);
@@ -100,7 +64,6 @@ local function create_clone_string(stanza, lookup, xmlns)
end
return lookup[stanza];
end
-local stanza_mt = st.stanza_mt;
local function create_cloner(stanza, chunkname)
local lookup = {};
local name = create_clone_string(stanza, lookup, "");
@@ -118,6 +81,7 @@ local template_mt = { __tostring = function(t) return t.name end };
local function create_template(templates, text)
local stanza, err = parse_xml(text);
if not stanza then error(err); end
+ trim_xml(stanza);
local info = debug.getinfo(3, "Sl");
info = info and ("template(%s:%d)"):format(info.short_src:match("[^\\/]*$"), info.currentline) or "template(unknown)";
diff --git a/util/termcolours.lua b/util/termcolours.lua
index df204688..ef978364 100644
--- a/util/termcolours.lua
+++ b/util/termcolours.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -9,6 +9,7 @@
local t_concat, t_insert = table.concat, table.insert;
local char, format = string.char, string.format;
+local tonumber = tonumber;
local ipairs = ipairs;
local io_write = io.write;
@@ -34,6 +35,15 @@ local winstylemap = {
["1;31"] = 4+8 -- bold red
}
+local cssmap = {
+ [1] = "font-weight: bold", [2] = "opacity: 0.5", [4] = "text-decoration: underline", [8] = "visibility: hidden",
+ [30] = "color:black", [31] = "color:red", [32]="color:green", [33]="color:#FFD700",
+ [34] = "color:blue", [35] = "color: magenta", [36] = "color:cyan", [37] = "color: white",
+ [40] = "background-color:black", [41] = "background-color:red", [42]="background-color:green",
+ [43]="background-color:yellow", [44] = "background-color:blue", [45] = "background-color: magenta",
+ [46] = "background-color:cyan", [47] = "background-color: white";
+};
+
local fmt_string = char(0x1B).."[%sm%s"..char(0x1B).."[0m";
function getstring(style, text)
if style then
@@ -76,4 +86,17 @@ if windows then
end
end
+local function ansi2css(ansi_codes)
+ if ansi_codes == "0" then return "</span>"; end
+ local css = {};
+ for code in ansi_codes:gmatch("[^;]+") do
+ t_insert(css, cssmap[tonumber(code)]);
+ end
+ return "</span><span style='"..t_concat(css, ";").."'>";
+end
+
+function tohtml(input)
+ return input:gsub("\027%[(.-)m", ansi2css);
+end
+
return _M;
diff --git a/util/throttle.lua b/util/throttle.lua
new file mode 100644
index 00000000..55e1d07b
--- /dev/null
+++ b/util/throttle.lua
@@ -0,0 +1,46 @@
+
+local gettime = require "socket".gettime;
+local setmetatable = setmetatable;
+local floor = math.floor;
+
+module "throttle"
+
+local throttle = {};
+local throttle_mt = { __index = throttle };
+
+function throttle:update()
+ local newt = gettime();
+ local elapsed = newt - self.t;
+ self.t = newt;
+ local balance = floor(self.rate * elapsed) + self.balance;
+ if balance > self.max then
+ self.balance = self.max;
+ else
+ self.balance = balance;
+ end
+ return self.balance;
+end
+
+function throttle:peek(cost)
+ cost = cost or 1;
+ return self.balance >= cost or self:update() >= cost;
+end
+
+function throttle:poll(cost, split)
+ if self:peek(cost) then
+ self.balance = self.balance - cost;
+ return true;
+ else
+ local balance = self.balance;
+ if split then
+ self.balance = 0;
+ end
+ return false, balance, (cost-balance);
+ end
+end
+
+function create(max, period)
+ return setmetatable({ rate = max / period, max = max, t = 0, balance = max }, throttle_mt);
+end
+
+return _M;
diff --git a/util/timer.lua b/util/timer.lua
index 3061da72..0e10e144 100644
--- a/util/timer.lua
+++ b/util/timer.lua
@@ -1,22 +1,17 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
-
-local ns_addtimer = require "net.server".addtimer;
-local event = require "net.server".event;
-local event_base = require "net.server".event_base;
-
+local server = require "net.server";
local math_min = math.min
local math_huge = math.huge
local get_time = require "socket".gettime;
local t_insert = table.insert;
-local t_remove = table.remove;
-local ipairs, pairs = ipairs, pairs;
+local pairs = pairs;
local type = type;
local data = {};
@@ -25,18 +20,21 @@ local new_data = {};
module "timer"
local _add_task;
-if not event then
- function _add_task(delay, func)
+if not server.event then
+ function _add_task(delay, callback)
local current_time = get_time();
delay = delay + current_time;
if delay >= current_time then
- t_insert(new_data, {delay, func});
+ t_insert(new_data, {delay, callback});
else
- func();
+ local r = callback(current_time);
+ if r and type(r) == "number" then
+ return _add_task(r, callback);
+ end
end
end
- ns_addtimer(function()
+ server._addtimer(function()
local current_time = get_time();
if #new_data > 0 then
for _, d in pairs(new_data) do
@@ -44,15 +42,15 @@ if not event then
end
new_data = {};
end
-
+
local next_time = math_huge;
for i, d in pairs(data) do
- local t, func = d[1], d[2];
+ local t, callback = d[1], d[2];
if t <= current_time then
data[i] = nil;
- local r = func(current_time);
+ local r = callback(current_time);
if type(r) == "number" then
- _add_task(r, func);
+ _add_task(r, callback);
next_time = math_min(next_time, r);
end
else
@@ -62,11 +60,14 @@ if not event then
return next_time;
end);
else
+ local event = server.event;
+ local event_base = server.event_base;
local EVENT_LEAVE = (event.core and event.core.LEAVE) or -1;
- function _add_task(delay, func)
+
+ function _add_task(delay, callback)
local event_handle;
event_handle = event_base:addevent(nil, 0, function ()
- local ret = func();
+ local ret = callback(get_time());
if ret then
return 0, ret;
elseif event_handle then
diff --git a/util/uuid.lua b/util/uuid.lua
index 796c8ee4..fc487c72 100644
--- a/util/uuid.lua
+++ b/util/uuid.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
diff --git a/util/watchdog.lua b/util/watchdog.lua
new file mode 100644
index 00000000..bcb2e274
--- /dev/null
+++ b/util/watchdog.lua
@@ -0,0 +1,34 @@
+local timer = require "util.timer";
+local setmetatable = setmetatable;
+local os_time = os.time;
+
+module "watchdog"
+
+local watchdog_methods = {};
+local watchdog_mt = { __index = watchdog_methods };
+
+function new(timeout, callback)
+ local watchdog = setmetatable({ timeout = timeout, last_reset = os_time(), callback = callback }, watchdog_mt);
+ timer.add_task(timeout+1, function (current_time)
+ local last_reset = watchdog.last_reset;
+ if not last_reset then
+ return;
+ end
+ local time_left = (last_reset + timeout) - current_time;
+ if time_left < 0 then
+ return watchdog:callback();
+ end
+ return time_left + 1;
+ end);
+ return watchdog;
+end
+
+function watchdog_methods:reset()
+ self.last_reset = os_time();
+end
+
+function watchdog_methods:cancel()
+ self.last_reset = nil;
+end
+
+return _M;
diff --git a/util/x509.lua b/util/x509.lua
index 11f231a0..19d4ec6d 100644
--- a/util/x509.lua
+++ b/util/x509.lua
@@ -11,8 +11,8 @@
-- IDN libraries complicate that.
--- [TLS-CERTS] - http://tools.ietf.org/html/draft-saintandre-tls-server-id-check-10
--- [XMPP-CORE] - http://tools.ietf.org/html/draft-ietf-xmpp-3920bis-18
+-- [TLS-CERTS] - http://tools.ietf.org/html/rfc6125
+-- [XMPP-CORE] - http://tools.ietf.org/html/rfc6120
-- [SRV-ID] - http://tools.ietf.org/html/rfc4985
-- [IDNA] - http://tools.ietf.org/html/rfc5890
-- [LDAP] - http://tools.ietf.org/html/rfc4519
@@ -21,6 +21,10 @@
local nameprep = require "util.encodings".stringprep.nameprep;
local idna_to_ascii = require "util.encodings".idna.to_ascii;
local log = require "util.logger".init("x509");
+local pairs, ipairs = pairs, ipairs;
+local s_format = string.format;
+local t_insert = table.insert;
+local t_concat = table.concat;
module "x509"
@@ -32,7 +36,7 @@ local oid_dnssrv = "1.3.6.1.5.5.7.8.7"; -- [SRV-ID]
-- Compare a hostname (possibly international) with asserted names
-- extracted from a certificate.
-- This function follows the rules laid out in
--- sections 4.4.1 and 4.4.2 of [TLS-CERTS]
+-- sections 6.4.1 and 6.4.2 of [TLS-CERTS]
--
-- A wildcard ("*") all by itself is allowed only as the left-most label
local function compare_dnsname(host, asserted_names)
@@ -150,7 +154,7 @@ function verify_identity(host, service, cert)
if ext[oid_subjectaltname] then
local sans = ext[oid_subjectaltname];
- -- Per [TLS-CERTS] 4.3, 4.4.4, "a client MUST NOT seek a match for a
+ -- Per [TLS-CERTS] 6.3, 6.4.4, "a client MUST NOT seek a match for a
-- reference identifier if the presented identifiers include a DNS-ID
-- SRV-ID, URI-ID, or any application-specific identifier types"
local had_supported_altnames = false
@@ -183,7 +187,7 @@ function verify_identity(host, service, cert)
-- a dNSName subjectAltName (wildcards may apply for, and receive,
-- cat treats)
--
- -- Per [TLS-CERTS] 1.5, a CN-ID is the Common Name from a cert subject
+ -- Per [TLS-CERTS] 1.8, a CN-ID is the Common Name from a cert subject
-- which has one and only one Common Name
local subject = cert:subject()
local cn = nil
@@ -200,7 +204,7 @@ function verify_identity(host, service, cert)
end
if cn then
- -- Per [TLS-CERTS] 4.4.4, follow the comparison rules for dNSName SANs.
+ -- Per [TLS-CERTS] 6.4.4, follow the comparison rules for dNSName SANs.
return compare_dnsname(host, { cn })
end
diff --git a/util/xml.lua b/util/xml.lua
new file mode 100644
index 00000000..6dbed65d
--- /dev/null
+++ b/util/xml.lua
@@ -0,0 +1,57 @@
+
+local st = require "util.stanza";
+local lxp = require "lxp";
+
+module("xml")
+
+local parse_xml = (function()
+ local ns_prefixes = {
+ ["http://www.w3.org/XML/1998/namespace"] = "xml";
+ };
+ local ns_separator = "\1";
+ local ns_pattern = "^([^"..ns_separator.."]*)"..ns_separator.."?(.*)$";
+ return function(xml)
+ local handler = {};
+ local stanza = st.stanza("root");
+ function handler:StartElement(tagname, attr)
+ local curr_ns,name = tagname:match(ns_pattern);
+ if name == "" then
+ curr_ns, name = "", curr_ns;
+ end
+ if curr_ns ~= "" then
+ attr.xmlns = curr_ns;
+ end
+ for i=1,#attr do
+ local k = attr[i];
+ attr[i] = nil;
+ local ns, nm = k:match(ns_pattern);
+ if nm ~= "" then
+ ns = ns_prefixes[ns];
+ if ns then
+ attr[ns..":"..nm] = attr[k];
+ attr[k] = nil;
+ end
+ end
+ end
+ stanza:tag(name, attr);
+ end
+ function handler:CharacterData(data)
+ stanza:text(data);
+ end
+ function handler:EndElement(tagname)
+ stanza:up();
+ end
+ local parser = lxp.new(handler, "\1");
+ local ok, err, line, col = parser:parse(xml);
+ if ok then ok, err, line, col = parser:parse(); end
+ --parser:close();
+ if ok then
+ return stanza.tags[1];
+ else
+ return ok, err.." (line "..line..", col "..col..")";
+ end
+ end;
+end)();
+
+parse = parse_xml;
+return _M;
diff --git a/util/xmlrpc.lua b/util/xmlrpc.lua
deleted file mode 100644
index 29815b0d..00000000
--- a/util/xmlrpc.lua
+++ /dev/null
@@ -1,182 +0,0 @@
--- Prosody IM
--- Copyright (C) 2008-2010 Matthew Wild
--- Copyright (C) 2008-2010 Waqas Hussain
---
--- This project is MIT/X11 licensed. Please see the
--- COPYING file in the source package for more information.
---
-
-
-local pairs = pairs;
-local type = type;
-local error = error;
-local t_concat = table.concat;
-local t_insert = table.insert;
-local tostring = tostring;
-local tonumber = tonumber;
-local select = select;
-local st = require "util.stanza";
-
-module "xmlrpc"
-
-local _lua_to_xmlrpc;
-local map = {
- table=function(stanza, object)
- stanza:tag("struct");
- for name, value in pairs(object) do
- stanza:tag("member");
- stanza:tag("name"):text(tostring(name)):up();
- stanza:tag("value");
- _lua_to_xmlrpc(stanza, value);
- stanza:up();
- stanza:up();
- end
- stanza:up();
- end;
- boolean=function(stanza, object)
- stanza:tag("boolean"):text(object and "1" or "0"):up();
- end;
- string=function(stanza, object)
- stanza:tag("string"):text(object):up();
- end;
- number=function(stanza, object)
- stanza:tag("int"):text(tostring(object)):up();
- end;
- ["nil"]=function(stanza, object) -- nil extension
- stanza:tag("nil"):up();
- end;
-};
-_lua_to_xmlrpc = function(stanza, object)
- local h = map[type(object)];
- if h then
- h(stanza, object);
- else
- error("Type not supported by XML-RPC: " .. type(object));
- end
-end
-function create_response(object)
- local stanza = st.stanza("methodResponse"):tag("params"):tag("param"):tag("value");
- _lua_to_xmlrpc(stanza, object);
- stanza:up():up():up();
- return stanza;
-end
-function create_error_response(faultCode, faultString)
- local stanza = st.stanza("methodResponse"):tag("fault"):tag("value");
- _lua_to_xmlrpc(stanza, {faultCode=faultCode, faultString=faultString});
- stanza:up():up();
- return stanza;
-end
-
-function create_request(method_name, ...)
- local stanza = st.stanza("methodCall")
- :tag("methodName"):text(method_name):up()
- :tag("params");
- for i=1,select('#', ...) do
- stanza:tag("param"):tag("value");
- _lua_to_xmlrpc(stanza, select(i, ...));
- stanza:up():up();
- end
- stanza:up():up():up();
- return stanza;
-end
-
-local _xmlrpc_to_lua;
-local int_parse = function(stanza)
- if #stanza.tags ~= 0 or #stanza == 0 then error("<"..stanza.name.."> must have a single text child"); end
- local n = tonumber(t_concat(stanza));
- if n then return n; end
- error("Failed to parse content of <"..stanza.name..">");
-end
-local rmap = {
- methodCall=function(stanza)
- if #stanza.tags ~= 2 then error("<methodCall> must have exactly two subtags"); end -- FIXME <params> is optional
- if stanza.tags[1].name ~= "methodName" then error("First <methodCall> child tag must be <methodName>") end
- if stanza.tags[2].name ~= "params" then error("Second <methodCall> child tag must be <params>") end
- return _xmlrpc_to_lua(stanza.tags[1]), _xmlrpc_to_lua(stanza.tags[2]);
- end;
- methodName=function(stanza)
- if #stanza.tags ~= 0 then error("<methodName> must not have any subtags"); end
- if #stanza == 0 then error("<methodName> must have text content"); end
- return t_concat(stanza);
- end;
- params=function(stanza)
- local t = {};
- for _, child in pairs(stanza.tags) do
- if child.name ~= "param" then error("<params> can only have <param> children"); end;
- t_insert(t, _xmlrpc_to_lua(child));
- end
- return t;
- end;
- param=function(stanza)
- if not(#stanza.tags == 1 and stanza.tags[1].name == "value") then error("<param> must have exactly one <value> child"); end
- return _xmlrpc_to_lua(stanza.tags[1]);
- end;
- value=function(stanza)
- if #stanza.tags == 0 then return t_concat(stanza); end
- if #stanza.tags ~= 1 then error("<value> must have a single child"); end
- return _xmlrpc_to_lua(stanza.tags[1]);
- end;
- int=int_parse;
- i4=int_parse;
- double=int_parse;
- boolean=function(stanza)
- if #stanza.tags ~= 0 or #stanza == 0 then error("<boolean> must have a single text child"); end
- local b = t_concat(stanza);
- if b ~= "1" and b ~= "0" then error("Failed to parse content of <boolean>"); end
- return b == "1" and true or false;
- end;
- string=function(stanza)
- if #stanza.tags ~= 0 then error("<string> must have a single text child"); end
- return t_concat(stanza);
- end;
- array=function(stanza)
- if #stanza.tags ~= 1 then error("<array> must have a single <data> child"); end
- return _xmlrpc_to_lua(stanza.tags[1]);
- end;
- data=function(stanza)
- local t = {};
- for _,child in pairs(stanza.tags) do
- if child.name ~= "value" then error("<data> can only have <value> children"); end
- t_insert(t, _xmlrpc_to_lua(child));
- end
- return t;
- end;
- struct=function(stanza)
- local t = {};
- for _,child in pairs(stanza.tags) do
- if child.name ~= "member" then error("<struct> can only have <member> children"); end
- local name, value = _xmlrpc_to_lua(child);
- t[name] = value;
- end
- return t;
- end;
- member=function(stanza)
- if #stanza.tags ~= 2 then error("<member> must have exactly two subtags"); end -- FIXME <params> is optional
- if stanza.tags[1].name ~= "name" then error("First <member> child tag must be <name>") end
- if stanza.tags[2].name ~= "value" then error("Second <member> child tag must be <value>") end
- return _xmlrpc_to_lua(stanza.tags[1]), _xmlrpc_to_lua(stanza.tags[2]);
- end;
- name=function(stanza)
- if #stanza.tags ~= 0 then error("<name> must have a single text child"); end
- local n = t_concat(stanza)
- if tostring(tonumber(n)) == n then n = tonumber(n); end
- return n;
- end;
- ["nil"]=function(stanza) -- nil extension
- return nil;
- end;
-}
-_xmlrpc_to_lua = function(stanza)
- local h = rmap[stanza.name];
- if h then
- return h(stanza);
- else
- error("Unknown element: "..stanza.name);
- end
-end
-function translate_request(stanza)
- if stanza.name ~= "methodCall" then error("XML-RPC requests must have <methodCall> as root element"); end
- return _xmlrpc_to_lua(stanza);
-end
-
-return _M;
diff --git a/util/xmppstream.lua b/util/xmppstream.lua
index cbdadd9b..550170c9 100644
--- a/util/xmppstream.lua
+++ b/util/xmppstream.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -9,21 +9,27 @@
local lxp = require "lxp";
local st = require "util.stanza";
+local stanza_mt = st.stanza_mt;
+local error = error;
local tostring = tostring;
local t_insert = table.insert;
local t_concat = table.concat;
+local t_remove = table.remove;
+local setmetatable = setmetatable;
-local default_log = require "util.logger".init("xmppstream");
-
-local error = error;
+-- COMPAT: w/LuaExpat 1.1.0
+local lxp_supports_doctype = pcall(lxp.new, { StartDoctypeDecl = false });
module "xmppstream"
local new_parser = lxp.new;
-local ns_prefixes = {
- ["http://www.w3.org/XML/1998/namespace"] = "xml";
+local xml_namespace = {
+ ["http://www.w3.org/XML/1998/namespace\1lang"] = "xml:lang";
+ ["http://www.w3.org/XML/1998/namespace\1space"] = "xml:space";
+ ["http://www.w3.org/XML/1998/namespace\1base"] = "xml:base";
+ ["http://www.w3.org/XML/1998/namespace\1id"] = "xml:id";
};
local xmlns_streams = "http://etherx.jabber.org/streams";
@@ -36,29 +42,28 @@ _M.ns_pattern = ns_pattern;
function new_sax_handlers(session, stream_callbacks)
local xml_handlers = {};
-
- local log = session.log or default_log;
-
+
local cb_streamopened = stream_callbacks.streamopened;
local cb_streamclosed = stream_callbacks.streamclosed;
- local cb_error = stream_callbacks.error or function(session, e) error("XML stream error: "..tostring(e)); end;
+ local cb_error = stream_callbacks.error or function(session, e, stanza) error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); end;
local cb_handlestanza = stream_callbacks.handlestanza;
-
+
local stream_ns = stream_callbacks.stream_ns or xmlns_streams;
local stream_tag = stream_callbacks.stream_tag or "stream";
if stream_ns ~= "" then
stream_tag = stream_ns..ns_separator..stream_tag;
end
local stream_error_tag = stream_ns..ns_separator..(stream_callbacks.error_tag or "error");
-
+
local stream_default_ns = stream_callbacks.default_ns;
-
+
+ local stack = {};
local chardata, stanza = {};
local non_streamns_depth = 0;
function xml_handlers:StartElement(tagname, attr)
if stanza and #chardata > 0 then
-- We have some character data in the buffer
- stanza:text(t_concat(chardata));
+ t_insert(stanza, t_concat(chardata));
chardata = {};
end
local curr_ns,name = tagname:match(ns_pattern);
@@ -70,21 +75,17 @@ function new_sax_handlers(session, stream_callbacks)
attr.xmlns = curr_ns;
non_streamns_depth = non_streamns_depth + 1;
end
-
- -- FIXME !!!!!
+
for i=1,#attr do
local k = attr[i];
attr[i] = nil;
- local ns, nm = k:match(ns_pattern);
- if nm ~= "" then
- ns = ns_prefixes[ns];
- if ns then
- attr[ns..":"..nm] = attr[k];
- attr[k] = nil;
- end
+ local xmlk = xml_namespace[k];
+ if xmlk then
+ attr[xmlk] = attr[k];
+ attr[k] = nil;
end
end
-
+
if not stanza then --if we are not currently inside a stanza
if session.notopen then
if tagname == stream_tag then
@@ -101,10 +102,14 @@ function new_sax_handlers(session, stream_callbacks)
if curr_ns == "jabber:client" and name ~= "iq" and name ~= "presence" and name ~= "message" then
cb_error(session, "invalid-top-level-element");
end
-
- stanza = st.stanza(name, attr);
+
+ stanza = setmetatable({ name = name, attr = attr, tags = {} }, stanza_mt);
else -- we are inside a stanza, so add a tag
- stanza:tag(name, attr);
+ t_insert(stack, stanza);
+ local oldstanza = stanza;
+ stanza = setmetatable({ name = name, attr = attr, tags = {} }, stanza_mt);
+ t_insert(oldstanza, stanza);
+ t_insert(oldstanza.tags, stanza);
end
end
function xml_handlers:CharacterData(data)
@@ -119,12 +124,11 @@ function new_sax_handlers(session, stream_callbacks)
if stanza then
if #chardata > 0 then
-- We have some character data in the buffer
- stanza:text(t_concat(chardata));
+ t_insert(stanza, t_concat(chardata));
chardata = {};
end
-- Complete stanza
- local last_add = stanza.last_add;
- if not last_add or #last_add == 0 then
+ if #stack == 0 then
if tagname ~= stream_error_tag then
cb_handlestanza(session, stanza);
else
@@ -132,33 +136,37 @@ function new_sax_handlers(session, stream_callbacks)
end
stanza = nil;
else
- stanza:up();
+ stanza = t_remove(stack);
end
else
- if tagname == stream_tag then
- if cb_streamclosed then
- cb_streamclosed(session);
- end
- else
- local curr_ns,name = tagname:match(ns_pattern);
- if name == "" then
- curr_ns, name = "", curr_ns;
- end
- cb_error(session, "parse-error", "unexpected-element-close", name);
+ if cb_streamclosed then
+ cb_streamclosed(session);
end
- stanza, chardata = nil, {};
end
end
-
+
+ local function restricted_handler(parser)
+ cb_error(session, "parse-error", "restricted-xml", "Restricted XML, see RFC 6120 section 11.1.");
+ if not parser.stop or not parser:stop() then
+ error("Failed to abort parsing");
+ end
+ end
+
+ if lxp_supports_doctype then
+ xml_handlers.StartDoctypeDecl = restricted_handler;
+ end
+ xml_handlers.Comment = restricted_handler;
+ xml_handlers.ProcessingInstruction = restricted_handler;
+
local function reset()
stanza, chardata = nil, {};
+ stack = {};
end
-
+
local function set_session(stream, new_session)
session = new_session;
- log = new_session.log or default_log;
end
-
+
return xml_handlers, { reset = reset, set_session = set_session };
end