aboutsummaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/array.lua65
-rw-r--r--util/async.lua158
-rw-r--r--util/cache.lua105
-rw-r--r--util/caps.lua10
-rw-r--r--util/dataforms.lua74
-rw-r--r--util/datamanager.lua107
-rw-r--r--util/datetime.lua22
-rw-r--r--util/debug.lua76
-rw-r--r--util/dependencies.lua52
-rw-r--r--util/events.lua92
-rw-r--r--util/filters.lua36
-rw-r--r--util/helpers.lua37
-rw-r--r--util/hex.lua26
-rw-r--r--util/hmac.lua2
-rw-r--r--util/import.lua2
-rw-r--r--util/interpolation.lua85
-rw-r--r--util/ip.lua39
-rw-r--r--util/iterators.lua46
-rw-r--r--util/jid.lua68
-rw-r--r--util/json.lua11
-rw-r--r--util/logger.lua20
-rw-r--r--util/mercurial.lua34
-rw-r--r--util/multitable.lua13
-rw-r--r--util/openssl.lua34
-rw-r--r--util/paths.lua44
-rw-r--r--util/pluginloader.lua16
-rw-r--r--util/prosodyctl.lua115
-rw-r--r--util/pubsub.lua98
-rw-r--r--util/queue.lua73
-rw-r--r--util/random.lua23
-rw-r--r--util/sasl.lua58
-rw-r--r--util/sasl/anonymous.lua8
-rw-r--r--util/sasl/digest-md5.lua8
-rw-r--r--util/sasl/external.lua27
-rw-r--r--util/sasl/plain.lua8
-rw-r--r--util/sasl/scram.lua156
-rw-r--r--util/sasl_cyrus.lua12
-rw-r--r--util/serialization.lua18
-rw-r--r--util/session.lua65
-rw-r--r--util/set.lua144
-rw-r--r--util/sql.lua221
-rw-r--r--util/sslconfig.lua120
-rw-r--r--util/stanza.lua91
-rw-r--r--util/statistics.lua160
-rw-r--r--util/template.lua2
-rw-r--r--util/termcolours.lua19
-rw-r--r--util/throttle.lua8
-rw-r--r--util/timer.lua12
-rw-r--r--util/uuid.lua36
-rw-r--r--util/watchdog.lua8
-rw-r--r--util/x509.lua40
-rw-r--r--util/xml.lua14
-rw-r--r--util/xmppstream.lua59
53 files changed, 2119 insertions, 758 deletions
diff --git a/util/array.lua b/util/array.lua
index 2d58e7fb..3ddc97f6 100644
--- a/util/array.lua
+++ b/util/array.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -11,8 +11,10 @@ local t_insert, t_sort, t_remove, t_concat
local setmetatable = setmetatable;
local math_random = math.random;
+local math_floor = math.floor;
local pairs, ipairs = pairs, ipairs;
local tostring = tostring;
+local type = type;
local array = {};
local array_base = {};
@@ -35,7 +37,7 @@ setmetatable(array, { __call = new_array });
-- Read-only methods
function array_methods:random()
- return self[math_random(1,#self)];
+ return self[math_random(1, #self)];
end
-- These methods can be called two ways:
@@ -43,7 +45,7 @@ end
-- existing_array:method([params, ...]) -- Transform existing array into result
--
function array_base.map(outa, ina, func)
- for k,v in ipairs(ina) do
+ for k, v in ipairs(ina) do
outa[k] = func(v);
end
return outa;
@@ -52,20 +54,20 @@ end
function array_base.filter(outa, ina, func)
local inplace, start_length = ina == outa, #ina;
local write = 1;
- for read=1,start_length do
+ for read = 1, start_length do
local v = ina[read];
if func(v) then
outa[write] = v;
write = write + 1;
end
end
-
+
if inplace and write <= start_length then
- for i=write,start_length do
+ for i = write, start_length do
outa[i] = nil;
end
end
-
+
return outa;
end
@@ -78,34 +80,44 @@ function array_base.sort(outa, ina, ...)
end
function array_base.pluck(outa, ina, key)
- for i=1,#ina do
+ for i = 1, #ina do
outa[i] = ina[i][key];
end
return outa;
end
+function array_base.reverse(outa, ina)
+ local len = #ina;
+ if ina == outa then
+ local middle = math_floor(len/2);
+ len = len + 1;
+ local o; -- opposite
+ for i = 1, middle do
+ o = len - i;
+ outa[i], outa[o] = outa[o], outa[i];
+ end
+ else
+ local off = len + 1;
+ for i = 1, len do
+ outa[i] = ina[off - i];
+ end
+ end
+ return outa;
+end
+
--- These methods only mutate the array
function array_methods:shuffle(outa, ina)
local len = #self;
- for i=1,#self do
- local r = math_random(i,len);
+ for i = 1, #self do
+ local r = math_random(i, len);
self[i], self[r] = self[r], self[i];
end
return self;
end
-function array_methods:reverse()
- local len = #self-1;
- for i=len,1,-1 do
- self:push(self[i]);
- self:pop(i);
- end
- return self;
-end
-
function array_methods:append(array)
- local len,len2 = #self, #array;
- for i=1,len2 do
+ local len, len2 = #self, #array;
+ for i = 1, len2 do
self[len+i] = array[i];
end
return self;
@@ -116,11 +128,7 @@ function array_methods:push(x)
return self;
end
-function array_methods:pop(x)
- local v = self[x];
- t_remove(self, x);
- return v;
-end
+array_methods.pop = t_remove;
function array_methods:concat(sep)
return t_concat(array.map(self, tostring), sep);
@@ -135,7 +143,7 @@ function array.collect(f, s, var)
local t = {};
while true do
var = f(s, var);
- if var == nil then break; end
+ if var == nil then break; end
t_insert(t, var);
end
return setmetatable(t, array_mt);
@@ -157,7 +165,4 @@ for method, f in pairs(array_base) do
end
end
-_G.array = array;
-module("array");
-
return array;
diff --git a/util/async.lua b/util/async.lua
new file mode 100644
index 00000000..968ec804
--- /dev/null
+++ b/util/async.lua
@@ -0,0 +1,158 @@
+local log = require "util.logger".init("util.async");
+
+local function runner_continue(thread)
+ -- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure)
+ if coroutine.status(thread) ~= "suspended" then -- This should suffice
+ return false;
+ end
+ local ok, state, runner = coroutine.resume(thread);
+ if not ok then
+ local level = 0;
+ while debug.getinfo(thread, level, "") do level = level + 1; end
+ ok, runner = debug.getlocal(thread, level-1, 1);
+ local error_handler = runner.watchers.error;
+ if error_handler then error_handler(runner, debug.traceback(thread, state)); end
+ elseif state == "ready" then
+ -- If state is 'ready', it is our responsibility to update runner.state from 'waiting'.
+ -- We also have to :run(), because the queue might have further items that will not be
+ -- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer).
+ runner.state = "ready";
+ runner:run();
+ end
+ return true;
+end
+
+local function waiter(num)
+ local thread = coroutine.running();
+ if not thread then
+ error("Not running in an async context, see http://prosody.im/doc/developers/async");
+ end
+ num = num or 1;
+ local waiting;
+ return function ()
+ if num == 0 then return; end -- already done
+ waiting = true;
+ coroutine.yield("wait");
+ end, function ()
+ num = num - 1;
+ if num == 0 and waiting then
+ runner_continue(thread);
+ elseif num < 0 then
+ error("done() called too many times");
+ end
+ end;
+end
+
+local function guarder()
+ local guards = {};
+ return function (id, func)
+ local thread = coroutine.running();
+ if not thread then
+ error("Not running in an async context, see http://prosody.im/doc/developers/async");
+ end
+ local guard = guards[id];
+ if not guard then
+ guard = {};
+ guards[id] = guard;
+ log("debug", "New guard!");
+ else
+ table.insert(guard, thread);
+ log("debug", "Guarded. %d threads waiting.", #guard)
+ coroutine.yield("wait");
+ end
+ local function exit()
+ local next_waiting = table.remove(guard, 1);
+ if next_waiting then
+ log("debug", "guard: Executing next waiting thread (%d left)", #guard)
+ runner_continue(next_waiting);
+ else
+ log("debug", "Guard off duty.")
+ guards[id] = nil;
+ end
+ end
+ if func then
+ func();
+ exit();
+ return;
+ end
+ return exit;
+ end;
+end
+
+local runner_mt = {};
+runner_mt.__index = runner_mt;
+
+local function runner_create_thread(func, self)
+ local thread = coroutine.create(function (self)
+ while true do
+ func(coroutine.yield("ready", self));
+ end
+ end);
+ assert(coroutine.resume(thread, self)); -- Start it up, it will return instantly to wait for the first input
+ return thread;
+end
+
+local empty_watchers = {};
+local function runner(func, watchers, data)
+ return setmetatable({ func = func, thread = false, state = "ready", notified_state = "ready",
+ queue = {}, watchers = watchers or empty_watchers, data = data }
+ , runner_mt);
+end
+
+function runner_mt:run(input)
+ if input ~= nil then
+ table.insert(self.queue, input);
+ end
+ if self.state ~= "ready" then
+ return true, self.state, #self.queue;
+ end
+
+ local q, thread = self.queue, self.thread;
+ if not thread or coroutine.status(thread) == "dead" then
+ thread = runner_create_thread(self.func, self);
+ self.thread = thread;
+ end
+
+ local n, state, err = #q, self.state, nil;
+ self.state = "running";
+ while n > 0 and state == "ready" do
+ local consumed;
+ for i = 1,n do
+ local input = q[i];
+ local ok, new_state = coroutine.resume(thread, input);
+ if not ok then
+ consumed, state, err = i, "ready", debug.traceback(thread, new_state);
+ self.thread = nil;
+ break;
+ elseif new_state == "wait" then
+ consumed, state = i, "waiting";
+ break;
+ end
+ end
+ if not consumed then consumed = n; end
+ if q[n+1] ~= nil then
+ n = #q;
+ end
+ for i = 1, n do
+ q[i] = q[consumed+i];
+ end
+ n = #q;
+ end
+ self.state = state;
+ if err or state ~= self.notified_state then
+ if err then
+ state = "error"
+ else
+ self.notified_state = state;
+ end
+ local handler = self.watchers[state];
+ if handler then handler(self, err); end
+ end
+ return true, state, n;
+end
+
+function runner_mt:enqueue(input)
+ table.insert(self.queue, input);
+end
+
+return { waiter = waiter, guarder = guarder, runner = runner };
diff --git a/util/cache.lua b/util/cache.lua
new file mode 100644
index 00000000..d3639b3f
--- /dev/null
+++ b/util/cache.lua
@@ -0,0 +1,105 @@
+
+local function _remove(list, m)
+ if m.prev then
+ m.prev.next = m.next;
+ end
+ if m.next then
+ m.next.prev = m.prev;
+ end
+ if list._tail == m then
+ list._tail = m.prev;
+ end
+ if list._head == m then
+ list._head = m.next;
+ end
+ list._count = list._count - 1;
+end
+
+local function _insert(list, m)
+ if list._head then
+ list._head.prev = m;
+ end
+ m.prev, m.next = nil, list._head;
+ list._head = m;
+ if not list._tail then
+ list._tail = m;
+ end
+ list._count = list._count + 1;
+end
+
+local cache_methods = {};
+local cache_mt = { __index = cache_methods };
+
+function cache_methods:set(k, v)
+ local m = self._data[k];
+ if m then
+ -- Key already exists
+ if v ~= nil then
+ -- Bump to head of list
+ _remove(self, m);
+ _insert(self, m);
+ m.value = v;
+ else
+ -- Remove from list
+ _remove(self, m);
+ self._data[k] = nil;
+ end
+ return true;
+ end
+ -- New key
+ if v == nil then
+ return true;
+ end
+ -- Check whether we need to remove oldest k/v
+ local on_evict, evicted_key, evicted_value;
+ if self._count == self.size then
+ local tail = self._tail;
+ on_evict, evicted_key, evicted_value = self._on_evict, tail.key, tail.value;
+ _remove(self, tail);
+ self._data[evicted_key] = nil;
+ end
+
+ m = { key = k, value = v, prev = nil, next = nil };
+ self._data[k] = m;
+ _insert(self, m);
+ if on_evict and evicted_key then
+ on_evict(evicted_key, evicted_value, self);
+ end
+ return true;
+end
+
+function cache_methods:get(k)
+ local m = self._data[k];
+ if m then
+ return m.value;
+ end
+ return nil;
+end
+
+function cache_methods:items()
+ local m = self._head;
+ return function ()
+ if not m then
+ return;
+ end
+ local k, v = m.key, m.value;
+ m = m.next;
+ return k, v;
+ end
+end
+
+function cache_methods:count()
+ return self._count;
+end
+
+local function new(size, on_evict)
+ size = assert(tonumber(size), "cache size must be a number");
+ size = math.floor(size);
+ assert(size > 0, "cache size must be greater than zero");
+ local data = {};
+ return setmetatable({ _data = data, _count = 0, size = size, _head = nil, _tail = nil, _on_evict = on_evict }, cache_mt);
+end
+
+return {
+ new = new;
+}
diff --git a/util/caps.lua b/util/caps.lua
index a61e7403..cd5ff9c0 100644
--- a/util/caps.lua
+++ b/util/caps.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -12,9 +12,9 @@ local sha1 = require "util.hashes".sha1;
local t_insert, t_sort, t_concat = table.insert, table.sort, table.concat;
local ipairs = ipairs;
-module "caps"
+local _ENV = nil;
-function calculate_hash(disco_info)
+local function calculate_hash(disco_info)
local identities, features, extensions = {}, {}, {};
for _, tag in ipairs(disco_info) do
if tag.name == "identity" then
@@ -58,4 +58,6 @@ function calculate_hash(disco_info)
return ver, S;
end
-return _M;
+return {
+ calculate_hash = calculate_hash;
+};
diff --git a/util/dataforms.lua b/util/dataforms.lua
index ee37157a..79b4d1a4 100644
--- a/util/dataforms.lua
+++ b/util/dataforms.lua
@@ -1,26 +1,26 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
local setmetatable = setmetatable;
-local pairs, ipairs = pairs, ipairs;
+local ipairs = ipairs;
local tostring, type, next = tostring, type, next;
local t_concat = table.concat;
local st = require "util.stanza";
local jid_prep = require "util.jid".prep;
-module "dataforms"
+local _ENV = nil;
local xmlns_forms = 'jabber:x:data';
local form_t = {};
local form_mt = { __index = form_t };
-function new(layout)
+local function new(layout)
return setmetatable(layout, form_mt);
end
@@ -32,13 +32,13 @@ function form_t.form(layout, data, formtype)
if layout.instructions then
form:tag("instructions"):text(layout.instructions):up();
end
- for n, field in ipairs(layout) do
+ for _, field in ipairs(layout) do
local field_type = field.type or "text-single";
-- Add field tag
form:tag("field", { type = field_type, var = field.name, label = field.label });
local value = (data and data[field.name]) or field.value;
-
+
if value then
-- Add value, depending on type
if field_type == "hidden" then
@@ -102,11 +102,11 @@ function form_t.form(layout, data, formtype)
end
form:up();
end
-
+
if field.required then
form:tag("required"):up();
end
-
+
-- Jump back up to list of fields
form:up();
end
@@ -118,6 +118,7 @@ local field_readers = {};
function form_t.data(layout, stanza)
local data = {};
local errors = {};
+ local present = {};
for _, field in ipairs(layout) do
local tag;
@@ -133,6 +134,7 @@ function form_t.data(layout, stanza)
errors[field.name] = "Required value missing";
end
else
+ present[field.name] = true;
local reader = field_readers[field.type];
if reader then
data[field.name], errors[field.name] = reader(tag, field.required);
@@ -140,35 +142,34 @@ function form_t.data(layout, stanza)
end
end
if next(errors) then
- return data, errors;
+ return data, errors, present;
end
- return data;
+ return data, nil, present;
end
-field_readers["text-single"] =
- function (field_tag, required)
- local data = field_tag:get_child_text("value");
- if data and #data > 0 then
- return data
- elseif required then
- return nil, "Required value missing";
- end
+local function simple_text(field_tag, required)
+ local data = field_tag:get_child_text("value");
+ -- XEP-0004 does not say if an empty string is acceptable for a required value
+ -- so we will follow HTML5 which says that empty string means missing
+ if required and (data == nil or data == "") then
+ return nil, "Required value missing";
end
+ return data; -- Return whatever get_child_text returned, even if empty string
+end
-field_readers["text-private"] =
- field_readers["text-single"];
+field_readers["text-single"] = simple_text;
+
+field_readers["text-private"] = simple_text;
field_readers["jid-single"] =
function (field_tag, required)
- local raw_data = field_tag:get_child_text("value")
+ local raw_data, err = simple_text(field_tag, required);
+ if not raw_data then return raw_data, err; end
local data = jid_prep(raw_data);
- if data and #data > 0 then
- return data
- elseif raw_data then
+ if not data then
return nil, "Invalid JID: " .. raw_data;
- elseif required then
- return nil, "Required value missing";
end
+ return data;
end
field_readers["jid-multi"] =
@@ -212,8 +213,7 @@ field_readers["text-multi"] =
return data, err;
end
-field_readers["list-single"] =
- field_readers["text-single"];
+field_readers["list-single"] = simple_text;
local boolean_values = {
["1"] = true, ["true"] = true,
@@ -222,15 +222,13 @@ local boolean_values = {
field_readers["boolean"] =
function (field_tag, required)
- local raw_value = field_tag:get_child_text("value");
- local value = boolean_values[raw_value ~= nil and raw_value];
- if value ~= nil then
- return value;
- elseif raw_value then
- return nil, "Invalid boolean representation";
- elseif required then
- return nil, "Required value missing";
+ local raw_value, err = simple_text(field_tag, required);
+ if not raw_value then return raw_value, err; end
+ local value = boolean_values[raw_value];
+ if value == nil then
+ return nil, "Invalid boolean representation:" .. raw_value;
end
+ return value;
end
field_readers["hidden"] =
@@ -238,7 +236,9 @@ field_readers["hidden"] =
return field_tag:get_child_text("value");
end
-return _M;
+return {
+ new = new;
+};
--[=[
diff --git a/util/datamanager.lua b/util/datamanager.lua
index a107d95c..83f3dd13 100644
--- a/util/datamanager.lua
+++ b/util/datamanager.lua
@@ -43,7 +43,7 @@ pcall(function()
fallocate = pposix.fallocate or fallocate;
end);
-module "datamanager"
+local _ENV = nil;
---- utils -----
local encode, decode;
@@ -74,7 +74,7 @@ local callbacks = {};
------- API -------------
-function set_data_path(path)
+local function set_data_path(path)
log("debug", "Setting data path to: %s", path);
data_path = path;
end
@@ -87,14 +87,14 @@ local function callback(username, host, datastore, data)
return username, host, datastore, data;
end
-function add_callback(func)
+local function add_callback(func)
if not callbacks[func] then -- Would you really want to set the same callback more than once?
callbacks[func] = true;
callbacks[#callbacks+1] = func;
return true;
end
end
-function remove_callback(func)
+local function remove_callback(func)
if callbacks[func] then
for i, f in ipairs(callbacks) do
if f == func then
@@ -106,7 +106,7 @@ function remove_callback(func)
end
end
-function getpath(username, host, datastore, ext, create)
+local function getpath(username, host, datastore, ext, create)
ext = ext or "dat";
host = (host and encode(host)) or "_global";
username = username and encode(username);
@@ -119,7 +119,7 @@ function getpath(username, host, datastore, ext, create)
end
end
-function load(username, host, datastore)
+local function load(username, host, datastore)
local data, ret = envloadfile(getpath(username, host, datastore), {});
if not data then
local mode = lfs.attributes(getpath(username, host, datastore), "mode");
@@ -175,7 +175,7 @@ if prosody and prosody.platform ~= "posix" then
end
end
-function store(username, host, datastore, data)
+local function store(username, host, datastore, data)
if not data then
data = {};
end
@@ -209,33 +209,62 @@ function store(username, host, datastore, data)
return true;
end
-function list_append(username, host, datastore, data)
- if not data then return; end
- if callback(username, host, datastore) == false then return true; end
- -- save the datastore
- local f, msg = io_open(getpath(username, host, datastore, "list", true), "r+");
- if not f then
- f, msg = io_open(getpath(username, host, datastore, "list", true), "w");
- end
+-- Append a blob of data to a file
+local function append(username, host, datastore, ext, data)
+ if type(data) ~= "string" then return; end
+ local filename = getpath(username, host, datastore, ext, true);
+
+ local ok;
+ local f, msg = io_open(filename, "r+");
if not f then
- log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil");
- return;
+ -- File did probably not exist, let's create it
+ f, msg = io_open(filename, "w");
+ if not f then
+ return nil, msg, "open";
+ end
end
- local data = "item(" .. serialize(data) .. ");\n";
+
local pos = f:seek("end");
- local ok, msg = fallocate(f, pos, #data);
- f:seek("set", pos);
- if ok then
- f:write(data);
- else
+ ok, msg = fallocate(f, pos, #data);
+ if not ok then
+ log("warn", "fallocate() failed: %s", tostring(msg));
+ -- This doesn't work on every file system
+ end
+
+ if f:seek() ~= pos then
+ log("debug", "fallocate() changed file position");
+ f:seek("set", pos);
+ end
+
+ ok, msg = f:write(data);
+ if not ok then
+ f:close();
+ return ok, msg, "write";
+ end
+
+ ok, msg = f:close();
+ if not ok then
+ return ok, msg;
+ end
+
+ return true, pos;
+end
+
+local function list_append(username, host, datastore, data)
+ if not data then return; end
+ if callback(username, host, datastore) == false then return true; end
+ -- save the datastore
+
+ data = "item(" .. serialize(data) .. ");\n";
+ local ok, msg = append(username, host, datastore, "list", data);
+ if not ok then
log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil");
return ok, msg;
end
- f:close();
return true;
end
-function list_store(username, host, datastore, data)
+local function list_store(username, host, datastore, data)
if not data then
data = {};
end
@@ -259,7 +288,7 @@ function list_store(username, host, datastore, data)
return true;
end
-function list_load(username, host, datastore)
+local function list_load(username, host, datastore)
local items = {};
local data, ret = envloadfile(getpath(username, host, datastore, "list"), {item = function(i) t_insert(items, i); end});
if not data then
@@ -287,7 +316,7 @@ local type_map = {
list = "list";
}
-function users(host, store, typ)
+local function users(host, store, typ)
typ = type_map[typ or "keyval"];
local store_dir = format("%s/%s/%s", data_path, encode(host), store);
@@ -306,7 +335,7 @@ function users(host, store, typ)
end, state;
end
-function stores(username, host, typ)
+local function stores(username, host, typ)
typ = type_map[typ or "keyval"];
local store_dir = format("%s/%s/", data_path, encode(host));
@@ -346,7 +375,7 @@ local function do_remove(path)
return true
end
-function purge(username, host)
+local function purge(username, host)
local host_dir = format("%s/%s/", data_path, encode(host));
local ok, iter, state, var = pcall(lfs.dir, host_dir);
if not ok then
@@ -366,6 +395,20 @@ function purge(username, host)
return #errs == 0, t_concat(errs, ", ");
end
-_M.path_decode = decode;
-_M.path_encode = encode;
-return _M;
+return {
+ set_data_path = set_data_path;
+ add_callback = add_callback;
+ remove_callback = remove_callback;
+ getpath = getpath;
+ load = load;
+ store = store;
+ append_raw = append;
+ list_append = list_append;
+ list_store = list_store;
+ list_load = list_load;
+ users = users;
+ stores = stores;
+ purge = purge;
+ path_decode = decode;
+ path_encode = encode;
+};
diff --git a/util/datetime.lua b/util/datetime.lua
index a1f62a48..27f28ace 100644
--- a/util/datetime.lua
+++ b/util/datetime.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -15,25 +15,25 @@ local os_difftime = os.difftime;
local error = error;
local tonumber = tonumber;
-module "datetime"
+local _ENV = nil;
-function date(t)
+local function date(t)
return os_date("!%Y-%m-%d", t);
end
-function datetime(t)
+local function datetime(t)
return os_date("!%Y-%m-%dT%H:%M:%SZ", t);
end
-function time(t)
+local function time(t)
return os_date("!%H:%M:%S", t);
end
-function legacy(t)
+local function legacy(t)
return os_date("!%Y%m%dT%H:%M:%S", t);
end
-function parse(s)
+local function parse(s)
if s then
local year, month, day, hour, min, sec, tzd;
year, month, day, hour, min, sec, tzd = s:match("^(%d%d%d%d)%-?(%d%d)%-?(%d%d)T(%d%d):(%d%d):(%d%d)%.?%d*([Z+%-]?.*)$");
@@ -54,4 +54,10 @@ function parse(s)
end
end
-return _M;
+return {
+ date = date;
+ datetime = datetime;
+ time = time;
+ legacy = legacy;
+ parse = parse;
+};
diff --git a/util/debug.lua b/util/debug.lua
index bff0e347..a78524a9 100644
--- a/util/debug.lua
+++ b/util/debug.lua
@@ -13,7 +13,7 @@ local termcolours = require "util.termcolours";
local getstring = termcolours.getstring;
local styles;
do
- _ = termcolours.getstyle;
+ local _ = termcolours.getstyle;
styles = {
boundary_padding = _("bright");
filename = _("bright", "blue");
@@ -22,20 +22,23 @@ do
location = _("yellow");
};
end
-module("debugx", package.seeall);
-function get_locals_table(level)
- level = level + 1; -- Skip this function itself
+local function get_locals_table(thread, level)
local locals = {};
for local_num = 1, math.huge do
- local name, value = debug.getlocal(level, local_num);
+ local name, value;
+ if thread then
+ name, value = debug.getlocal(thread, level, local_num);
+ else
+ name, value = debug.getlocal(level+1, local_num);
+ end
if not name then break; end
table.insert(locals, { name = name, value = value });
end
return locals;
end
-function get_upvalues_table(func)
+local function get_upvalues_table(func)
local upvalues = {};
if func then
for upvalue_num = 1, math.huge do
@@ -47,7 +50,7 @@ function get_upvalues_table(func)
return upvalues;
end
-function string_from_var_table(var_table, max_line_len, indent_str)
+local function string_from_var_table(var_table, max_line_len, indent_str)
local var_string = {};
local col_pos = 0;
max_line_len = max_line_len or math.huge;
@@ -83,33 +86,25 @@ function string_from_var_table(var_table, max_line_len, indent_str)
end
end
-function get_traceback_table(thread, start_level)
+local function get_traceback_table(thread, start_level)
local levels = {};
for level = start_level, math.huge do
local info;
if thread then
- info = debug.getinfo(thread, level+1);
+ info = debug.getinfo(thread, level);
else
info = debug.getinfo(level+1);
end
if not info then break; end
-
+
levels[(level-start_level)+1] = {
level = level;
info = info;
- locals = get_locals_table(level+1);
+ locals = get_locals_table(thread, level+(thread and 0 or 1));
upvalues = get_upvalues_table(info.func);
};
- end
- return levels;
-end
-
-function traceback(...)
- local ok, ret = pcall(_traceback, ...);
- if not ok then
- return "Error in error handling: "..ret;
end
- return ret;
+ return levels;
end
local function build_source_boundary_marker(last_source_desc)
@@ -117,7 +112,7 @@ local function build_source_boundary_marker(last_source_desc)
return getstring(styles.boundary_padding, "v"..padding).." "..getstring(styles.filename, last_source_desc).." "..getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v "));
end
-function _traceback(thread, message, level)
+local function _traceback(thread, message, level)
-- Lua manual says: debug.traceback ([thread,] [message [, level]])
-- I fathom this to mean one of:
@@ -134,15 +129,15 @@ function _traceback(thread, message, level)
return nil; -- debug.traceback() does this
end
- level = level or 1;
+ level = level or 0;
message = message and (message.."\n") or "";
-
- -- +3 counts for this function, and the pcall() and wrapper above us
- local levels = get_traceback_table(thread, level+3);
-
+
+ -- +3 counts for this function, and the pcall() and wrapper above us, the +1... I don't know.
+ local levels = get_traceback_table(thread, level+(thread == nil and 4 or 0));
+
local last_source_desc;
-
+
local lines = {};
for nlevel, level in ipairs(levels) do
local info = level.info;
@@ -171,9 +166,11 @@ function _traceback(thread, message, level)
nlevel = nlevel-1;
table.insert(lines, "\t"..(nlevel==0 and ">" or " ")..getstring(styles.level_num, "("..nlevel..") ")..line);
local npadding = (" "):rep(#tostring(nlevel));
- local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding);
- if locals_str then
- table.insert(lines, "\t "..npadding.."Locals: "..locals_str);
+ if level.locals then
+ local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding);
+ if locals_str then
+ table.insert(lines, "\t "..npadding.."Locals: "..locals_str);
+ end
end
local upvalues_str = string_from_var_table(level.upvalues, optimal_line_length, "\t "..npadding);
if upvalues_str then
@@ -186,8 +183,23 @@ function _traceback(thread, message, level)
return message.."stack traceback:\n"..table.concat(lines, "\n");
end
-function use()
+local function traceback(...)
+ local ok, ret = pcall(_traceback, ...);
+ if not ok then
+ return "Error in error handling: "..ret;
+ end
+ return ret;
+end
+
+local function use()
debug.traceback = traceback;
end
-return _M;
+return {
+ get_locals_table = get_locals_table;
+ get_upvalues_table = get_upvalues_table;
+ string_from_var_table = string_from_var_table;
+ get_traceback_table = get_traceback_table;
+ traceback = traceback;
+ use = use;
+};
diff --git a/util/dependencies.lua b/util/dependencies.lua
index 4d50cf63..9ab40765 100644
--- a/util/dependencies.lua
+++ b/util/dependencies.lua
@@ -1,21 +1,19 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
-module("dependencies", package.seeall)
-
-function softreq(...) local ok, lib = pcall(require, ...); if ok then return lib; else return nil, lib; end end
+local function softreq(...) local ok, lib = pcall(require, ...); if ok then return lib; else return nil, lib; end end
-- Required to be able to find packages installed with luarocks
if not softreq "luarocks.loader" then -- LuaRocks 2.x
softreq "luarocks.require"; -- LuaRocks <1.x
end
-function missingdep(name, sources, msg)
+local function missingdep(name, sources, msg)
print("");
print("**************************");
print("Prosody was unable to find "..tostring(name));
@@ -35,7 +33,7 @@ function missingdep(name, sources, msg)
print("");
end
--- COMPAT w/pre-0.8 Debian: The Debian config file used to use
+-- COMPAT w/pre-0.8 Debian: The Debian config file used to use
-- util.ztact, which has been removed from Prosody in 0.8. This
-- is to log an error for people who still use it, so they can
-- update their configs.
@@ -48,19 +46,19 @@ package.preload["util.ztact"] = function ()
end
end;
-function check_dependencies()
- if _VERSION ~= "Lua 5.1" then
+local function check_dependencies()
+ if _VERSION < "Lua 5.1" then
print "***********************************"
print("Unsupported Lua version: ".._VERSION);
- print("Only Lua 5.1 is supported.");
+ print("At least Lua 5.1 is required.");
print "***********************************"
return false;
end
local fatal;
-
+
local lxp = softreq "lxp"
-
+
if not lxp then
missingdep("luaexpat", {
["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-expat0";
@@ -69,9 +67,9 @@ function check_dependencies()
});
fatal = true;
end
-
+
local socket = softreq "socket"
-
+
if not socket then
missingdep("luasocket", {
["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-socket2";
@@ -80,7 +78,7 @@ function check_dependencies()
});
fatal = true;
end
-
+
local lfs, err = softreq "lfs"
if not lfs then
missingdep("luafilesystem", {
@@ -90,9 +88,9 @@ function check_dependencies()
});
fatal = true;
end
-
+
local ssl = softreq "ssl"
-
+
if not ssl then
missingdep("LuaSec", {
["Debian/Ubuntu"] = "http://prosody.im/download/start#debian_and_ubuntu";
@@ -100,7 +98,7 @@ function check_dependencies()
["Source"] = "http://www.inf.puc-rio.br/~brunoos/luasec/";
}, "SSL/TLS support will not be available");
end
-
+
local encodings, err = softreq "util.encodings"
if not encodings then
if err:match("not found") then
@@ -137,22 +135,27 @@ function check_dependencies()
return not fatal;
end
-function log_warnings()
+local function log_warnings()
+ if _VERSION > "Lua 5.1" then
+ prosody.log("warn", "Support for %s is experimental, please report any issues", _VERSION);
+ end
+ local ssl = softreq"ssl";
if ssl then
local major, minor, veryminor, patched = ssl._VERSION:match("(%d+)%.(%d+)%.?(%d*)(M?)");
if not major or ((tonumber(major) == 0 and (tonumber(minor) or 0) <= 3 and (tonumber(veryminor) or 0) <= 2) and patched ~= "M") then
- log("error", "This version of LuaSec contains a known bug that causes disconnects, see http://prosody.im/doc/depends");
+ prosody.log("error", "This version of LuaSec contains a known bug that causes disconnects, see http://prosody.im/doc/depends");
end
end
+ local lxp = softreq"lxp";
if lxp then
if not pcall(lxp.new, { StartDoctypeDecl = false }) then
- log("error", "The version of LuaExpat on your system leaves Prosody "
+ prosody.log("error", "The version of LuaExpat on your system leaves Prosody "
.."vulnerable to denial-of-service attacks. You should upgrade to "
.."LuaExpat 1.3.0 or higher as soon as possible. See "
.."http://prosody.im/doc/depends#luaexpat for more information.");
end
if not lxp.new({}).getcurrentbytecount then
- log("error", "The version of LuaExpat on your system does not support "
+ prosody.log("error", "The version of LuaExpat on your system does not support "
.."stanza size limits, which may leave servers on untrusted "
.."networks (e.g. the internet) vulnerable to denial-of-service "
.."attacks. You should upgrade to LuaExpat 1.3.0 or higher as "
@@ -162,4 +165,9 @@ function log_warnings()
end
end
-return _M;
+return {
+ softreq = softreq;
+ missingdep = missingdep;
+ check_dependencies = check_dependencies;
+ log_warnings = log_warnings;
+};
diff --git a/util/events.lua b/util/events.lua
index 412acccd..073d2a60 100644
--- a/util/events.lua
+++ b/util/events.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -9,15 +9,23 @@
local pairs = pairs;
local t_insert = table.insert;
+local t_remove = table.remove;
local t_sort = table.sort;
local setmetatable = setmetatable;
local next = next;
-module "events"
+local _ENV = nil;
-function new()
+local function new()
+ -- Map event name to ordered list of handlers (lazily built): handlers[event_name] = array_of_handler_functions
local handlers = {};
+ -- Array of wrapper functions that wrap all events (nil if empty)
+ local global_wrappers;
+ -- Per-event wrappers: wrappers[event_name] = wrapper_function
+ local wrappers = {};
+ -- Event map: event_map[handler_function] = priority_number
local event_map = {};
+ -- Called on-demand to build handlers entries
local function _rebuild_index(handlers, event)
local _handlers = event_map[event];
if not _handlers or next(_handlers) == nil then return; end
@@ -50,6 +58,9 @@ function new()
end
end
end;
+ local function get_handlers(event)
+ return handlers[event];
+ end;
local function add_handlers(handlers)
for event, handler in pairs(handlers) do
add_handler(event, handler);
@@ -60,24 +71,91 @@ function new()
remove_handler(event, handler);
end
end;
- local function fire_event(event, ...)
- local h = handlers[event];
+ local function _fire_event(event_name, event_data)
+ local h = handlers[event_name];
if h then
for i=1,#h do
- local ret = h[i](...);
+ local ret = h[i](event_data);
if ret ~= nil then return ret; end
end
end
end;
+ local function fire_event(event_name, event_data)
+ local w = wrappers[event_name] or global_wrappers;
+ if w then
+ local curr_wrapper = #w;
+ local function c(event_name, event_data)
+ curr_wrapper = curr_wrapper - 1;
+ if curr_wrapper == 0 then
+ if global_wrappers == nil or w == global_wrappers then
+ return _fire_event(event_name, event_data);
+ end
+ w, curr_wrapper = global_wrappers, #global_wrappers;
+ return w[curr_wrapper](c, event_name, event_data);
+ else
+ return w[curr_wrapper](c, event_name, event_data);
+ end
+ end
+ return w[curr_wrapper](c, event_name, event_data);
+ end
+ return _fire_event(event_name, event_data);
+ end
+ local function add_wrapper(event_name, wrapper)
+ local w;
+ if event_name == false then
+ w = global_wrappers;
+ if not w then
+ w = {};
+ global_wrappers = w;
+ end
+ else
+ w = wrappers[event_name];
+ if not w then
+ w = {};
+ wrappers[event_name] = w;
+ end
+ end
+ w[#w+1] = wrapper;
+ end
+ local function remove_wrapper(event_name, wrapper)
+ local w;
+ if event_name == false then
+ w = global_wrappers;
+ else
+ w = wrappers[event_name];
+ end
+ if not w then return; end
+ for i = #w, 1 do
+ if w[i] == wrapper then
+ t_remove(w, i);
+ end
+ end
+ if #w == 0 then
+ if event_name == nil then
+ global_wrappers = nil;
+ else
+ wrappers[event_name] = nil;
+ end
+ end
+ end
return {
add_handler = add_handler;
remove_handler = remove_handler;
add_handlers = add_handlers;
remove_handlers = remove_handlers;
+ get_handlers = get_handlers;
+ wrappers = {
+ add_handler = add_wrapper;
+ remove_handler = remove_wrapper;
+ };
+ add_wrapper = add_wrapper;
+ remove_wrapper = remove_wrapper;
fire_event = fire_event;
_handlers = handlers;
_event_map = event_map;
};
end
-return _M;
+return {
+ new = new;
+};
diff --git a/util/filters.lua b/util/filters.lua
index 6290e53b..f405c0bd 100644
--- a/util/filters.lua
+++ b/util/filters.lua
@@ -1,22 +1,22 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
local t_insert, t_remove = table.insert, table.remove;
-module "filters"
+local _ENV = nil;
local new_filter_hooks = {};
-function initialize(session)
+local function initialize(session)
if not session.filters then
local filters = {};
session.filters = filters;
-
+
function session.filter(type, data)
local filter_list = filters[type];
if filter_list then
@@ -28,19 +28,19 @@ function initialize(session)
return data;
end
end
-
+
for i=1,#new_filter_hooks do
new_filter_hooks[i](session);
end
-
+
return session.filter;
end
-function add_filter(session, type, callback, priority)
+local function add_filter(session, type, callback, priority)
if not session.filters then
initialize(session);
end
-
+
local filter_list = session.filters[type];
if not filter_list then
filter_list = {};
@@ -48,19 +48,19 @@ function add_filter(session, type, callback, priority)
elseif filter_list[callback] then
return; -- Filter already added
end
-
+
priority = priority or 0;
-
+
local i = 0;
repeat
i = i + 1;
until not filter_list[i] or filter_list[filter_list[i]] < priority;
-
+
t_insert(filter_list, i, callback);
filter_list[callback] = priority;
end
-function remove_filter(session, type, callback)
+local function remove_filter(session, type, callback)
if not session.filters then return; end
local filter_list = session.filters[type];
if filter_list and filter_list[callback] then
@@ -74,11 +74,11 @@ function remove_filter(session, type, callback)
end
end
-function add_filter_hook(callback)
+local function add_filter_hook(callback)
t_insert(new_filter_hooks, callback);
end
-function remove_filter_hook(callback)
+local function remove_filter_hook(callback)
for i=1,#new_filter_hooks do
if new_filter_hooks[i] == callback then
t_remove(new_filter_hooks, i);
@@ -86,4 +86,10 @@ function remove_filter_hook(callback)
end
end
-return _M;
+return {
+ initialize = initialize;
+ add_filter = add_filter;
+ remove_filter = remove_filter;
+ add_filter_hook = add_filter_hook;
+ remove_filter_hook = remove_filter_hook;
+};
diff --git a/util/helpers.lua b/util/helpers.lua
index 08b86a7c..bf76d258 100644
--- a/util/helpers.lua
+++ b/util/helpers.lua
@@ -1,28 +1,18 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
local debug = require "util.debug";
-module("helpers", package.seeall);
-
-- Helper functions for debugging
local log = require "util.logger".init("util.debug");
-function log_host_events(host)
- return log_events(prosody.hosts[host].events, host);
-end
-
-function revert_log_host_events(host)
- return revert_log_events(prosody.hosts[host].events);
-end
-
-function log_events(events, name, logger)
+local function log_events(events, name, logger)
local f = events.fire_event;
if not f then
error("Object does not appear to be a util.events object");
@@ -37,11 +27,19 @@ function log_events(events, name, logger)
return events;
end
-function revert_log_events(events)
+local function revert_log_events(events)
events.fire_event, events[events.fire_event] = events[events.fire_event], nil; -- :))
end
-function show_events(events, specific_event)
+local function log_host_events(host)
+ return log_events(prosody.hosts[host].events, host);
+end
+
+local function revert_log_host_events(host)
+ return revert_log_events(prosody.hosts[host].events);
+end
+
+local function show_events(events, specific_event)
local event_handlers = events._handlers;
local events_array = {};
local event_handler_arrays = {};
@@ -70,7 +68,7 @@ function show_events(events, specific_event)
return table.concat(events_array, "\n");
end
-function get_upvalue(f, get_name)
+local function get_upvalue(f, get_name)
local i, name, value = 0;
repeat
i = i + 1;
@@ -79,4 +77,11 @@ function get_upvalue(f, get_name)
return value;
end
-return _M;
+return {
+ log_host_events = log_host_events;
+ revert_log_host_events = revert_log_host_events;
+ log_events = log_events;
+ revert_log_events = revert_log_events;
+ show_events = show_events;
+ get_upvalue = get_upvalue;
+};
diff --git a/util/hex.lua b/util/hex.lua
new file mode 100644
index 00000000..4cc28d33
--- /dev/null
+++ b/util/hex.lua
@@ -0,0 +1,26 @@
+local s_char = string.char;
+local s_format = string.format;
+local s_gsub = string.gsub;
+local s_lower = string.lower;
+
+local char_to_hex = {};
+local hex_to_char = {};
+
+do
+ local char, hex;
+ for i = 0,255 do
+ char, hex = s_char(i), s_format("%02x", i);
+ char_to_hex[char] = hex;
+ hex_to_char[hex] = char;
+ end
+end
+
+local function to(s)
+ return (s_gsub(s, ".", char_to_hex));
+end
+
+local function from(s)
+ return (s_gsub(s_lower(s), "%X*(%x%x)%X*", hex_to_char));
+end
+
+return { to = to, from = from }
diff --git a/util/hmac.lua b/util/hmac.lua
index 51211c7a..2c4cc6ef 100644
--- a/util/hmac.lua
+++ b/util/hmac.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
diff --git a/util/import.lua b/util/import.lua
index 81401e8b..174da0ca 100644
--- a/util/import.lua
+++ b/util/import.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
diff --git a/util/interpolation.lua b/util/interpolation.lua
new file mode 100644
index 00000000..315cc203
--- /dev/null
+++ b/util/interpolation.lua
@@ -0,0 +1,85 @@
+-- Simple template language
+--
+-- The new() function takes a pattern and an escape function and returns
+-- a render() function. Both are required.
+--
+-- The function render() takes a string template and a table of values.
+-- Sequences like {name} in the template string are substituted
+-- with values from the table, optionally depending on a modifier
+-- symbol.
+--
+-- Variants are:
+-- {name} is substituted for values["name"] and is escaped using the
+-- second argument to new_render(). To disable the escaping, use {name!}.
+-- {name.item} can be used to access table items.
+-- To renter lists of items: {name# item number {idx} is {item} }
+-- Or key-value pairs: {name% t[ {idx} ] = {item} }
+-- To show a defaults for missing values {name? sub-template } can be used,
+-- which renders a sub-template if values["name"] is false-ish.
+-- {name& sub-template } does the opposite, the sub-template is rendered
+-- if the selected value is anything but false or nil.
+
+local type, tostring = type, tostring;
+local pairs, ipairs = pairs, ipairs;
+local s_sub, s_gsub, s_match = string.sub, string.gsub, string.match;
+local t_concat = table.concat;
+
+local function new_render(pat, escape, funcs)
+ -- assert(type(pat) == "string", "bad argument #1 to 'new_render' (string expected)");
+ -- assert(type(escape) == "function", "bad argument #2 to 'new_render' (function expected)");
+ local function render(template, values)
+ -- assert(type(template) == "string", "bad argument #1 to 'render' (string expected)");
+ -- assert(type(values) == "table", "bad argument #2 to 'render' (table expected)");
+ return (s_gsub(template, pat, function (block)
+ block = s_sub(block, 2, -2);
+ local name, opt, e = s_match(block, "^([%a_][%w_.]*)(%p?)()");
+ if not name then return end
+ local value = values[name];
+ if not value and name:find(".", 2, true) then
+ value = values;
+ for word in name:gmatch"[^.]+" do
+ value = value[word];
+ if not value then break; end
+ end
+ end
+ if funcs then
+ while value ~= nil and opt == '|' do
+ local f;
+ f, opt, e = s_match(block, "^([%a_][%w_.]*)(%p?)()", e);
+ f = funcs[f];
+ if f then value = f(value); end
+ end
+ end
+ if opt == '#' or opt == '%' then
+ if type(value) ~= "table" then return ""; end
+ local iter = opt == '#' and ipairs or pairs;
+ local out, i, subtpl = {}, 1, s_sub(block, e);
+ local subvalues = setmetatable({}, { __index = values });
+ for idx, item in iter(value) do
+ subvalues.idx = idx;
+ subvalues.item = item;
+ out[i], i = render(subtpl, subvalues), i+1;
+ end
+ return t_concat(out);
+ elseif opt == '&' then
+ if not value then return ""; end
+ return render(s_sub(block, e), values);
+ elseif opt == '?' and not value then
+ return render(s_sub(block, e), values);
+ elseif value ~= nil then
+ if type(value) ~= "string" then
+ value = tostring(value);
+ end
+ if opt ~= '!' then
+ return escape(value);
+ end
+ return value;
+ end
+ end));
+ end
+ return render;
+end
+
+return {
+ new = new_render;
+};
diff --git a/util/ip.lua b/util/ip.lua
index acfd7f24..ec3b4d7e 100644
--- a/util/ip.lua
+++ b/util/ip.lua
@@ -96,7 +96,7 @@ local function v6scope(ip)
if ip:match("^[0:]*1$") then
return 0x2;
-- Link-local unicast:
- elseif ip:match("^[Ff][Ee][89ABab]") then
+ elseif ip:match("^[Ff][Ee][89ABab]") then
return 0x2;
-- Site-local unicast:
elseif ip:match("^[Ff][Ee][CcDdEeFf]") then
@@ -206,5 +206,40 @@ function ip_methods:scope()
return value;
end
+function ip_methods:private()
+ local private = self.scope ~= 0xE;
+ if not private and self.proto == "IPv4" then
+ local ip = self.addr;
+ local fields = {};
+ ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end);
+ if fields[1] == 127 or fields[1] == 10 or (fields[1] == 192 and fields[2] == 168)
+ or (fields[1] == 172 and (fields[2] >= 16 or fields[2] <= 32)) then
+ private = true;
+ end
+ end
+ self.private = private;
+ return private;
+end
+
+local function parse_cidr(cidr)
+ local bits;
+ local ip_len = cidr:find("/", 1, true);
+ if ip_len then
+ bits = tonumber(cidr:sub(ip_len+1, -1));
+ cidr = cidr:sub(1, ip_len-1);
+ end
+ return new_ip(cidr), bits;
+end
+
+local function match(ipA, ipB, bits)
+ local common_bits = commonPrefixLength(ipA, ipB);
+ if bits and ipB.proto == "IPv4" then
+ common_bits = common_bits - 96; -- v6 mapped addresses always share these bits
+ end
+ return common_bits >= (bits or 128);
+end
+
return {new_ip = new_ip,
- commonPrefixLength = commonPrefixLength};
+ commonPrefixLength = commonPrefixLength,
+ parse_cidr = parse_cidr,
+ match=match};
diff --git a/util/iterators.lua b/util/iterators.lua
index 1f6aacb8..aa9c3ec0 100644
--- a/util/iterators.lua
+++ b/util/iterators.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -10,6 +10,10 @@
local it = {};
+local t_insert = table.insert;
+local select, unpack, next = select, unpack, next;
+local function pack(...) return { n = select("#", ...), ... }; end
+
-- Reverse an iterator
function it.reverse(f, s, var)
local results = {};
@@ -19,9 +23,9 @@ function it.reverse(f, s, var)
local ret = { f(s, var) };
var = ret[1];
if var == nil then break; end
- table.insert(results, 1, ret);
+ t_insert(results, 1, ret);
end
-
+
-- Then return our reverse one
local i,max = 0, #results;
return function (results)
@@ -52,15 +56,15 @@ end
-- Given an iterator, iterate only over unique items
function it.unique(f, s, var)
local set = {};
-
+
return function ()
while true do
- local ret = { f(s, var) };
+ local ret = pack(f(s, var));
var = ret[1];
if var == nil then break; end
if not set[var] then
set[var] = true;
- return var;
+ return unpack(ret, 1, ret.n);
end
end
end;
@@ -69,14 +73,13 @@ end
--[[ Return the number of items an iterator returns ]]--
function it.count(f, s, var)
local x = 0;
-
+
while true do
- local ret = { f(s, var) };
- var = ret[1];
+ var = f(s, var);
if var == nil then break; end
x = x + 1;
end
-
+
return x;
end
@@ -104,7 +107,7 @@ end
function it.tail(n, f, s, var)
local results, count = {}, 0;
while true do
- local ret = { f(s, var) };
+ local ret = pack(f(s, var));
var = ret[1];
if var == nil then break; end
results[(count%n)+1] = ret;
@@ -117,9 +120,24 @@ function it.tail(n, f, s, var)
return function ()
pos = pos + 1;
if pos > n then return nil; end
- return unpack(results[((count-1+pos)%n)+1]);
+ local ret = results[((count-1+pos)%n)+1];
+ return unpack(ret, 1, ret.n);
end
- --return reverse(head(n, reverse(f, s, var)));
+ --return reverse(head(n, reverse(f, s, var))); -- !
+end
+
+function it.filter(filter, f, s, var)
+ if type(filter) ~= "function" then
+ local filter_value = filter;
+ function filter(x) return x ~= filter_value; end
+ end
+ return function (s, var)
+ local ret;
+ repeat ret = pack(f(s, var));
+ var = ret[1];
+ until var == nil or filter(unpack(ret, 1, ret.n));
+ return unpack(ret, 1, ret.n);
+ end, s, var;
end
local function _ripairs_iter(t, key) if key > 1 then return key-1, t[key-1]; end end
@@ -139,7 +157,7 @@ function it.to_array(f, s, var)
while true do
var = f(s, var);
if var == nil then break; end
- table.insert(t, var);
+ t_insert(t, var);
end
return t;
end
diff --git a/util/jid.lua b/util/jid.lua
index 8e0a784c..76155ac7 100644
--- a/util/jid.lua
+++ b/util/jid.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -23,9 +23,9 @@ local escapes = {
local unescapes = {};
for k,v in pairs(escapes) do unescapes[v] = k; end
-module "jid"
+local _ENV = nil;
-local function _split(jid)
+local function split(jid)
if not jid then return; end
local node, nodepos = match(jid, "^([^@/]+)@()");
local host, hostpos = match(jid, "^([^@/]+)()", nodepos)
@@ -34,18 +34,17 @@ local function _split(jid)
if (not host) or ((not resource) and #jid >= hostpos) then return nil, nil, nil; end
return node, host, resource;
end
-split = _split;
-function bare(jid)
- local node, host = _split(jid);
+local function bare(jid)
+ local node, host = split(jid);
if node and host then
return node.."@"..host;
end
return host;
end
-local function _prepped_split(jid)
- local node, host, resource = _split(jid);
+local function prepped_split(jid)
+ local node, host, resource = split(jid);
if host then
if sub(host, -1, -1) == "." then -- Strip empty root label
host = sub(host, 1, -2);
@@ -63,39 +62,29 @@ local function _prepped_split(jid)
return node, host, resource;
end
end
-prepped_split = _prepped_split;
-
-function prep(jid)
- local node, host, resource = _prepped_split(jid);
- if host then
- if node then
- host = node .. "@" .. host;
- end
- if resource then
- host = host .. "/" .. resource;
- end
- end
- return host;
-end
-function join(node, host, resource)
- if node and host and resource then
+local function join(node, host, resource)
+ if not host then return end
+ if node and resource then
return node.."@"..host.."/"..resource;
- elseif node and host then
+ elseif node then
return node.."@"..host;
- elseif host and resource then
+ elseif resource then
return host.."/"..resource;
- elseif host then
- return host;
end
- return nil; -- Invalid JID
+ return host;
end
-function compare(jid, acl)
+local function prep(jid)
+ local node, host, resource = prepped_split(jid);
+ return join(node, host, resource);
+end
+
+local function compare(jid, acl)
-- compare jid to single acl rule
-- TODO compare to table of rules?
- local jid_node, jid_host, jid_resource = _split(jid);
- local acl_node, acl_host, acl_resource = _split(acl);
+ local jid_node, jid_host, jid_resource = split(jid);
+ local acl_node, acl_host, acl_resource = split(acl);
if ((acl_node ~= nil and acl_node == jid_node) or acl_node == nil) and
((acl_host ~= nil and acl_host == jid_host) or acl_host == nil) and
((acl_resource ~= nil and acl_resource == jid_resource) or acl_resource == nil) then
@@ -104,7 +93,16 @@ function compare(jid, acl)
return false
end
-function escape(s) return s and (s:gsub(".", escapes)); end
-function unescape(s) return s and (s:gsub("\\%x%x", unescapes)); end
+local function escape(s) return s and (s:gsub(".", escapes)); end
+local function unescape(s) return s and (s:gsub("\\%x%x", unescapes)); end
-return _M;
+return {
+ split = split;
+ bare = bare;
+ prepped_split = prepped_split;
+ join = join;
+ prep = prep;
+ compare = compare;
+ escape = escape;
+ unescape = unescape;
+};
diff --git a/util/json.lua b/util/json.lua
index 82ebcc43..becd295d 100644
--- a/util/json.lua
+++ b/util/json.lua
@@ -13,7 +13,7 @@ local tostring, tonumber = tostring, tonumber;
local pairs, ipairs = pairs, ipairs;
local next = next;
local error = error;
-local newproxy, getmetatable, setmetatable = newproxy, getmetatable, setmetatable;
+local getmetatable, setmetatable = getmetatable, setmetatable;
local print = print;
local has_array, array = pcall(require, "util.array");
@@ -22,10 +22,7 @@ local array_mt = has_array and getmetatable(array()) or {};
--module("json")
local json = {};
-local null = newproxy and newproxy(true) or {};
-if getmetatable and getmetatable(null) then
- getmetatable(null).__tostring = function() return "null"; end;
-end
+local null = setmetatable({}, { __tostring = function() return "null"; end; });
json.null = null;
local escapes = {
@@ -348,9 +345,9 @@ local first_escape = {
function json.decode(json)
json = json:gsub("\\.", first_escape) -- get rid of all escapes except \uXXXX, making string parsing much simpler
--:gsub("[\r\n]", "\t"); -- \r\n\t are equivalent, we care about none of them, and none of them can be in strings
-
+
-- TODO do encoding verification
-
+
local val, index = _readvalue(json, 1);
if val == nil then return val, index; end
if json:find("[^ \t\r\n]", index) then return nil, "garbage at eof"; end
diff --git a/util/logger.lua b/util/logger.lua
index 26206d4d..3d1f1c8b 100644
--- a/util/logger.lua
+++ b/util/logger.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -11,13 +11,13 @@ local pcall = pcall;
local find = string.find;
local ipairs, pairs, setmetatable = ipairs, pairs, setmetatable;
-module "logger"
+local _ENV = nil;
local level_sinks = {};
local make_logger;
-function init(name)
+local function init(name)
local log_debug = make_logger(name, "debug");
local log_info = make_logger(name, "info");
local log_warn = make_logger(name, "warn");
@@ -52,7 +52,7 @@ function make_logger(source_name, level)
return logger;
end
-function reset()
+local function reset()
for level, handler_list in pairs(level_sinks) do
-- Clear all handlers for this level
for i = 1, #handler_list do
@@ -61,7 +61,7 @@ function reset()
end
end
-function add_level_sink(level, sink_function)
+local function add_level_sink(level, sink_function)
if not level_sinks[level] then
level_sinks[level] = { sink_function };
else
@@ -69,6 +69,10 @@ function add_level_sink(level, sink_function)
end
end
-_M.new = make_logger;
-
-return _M;
+return {
+ init = init;
+ make_logger = make_logger;
+ reset = reset;
+ add_level_sink = add_level_sink;
+ new = make_logger;
+};
diff --git a/util/mercurial.lua b/util/mercurial.lua
new file mode 100644
index 00000000..3f75c4c1
--- /dev/null
+++ b/util/mercurial.lua
@@ -0,0 +1,34 @@
+
+local lfs = require"lfs";
+
+local hg = { };
+
+function hg.check_id(path)
+ if lfs.attributes(path, 'mode') ~= "directory" then
+ return nil, "not a directory";
+ end
+ local hg_dirstate = io.open(path.."/.hg/dirstate");
+ local hgid, hgrepo
+ if hg_dirstate then
+ hgid = ("%02x%02x%02x%02x%02x%02x"):format(hg_dirstate:read(6):byte(1, 6));
+ hg_dirstate:close();
+ local hg_changelog = io.open(path.."/.hg/store/00changelog.i");
+ if hg_changelog then
+ hg_changelog:seek("set", 0x20);
+ hgrepo = ("%02x%02x%02x%02x%02x%02x"):format(hg_changelog:read(6):byte(1, 6));
+ hg_changelog:close();
+ end
+ else
+ local hg_archival,e = io.open(path.."/.hg_archival.txt");
+ if hg_archival then
+ local repo = hg_archival:read("*l");
+ local node = hg_archival:read("*l");
+ hg_archival:close()
+ hgid = node and node:match("^node: (%x%x%x%x%x%x%x%x%x%x%x%x)")
+ hgrepo = repo and repo:match("^repo: (%x%x%x%x%x%x%x%x%x%x%x%x)")
+ end
+ end
+ return hgid, hgrepo;
+end
+
+return hg;
diff --git a/util/multitable.lua b/util/multitable.lua
index dbf34d28..7a2d2b2a 100644
--- a/util/multitable.lua
+++ b/util/multitable.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -10,7 +10,7 @@ local select = select;
local t_insert = table.insert;
local unpack, pairs, next, type = unpack, pairs, next, type;
-module "multitable"
+local _ENV = nil;
local function get(self, ...)
local t = self.data;
@@ -126,7 +126,7 @@ local function search_add(self, results, ...)
return results;
end
-function iter(self, ...)
+local function iter(self, ...)
local query = { ... };
local maxdepth = select("#", ...);
local stack = { self.data };
@@ -161,7 +161,7 @@ function iter(self, ...)
return it, self;
end
-function new()
+local function new()
return {
data = {};
get = get;
@@ -174,4 +174,7 @@ function new()
};
end
-return _M;
+return {
+ iter = iter;
+ new = new;
+};
diff --git a/util/openssl.lua b/util/openssl.lua
index 39fe99d6..12e49eac 100644
--- a/util/openssl.lua
+++ b/util/openssl.lua
@@ -12,7 +12,7 @@ local config = {};
_M.config = config;
local ssl_config = {};
-local ssl_config_mt = {__index=ssl_config};
+local ssl_config_mt = { __index = ssl_config };
function config.new()
return setmetatable({
@@ -65,12 +65,12 @@ function ssl_config:serialize()
s = s .. ("[%s]\n"):format(k);
if k == "subject_alternative_name" then
for san, n in pairs(t) do
- for i = 1,#n do
+ for i = 1, #n do
s = s .. s_format("%s.%d = %s\n", san, i -1, n[i]);
end
end
elseif k == "distinguished_name" then
- for i=1,#DN_order do
+ for i=1, #DN_order do
local k = DN_order[i]
local v = t[k];
if v then
@@ -107,7 +107,7 @@ end
function ssl_config:add_sRVName(host, service)
t_insert(self.subject_alternative_name.otherName,
- s_format("%s;%s", oid_dnssrv, ia5string("_" .. service .."." .. idna_to_ascii(host))));
+ s_format("%s;%s", oid_dnssrv, ia5string("_" .. service .. "." .. idna_to_ascii(host))));
end
function ssl_config:add_xmppAddr(host)
@@ -118,10 +118,10 @@ end
function ssl_config:from_prosody(hosts, config, certhosts)
-- TODO Decide if this should go elsewhere
local found_matching_hosts = false;
- for i = 1,#certhosts do
+ for i = 1, #certhosts do
local certhost = certhosts[i];
for name in pairs(hosts) do
- if name == certhost or name:sub(-1-#certhost) == "."..certhost then
+ if name == certhost or name:sub(-1-#certhost) == "." .. certhost then
found_matching_hosts = true;
self:add_dNSName(name);
--print(name .. "#component_module: " .. (config.get(name, "component_module") or "nil"));
@@ -144,30 +144,30 @@ end
do -- Lua to shell calls.
local function shell_escape(s)
- return s:gsub("'",[['\'']]);
+ return "'" .. tostring(s):gsub("'",[['\'']]) .. "'";
end
- local function serialize(f,o)
- local r = {"openssl", f};
- for k,v in pairs(o) do
+ local function serialize(command, args)
+ local commandline = { "openssl", command };
+ for k, v in pairs(args) do
if type(k) == "string" then
- t_insert(r, ("-%s"):format(k));
+ t_insert(commandline, ("-%s"):format(k));
if v ~= true then
- t_insert(r, ("'%s'"):format(shell_escape(tostring(v))));
+ t_insert(commandline, shell_escape(v));
end
end
end
- for _,v in ipairs(o) do
- t_insert(r, ("'%s'"):format(shell_escape(tostring(v))));
+ for _, v in ipairs(args) do
+ t_insert(commandline, shell_escape(v));
end
- return t_concat(r, " ");
+ return t_concat(commandline, " ");
end
local os_execute = os.execute;
setmetatable(_M, {
- __index=function(_,f)
+ __index = function(_, command)
return function(opts)
- return 0 == os_execute(serialize(f, type(opts) == "table" and opts or {}));
+ return 0 == os_execute(serialize(command, type(opts) == "table" and opts or {}));
end;
end;
});
diff --git a/util/paths.lua b/util/paths.lua
new file mode 100644
index 00000000..89f4cad9
--- /dev/null
+++ b/util/paths.lua
@@ -0,0 +1,44 @@
+local t_concat = table.concat;
+
+local path_sep = package.config:sub(1,1);
+
+local path_util = {}
+
+-- Helper function to resolve relative paths (needed by config)
+function path_util.resolve_relative_path(parent_path, path)
+ if path then
+ -- Some normalization
+ parent_path = parent_path:gsub("%"..path_sep.."+$", "");
+ path = path:gsub("^%.%"..path_sep.."+", "");
+
+ local is_relative;
+ if path_sep == "/" and path:sub(1,1) ~= "/" then
+ is_relative = true;
+ elseif path_sep == "\\" and (path:sub(1,1) ~= "/" and (path:sub(2,3) ~= ":\\" and path:sub(2,3) ~= ":/")) then
+ is_relative = true;
+ end
+ if is_relative then
+ return parent_path..path_sep..path;
+ end
+ end
+ return path;
+end
+
+-- Helper function to convert a glob to a Lua pattern
+function path_util.glob_to_pattern(glob)
+ return "^"..glob:gsub("[%p*?]", function (c)
+ if c == "*" then
+ return ".*";
+ elseif c == "?" then
+ return ".";
+ else
+ return "%"..c;
+ end
+ end).."$";
+end
+
+function path_util.join(...)
+ return t_concat({...}, path_sep);
+end
+
+return path_util;
diff --git a/util/pluginloader.lua b/util/pluginloader.lua
index 112c0d52..0d7eafa7 100644
--- a/util/pluginloader.lua
+++ b/util/pluginloader.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -17,9 +17,7 @@ end
local io_open = io.open;
local envload = require "util.envload".envload;
-module "pluginloader"
-
-function load_file(names)
+local function load_file(names)
local file, err, path;
for i=1,#plugin_dir do
for j=1,#names do
@@ -35,7 +33,7 @@ function load_file(names)
return file, err;
end
-function load_resource(plugin, resource)
+local function load_resource(plugin, resource)
resource = resource or "mod_"..plugin..".lua";
local names = {
@@ -48,7 +46,7 @@ function load_resource(plugin, resource)
return load_file(names);
end
-function load_code(plugin, resource, env)
+local function load_code(plugin, resource, env)
local content, err = load_resource(plugin, resource);
if not content then return content, err; end
local path = err;
@@ -57,4 +55,8 @@ function load_code(plugin, resource, env)
return f, path;
end
-return _M;
+return {
+ load_file = load_file;
+ load_resource = load_resource;
+ load_code = load_code;
+};
diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua
index c6fe1986..f8f28644 100644
--- a/util/prosodyctl.lua
+++ b/util/prosodyctl.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -29,25 +29,19 @@ local CFG_SOURCEDIR = _G.CFG_SOURCEDIR;
local _G = _G;
local prosody = prosody;
-module "prosodyctl"
-
-- UI helpers
-function show_message(msg, ...)
- print(msg:format(...));
-end
-
-function show_warning(msg, ...)
+local function show_message(msg, ...)
print(msg:format(...));
end
-function show_usage(usage, desc)
+local function show_usage(usage, desc)
print("Usage: ".._G.arg[0].." "..usage);
if desc then
print(" "..desc);
end
end
-function getchar(n)
+local function getchar(n)
local stty_ret = os.execute("stty raw -echo 2>/dev/null");
local ok, char;
if stty_ret == 0 then
@@ -64,14 +58,14 @@ function getchar(n)
end
end
-function getline()
+local function getline()
local ok, line = pcall(io.read, "*l");
if ok then
return line;
end
end
-function getpass()
+local function getpass()
local stty_ret = os.execute("stty -echo 2>/dev/null");
if stty_ret ~= 0 then
io.write("\027[08m"); -- ANSI 'hidden' text attribute
@@ -88,7 +82,7 @@ function getpass()
end
end
-function show_yesno(prompt)
+local function show_yesno(prompt)
io.write(prompt, " ");
local choice = getchar():lower();
io.write("\n");
@@ -99,7 +93,7 @@ function show_yesno(prompt)
return (choice == "y");
end
-function read_password()
+local function read_password()
local password;
while true do
io.write("Enter new password: ");
@@ -120,7 +114,7 @@ function read_password()
return password;
end
-function show_prompt(prompt)
+local function show_prompt(prompt)
io.write(prompt, " ");
local line = getline();
line = line and line:gsub("\n$","");
@@ -128,7 +122,7 @@ function show_prompt(prompt)
end
-- Server control
-function adduser(params)
+local function adduser(params)
local user, host, password = nodeprep(params.user), nameprep(params.host), params.password;
if not user then
return false, "invalid-username";
@@ -146,15 +140,15 @@ function adduser(params)
if not(provider) or provider.name == "null" then
usermanager.initialize_host(host);
end
-
+
local ok, errmsg = usermanager.create_user(user, password, host);
if not ok then
- return false, errmsg;
+ return false, errmsg or "creating-user-failed";
end
return true;
end
-function user_exists(params)
+local function user_exists(params)
local user, host, password = nodeprep(params.user), nameprep(params.host), params.password;
storagemanager.initialize_host(host);
@@ -162,28 +156,28 @@ function user_exists(params)
if not(provider) or provider.name == "null" then
usermanager.initialize_host(host);
end
-
+
return usermanager.user_exists(user, host);
end
-function passwd(params)
- if not _M.user_exists(params) then
+local function passwd(params)
+ if not user_exists(params) then
return false, "no-such-user";
end
-
- return _M.adduser(params);
+
+ return adduser(params);
end
-function deluser(params)
- if not _M.user_exists(params) then
+local function deluser(params)
+ if not user_exists(params) then
return false, "no-such-user";
end
local user, host = nodeprep(params.user), nameprep(params.host);
-
+
return usermanager.delete_user(user, host);
end
-function getpid()
+local function getpid()
local pidfile = config.get("*", "pidfile");
if not pidfile then
return false, "no-pidfile";
@@ -192,35 +186,35 @@ function getpid()
if type(pidfile) ~= "string" then
return false, "invalid-pidfile";
end
-
- local modules_enabled = set.new(config.get("*", "modules_enabled"));
- if not modules_enabled:contains("posix") then
+
+ local modules_enabled = set.new(config.get("*", "modules_disabled"));
+ if prosody.platform ~= "posix" or modules_enabled:contains("posix") then
return false, "no-posix";
end
-
+
local file, err = io.open(pidfile, "r+");
if not file then
return false, "pidfile-read-failed", err;
end
-
+
local locked, err = lfs.lock(file, "w");
if locked then
file:close();
return false, "pidfile-not-locked";
end
-
+
local pid = tonumber(file:read("*a"));
file:close();
-
+
if not pid then
return false, "invalid-pid";
end
-
+
return true, pid;
end
-function isrunning()
- local ok, pid, err = _M.getpid();
+local function isrunning()
+ local ok, pid, err = getpid();
if not ok then
if pid == "pidfile-read-failed" or pid == "pidfile-not-locked" then
-- Report as not running, since we can't open the pidfile
@@ -232,8 +226,8 @@ function isrunning()
return true, signal.kill(pid, 0) == 0;
end
-function start()
- local ok, ret = _M.isrunning();
+local function start()
+ local ok, ret = isrunning();
if not ok then
return ok, ret;
end
@@ -248,36 +242,55 @@ function start()
return true;
end
-function stop()
- local ok, ret = _M.isrunning();
+local function stop()
+ local ok, ret = isrunning();
if not ok then
return ok, ret;
end
if not ret then
return false, "not-running";
end
-
- local ok, pid = _M.getpid()
+
+ local ok, pid = getpid()
if not ok then return false, pid; end
-
+
signal.kill(pid, signal.SIGTERM);
return true;
end
-function reload()
- local ok, ret = _M.isrunning();
+local function reload()
+ local ok, ret = isrunning();
if not ok then
return ok, ret;
end
if not ret then
return false, "not-running";
end
-
- local ok, pid = _M.getpid()
+
+ local ok, pid = getpid()
if not ok then return false, pid; end
-
+
signal.kill(pid, signal.SIGHUP);
return true;
end
-return _M;
+return {
+ show_message = show_message;
+ show_warning = show_message;
+ show_usage = show_usage;
+ getchar = getchar;
+ getline = getline;
+ getpass = getpass;
+ show_yesno = show_yesno;
+ read_password = read_password;
+ show_prompt = show_prompt;
+ adduser = adduser;
+ user_exists = user_exists;
+ passwd = passwd;
+ deluser = deluser;
+ getpid = getpid;
+ isrunning = isrunning;
+ start = start;
+ stop = stop;
+ reload = reload;
+};
diff --git a/util/pubsub.lua b/util/pubsub.lua
index e1418c62..6d12690a 100644
--- a/util/pubsub.lua
+++ b/util/pubsub.lua
@@ -1,23 +1,27 @@
local events = require "util.events";
-
-module("pubsub", package.seeall);
+local t_remove = table.remove;
local service = {};
local service_mt = { __index = service };
-local default_config = {
+local default_config = { __index = {
broadcaster = function () end;
get_affiliation = function () end;
capabilities = {};
-};
+} };
+local default_node_config = { __index = {
+ ["pubsub#max_items"] = "20";
+} };
-function new(config)
+local function new(config)
config = config or {};
return setmetatable({
- config = setmetatable(config, { __index = default_config });
+ config = setmetatable(config, default_config);
+ node_defaults = setmetatable(config.node_defaults or {}, default_node_config);
affiliations = {};
subscriptions = {};
nodes = {};
+ data = {};
events = events.new();
}, service_mt);
end
@@ -29,13 +33,13 @@ end
function service:may(node, actor, action)
if actor == true then return true; end
-
+
local node_obj = self.nodes[node];
local node_aff = node_obj and node_obj.affiliations[actor];
local service_aff = self.affiliations[actor]
or self.config.get_affiliation(actor, node, action)
or "none";
-
+
-- Check if node allows/forbids it
local node_capabilities = node_obj and node_obj.capabilities;
if node_capabilities then
@@ -47,7 +51,7 @@ function service:may(node, actor, action)
end
end
end
-
+
-- Check service-wide capabilities instead
local service_capabilities = self.config.capabilities;
local caps = service_capabilities[node_aff or service_aff];
@@ -57,7 +61,7 @@ function service:may(node, actor, action)
return can;
end
end
-
+
return false;
end
@@ -202,7 +206,7 @@ function service:get_subscription(node, actor, jid)
return true, node_obj.subscribers[jid];
end
-function service:create(node, actor)
+function service:create(node, actor, options)
-- Access checking
if not self:may(node, actor, "create") then
return false, "forbidden";
@@ -211,17 +215,20 @@ function service:create(node, actor)
if self.nodes[node] then
return false, "conflict";
end
-
+
+ self.data[node] = {};
self.nodes[node] = {
name = node;
subscribers = {};
- config = {};
- data = {};
+ config = setmetatable(options or {}, {__index=self.node_defaults});
affiliations = {};
};
+ setmetatable(self.nodes[node], { __index = { data = self.data[node] } }); -- COMPAT
+ self.events.fire_event("node-created", { node = node, actor = actor });
local ok, err = self:set_affiliation(node, true, actor, "owner");
if not ok then
self.nodes[node] = nil;
+ self.data[node] = nil;
end
return ok, err;
end
@@ -237,10 +244,31 @@ function service:delete(node, actor)
return false, "item-not-found";
end
self.nodes[node] = nil;
+ self.data[node] = nil;
+ self.events.fire_event("node-deleted", { node = node, actor = actor });
self.config.broadcaster("delete", node, node_obj.subscribers);
return true;
end
+local function remove_item_by_id(data, id)
+ if not data[id] then return end
+ data[id] = nil;
+ for i, _id in ipairs(data) do
+ if id == _id then
+ t_remove(data, i);
+ return i;
+ end
+ end
+end
+
+local function trim_items(data, max)
+ max = tonumber(max);
+ if not max or #data <= max then return end
+ repeat
+ data[t_remove(data, 1)] = nil;
+ until #data <= max
+end
+
function service:publish(node, actor, id, item)
-- Access checking
if not self:may(node, actor, "publish") then
@@ -258,9 +286,13 @@ function service:publish(node, actor, id, item)
end
node_obj = self.nodes[node];
end
- node_obj.data[id] = item;
+ local node_data = self.data[node];
+ remove_item_by_id(node_data, id);
+ node_data[#node_data + 1] = id;
+ node_data[id] = item;
+ trim_items(node_data, node_obj.config["pubsub#max_items"]);
self.events.fire_event("item-published", { node = node, actor = actor, id = id, item = item });
- self.config.broadcaster("items", node, node_obj.subscribers, item);
+ self.config.broadcaster("items", node, node_obj.subscribers, item, actor);
return true;
end
@@ -271,10 +303,11 @@ function service:retract(node, actor, id, retract)
end
--
local node_obj = self.nodes[node];
- if (not node_obj) or (not node_obj.data[id]) then
+ if (not node_obj) or (not self.data[node][id]) then
return false, "item-not-found";
end
- node_obj.data[id] = nil;
+ self.events.fire_event("item-retracted", { node = node, actor = actor, id = id });
+ remove_item_by_id(self.data[node], id);
if retract then
self.config.broadcaster("items", node, node_obj.subscribers, retract);
end
@@ -291,7 +324,8 @@ function service:purge(node, actor, notify)
if not node_obj then
return false, "item-not-found";
end
- node_obj.data = {}; -- Purge
+ self.data[node] = {}; -- Purge
+ self.events.fire_event("node-purged", { node = node, actor = actor });
if notify then
self.config.broadcaster("purge", node, node_obj.subscribers);
end
@@ -309,9 +343,9 @@ function service:get_items(node, actor, id)
return false, "item-not-found";
end
if id then -- Restrict results to a single specific item
- return true, { [id] = node_obj.data[id] };
+ return true, { id, [id] = self.data[node][id] };
else
- return true, node_obj.data;
+ return true, self.data[node];
end
end
@@ -388,4 +422,24 @@ function service:set_node_capabilities(node, actor, capabilities)
return true;
end
-return _M;
+function service:set_node_config(node, actor, new_config)
+ if not self:may(node, actor, "configure") then
+ return false, "forbidden";
+ end
+
+ local node_obj = self.nodes[node];
+ if not node_obj then
+ return false, "item-not-found";
+ end
+
+ for k,v in pairs(new_config) do
+ node_obj.config[k] = v;
+ end
+ trim_items(self.data[node], node_obj.config["pubsub#max_items"]);
+
+ return true;
+end
+
+return {
+ new = new;
+};
diff --git a/util/queue.lua b/util/queue.lua
new file mode 100644
index 00000000..728e905f
--- /dev/null
+++ b/util/queue.lua
@@ -0,0 +1,73 @@
+-- Prosody IM
+-- Copyright (C) 2008-2015 Matthew Wild
+-- Copyright (C) 2008-2015 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+-- Small ringbuffer library (i.e. an efficient FIFO queue with a size limit)
+-- (because unbounded dynamically-growing queues are a bad thing...)
+
+local have_utable, utable = pcall(require, "util.table"); -- For pre-allocation of table
+
+local function new(size, allow_wrapping)
+ -- Head is next insert, tail is next read
+ local head, tail = 1, 1;
+ local items = 0; -- Number of stored items
+ local t = have_utable and utable.create(size, 0) or {}; -- Table to hold items
+ --luacheck: ignore 212/self
+ return {
+ _items = t;
+ size = size;
+ count = function (self) return items; end;
+ push = function (self, item)
+ if items >= size then
+ if allow_wrapping then
+ tail = (tail%size)+1; -- Advance to next oldest item
+ items = items - 1;
+ else
+ return nil, "queue full";
+ end
+ end
+ t[head] = item;
+ items = items + 1;
+ head = (head%size)+1;
+ return true;
+ end;
+ pop = function (self)
+ if items == 0 then
+ return nil;
+ end
+ local item;
+ item, t[tail] = t[tail], 0;
+ tail = (tail%size)+1;
+ items = items - 1;
+ return item;
+ end;
+ peek = function (self)
+ if items == 0 then
+ return nil;
+ end
+ return t[tail];
+ end;
+ items = function (self)
+ --luacheck: ignore 431/t
+ return function (t, pos)
+ if pos >= t:count() then
+ return nil;
+ end
+ local read_pos = tail + pos;
+ if read_pos > t.size then
+ read_pos = (read_pos%size);
+ end
+ return pos+1, t._items[read_pos];
+ end, self, 0;
+ end;
+ };
+end
+
+return {
+ new = new;
+};
+
diff --git a/util/random.lua b/util/random.lua
new file mode 100644
index 00000000..4963e98c
--- /dev/null
+++ b/util/random.lua
@@ -0,0 +1,23 @@
+-- Prosody IM
+-- Copyright (C) 2008-2014 Matthew Wild
+-- Copyright (C) 2008-2014 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local urandom = assert(io.open("/dev/urandom", "r+"));
+
+local function seed(x)
+ urandom:write(x);
+ urandom:flush();
+end
+
+local function bytes(n)
+ return urandom:read(n);
+end
+
+return {
+ seed = seed;
+ bytes = bytes;
+};
diff --git a/util/sasl.lua b/util/sasl.lua
index afb3861b..5845f34a 100644
--- a/util/sasl.lua
+++ b/util/sasl.lua
@@ -19,7 +19,7 @@ local setmetatable = setmetatable;
local assert = assert;
local require = require;
-module "sasl"
+local _ENV = nil;
--[[
Authentication Backend Prototypes:
@@ -27,19 +27,38 @@ Authentication Backend Prototypes:
state = false : disabled
state = true : enabled
state = nil : non-existant
+
+Channel Binding:
+
+To enable support of channel binding in some mechanisms you need to provide appropriate callbacks in a table
+at profile.cb.
+
+Example:
+ profile.cb["tls-unique"] = function(self)
+ return self.user
+ end
+
]]
local method = {};
method.__index = method;
local mechanisms = {};
local backend_mechanism = {};
+local mechanism_channelbindings = {};
-- register a new SASL mechanims
-function registerMechanism(name, backends, f)
+local function registerMechanism(name, backends, f, cb_backends)
assert(type(name) == "string", "Parameter name MUST be a string.");
assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table.");
assert(type(f) == "function", "Parameter f MUST be a function.");
+ if cb_backends then assert(type(cb_backends) == "table"); end
mechanisms[name] = f
+ if cb_backends then
+ mechanism_channelbindings[name] = {};
+ for _, cb_name in ipairs(cb_backends) do
+ mechanism_channelbindings[name][cb_name] = true;
+ end
+ end
for _, backend_name in ipairs(backends) do
if backend_mechanism[backend_name] == nil then backend_mechanism[backend_name] = {}; end
t_insert(backend_mechanism[backend_name], name);
@@ -47,7 +66,7 @@ function registerMechanism(name, backends, f)
end
-- create a new SASL object which can be used to authenticate clients
-function new(realm, profile)
+local function new(realm, profile)
local mechanisms = profile.mechanisms;
if not mechanisms then
mechanisms = {};
@@ -63,6 +82,15 @@ function new(realm, profile)
return setmetatable({ profile = profile, realm = realm, mechs = mechanisms }, method);
end
+-- add a channel binding handler
+function method:add_cb_handler(name, f)
+ if type(self.profile.cb) ~= "table" then
+ self.profile.cb = {};
+ end
+ self.profile.cb[name] = f;
+ return self;
+end
+
-- get a fresh clone with the same realm and profile
function method:clean_clone()
return new(self.realm, self.profile)
@@ -70,7 +98,23 @@ end
-- get a list of possible SASL mechanims to use
function method:mechanisms()
- return self.mechs;
+ local current_mechs = {};
+ for mech, _ in pairs(self.mechs) do
+ if mechanism_channelbindings[mech] then
+ if self.profile.cb then
+ local ok = false;
+ for cb_name, _ in pairs(self.profile.cb) do
+ if mechanism_channelbindings[mech][cb_name] then
+ ok = true;
+ end
+ end
+ if ok == true then current_mechs[mech] = true; end
+ end
+ else
+ current_mechs[mech] = true;
+ end
+ end
+ return current_mechs;
end
-- select a mechanism to use
@@ -92,5 +136,9 @@ require "util.sasl.plain" .init(registerMechanism);
require "util.sasl.digest-md5".init(registerMechanism);
require "util.sasl.anonymous" .init(registerMechanism);
require "util.sasl.scram" .init(registerMechanism);
+require "util.sasl.external" .init(registerMechanism);
-return _M;
+return {
+ registerMechanism = registerMechanism;
+ new = new;
+};
diff --git a/util/sasl/anonymous.lua b/util/sasl/anonymous.lua
index ca5fe404..af05c0e7 100644
--- a/util/sasl/anonymous.lua
+++ b/util/sasl/anonymous.lua
@@ -16,7 +16,7 @@ local s_match = string.match;
local log = require "util.logger".init("sasl");
local generate_uuid = require "util.uuid".generate;
-module "sasl.anonymous"
+local _ENV = nil;
--=========================
--SASL ANONYMOUS according to RFC 4505
@@ -39,8 +39,10 @@ local function anonymous(self, message)
return "success"
end
-function init(registerMechanism)
+local function init(registerMechanism)
registerMechanism("ANONYMOUS", {"anonymous"}, anonymous);
end
-return _M;
+return {
+ init = init;
+}
diff --git a/util/sasl/digest-md5.lua b/util/sasl/digest-md5.lua
index 591d8537..695dd2a3 100644
--- a/util/sasl/digest-md5.lua
+++ b/util/sasl/digest-md5.lua
@@ -25,7 +25,7 @@ local log = require "util.logger".init("sasl");
local generate_uuid = require "util.uuid".generate;
local nodeprep = require "util.encodings".stringprep.nodeprep;
-module "sasl.digest-md5"
+local _ENV = nil;
--=========================
--SASL DIGEST-MD5 according to RFC 2831
@@ -241,8 +241,10 @@ local function digest(self, message)
end
end
-function init(registerMechanism)
+local function init(registerMechanism)
registerMechanism("DIGEST-MD5", {"plain"}, digest);
end
-return _M;
+return {
+ init = init;
+}
diff --git a/util/sasl/external.lua b/util/sasl/external.lua
new file mode 100644
index 00000000..5ba90190
--- /dev/null
+++ b/util/sasl/external.lua
@@ -0,0 +1,27 @@
+local saslprep = require "util.encodings".stringprep.saslprep;
+
+local _ENV = nil;
+
+local function external(self, message)
+ message = saslprep(message);
+ local state
+ self.username, state = self.profile.external(message);
+
+ if state == false then
+ return "failure", "account-disabled";
+ elseif state == nil then
+ return "failure", "not-authorized";
+ elseif state == "expired" then
+ return "false", "credentials-expired";
+ end
+
+ return "success";
+end
+
+local function init(registerMechanism)
+ registerMechanism("EXTERNAL", {"external"}, external);
+end
+
+return {
+ init = init;
+}
diff --git a/util/sasl/plain.lua b/util/sasl/plain.lua
index c9ec2911..26e65335 100644
--- a/util/sasl/plain.lua
+++ b/util/sasl/plain.lua
@@ -16,7 +16,7 @@ local saslprep = require "util.encodings".stringprep.saslprep;
local nodeprep = require "util.encodings".stringprep.nodeprep;
local log = require "util.logger".init("sasl");
-module "sasl.plain"
+local _ENV = nil;
-- ================================
-- SASL PLAIN according to RFC 4616
@@ -82,8 +82,10 @@ local function plain(self, message)
return "success";
end
-function init(registerMechanism)
+local function init(registerMechanism)
registerMechanism("PLAIN", {"plain", "plain_test"}, plain);
end
-return _M;
+return {
+ init = init;
+}
diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua
index cf2f0ede..a1b5117a 100644
--- a/util/sasl/scram.lua
+++ b/util/sasl/scram.lua
@@ -13,7 +13,6 @@
local s_match = string.match;
local type = type
-local string = string
local base64 = require "util.encodings".base64;
local hmac_sha1 = require "util.hashes".hmac_sha1;
local sha1 = require "util.hashes".sha1;
@@ -26,7 +25,7 @@ local t_concat = table.concat;
local char = string.char;
local byte = string.byte;
-module "sasl.scram"
+local _ENV = nil;
--=========================
--SASL SCRAM-SHA-1 according to RFC 5802
@@ -39,18 +38,14 @@ scram_{MECH}:
function(username, realm)
return stored_key, server_key, iteration_count, salt, state;
end
+
+Supported Channel Binding Backends
+
+'tls-unique' according to RFC 5929
]]
local default_i = 4096
-local function bp( b )
- local result = ""
- for i=1, b:len() do
- result = result.."\\"..b:byte(i)
- end
- return result
-end
-
local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;};
local result = {};
@@ -73,11 +68,11 @@ local function validate_username(username, _nodeprep)
return false
end
end
-
+
-- replace =2C with , and =3D with =
username = username:gsub("=2C", ",");
username = username:gsub("=3D", "=");
-
+
-- apply SASLprep
username = saslprep(username);
@@ -92,7 +87,7 @@ local function hashprep(hashname)
return hashname:lower():gsub("-", "_");
end
-function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
+local function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
return false, "inappropriate argument types"
end
@@ -106,96 +101,131 @@ function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
end
local function scram_gen(hash_name, H_f, HMAC_f)
+ local profile_name = "scram_" .. hashprep(hash_name);
local function scram_hash(self, message)
- if not self.state then self["state"] = {} end
-
+ local support_channel_binding = false;
+ if self.profile.cb then support_channel_binding = true; end
+
if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end
- if not self.state.name then
+ local state = self.state;
+ if not state then
-- we are processing client_first_message
local client_first_message = message;
-
+
-- TODO: fail if authzid is provided, since we don't support them yet
- self.state["client_first_message"] = client_first_message;
- self.state["gs2_cbind_flag"], self.state["authzid"], self.state["name"], self.state["clientnonce"]
- = client_first_message:match("^(%a),(.*),n=(.*),r=([^,]*).*");
+ local gs2_header, gs2_cbind_flag, gs2_cbind_name, authzid, client_first_message_bare, username, clientnonce
+ = s_match(client_first_message, "^(([pny])=?([^,]*),([^,]*),)(m?=?[^,]*,?n=([^,]*),r=([^,]*),?.*)$");
- -- we don't do any channel binding yet
- if self.state.gs2_cbind_flag ~= "n" and self.state.gs2_cbind_flag ~= "y" then
+ if not gs2_cbind_flag then
return "failure", "malformed-request";
end
- if not self.state.name or not self.state.clientnonce then
- return "failure", "malformed-request", "Channel binding isn't support at this time.";
+ if support_channel_binding and gs2_cbind_flag == "y" then
+ -- "y" -> client does support channel binding
+ -- but thinks the server does not.
+ return "failure", "malformed-request";
+ end
+
+ if gs2_cbind_flag == "n" then
+ -- "n" -> client doesn't support channel binding.
+ support_channel_binding = false;
end
-
- self.state.name = validate_username(self.state.name, self.profile.nodeprep);
- if not self.state.name then
+
+ if support_channel_binding and gs2_cbind_flag == "p" then
+ -- check whether we support the proposed channel binding type
+ if not self.profile.cb[gs2_cbind_name] then
+ return "failure", "malformed-request", "Proposed channel binding type isn't supported.";
+ end
+ else
+ -- no channel binding,
+ gs2_cbind_name = nil;
+ end
+
+ username = validate_username(username, self.profile.nodeprep);
+ if not username then
log("debug", "Username violates either SASLprep or contains forbidden character sequences.")
return "failure", "malformed-request", "Invalid username.";
end
-
- self.state["servernonce"] = generate_uuid();
-
+
-- retreive credentials
+ local stored_key, server_key, salt, iteration_count;
if self.profile.plain then
- local password, state = self.profile.plain(self, self.state.name, self.realm)
+ local password, state = self.profile.plain(self, username, self.realm)
if state == nil then return "failure", "not-authorized"
elseif state == false then return "failure", "account-disabled" end
-
+
password = saslprep(password);
if not password then
log("debug", "Password violates SASLprep.");
return "failure", "not-authorized", "Invalid password."
end
- self.state.salt = generate_uuid();
- self.state.iteration_count = default_i;
+ salt = generate_uuid();
+ iteration_count = default_i;
local succ = false;
- succ, self.state.stored_key, self.state.server_key = getAuthenticationDatabaseSHA1(password, self.state.salt, default_i, self.state.iteration_count);
+ succ, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count);
if not succ then
- log("error", "Generating authentication database failed. Reason: %s", self.state.stored_key);
+ log("error", "Generating authentication database failed. Reason: %s", stored_key);
return "failure", "temporary-auth-failure";
end
- elseif self.profile["scram_"..hashprep(hash_name)] then
- local stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, self.state.name, self.realm);
+ elseif self.profile[profile_name] then
+ local state;
+ stored_key, server_key, iteration_count, salt, state = self.profile[profile_name](self, username, self.realm);
if state == nil then return "failure", "not-authorized"
elseif state == false then return "failure", "account-disabled" end
-
- self.state.stored_key = stored_key;
- self.state.server_key = server_key;
- self.state.iteration_count = iteration_count;
- self.state.salt = salt
end
-
- local server_first_message = "r="..self.state.clientnonce..self.state.servernonce..",s="..base64.encode(self.state.salt)..",i="..self.state.iteration_count;
- self.state["server_first_message"] = server_first_message;
+
+ local nonce = clientnonce .. generate_uuid();
+ local server_first_message = "r="..nonce..",s="..base64.encode(salt)..",i="..iteration_count;
+ self.state = {
+ gs2_header = gs2_header;
+ gs2_cbind_name = gs2_cbind_name;
+ username = username;
+ nonce = nonce;
+
+ server_key = server_key;
+ stored_key = stored_key;
+ client_first_message_bare = client_first_message_bare;
+ server_first_message = server_first_message;
+ }
return "challenge", server_first_message
else
-- we are processing client_final_message
local client_final_message = message;
-
- self.state["channelbinding"], self.state["nonce"], self.state["proof"] = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)");
-
- if not self.state.proof or not self.state.nonce or not self.state.channelbinding then
+
+ local client_final_message_without_proof, channelbinding, nonce, proof
+ = s_match(client_final_message, "(c=([^,]*),r=([^,]*),?.-),p=(.*)$");
+
+ if not proof or not nonce or not channelbinding then
return "failure", "malformed-request", "Missing an attribute(p, r or c) in SASL message.";
end
- if self.state.nonce ~= self.state.clientnonce..self.state.servernonce then
+ local client_gs2_header = base64.decode(channelbinding)
+ local our_client_gs2_header = state["gs2_header"]
+ if state.gs2_cbind_name then
+ -- we support channelbinding, so check if the value is valid
+ our_client_gs2_header = our_client_gs2_header .. self.profile.cb[state.gs2_cbind_name](self);
+ end
+ if client_gs2_header ~= our_client_gs2_header then
+ return "failure", "malformed-request", "Invalid channel binding value.";
+ end
+
+ if nonce ~= state.nonce then
return "failure", "malformed-request", "Wrong nonce in client-final-message.";
end
-
- local ServerKey = self.state.server_key;
- local StoredKey = self.state.stored_key;
-
- local AuthMessage = "n=" .. s_match(self.state.client_first_message,"n=(.+)") .. "," .. self.state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+")
+
+ local ServerKey = state.server_key;
+ local StoredKey = state.stored_key;
+
+ local AuthMessage = state.client_first_message_bare .. "," .. state.server_first_message .. "," .. client_final_message_without_proof
local ClientSignature = HMAC_f(StoredKey, AuthMessage)
- local ClientKey = binaryXOR(ClientSignature, base64.decode(self.state.proof))
+ local ClientKey = binaryXOR(ClientSignature, base64.decode(proof))
local ServerSignature = HMAC_f(ServerKey, AuthMessage)
if StoredKey == H_f(ClientKey) then
local server_final_message = "v="..base64.encode(ServerSignature);
- self["username"] = self.state.name;
+ self["username"] = state.username;
return "success", server_final_message;
else
return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated.";
@@ -205,12 +235,18 @@ local function scram_gen(hash_name, H_f, HMAC_f)
return scram_hash;
end
-function init(registerMechanism)
+local function init(registerMechanism)
local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash));
+
+ -- register channel binding equivalent
+ registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
end
registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
end
-return _M;
+return {
+ getAuthenticationDatabaseSHA1 = getAuthenticationDatabaseSHA1;
+ init = init;
+}
diff --git a/util/sasl_cyrus.lua b/util/sasl_cyrus.lua
index 19684587..4e9a4af5 100644
--- a/util/sasl_cyrus.lua
+++ b/util/sasl_cyrus.lua
@@ -60,7 +60,7 @@ local sasl_errstring = {
};
setmetatable(sasl_errstring, { __index = function() return "undefined error!" end });
-module "sasl_cyrus"
+local _ENV = nil;
local method = {};
method.__index = method;
@@ -78,11 +78,11 @@ local function init(service_name)
end
-- create a new SASL object which can be used to authenticate clients
--- host_fqdn may be nil in which case gethostname() gives the value.
+-- host_fqdn may be nil in which case gethostname() gives the value.
-- For GSSAPI, this determines the hostname in the service ticket (after
-- reverse DNS canonicalization, only if [libdefaults] rdns = true which
--- is the default).
-function new(realm, service_name, app_name, host_fqdn)
+-- is the default).
+local function new(realm, service_name, app_name, host_fqdn)
init(app_name or service_name);
@@ -163,4 +163,6 @@ function method:process(message)
end
end
-return _M;
+return {
+ new = new;
+};
diff --git a/util/serialization.lua b/util/serialization.lua
index 8a259184..206f5fbb 100644
--- a/util/serialization.lua
+++ b/util/serialization.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -11,18 +11,16 @@ local type = type;
local tostring = tostring;
local t_insert = table.insert;
local t_concat = table.concat;
-local error = error;
local pairs = pairs;
local next = next;
-local loadstring = loadstring;
local pcall = pcall;
local debug_traceback = debug.traceback;
local log = require "util.logger".init("serialization");
local envload = require"util.envload".envload;
-module "serialization"
+local _ENV = nil;
local indent = function(i)
return string_rep("\t", i);
@@ -73,16 +71,16 @@ local function _simplesave(o, ind, t, func)
end
end
-function append(t, o)
+local function append(t, o)
_simplesave(o, 1, t, t.write or t_insert);
return t;
end
-function serialize(o)
+local function serialize(o)
return t_concat(append({}, o));
end
-function deserialize(str)
+local function deserialize(str)
if type(str) ~= "string" then return nil; end
str = "return "..str;
local f, err = envload(str, "@data", {});
@@ -92,4 +90,8 @@ function deserialize(str)
return ret;
end
-return _M;
+return {
+ append = append;
+ serialize = serialize;
+ deserialize = deserialize;
+};
diff --git a/util/session.lua b/util/session.lua
new file mode 100644
index 00000000..fda7fccd
--- /dev/null
+++ b/util/session.lua
@@ -0,0 +1,65 @@
+local initialize_filters = require "util.filters".initialize;
+local logger = require "util.logger";
+
+local function new_session(typ)
+ local session = {
+ type = typ .. "_unauthed";
+ };
+ return session;
+end
+
+local function set_id(session)
+ local id = typ .. tostring(session):match("%x+$"):lower();
+ session.id = id;
+ return session;
+end
+
+local function set_logger(session)
+ local log = logger.init(id);
+ session.log = log;
+ return session;
+end
+
+local function set_conn(session, conn)
+ session.conn = conn;
+ session.ip = conn:ip();
+ return session;
+end
+
+local function set_send(session)
+ local conn = session.conn;
+ if not conn then
+ function session.send(data)
+ log("debug", "Discarding data sent to unconnected session: %s", tostring(data));
+ return false;
+ end
+ return session;
+ end
+ local filter = initialize_filters(session);
+ local w = conn.write;
+ session.send = function (t)
+ if t.name then
+ t = filter("stanzas/out", t);
+ end
+ if t then
+ t = filter("bytes/out", tostring(t));
+ if t then
+ local ret, err = w(conn, t);
+ if not ret then
+ session.log("debug", "Error writing to connection: %s", tostring(err));
+ return false, err;
+ end
+ end
+ end
+ return true;
+ end
+ return session;
+end
+
+return {
+ new = new_session;
+ set_id = set_id;
+ set_logger = set_logger;
+ set_conn = set_conn;
+ set_send = set_send;
+}
diff --git a/util/set.lua b/util/set.lua
index fa065a9c..c136a522 100644
--- a/util/set.lua
+++ b/util/set.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -10,86 +10,49 @@ local ipairs, pairs, setmetatable, next, tostring =
ipairs, pairs, setmetatable, next, tostring;
local t_concat = table.concat;
-module "set"
+local _ENV = nil;
local set_mt = {};
function set_mt.__call(set, _, k)
return next(set._items, k);
end
-function set_mt.__add(set1, set2)
- return _M.union(set1, set2);
-end
-function set_mt.__sub(set1, set2)
- return _M.difference(set1, set2);
-end
-function set_mt.__div(set, func)
- local new_set = _M.new();
- local items, new_items = set._items, new_set._items;
- for item in pairs(items) do
- local new_item = func(item);
- if new_item ~= nil then
- new_items[new_item] = true;
- end
- end
- return new_set;
-end
-function set_mt.__eq(set1, set2)
- local set1, set2 = set1._items, set2._items;
- for item in pairs(set1) do
- if not set2[item] then
- return false;
- end
- end
-
- for item in pairs(set2) do
- if not set1[item] then
- return false;
- end
- end
-
- return true;
-end
-function set_mt.__tostring(set)
- local s, items = { }, set._items;
- for item in pairs(items) do
- s[#s+1] = tostring(item);
- end
- return t_concat(s, ", ");
-end
local items_mt = {};
function items_mt.__call(items, _, k)
return next(items, k);
end
-function new(list)
+local function new(list)
local items = setmetatable({}, items_mt);
local set = { _items = items };
-
+
+ -- We access the set through an upvalue in these methods, so ignore 'self' being unused
+ --luacheck: ignore 212/self
+
function set:add(item)
items[item] = true;
end
-
+
function set:contains(item)
return items[item];
end
-
+
function set:items()
- return items;
+ return next, items;
end
-
+
function set:remove(item)
items[item] = nil;
end
-
- function set:add_list(list)
- if list then
- for _, item in ipairs(list) do
+
+ function set:add_list(item_list)
+ if item_list then
+ for _, item in ipairs(item_list) do
items[item] = true;
end
end
end
-
+
function set:include(otherset)
for item in otherset do
items[item] = true;
@@ -101,22 +64,22 @@ function new(list)
items[item] = nil;
end
end
-
+
function set:empty()
return not next(items);
end
-
+
if list then
set:add_list(list);
end
-
+
return setmetatable(set, set_mt);
end
-function union(set1, set2)
+local function union(set1, set2)
local set = new();
local items = set._items;
-
+
for item in pairs(set1._items) do
items[item] = true;
end
@@ -124,14 +87,14 @@ function union(set1, set2)
for item in pairs(set2._items) do
items[item] = true;
end
-
+
return set;
end
-function difference(set1, set2)
+local function difference(set1, set2)
local set = new();
local items = set._items;
-
+
for item in pairs(set1._items) do
items[item] = (not set2._items[item]) or nil;
end
@@ -139,21 +102,68 @@ function difference(set1, set2)
return set;
end
-function intersection(set1, set2)
+local function intersection(set1, set2)
local set = new();
local items = set._items;
-
+
set1, set2 = set1._items, set2._items;
-
+
for item in pairs(set1) do
items[item] = (not not set2[item]) or nil;
end
-
+
return set;
end
-function xor(set1, set2)
+local function xor(set1, set2)
return union(set1, set2) - intersection(set1, set2);
end
-return _M;
+function set_mt.__add(set1, set2)
+ return union(set1, set2);
+end
+function set_mt.__sub(set1, set2)
+ return difference(set1, set2);
+end
+function set_mt.__div(set, func)
+ local new_set = new();
+ local items, new_items = set._items, new_set._items;
+ for item in pairs(items) do
+ local new_item = func(item);
+ if new_item ~= nil then
+ new_items[new_item] = true;
+ end
+ end
+ return new_set;
+end
+function set_mt.__eq(set1, set2)
+ set1, set2 = set1._items, set2._items;
+ for item in pairs(set1) do
+ if not set2[item] then
+ return false;
+ end
+ end
+
+ for item in pairs(set2) do
+ if not set1[item] then
+ return false;
+ end
+ end
+
+ return true;
+end
+function set_mt.__tostring(set)
+ local s, items = { }, set._items;
+ for item in pairs(items) do
+ s[#s+1] = tostring(item);
+ end
+ return t_concat(s, ", ");
+end
+
+return {
+ new = new;
+ union = union;
+ difference = difference;
+ intersection = intersection;
+ xor = xor;
+};
diff --git a/util/sql.lua b/util/sql.lua
index f360d6d0..2d5e1774 100644
--- a/util/sql.lua
+++ b/util/sql.lua
@@ -13,7 +13,7 @@ local DBI = require "DBI";
DBI.Drivers();
local build_url = require "socket.url".build;
-module("sql")
+local _ENV = nil;
local column_mt = {};
local table_mt = {};
@@ -21,42 +21,17 @@ local query_mt = {};
--local op_mt = {};
local index_mt = {};
-function is_column(x) return getmetatable(x)==column_mt; end
-function is_index(x) return getmetatable(x)==index_mt; end
-function is_table(x) return getmetatable(x)==table_mt; end
-function is_query(x) return getmetatable(x)==query_mt; end
---function is_op(x) return getmetatable(x)==op_mt; end
---function expr(...) return setmetatable({...}, op_mt); end
-function Integer(n) return "Integer()" end
-function String(n) return "String()" end
+local function is_column(x) return getmetatable(x)==column_mt; end
+local function is_index(x) return getmetatable(x)==index_mt; end
+local function is_table(x) return getmetatable(x)==table_mt; end
+local function is_query(x) return getmetatable(x)==query_mt; end
+local function Integer(n) return "Integer()" end
+local function String(n) return "String()" end
---[[local ops = {
- __add = function(a, b) return "("..a.."+"..b..")" end;
- __sub = function(a, b) return "("..a.."-"..b..")" end;
- __mul = function(a, b) return "("..a.."*"..b..")" end;
- __div = function(a, b) return "("..a.."/"..b..")" end;
- __mod = function(a, b) return "("..a.."%"..b..")" end;
- __pow = function(a, b) return "POW("..a..","..b..")" end;
- __unm = function(a) return "NOT("..a..")" end;
- __len = function(a) return "COUNT("..a..")" end;
- __eq = function(a, b) return "("..a.."=="..b..")" end;
- __lt = function(a, b) return "("..a.."<"..b..")" end;
- __le = function(a, b) return "("..a.."<="..b..")" end;
-};
-
-local functions = {
-
-};
-
-local cmap = {
- [Integer] = Integer();
- [String] = String();
-};]]
-
-function Column(definition)
+local function Column(definition)
return setmetatable(definition, column_mt);
end
-function Table(definition)
+local function Table(definition)
local c = {}
for i,col in ipairs(definition) do
if is_column(col) then
@@ -67,7 +42,7 @@ function Table(definition)
end
return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
end
-function Index(definition)
+local function Index(definition)
return setmetatable(definition, index_mt);
end
@@ -94,7 +69,6 @@ function index_mt:__tostring()
return s..' }';
-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
end
---
local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
local function parse_url(url)
@@ -121,32 +95,13 @@ local function parse_url(url)
};
end
---[[local session = {};
-
-function session.query(...)
- local rets = {...};
- local query = setmetatable({ __rets = rets, __filters }, query_mt);
- return query;
-end
---
-
-local function db2uri(params)
- return build_url{
- scheme = params.driver,
- user = params.username,
- password = params.password,
- host = params.host,
- port = params.port,
- path = params.database,
- };
-end]]
-
local engine = {};
function engine:connect()
if self.conn then return true; end
local params = self.params;
assert(params.driver, "no driver")
+ log("debug", "Connecting to [%s] %s...", params.driver, params.database);
local dbh, err = DBI.Connect(
params.driver, params.database,
params.username, params.password,
@@ -156,8 +111,19 @@ function engine:connect()
dbh:autocommit(false); -- don't commit automatically
self.conn = dbh;
self.prepared = {};
+ local ok, err = self:set_encoding();
+ if not ok then
+ return ok, err;
+ end
+ local ok, err = self:onconnect();
+ if ok == false then
+ return ok, err;
+ end
return true;
end
+function engine:onconnect()
+ -- Override from create_engine()
+end
function engine:execute(sql, ...)
local success, err = self:connect();
if not success then return success, err; end
@@ -177,8 +143,8 @@ function engine:execute(sql, ...)
end
local result_mt = { __index = {
- affected = function(self) return self.__affected; end;
- rowcount = function(self) return self.__rowcount; end;
+ affected = function(self) return self.__stmt:affected(); end;
+ rowcount = function(self) return self.__stmt:rowcount(); end;
} };
function engine:execute_query(sql, ...)
@@ -200,7 +166,7 @@ function engine:execute_update(sql, ...)
prepared[sql] = stmt;
end
assert(stmt:execute(...));
- return setmetatable({ __affected = stmt:affected(), __rowcount = stmt:rowcount() }, result_mt);
+ return setmetatable({ __stmt = stmt }, result_mt);
end
engine.insert = engine.execute_update;
engine.select = engine.execute_query;
@@ -208,12 +174,13 @@ engine.delete = engine.execute_update;
engine.update = engine.execute_update;
function engine:_transaction(func, ...)
if not self.conn then
- local a,b = self:connect();
- if not a then return a,b; end
+ local ok, err = self:connect();
+ if not ok then return ok, err; end
end
--assert(not self.__transaction, "Recursive transactions not allowed");
local args, n_args = {...}, select("#", ...);
local function f() return func(unpack(args, 1, n_args)); end
+ log("debug", "SQL transaction begin [%s]", tostring(func));
self.__transaction = true;
local success, a, b, c = xpcall(f, debug_traceback);
self.__transaction = nil;
@@ -229,15 +196,15 @@ function engine:_transaction(func, ...)
end
end
function engine:transaction(...)
- local a,b = self:_transaction(...);
- if not a then
+ local ok, ret = self:_transaction(...);
+ if not ok then
local conn = self.conn;
if not conn or not conn:ping() then
self.conn = nil;
- a,b = self:_transaction(...);
+ ok, ret = self:_transaction(...);
end
end
- return a,b;
+ return ok, ret;
end
function engine:_create_index(index)
local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` (";
@@ -251,19 +218,39 @@ function engine:_create_index(index)
elseif self.params.driver == "MySQL" then
sql = sql:gsub("`([,)])", "`(20)%1");
end
+ if index.unique then
+ sql = sql:gsub("^CREATE", "CREATE UNIQUE");
+ end
--print(sql);
return self:execute(sql);
end
function engine:_create_table(table)
local sql = "CREATE TABLE `"..table.name.."` (";
for i,col in ipairs(table.c) do
- sql = sql.."`"..col.name.."` "..col.type;
+ local col_type = col.type;
+ if col_type == "MEDIUMTEXT" and self.params.driver ~= "MySQL" then
+ col_type = "TEXT"; -- MEDIUMTEXT is MySQL-specific
+ end
+ if col.auto_increment == true and self.params.driver == "PostgreSQL" then
+ col_type = "BIGSERIAL";
+ end
+ sql = sql.."`"..col.name.."` "..col_type;
if col.nullable == false then sql = sql.." NOT NULL"; end
+ if col.primary_key == true then sql = sql.." PRIMARY KEY"; end
+ if col.auto_increment == true then
+ if self.params.driver == "MySQL" then
+ sql = sql.." AUTO_INCREMENT";
+ elseif self.params.driver == "SQLite3" then
+ sql = sql.." AUTOINCREMENT";
+ end
+ end
if i ~= #table.c then sql = sql..", "; end
end
sql = sql.. ");"
if self.params.driver == "PostgreSQL" then
sql = sql:gsub("`", "\"");
+ elseif self.params.driver == "MySQL" then
+ sql = sql:gsub(";$", (" CHARACTER SET '%s' COLLATE '%s_bin';"):format(self.charset, self.charset));
end
local success,err = self:execute(sql);
if not success then return success,err; end
@@ -274,6 +261,46 @@ function engine:_create_table(table)
end
return success;
end
+function engine:set_encoding() -- to UTF-8
+ local driver = self.params.driver;
+ if driver == "SQLite3" then
+ return self:transaction(function()
+ if self:select"PRAGMA encoding;"()[1] == "UTF-8" then
+ self.charset = "utf8";
+ end
+ end);
+ end
+ local set_names_query = "SET NAMES '%s';"
+ local charset = "utf8";
+ if driver == "MySQL" then
+ local ok, charsets = self:transaction(function()
+ return self:select"SELECT `CHARACTER_SET_NAME` FROM `information_schema`.`CHARACTER_SETS` WHERE `CHARACTER_SET_NAME` LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;";
+ end);
+ local row = ok and charsets();
+ charset = row and row[1] or charset;
+ set_names_query = set_names_query:gsub(";$", (" COLLATE '%s';"):format(charset.."_bin"));
+ end
+ self.charset = charset;
+ log("debug", "Using encoding '%s' for database connection", charset);
+ local ok, err = self:transaction(function() return self:execute(set_names_query:format(charset)); end);
+ if not ok then
+ return ok, err;
+ end
+
+ if driver == "MySQL" then
+ local ok, actual_charset = self:transaction(function ()
+ return self:select"SHOW SESSION VARIABLES LIKE 'character_set_client'";
+ end);
+ for row in actual_charset do
+ if row[2] ~= charset then
+ log("error", "MySQL %s is actually %q (expected %q)", row[1], row[2], charset);
+ return false, "Failed to set connection encoding";
+ end
+ end
+ end
+
+ return true;
+end
local engine_mt = { __index = engine };
local function db2uri(params)
@@ -286,55 +313,21 @@ local function db2uri(params)
path = params.database,
};
end
-local engine_cache = {}; -- TODO make weak valued
-function create_engine(self, params)
- local url = db2uri(params);
- if not engine_cache[url] then
- local engine = setmetatable({ url = url, params = params }, engine_mt);
- engine_cache[url] = engine;
- end
- return engine_cache[url];
-end
-
-
---[[Users = Table {
- name="users";
- Column { name="user_id", type=String(), primary_key=true };
-};
-print(Users)
-print(Users.c.user_id)]]
---local engine = create_engine('postgresql://scott:tiger@localhost:5432/mydatabase');
---[[local engine = create_engine{ driver = "SQLite3", database = "./alchemy.sqlite" };
-
-local i = 0;
-for row in assert(engine:execute("select * from sqlite_master")):rows(true) do
- i = i+1;
- print(i);
- for k,v in pairs(row) do
- print("",k,v);
- end
+local function create_engine(self, params, onconnect)
+ return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt);
end
-print("---")
-Prosody = Table {
- name="prosody";
- Column { name="host", type="TEXT", nullable=false };
- Column { name="user", type="TEXT", nullable=false };
- Column { name="store", type="TEXT", nullable=false };
- Column { name="key", type="TEXT", nullable=false };
- Column { name="type", type="TEXT", nullable=false };
- Column { name="value", type="TEXT", nullable=false };
- Index { name="prosody_index", "host", "user", "store", "key" };
+return {
+ is_column = is_column;
+ is_index = is_index;
+ is_table = is_table;
+ is_query = is_query;
+ Integer = Integer;
+ String = String;
+ Column = Column;
+ Table = Table;
+ Index = Index;
+ create_engine = create_engine;
+ db2uri = db2uri;
};
---print(Prosody);
-assert(engine:transaction(function()
- assert(Prosody:create(engine));
-end));
-
-for row in assert(engine:execute("select user from prosody")):rows(true) do
- print("username:", row['username'])
-end
---result.close();]]
-
-return _M;
diff --git a/util/sslconfig.lua b/util/sslconfig.lua
new file mode 100644
index 00000000..c849aa28
--- /dev/null
+++ b/util/sslconfig.lua
@@ -0,0 +1,120 @@
+-- util to easily merge multiple sets of LuaSec context options
+
+local type = type;
+local pairs = pairs;
+local rawset = rawset;
+local t_concat = table.concat;
+local t_insert = table.insert;
+local setmetatable = setmetatable;
+
+local _ENV = nil;
+
+local handlers = { };
+local finalisers = { };
+local id = function (v) return v end
+
+-- All "handlers" behave like extended rawset(table, key, value) with extra
+-- processing usually merging the new value with the old in some reasonable
+-- way
+-- If a field does not have a defined handler then a new value simply
+-- replaces the old.
+
+
+-- Convert either a list or a set into a special type of set where each
+-- item is either positive or negative in order for a later set of options
+-- to be able to remove options from this set by filtering out the negative ones
+function handlers.options(config, field, new)
+ local options = config[field] or { };
+ if type(new) ~= "table" then new = { new } end
+ for key, value in pairs(new) do
+ if value == true or value == false then
+ options[key] = value;
+ else -- list item
+ options[value] = true;
+ end
+ end
+ config[field] = options;
+end
+
+handlers.verify = handlers.options;
+handlers.verifyext = handlers.options;
+
+-- finalisers take something produced by handlers and return what luasec
+-- expects it to be
+
+-- Produce a list of "positive" options from the set
+function finalisers.options(options)
+ local output = {};
+ for opt, enable in pairs(options) do
+ if enable then
+ output[#output+1] = opt;
+ end
+ end
+ return output;
+end
+
+finalisers.verify = finalisers.options;
+finalisers.verifyext = finalisers.options;
+
+-- We allow ciphers to be a list
+
+function finalisers.ciphers(cipherlist)
+ if type(cipherlist) == "table" then
+ return t_concat(cipherlist, ":");
+ end
+ return cipherlist;
+end
+
+-- protocol = "x" should enable only that protocol
+-- protocol = "x+" should enable x and later versions
+
+local protocols = { "sslv2", "sslv3", "tlsv1", "tlsv1_1", "tlsv1_2" };
+for i = 1, #protocols do protocols[protocols[i] .. "+"] = i - 1; end
+
+-- this interacts with ssl.options as well to add no_x
+local function protocol(config)
+ local min_protocol = protocols[config.protocol];
+ if min_protocol then
+ config.protocol = "sslv23";
+ for i = 1, min_protocol do
+ t_insert(config.options, "no_"..protocols[i]);
+ end
+ end
+end
+
+-- Merge options from 'new' config into 'config'
+local function apply(config, new)
+ if type(new) == "table" then
+ for field, value in pairs(new) do
+ (handlers[field] or rawset)(config, field, value);
+ end
+ end
+end
+
+-- Finalize the config into the form LuaSec expects
+local function final(config)
+ local output = { };
+ for field, value in pairs(config) do
+ output[field] = (finalisers[field] or id)(value);
+ end
+ -- Need to handle protocols last because it adds to the options list
+ protocol(output);
+ return output;
+end
+
+local sslopts_mt = {
+ __index = {
+ apply = apply;
+ final = final;
+ };
+};
+
+local function new()
+ return setmetatable({options={}}, sslopts_mt);
+end
+
+return {
+ apply = apply;
+ final = final;
+ new = new;
+};
diff --git a/util/stanza.lua b/util/stanza.lua
index 7c214210..90422a06 100644
--- a/util/stanza.lua
+++ b/util/stanza.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -35,13 +35,12 @@ end
local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas";
-module "stanza"
+local _ENV = nil;
-stanza_mt = { __type = "stanza" };
+local stanza_mt = { __type = "stanza" };
stanza_mt.__index = stanza_mt;
-local stanza_mt = stanza_mt;
-function stanza(name, attr)
+local function stanza(name, attr)
local stanza = { name = name, attr = attr or {}, tags = {} };
return setmetatable(stanza, stanza_mt);
end
@@ -99,7 +98,7 @@ function stanza_mt:get_child(name, xmlns)
if (not name or child.name == name)
and ((not xmlns and self.attr.xmlns == child.attr.xmlns)
or child.attr.xmlns == xmlns) then
-
+
return child;
end
end
@@ -152,7 +151,7 @@ end
function stanza_mt:maptags(callback)
local tags, curr_tag = self.tags, 1;
local n_children, n_tags = #self, #tags;
-
+
local i = 1;
while curr_tag <= n_tags and n_tags > 0 do
if self[i] == tags[curr_tag] then
@@ -200,12 +199,8 @@ function stanza_mt:find(path)
end
-local xml_escape
-do
- local escape_table = { ["'"] = "&apos;", ["\""] = "&quot;", ["<"] = "&lt;", [">"] = "&gt;", ["&"] = "&amp;" };
- function xml_escape(str) return (s_gsub(str, "['&<>\"]", escape_table)); end
- _M.xml_escape = xml_escape;
-end
+local escape_table = { ["'"] = "&apos;", ["\""] = "&quot;", ["<"] = "&lt;", [">"] = "&gt;", ["&"] = "&amp;" };
+local function xml_escape(str) return (s_gsub(str, "['&<>\"]", escape_table)); end
local function _dostring(t, buf, self, xml_escape, parentns)
local nsid = 0;
@@ -258,13 +253,13 @@ end
function stanza_mt.get_error(stanza)
local type, condition, text;
-
+
local error_tag = stanza:get_child("error");
if not error_tag then
return nil, nil, nil;
end
type = error_tag.attr.type;
-
+
for _, child in ipairs(error_tag.tags) do
if child.attr.xmlns == xmlns_stanzas then
if not text and child.name == "text" then
@@ -280,15 +275,13 @@ function stanza_mt.get_error(stanza)
return type, condition or "undefined-condition", text;
end
-do
- local id = 0;
- function new_id()
- id = id + 1;
- return "lx"..id;
- end
+local id = 0;
+local function new_id()
+ id = id + 1;
+ return "lx"..id;
end
-function preserialize(stanza)
+local function preserialize(stanza)
local s = { name = stanza.name, attr = stanza.attr };
for _, child in ipairs(stanza) do
if type(child) == "table" then
@@ -300,7 +293,7 @@ function preserialize(stanza)
return s;
end
-function deserialize(stanza)
+local function deserialize(stanza)
-- Set metatable
if stanza then
local attr = stanza.attr;
@@ -333,55 +326,52 @@ function deserialize(stanza)
stanza.tags = tags;
end
end
-
+
return stanza;
end
-local function _clone(stanza)
+local function clone(stanza)
local attr, tags = {}, {};
for k,v in pairs(stanza.attr) do attr[k] = v; end
local new = { name = stanza.name, attr = attr, tags = tags };
for i=1,#stanza do
local child = stanza[i];
if child.name then
- child = _clone(child);
+ child = clone(child);
t_insert(tags, child);
end
t_insert(new, child);
end
return setmetatable(new, stanza_mt);
end
-clone = _clone;
-function message(attr, body)
+local function message(attr, body)
if not body then
return stanza("message", attr);
else
return stanza("message", attr):tag("body"):text(body):up();
end
end
-function iq(attr)
+local function iq(attr)
if attr and not attr.id then attr.id = new_id(); end
return stanza("iq", attr or { id = new_id() });
end
-function reply(orig)
+local function reply(orig)
return stanza(orig.name, orig.attr and { to = orig.attr.from, from = orig.attr.to, id = orig.attr.id, type = ((orig.name == "iq" and "result") or orig.attr.type) });
end
-do
- local xmpp_stanzas_attr = { xmlns = xmlns_stanzas };
- function error_reply(orig, type, condition, message)
- local t = reply(orig);
- t.attr.type = "error";
- t:tag("error", {type = type}) --COMPAT: Some day xmlns:stanzas goes here
- :tag(condition, xmpp_stanzas_attr):up();
- if (message) then t:tag("text", xmpp_stanzas_attr):text(message):up(); end
- return t; -- stanza ready for adding app-specific errors
- end
+local xmpp_stanzas_attr = { xmlns = xmlns_stanzas };
+local function error_reply(orig, type, condition, message)
+ local t = reply(orig);
+ t.attr.type = "error";
+ t:tag("error", {type = type}) --COMPAT: Some day xmlns:stanzas goes here
+ :tag(condition, xmpp_stanzas_attr):up();
+ if (message) then t:tag("text", xmpp_stanzas_attr):text(message):up(); end
+ return t; -- stanza ready for adding app-specific errors
end
-function presence(attr)
+local function presence(attr)
return stanza("presence", attr);
end
@@ -390,7 +380,7 @@ if do_pretty_printing then
local style_attrv = getstyle("red");
local style_tagname = getstyle("red");
local style_punc = getstyle("magenta");
-
+
local attr_format = " "..getstring(style_attrk, "%s")..getstring(style_punc, "=")..getstring(style_attrv, "'%s'");
local top_tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">");
--local tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">").."%s"..getstring(style_punc, "</")..getstring(style_tagname, "%s")..getstring(style_punc, ">");
@@ -411,7 +401,7 @@ if do_pretty_printing then
end
return s_format(tag_format, t.name, attr_string, children_text, t.name);
end
-
+
function stanza_mt.pretty_top_tag(t)
local attr_string = "";
if t.attr then
@@ -425,4 +415,17 @@ else
stanza_mt.pretty_top_tag = stanza_mt.top_tag;
end
-return _M;
+return {
+ stanza_mt = stanza_mt;
+ stanza = stanza;
+ new_id = new_id;
+ preserialize = preserialize;
+ deserialize = deserialize;
+ clone = clone;
+ message = message;
+ iq = iq;
+ reply = reply;
+ error_reply = error_reply;
+ presence = presence;
+ xml_escape = xml_escape;
+};
diff --git a/util/statistics.lua b/util/statistics.lua
new file mode 100644
index 00000000..26355026
--- /dev/null
+++ b/util/statistics.lua
@@ -0,0 +1,160 @@
+local t_sort = table.sort
+local m_floor = math.floor;
+local time = require "socket".gettime;
+
+local function nop_function() end
+
+local function percentile(arr, length, pc)
+ local n = pc/100 * (length + 1);
+ local k, d = m_floor(n), n%1;
+ if k == 0 then
+ return arr[1] or 0;
+ elseif k >= length then
+ return arr[length];
+ end
+ return arr[k] + d*(arr[k+1] - arr[k]);
+end
+
+local function new_registry(config)
+ config = config or {};
+ local duration_sample_interval = config.duration_sample_interval or 5;
+ local duration_max_samples = config.duration_max_stored_samples or 5000;
+
+ local function get_distribution_stats(events, n_actual_events, since, new_time, units)
+ local n_stored_events = #events;
+ t_sort(events);
+ local sum = 0;
+ for i = 1, n_stored_events do
+ sum = sum + events[i];
+ end
+
+ return {
+ samples = events;
+ sample_count = n_stored_events;
+ count = n_actual_events,
+ rate = n_actual_events/(new_time-since);
+ average = n_stored_events > 0 and sum/n_stored_events or 0,
+ min = events[1] or 0,
+ max = events[n_stored_events] or 0,
+ units = units,
+ };
+ end
+
+
+ local registry = {};
+ local methods;
+ methods = {
+ amount = function (name, initial)
+ local v = initial or 0;
+ registry[name..":amount"] = function () return "amount", v; end
+ return function (new_v) v = new_v; end
+ end;
+ counter = function (name, initial)
+ local v = initial or 0;
+ registry[name..":amount"] = function () return "amount", v; end
+ return function (delta)
+ v = v + delta;
+ end;
+ end;
+ rate = function (name)
+ local since, n = time(), 0;
+ registry[name..":rate"] = function ()
+ local t = time();
+ local stats = {
+ rate = n/(t-since);
+ count = n;
+ };
+ since, n = t, 0;
+ return "rate", stats.rate, stats;
+ end;
+ return function ()
+ n = n + 1;
+ end;
+ end;
+ distribution = function (name, unit, type)
+ type = type or "distribution";
+ local events, last_event = {}, 0;
+ local n_actual_events = 0;
+ local since = time();
+
+ registry[name..":"..type] = function ()
+ local new_time = time();
+ local stats = get_distribution_stats(events, n_actual_events, since, new_time, unit);
+ events, last_event = {}, 0;
+ n_actual_events = 0;
+ since = new_time;
+ return type, stats.average, stats;
+ end;
+
+ return function (value)
+ n_actual_events = n_actual_events + 1;
+ if n_actual_events%duration_sample_interval == 1 then
+ last_event = (last_event%duration_max_samples) + 1;
+ events[last_event] = value;
+ end
+ end;
+ end;
+ sizes = function (name)
+ return methods.distribution(name, "bytes", "size");
+ end;
+ times = function (name)
+ local events, last_event = {}, 0;
+ local n_actual_events = 0;
+ local since = time();
+
+ registry[name..":duration"] = function ()
+ local new_time = time();
+ local stats = get_distribution_stats(events, n_actual_events, since, new_time, "seconds");
+ events, last_event = {}, 0;
+ n_actual_events = 0;
+ since = new_time;
+ return "duration", stats.average, stats;
+ end;
+
+ return function ()
+ n_actual_events = n_actual_events + 1;
+ if n_actual_events%duration_sample_interval ~= 1 then
+ return nop_function;
+ end
+
+ local start_time = time();
+ return function ()
+ local end_time = time();
+ local duration = end_time - start_time;
+ last_event = (last_event%duration_max_samples) + 1;
+ events[last_event] = duration;
+ end
+ end;
+ end;
+
+ get_stats = function ()
+ return registry;
+ end;
+ };
+ return methods;
+end
+
+return {
+ new = new_registry;
+ get_histogram = function (duration, n_buckets)
+ n_buckets = n_buckets or 100;
+ local events, n_events = duration.samples, duration.sample_count;
+ if not (events and n_events) then
+ return nil, "not a valid distribution stat";
+ end
+ local histogram = {};
+
+ for i = 1, 100, 100/n_buckets do
+ histogram[i] = percentile(events, n_events, i);
+ end
+ return histogram;
+ end;
+
+ get_percentile = function (duration, pc)
+ local events, n_events = duration.samples, duration.sample_count;
+ if not (events and n_events) then
+ return nil, "not a valid distribution stat";
+ end
+ return percentile(events, n_events, pc);
+ end;
+}
diff --git a/util/template.lua b/util/template.lua
index 66d4fca7..a26dd7ca 100644
--- a/util/template.lua
+++ b/util/template.lua
@@ -9,7 +9,7 @@ local debug = debug;
local t_remove = table.remove;
local parse_xml = require "util.xml".parse;
-module("template")
+local _ENV = nil;
local function trim_xml(stanza)
for i=#stanza,1,-1 do
diff --git a/util/termcolours.lua b/util/termcolours.lua
index 6ef3b689..a1c01aa5 100644
--- a/util/termcolours.lua
+++ b/util/termcolours.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -19,7 +19,7 @@ if os.getenv("WINDIR") then
end
local orig_color = windows and windows.get_consolecolor and windows.get_consolecolor();
-module "termcolours"
+local _ENV = nil;
local stylemap = {
reset = 0; bright = 1, dim = 2, underscore = 4, blink = 5, reverse = 7, hidden = 8;
@@ -45,7 +45,7 @@ local cssmap = {
};
local fmt_string = char(0x1B).."[%sm%s"..char(0x1B).."[0m";
-function getstring(style, text)
+local function getstring(style, text)
if style then
return format(fmt_string, style, text);
else
@@ -53,7 +53,7 @@ function getstring(style, text)
end
end
-function getstyle(...)
+local function getstyle(...)
local styles, result = { ... }, {};
for i, style in ipairs(styles) do
style = stylemap[style];
@@ -65,7 +65,7 @@ function getstyle(...)
end
local last = "0";
-function setstyle(style)
+local function setstyle(style)
style = style or "0";
if style ~= last then
io_write("\27["..style.."m");
@@ -95,8 +95,13 @@ local function ansi2css(ansi_codes)
return "</span><span style='"..t_concat(css, ";").."'>";
end
-function tohtml(input)
+local function tohtml(input)
return input:gsub("\027%[(.-)m", ansi2css);
end
-return _M;
+return {
+ getstring = getstring;
+ getstyle = getstyle;
+ setstyle = setstyle;
+ tohtml = tohtml;
+};
diff --git a/util/throttle.lua b/util/throttle.lua
index 55e1d07b..3d3f5d2d 100644
--- a/util/throttle.lua
+++ b/util/throttle.lua
@@ -3,7 +3,7 @@ local gettime = require "socket".gettime;
local setmetatable = setmetatable;
local floor = math.floor;
-module "throttle"
+local _ENV = nil;
local throttle = {};
local throttle_mt = { __index = throttle };
@@ -39,8 +39,10 @@ function throttle:poll(cost, split)
end
end
-function create(max, period)
+local function create(max, period)
return setmetatable({ rate = max / period, max = max, t = 0, balance = max }, throttle_mt);
end
-return _M;
+return {
+ create = create;
+};
diff --git a/util/timer.lua b/util/timer.lua
index af1e57b6..3713625d 100644
--- a/util/timer.lua
+++ b/util/timer.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -17,7 +17,7 @@ local type = type;
local data = {};
local new_data = {};
-module "timer"
+local _ENV = nil;
local _add_task;
if not server.event then
@@ -42,7 +42,7 @@ if not server.event then
end
new_data = {};
end
-
+
local next_time = math_huge;
for i, d in pairs(data) do
local t, callback = d[1], d[2];
@@ -78,6 +78,6 @@ else
end
end
-add_task = _add_task;
-
-return _M;
+return {
+ add_task = _add_task;
+};
diff --git a/util/uuid.lua b/util/uuid.lua
index 58f792fd..f4fd21f6 100644
--- a/util/uuid.lua
+++ b/util/uuid.lua
@@ -1,38 +1,32 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
-local error = error;
-local round_up = math.ceil;
-local urandom, urandom_err = io.open("/dev/urandom", "r+");
-
-module "uuid"
+local random = require "util.random";
+local random_bytes = random.bytes;
+local hex = require "util.hex".to;
+local m_ceil = math.ceil;
local function get_nibbles(n)
- local binary_random = urandom:read(round_up(n/2));
- local hex_random = binary_random:gsub(".",
- function (x) return ("%02x"):format(x:byte()) end);
- return hex_random:sub(1, n);
+ return hex(random_bytes(m_ceil(n/2))):sub(1, n);
end
+
local function get_twobits()
- return ("%x"):format(urandom:read(1):byte() % 4 + 8);
+ return ("%x"):format(random_bytes(1):byte() % 4 + 8);
end
-function generate()
- if not urandom then
- error("Unable to obtain a secure random number generator, please see https://prosody.im/doc/random ("..urandom_err..")");
- end
+local function generate()
-- generate RFC 4122 complaint UUIDs (version 4 - random)
return get_nibbles(8).."-"..get_nibbles(4).."-4"..get_nibbles(3).."-"..(get_twobits())..get_nibbles(3).."-"..get_nibbles(12);
end
-function seed(x)
- urandom:write(x);
- urandom:flush();
-end
-
-return _M;
+return {
+ get_nibbles=get_nibbles;
+ generate = generate ;
+ -- COMPAT
+ seed = random.seed;
+};
diff --git a/util/watchdog.lua b/util/watchdog.lua
index bcb2e274..aa8c6486 100644
--- a/util/watchdog.lua
+++ b/util/watchdog.lua
@@ -2,12 +2,12 @@ local timer = require "util.timer";
local setmetatable = setmetatable;
local os_time = os.time;
-module "watchdog"
+local _ENV = nil;
local watchdog_methods = {};
local watchdog_mt = { __index = watchdog_methods };
-function new(timeout, callback)
+local function new(timeout, callback)
local watchdog = setmetatable({ timeout = timeout, last_reset = os_time(), callback = callback }, watchdog_mt);
timer.add_task(timeout+1, function (current_time)
local last_reset = watchdog.last_reset;
@@ -31,4 +31,6 @@ function watchdog_methods:cancel()
self.last_reset = nil;
end
-return _M;
+return {
+ new = new;
+};
diff --git a/util/x509.lua b/util/x509.lua
index 19d4ec6d..f228b201 100644
--- a/util/x509.lua
+++ b/util/x509.lua
@@ -20,13 +20,11 @@
local nameprep = require "util.encodings".stringprep.nameprep;
local idna_to_ascii = require "util.encodings".idna.to_ascii;
+local base64 = require "util.encodings".base64;
local log = require "util.logger".init("x509");
-local pairs, ipairs = pairs, ipairs;
local s_format = string.format;
-local t_insert = table.insert;
-local t_concat = table.concat;
-module "x509"
+local _ENV = nil;
local oid_commonname = "2.5.4.3"; -- [LDAP] 2.3
local oid_subjectaltname = "2.5.29.17"; -- [PKIX] 4.2.1.6
@@ -149,7 +147,10 @@ local function compare_srvname(host, service, asserted_names)
return false
end
-function verify_identity(host, service, cert)
+local function verify_identity(host, service, cert)
+ if cert.setencode then
+ cert:setencode("utf8");
+ end
local ext = cert:extensions()
if ext[oid_subjectaltname] then
local sans = ext[oid_subjectaltname];
@@ -161,7 +162,9 @@ function verify_identity(host, service, cert)
if sans[oid_xmppaddr] then
had_supported_altnames = true
- if compare_xmppaddr(host, sans[oid_xmppaddr]) then return true end
+ if service == "_xmpp-client" or service == "_xmpp-server" then
+ if compare_xmppaddr(host, sans[oid_xmppaddr]) then return true end
+ end
end
if sans[oid_dnssrv] then
@@ -212,4 +215,27 @@ function verify_identity(host, service, cert)
return false
end
-return _M;
+local pat = "%-%-%-%-%-BEGIN ([A-Z ]+)%-%-%-%-%-\r?\n"..
+"([0-9A-Za-z+/=\r\n]*)\r?\n%-%-%-%-%-END %1%-%-%-%-%-";
+
+local function pem2der(pem)
+ local typ, data = pem:match(pat);
+ if typ and data then
+ return base64.decode(data), typ;
+ end
+end
+
+local wrap = ('.'):rep(64);
+local envelope = "-----BEGIN %s-----\n%s\n-----END %s-----\n"
+
+local function der2pem(data, typ)
+ typ = typ and typ:upper() or "CERTIFICATE";
+ data = base64.encode(data);
+ return s_format(envelope, typ, data:gsub(wrap, '%0\n', (#data-1)/64), typ);
+end
+
+return {
+ verify_identity = verify_identity;
+ pem2der = pem2der;
+ der2pem = der2pem;
+};
diff --git a/util/xml.lua b/util/xml.lua
index 076490fa..733d821a 100644
--- a/util/xml.lua
+++ b/util/xml.lua
@@ -2,7 +2,7 @@
local st = require "util.stanza";
local lxp = require "lxp";
-module("xml")
+local _ENV = nil;
local parse_xml = (function()
local ns_prefixes = {
@@ -11,6 +11,7 @@ local parse_xml = (function()
local ns_separator = "\1";
local ns_pattern = "^([^"..ns_separator.."]*)"..ns_separator.."?(.*)$";
return function(xml)
+ --luacheck: ignore 212/self
local handler = {};
local stanza = st.stanza("root");
function handler:StartElement(tagname, attr)
@@ -26,8 +27,8 @@ local parse_xml = (function()
attr[i] = nil;
local ns, nm = k:match(ns_pattern);
if nm ~= "" then
- ns = ns_prefixes[ns];
- if ns then
+ ns = ns_prefixes[ns];
+ if ns then
attr[ns..":"..nm] = attr[k];
attr[k] = nil;
end
@@ -38,7 +39,7 @@ local parse_xml = (function()
function handler:CharacterData(data)
stanza:text(data);
end
- function handler:EndElement(tagname)
+ function handler:EndElement()
stanza:up();
end
local parser = lxp.new(handler, "\1");
@@ -53,5 +54,6 @@ local parse_xml = (function()
end;
end)();
-parse = parse_xml;
-return _M;
+return {
+ parse = parse_xml;
+};
diff --git a/util/xmppstream.lua b/util/xmppstream.lua
index 138c86b7..7be63285 100644
--- a/util/xmppstream.lua
+++ b/util/xmppstream.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -24,7 +24,7 @@ local lxp_supports_bytecount = not not lxp.new({}).getcurrentbytecount;
local default_stanza_size_limit = 1024*1024*10; -- 10MB
-module "xmppstream"
+local _ENV = nil;
local new_parser = lxp.new;
@@ -40,29 +40,26 @@ local xmlns_streams = "http://etherx.jabber.org/streams";
local ns_separator = "\1";
local ns_pattern = "^([^"..ns_separator.."]*)"..ns_separator.."?(.*)$";
-_M.ns_separator = ns_separator;
-_M.ns_pattern = ns_pattern;
-
local function dummy_cb() end
-function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
+local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
local xml_handlers = {};
-
+
local cb_streamopened = stream_callbacks.streamopened;
local cb_streamclosed = stream_callbacks.streamclosed;
local cb_error = stream_callbacks.error or function(session, e, stanza) error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); end;
local cb_handlestanza = stream_callbacks.handlestanza;
cb_handleprogress = cb_handleprogress or dummy_cb;
-
+
local stream_ns = stream_callbacks.stream_ns or xmlns_streams;
local stream_tag = stream_callbacks.stream_tag or "stream";
if stream_ns ~= "" then
stream_tag = stream_ns..ns_separator..stream_tag;
end
local stream_error_tag = stream_ns..ns_separator..(stream_callbacks.error_tag or "error");
-
+
local stream_default_ns = stream_callbacks.default_ns;
-
+
local stack = {};
local chardata, stanza = {};
local stanza_size = 0;
@@ -82,7 +79,7 @@ function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
attr.xmlns = curr_ns;
non_streamns_depth = non_streamns_depth + 1;
end
-
+
for i=1,#attr do
local k = attr[i];
attr[i] = nil;
@@ -92,7 +89,7 @@ function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
attr[k] = nil;
end
end
-
+
if not stanza then --if we are not currently inside a stanza
if lxp_supports_bytecount then
stanza_size = self:getcurrentbytecount();
@@ -116,7 +113,7 @@ function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
if curr_ns == "jabber:client" and name ~= "iq" and name ~= "presence" and name ~= "message" then
cb_error(session, "invalid-top-level-element");
end
-
+
stanza = setmetatable({ name = name, attr = attr, tags = {} }, stanza_mt);
else -- we are inside a stanza, so add a tag
if lxp_supports_bytecount then
@@ -205,26 +202,26 @@ function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
error("Failed to abort parsing");
end
end
-
+
if lxp_supports_doctype then
xml_handlers.StartDoctypeDecl = restricted_handler;
end
xml_handlers.Comment = restricted_handler;
xml_handlers.ProcessingInstruction = restricted_handler;
-
+
local function reset()
stanza, chardata, stanza_size = nil, {}, 0;
stack = {};
end
-
+
local function set_session(stream, new_session)
session = new_session;
end
-
+
return xml_handlers, { reset = reset, set_session = set_session };
end
-function new(session, stream_callbacks, stanza_size_limit)
+local function new(session, stream_callbacks, stanza_size_limit)
-- Used to track parser progress (e.g. to enforce size limits)
local n_outstanding_bytes = 0;
local handle_progress;
@@ -241,6 +238,25 @@ function new(session, stream_callbacks, stanza_size_limit)
local parser = new_parser(handlers, ns_separator, false);
local parse = parser.parse;
+ function session.open_stream(session, from, to)
+ local send = session.sends2s or session.send;
+
+ local attr = {
+ ["xmlns:stream"] = "http://etherx.jabber.org/streams",
+ ["xml:lang"] = "en",
+ xmlns = stream_callbacks.default_ns,
+ version = session.version and (session.version > 0 and "1.0" or nil),
+ id = session.streamid,
+ from = from or session.host, to = to,
+ };
+ if session.stream_attrs then
+ session:stream_attrs(from, to, attr)
+ end
+ send("<?xml version='1.0'?>");
+ send(st.stanza("stream:stream", attr):top_tag());
+ return true;
+ end
+
return {
reset = function ()
parser = new_parser(handlers, ns_separator, false);
@@ -262,4 +278,9 @@ function new(session, stream_callbacks, stanza_size_limit)
};
end
-return _M;
+return {
+ ns_separator = ns_separator;
+ ns_pattern = ns_pattern;
+ new_sax_handlers = new_sax_handlers;
+ new = new;
+};