aboutsummaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/array.lua65
-rw-r--r--util/cache.lua151
-rw-r--r--util/caps.lua10
-rw-r--r--util/dataforms.lua90
-rw-r--r--util/datamanager.lua139
-rw-r--r--util/datetime.lua23
-rw-r--r--util/debug.lua79
-rw-r--r--util/dependencies.lua56
-rw-r--r--util/events.lua92
-rw-r--r--util/filters.lua36
-rw-r--r--util/helpers.lua37
-rw-r--r--util/hex.lua26
-rw-r--r--util/hmac.lua2
-rw-r--r--util/import.lua3
-rw-r--r--util/interpolation.lua85
-rw-r--r--util/ip.lua47
-rw-r--r--util/iterators.lua96
-rw-r--r--util/jid.lua84
-rw-r--r--util/json.lua44
-rw-r--r--util/logger.lua26
-rw-r--r--util/mercurial.lua34
-rw-r--r--util/multitable.lua16
-rw-r--r--util/openssl.lua44
-rw-r--r--util/paths.lua44
-rw-r--r--util/pluginloader.lua31
-rw-r--r--util/presence.lua38
-rw-r--r--util/prosodyctl.lua121
-rw-r--r--util/pubsub.lua98
-rw-r--r--util/queue.lua73
-rw-r--r--util/random.lua30
-rw-r--r--util/rfc6724.lua1
-rw-r--r--util/sasl.lua58
-rw-r--r--util/sasl/anonymous.lua10
-rw-r--r--util/sasl/digest-md5.lua8
-rw-r--r--util/sasl/external.lua27
-rw-r--r--util/sasl/plain.lua8
-rw-r--r--util/sasl/scram.lua166
-rw-r--r--util/sasl_cyrus.lua12
-rw-r--r--util/serialization.lua18
-rw-r--r--util/session.lua65
-rw-r--r--util/set.lua144
-rw-r--r--util/sql.lua293
-rw-r--r--util/sslconfig.lua120
-rw-r--r--util/stanza.lua122
-rw-r--r--util/statistics.lua160
-rw-r--r--util/statsd.lua84
-rw-r--r--util/template.lua12
-rw-r--r--util/termcolours.lua65
-rw-r--r--util/throttle.lua8
-rw-r--r--util/time.lua8
-rw-r--r--util/timer.lua12
-rw-r--r--util/uuid.lua34
-rw-r--r--util/watchdog.lua8
-rw-r--r--util/x509.lua40
-rw-r--r--util/xml.lua14
-rw-r--r--util/xmppstream.lua59
56 files changed, 2403 insertions, 873 deletions
diff --git a/util/array.lua b/util/array.lua
index 2d58e7fb..3ddc97f6 100644
--- a/util/array.lua
+++ b/util/array.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -11,8 +11,10 @@ local t_insert, t_sort, t_remove, t_concat
local setmetatable = setmetatable;
local math_random = math.random;
+local math_floor = math.floor;
local pairs, ipairs = pairs, ipairs;
local tostring = tostring;
+local type = type;
local array = {};
local array_base = {};
@@ -35,7 +37,7 @@ setmetatable(array, { __call = new_array });
-- Read-only methods
function array_methods:random()
- return self[math_random(1,#self)];
+ return self[math_random(1, #self)];
end
-- These methods can be called two ways:
@@ -43,7 +45,7 @@ end
-- existing_array:method([params, ...]) -- Transform existing array into result
--
function array_base.map(outa, ina, func)
- for k,v in ipairs(ina) do
+ for k, v in ipairs(ina) do
outa[k] = func(v);
end
return outa;
@@ -52,20 +54,20 @@ end
function array_base.filter(outa, ina, func)
local inplace, start_length = ina == outa, #ina;
local write = 1;
- for read=1,start_length do
+ for read = 1, start_length do
local v = ina[read];
if func(v) then
outa[write] = v;
write = write + 1;
end
end
-
+
if inplace and write <= start_length then
- for i=write,start_length do
+ for i = write, start_length do
outa[i] = nil;
end
end
-
+
return outa;
end
@@ -78,34 +80,44 @@ function array_base.sort(outa, ina, ...)
end
function array_base.pluck(outa, ina, key)
- for i=1,#ina do
+ for i = 1, #ina do
outa[i] = ina[i][key];
end
return outa;
end
+function array_base.reverse(outa, ina)
+ local len = #ina;
+ if ina == outa then
+ local middle = math_floor(len/2);
+ len = len + 1;
+ local o; -- opposite
+ for i = 1, middle do
+ o = len - i;
+ outa[i], outa[o] = outa[o], outa[i];
+ end
+ else
+ local off = len + 1;
+ for i = 1, len do
+ outa[i] = ina[off - i];
+ end
+ end
+ return outa;
+end
+
--- These methods only mutate the array
function array_methods:shuffle(outa, ina)
local len = #self;
- for i=1,#self do
- local r = math_random(i,len);
+ for i = 1, #self do
+ local r = math_random(i, len);
self[i], self[r] = self[r], self[i];
end
return self;
end
-function array_methods:reverse()
- local len = #self-1;
- for i=len,1,-1 do
- self:push(self[i]);
- self:pop(i);
- end
- return self;
-end
-
function array_methods:append(array)
- local len,len2 = #self, #array;
- for i=1,len2 do
+ local len, len2 = #self, #array;
+ for i = 1, len2 do
self[len+i] = array[i];
end
return self;
@@ -116,11 +128,7 @@ function array_methods:push(x)
return self;
end
-function array_methods:pop(x)
- local v = self[x];
- t_remove(self, x);
- return v;
-end
+array_methods.pop = t_remove;
function array_methods:concat(sep)
return t_concat(array.map(self, tostring), sep);
@@ -135,7 +143,7 @@ function array.collect(f, s, var)
local t = {};
while true do
var = f(s, var);
- if var == nil then break; end
+ if var == nil then break; end
t_insert(t, var);
end
return setmetatable(t, array_mt);
@@ -157,7 +165,4 @@ for method, f in pairs(array_base) do
end
end
-_G.array = array;
-module("array");
-
return array;
diff --git a/util/cache.lua b/util/cache.lua
new file mode 100644
index 00000000..44bbfe30
--- /dev/null
+++ b/util/cache.lua
@@ -0,0 +1,151 @@
+
+local function _remove(list, m)
+ if m.prev then
+ m.prev.next = m.next;
+ end
+ if m.next then
+ m.next.prev = m.prev;
+ end
+ if list._tail == m then
+ list._tail = m.prev;
+ end
+ if list._head == m then
+ list._head = m.next;
+ end
+ list._count = list._count - 1;
+end
+
+local function _insert(list, m)
+ if list._head then
+ list._head.prev = m;
+ end
+ m.prev, m.next = nil, list._head;
+ list._head = m;
+ if not list._tail then
+ list._tail = m;
+ end
+ list._count = list._count + 1;
+end
+
+local cache_methods = {};
+local cache_mt = { __index = cache_methods };
+
+function cache_methods:set(k, v)
+ local m = self._data[k];
+ if m then
+ -- Key already exists
+ if v ~= nil then
+ -- Bump to head of list
+ _remove(self, m);
+ _insert(self, m);
+ m.value = v;
+ else
+ -- Remove from list
+ _remove(self, m);
+ self._data[k] = nil;
+ end
+ return true;
+ end
+ -- New key
+ if v == nil then
+ return true;
+ end
+ -- Check whether we need to remove oldest k/v
+ if self._count == self.size then
+ local tail = self._tail;
+ local on_evict, evicted_key, evicted_value = self._on_evict, tail.key, tail.value;
+ if on_evict ~= nil and (on_evict == false or on_evict(evicted_key, evicted_value) == false) then
+ -- Cache is full, and we're not allowed to evict
+ return false;
+ end
+ _remove(self, tail);
+ self._data[evicted_key] = nil;
+ end
+
+ m = { key = k, value = v, prev = nil, next = nil };
+ self._data[k] = m;
+ _insert(self, m);
+ return true;
+end
+
+function cache_methods:get(k)
+ local m = self._data[k];
+ if m then
+ return m.value;
+ end
+ return nil;
+end
+
+function cache_methods:items()
+ local m = self._head;
+ return function ()
+ if not m then
+ return;
+ end
+ local k, v = m.key, m.value;
+ m = m.next;
+ return k, v;
+ end
+end
+
+function cache_methods:values()
+ local m = self._head;
+ return function ()
+ if not m then
+ return;
+ end
+ local v = m.value;
+ m = m.next;
+ return v;
+ end
+end
+
+function cache_methods:count()
+ return self._count;
+end
+
+function cache_methods:head()
+ local head = self._head;
+ if not head then return nil, nil; end
+ return head.key, head.value;
+end
+
+function cache_methods:tail()
+ local tail = self._tail;
+ if not tail then return nil, nil; end
+ return tail.key, tail.value;
+end
+
+function cache_methods:table()
+ if not self.proxy_table then
+ self.proxy_table = setmetatable({}, {
+ __index = function (t, k)
+ return self:get(k);
+ end;
+ __newindex = function (t, k, v)
+ if not self:set(k, v) then
+ error("failed to insert key into cache - full");
+ end
+ end;
+ __pairs = function (t)
+ return self:items();
+ end;
+ __len = function (t)
+ return self:count();
+ end;
+ });
+ end
+ return self.proxy_table;
+end
+
+local function new(size, on_evict)
+ size = assert(tonumber(size), "cache size must be a number");
+ size = math.floor(size);
+ assert(size > 0, "cache size must be greater than zero");
+ local data = {};
+ return setmetatable({ _data = data, _count = 0, size = size, _head = nil, _tail = nil, _on_evict = on_evict }, cache_mt);
+end
+
+return {
+ new = new;
+}
diff --git a/util/caps.lua b/util/caps.lua
index a61e7403..cd5ff9c0 100644
--- a/util/caps.lua
+++ b/util/caps.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -12,9 +12,9 @@ local sha1 = require "util.hashes".sha1;
local t_insert, t_sort, t_concat = table.insert, table.sort, table.concat;
local ipairs = ipairs;
-module "caps"
+local _ENV = nil;
-function calculate_hash(disco_info)
+local function calculate_hash(disco_info)
local identities, features, extensions = {}, {}, {};
for _, tag in ipairs(disco_info) do
if tag.name == "identity" then
@@ -58,4 +58,6 @@ function calculate_hash(disco_info)
return ver, S;
end
-return _M;
+return {
+ calculate_hash = calculate_hash;
+};
diff --git a/util/dataforms.lua b/util/dataforms.lua
index ee37157a..756f35a7 100644
--- a/util/dataforms.lua
+++ b/util/dataforms.lua
@@ -1,26 +1,26 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
local setmetatable = setmetatable;
-local pairs, ipairs = pairs, ipairs;
+local ipairs = ipairs;
local tostring, type, next = tostring, type, next;
local t_concat = table.concat;
local st = require "util.stanza";
local jid_prep = require "util.jid".prep;
-module "dataforms"
+local _ENV = nil;
local xmlns_forms = 'jabber:x:data';
local form_t = {};
local form_mt = { __index = form_t };
-function new(layout)
+local function new(layout)
return setmetatable(layout, form_mt);
end
@@ -32,13 +32,13 @@ function form_t.form(layout, data, formtype)
if layout.instructions then
form:tag("instructions"):text(layout.instructions):up();
end
- for n, field in ipairs(layout) do
+ for _, field in ipairs(layout) do
local field_type = field.type or "text-single";
-- Add field tag
form:tag("field", { type = field_type, var = field.name, label = field.label });
local value = (data and data[field.name]) or field.value;
-
+
if value then
-- Add value, depending on type
if field_type == "hidden" then
@@ -69,10 +69,10 @@ function form_t.form(layout, data, formtype)
end
elseif field_type == "list-single" then
local has_default = false;
- for _, val in ipairs(value) do
+ for _, val in ipairs(field.options or value) do
if type(val) == "table" then
form:tag("option", { label = val.label }):tag("value"):text(val.value):up():up();
- if val.default and (not has_default) then
+ if value == val.value or val.default and (not has_default) then
form:tag("value"):text(val.value):up();
has_default = true;
end
@@ -80,17 +80,25 @@ function form_t.form(layout, data, formtype)
form:tag("option", { label= val }):tag("value"):text(tostring(val)):up():up();
end
end
+ if field.options and value then
+ form:tag("value"):text(value):up();
+ end
elseif field_type == "list-multi" then
- for _, val in ipairs(value) do
+ for _, val in ipairs(field.options or value) do
if type(val) == "table" then
form:tag("option", { label = val.label }):tag("value"):text(val.value):up():up();
- if val.default then
+ if not field.options and val.default then
form:tag("value"):text(val.value):up();
end
else
form:tag("option", { label= val }):tag("value"):text(tostring(val)):up():up();
end
end
+ if field.options and value then
+ for _, val in ipairs(value) do
+ form:tag("value"):text(val):up();
+ end
+ end
end
end
@@ -102,11 +110,11 @@ function form_t.form(layout, data, formtype)
end
form:up();
end
-
+
if field.required then
form:tag("required"):up();
end
-
+
-- Jump back up to list of fields
form:up();
end
@@ -118,6 +126,7 @@ local field_readers = {};
function form_t.data(layout, stanza)
local data = {};
local errors = {};
+ local present = {};
for _, field in ipairs(layout) do
local tag;
@@ -133,6 +142,7 @@ function form_t.data(layout, stanza)
errors[field.name] = "Required value missing";
end
else
+ present[field.name] = true;
local reader = field_readers[field.type];
if reader then
data[field.name], errors[field.name] = reader(tag, field.required);
@@ -140,35 +150,34 @@ function form_t.data(layout, stanza)
end
end
if next(errors) then
- return data, errors;
+ return data, errors, present;
end
- return data;
+ return data, nil, present;
end
-field_readers["text-single"] =
- function (field_tag, required)
- local data = field_tag:get_child_text("value");
- if data and #data > 0 then
- return data
- elseif required then
- return nil, "Required value missing";
- end
+local function simple_text(field_tag, required)
+ local data = field_tag:get_child_text("value");
+ -- XEP-0004 does not say if an empty string is acceptable for a required value
+ -- so we will follow HTML5 which says that empty string means missing
+ if required and (data == nil or data == "") then
+ return nil, "Required value missing";
end
+ return data; -- Return whatever get_child_text returned, even if empty string
+end
-field_readers["text-private"] =
- field_readers["text-single"];
+field_readers["text-single"] = simple_text;
+
+field_readers["text-private"] = simple_text;
field_readers["jid-single"] =
function (field_tag, required)
- local raw_data = field_tag:get_child_text("value")
+ local raw_data, err = simple_text(field_tag, required);
+ if not raw_data then return raw_data, err; end
local data = jid_prep(raw_data);
- if data and #data > 0 then
- return data
- elseif raw_data then
+ if not data then
return nil, "Invalid JID: " .. raw_data;
- elseif required then
- return nil, "Required value missing";
end
+ return data;
end
field_readers["jid-multi"] =
@@ -212,8 +221,7 @@ field_readers["text-multi"] =
return data, err;
end
-field_readers["list-single"] =
- field_readers["text-single"];
+field_readers["list-single"] = simple_text;
local boolean_values = {
["1"] = true, ["true"] = true,
@@ -222,15 +230,13 @@ local boolean_values = {
field_readers["boolean"] =
function (field_tag, required)
- local raw_value = field_tag:get_child_text("value");
- local value = boolean_values[raw_value ~= nil and raw_value];
- if value ~= nil then
- return value;
- elseif raw_value then
- return nil, "Invalid boolean representation";
- elseif required then
- return nil, "Required value missing";
+ local raw_value, err = simple_text(field_tag, required);
+ if not raw_value then return raw_value, err; end
+ local value = boolean_values[raw_value];
+ if value == nil then
+ return nil, "Invalid boolean representation:" .. raw_value;
end
+ return value;
end
field_readers["hidden"] =
@@ -238,7 +244,9 @@ field_readers["hidden"] =
return field_tag:get_child_text("value");
end
-return _M;
+return {
+ new = new;
+};
--[=[
diff --git a/util/datamanager.lua b/util/datamanager.lua
index c69ecd25..fb9ba3a4 100644
--- a/util/datamanager.lua
+++ b/util/datamanager.lua
@@ -17,7 +17,9 @@ local io_open = io.open;
local os_remove = os.remove;
local os_rename = os.rename;
local tonumber = tonumber;
+local tostring = tostring;
local next = next;
+local type = type;
local t_insert = table.insert;
local t_concat = table.concat;
local envloadfile = require"util.envload".envloadfile;
@@ -43,7 +45,7 @@ pcall(function()
fallocate = pposix.fallocate or fallocate;
end);
-module "datamanager"
+local _ENV = nil;
---- utils -----
local encode, decode;
@@ -74,7 +76,7 @@ local callbacks = {};
------- API -------------
-function set_data_path(path)
+local function set_data_path(path)
log("debug", "Setting data path to: %s", path);
data_path = path;
end
@@ -87,14 +89,14 @@ local function callback(username, host, datastore, data)
return username, host, datastore, data;
end
-function add_callback(func)
+local function add_callback(func)
if not callbacks[func] then -- Would you really want to set the same callback more than once?
callbacks[func] = true;
callbacks[#callbacks+1] = func;
return true;
end
end
-function remove_callback(func)
+local function remove_callback(func)
if callbacks[func] then
for i, f in ipairs(callbacks) do
if f == func then
@@ -106,7 +108,7 @@ function remove_callback(func)
end
end
-function getpath(username, host, datastore, ext, create)
+local function getpath(username, host, datastore, ext, create)
ext = ext or "dat";
host = (host and encode(host)) or "_global";
username = username and encode(username);
@@ -119,7 +121,7 @@ function getpath(username, host, datastore, ext, create)
end
end
-function load(username, host, datastore)
+local function load(username, host, datastore)
local data, ret = envloadfile(getpath(username, host, datastore), {});
if not data then
local mode = lfs.attributes(getpath(username, host, datastore), "mode");
@@ -144,24 +146,26 @@ end
local function atomic_store(filename, data)
local scratch = filename.."~";
local f, ok, msg;
- repeat
- f, msg = io_open(scratch, "w");
- if not f then break end
- ok, msg = f:write(data);
- if not ok then break end
+ f, msg = io_open(scratch, "w");
+ if not f then
+ return nil, msg;
+ end
- ok, msg = f:close();
- f = nil; -- no longer valid
- if not ok then break end
+ ok, msg = f:write(data);
+ if not ok then
+ f:close();
+ os_remove(scratch);
+ return nil, msg;
+ end
- return os_rename(scratch, filename);
- until false;
+ ok, msg = f:close();
+ if not ok then
+ os_remove(scratch);
+ return nil, msg;
+ end
- -- Cleanup
- if f then f:close(); end
- os_remove(scratch);
- return nil, msg;
+ return os_rename(scratch, filename);
end
if prosody and prosody.platform ~= "posix" then
@@ -176,7 +180,7 @@ if prosody and prosody.platform ~= "posix" then
end
end
-function store(username, host, datastore, data)
+local function store(username, host, datastore, data)
if not data then
data = {};
end
@@ -210,33 +214,62 @@ function store(username, host, datastore, data)
return true;
end
-function list_append(username, host, datastore, data)
- if not data then return; end
- if callback(username, host, datastore) == false then return true; end
- -- save the datastore
- local f, msg = io_open(getpath(username, host, datastore, "list", true), "r+");
- if not f then
- f, msg = io_open(getpath(username, host, datastore, "list", true), "w");
- end
+-- Append a blob of data to a file
+local function append(username, host, datastore, ext, data)
+ if type(data) ~= "string" then return; end
+ local filename = getpath(username, host, datastore, ext, true);
+
+ local ok;
+ local f, msg = io_open(filename, "r+");
if not f then
- log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil");
- return;
+ -- File did probably not exist, let's create it
+ f, msg = io_open(filename, "w");
+ if not f then
+ return nil, msg, "open";
+ end
end
- local data = "item(" .. serialize(data) .. ");\n";
+
local pos = f:seek("end");
- local ok, msg = fallocate(f, pos, #data);
- f:seek("set", pos);
- if ok then
- f:write(data);
- else
+ ok, msg = fallocate(f, pos, #data);
+ if not ok then
+ log("warn", "fallocate() failed: %s", tostring(msg));
+ -- This doesn't work on every file system
+ end
+
+ if f:seek() ~= pos then
+ log("debug", "fallocate() changed file position");
+ f:seek("set", pos);
+ end
+
+ ok, msg = f:write(data);
+ if not ok then
+ f:close();
+ return ok, msg, "write";
+ end
+
+ ok, msg = f:close();
+ if not ok then
+ return ok, msg;
+ end
+
+ return true, pos;
+end
+
+local function list_append(username, host, datastore, data)
+ if not data then return; end
+ if callback(username, host, datastore) == false then return true; end
+ -- save the datastore
+
+ data = "item(" .. serialize(data) .. ");\n";
+ local ok, msg = append(username, host, datastore, "list", data);
+ if not ok then
log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil");
return ok, msg;
end
- f:close();
return true;
end
-function list_store(username, host, datastore, data)
+local function list_store(username, host, datastore, data)
if not data then
data = {};
end
@@ -260,7 +293,7 @@ function list_store(username, host, datastore, data)
return true;
end
-function list_load(username, host, datastore)
+local function list_load(username, host, datastore)
local items = {};
local data, ret = envloadfile(getpath(username, host, datastore, "list"), {item = function(i) t_insert(items, i); end});
if not data then
@@ -288,7 +321,7 @@ local type_map = {
list = "list";
}
-function users(host, store, typ)
+local function users(host, store, typ)
typ = type_map[typ or "keyval"];
local store_dir = format("%s/%s/%s", data_path, encode(host), store);
@@ -307,7 +340,7 @@ function users(host, store, typ)
end, state;
end
-function stores(username, host, typ)
+local function stores(username, host, typ)
typ = type_map[typ or "keyval"];
local store_dir = format("%s/%s/", data_path, encode(host));
@@ -347,7 +380,7 @@ local function do_remove(path)
return true
end
-function purge(username, host)
+local function purge(username, host)
local host_dir = format("%s/%s/", data_path, encode(host));
local ok, iter, state, var = pcall(lfs.dir, host_dir);
if not ok then
@@ -367,6 +400,20 @@ function purge(username, host)
return #errs == 0, t_concat(errs, ", ");
end
-_M.path_decode = decode;
-_M.path_encode = encode;
-return _M;
+return {
+ set_data_path = set_data_path;
+ add_callback = add_callback;
+ remove_callback = remove_callback;
+ getpath = getpath;
+ load = load;
+ store = store;
+ append_raw = append;
+ list_append = list_append;
+ list_store = list_store;
+ list_load = list_load;
+ users = users;
+ stores = stores;
+ purge = purge;
+ path_decode = decode;
+ path_encode = encode;
+};
diff --git a/util/datetime.lua b/util/datetime.lua
index a1f62a48..abb4e867 100644
--- a/util/datetime.lua
+++ b/util/datetime.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -12,28 +12,27 @@
local os_date = os.date;
local os_time = os.time;
local os_difftime = os.difftime;
-local error = error;
local tonumber = tonumber;
-module "datetime"
+local _ENV = nil;
-function date(t)
+local function date(t)
return os_date("!%Y-%m-%d", t);
end
-function datetime(t)
+local function datetime(t)
return os_date("!%Y-%m-%dT%H:%M:%SZ", t);
end
-function time(t)
+local function time(t)
return os_date("!%H:%M:%S", t);
end
-function legacy(t)
+local function legacy(t)
return os_date("!%Y%m%dT%H:%M:%S", t);
end
-function parse(s)
+local function parse(s)
if s then
local year, month, day, hour, min, sec, tzd;
year, month, day, hour, min, sec, tzd = s:match("^(%d%d%d%d)%-?(%d%d)%-?(%d%d)T(%d%d):(%d%d):(%d%d)%.?%d*([Z+%-]?.*)$");
@@ -54,4 +53,10 @@ function parse(s)
end
end
-return _M;
+return {
+ date = date;
+ datetime = datetime;
+ time = time;
+ legacy = legacy;
+ parse = parse;
+};
diff --git a/util/debug.lua b/util/debug.lua
index bff0e347..00f476d0 100644
--- a/util/debug.lua
+++ b/util/debug.lua
@@ -1,6 +1,9 @@
-- Variables ending with these names will not
-- have their values printed ('password' includes
-- 'new_password', etc.)
+--
+-- luacheck: ignore 122/debug
+
local censored_names = {
password = true;
passwd = true;
@@ -13,7 +16,7 @@ local termcolours = require "util.termcolours";
local getstring = termcolours.getstring;
local styles;
do
- _ = termcolours.getstyle;
+ local _ = termcolours.getstyle;
styles = {
boundary_padding = _("bright");
filename = _("bright", "blue");
@@ -22,20 +25,23 @@ do
location = _("yellow");
};
end
-module("debugx", package.seeall);
-function get_locals_table(level)
- level = level + 1; -- Skip this function itself
+local function get_locals_table(thread, level)
local locals = {};
for local_num = 1, math.huge do
- local name, value = debug.getlocal(level, local_num);
+ local name, value;
+ if thread then
+ name, value = debug.getlocal(thread, level, local_num);
+ else
+ name, value = debug.getlocal(level+1, local_num);
+ end
if not name then break; end
table.insert(locals, { name = name, value = value });
end
return locals;
end
-function get_upvalues_table(func)
+local function get_upvalues_table(func)
local upvalues = {};
if func then
for upvalue_num = 1, math.huge do
@@ -47,7 +53,7 @@ function get_upvalues_table(func)
return upvalues;
end
-function string_from_var_table(var_table, max_line_len, indent_str)
+local function string_from_var_table(var_table, max_line_len, indent_str)
local var_string = {};
local col_pos = 0;
max_line_len = max_line_len or math.huge;
@@ -83,33 +89,25 @@ function string_from_var_table(var_table, max_line_len, indent_str)
end
end
-function get_traceback_table(thread, start_level)
+local function get_traceback_table(thread, start_level)
local levels = {};
for level = start_level, math.huge do
local info;
if thread then
- info = debug.getinfo(thread, level+1);
+ info = debug.getinfo(thread, level);
else
info = debug.getinfo(level+1);
end
if not info then break; end
-
+
levels[(level-start_level)+1] = {
level = level;
info = info;
- locals = get_locals_table(level+1);
+ locals = get_locals_table(thread, level+(thread and 0 or 1));
upvalues = get_upvalues_table(info.func);
};
- end
- return levels;
-end
-
-function traceback(...)
- local ok, ret = pcall(_traceback, ...);
- if not ok then
- return "Error in error handling: "..ret;
end
- return ret;
+ return levels;
end
local function build_source_boundary_marker(last_source_desc)
@@ -117,7 +115,7 @@ local function build_source_boundary_marker(last_source_desc)
return getstring(styles.boundary_padding, "v"..padding).." "..getstring(styles.filename, last_source_desc).." "..getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v "));
end
-function _traceback(thread, message, level)
+local function _traceback(thread, message, level)
-- Lua manual says: debug.traceback ([thread,] [message [, level]])
-- I fathom this to mean one of:
@@ -134,15 +132,15 @@ function _traceback(thread, message, level)
return nil; -- debug.traceback() does this
end
- level = level or 1;
+ level = level or 0;
message = message and (message.."\n") or "";
-
- -- +3 counts for this function, and the pcall() and wrapper above us
- local levels = get_traceback_table(thread, level+3);
-
+
+ -- +3 counts for this function, and the pcall() and wrapper above us, the +1... I don't know.
+ local levels = get_traceback_table(thread, level+(thread == nil and 4 or 0));
+
local last_source_desc;
-
+
local lines = {};
for nlevel, level in ipairs(levels) do
local info = level.info;
@@ -171,9 +169,11 @@ function _traceback(thread, message, level)
nlevel = nlevel-1;
table.insert(lines, "\t"..(nlevel==0 and ">" or " ")..getstring(styles.level_num, "("..nlevel..") ")..line);
local npadding = (" "):rep(#tostring(nlevel));
- local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding);
- if locals_str then
- table.insert(lines, "\t "..npadding.."Locals: "..locals_str);
+ if level.locals then
+ local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding);
+ if locals_str then
+ table.insert(lines, "\t "..npadding.."Locals: "..locals_str);
+ end
end
local upvalues_str = string_from_var_table(level.upvalues, optimal_line_length, "\t "..npadding);
if upvalues_str then
@@ -186,8 +186,23 @@ function _traceback(thread, message, level)
return message.."stack traceback:\n"..table.concat(lines, "\n");
end
-function use()
+local function traceback(...)
+ local ok, ret = pcall(_traceback, ...);
+ if not ok then
+ return "Error in error handling: "..ret;
+ end
+ return ret;
+end
+
+local function use()
debug.traceback = traceback;
end
-return _M;
+return {
+ get_locals_table = get_locals_table;
+ get_upvalues_table = get_upvalues_table;
+ string_from_var_table = string_from_var_table;
+ get_traceback_table = get_traceback_table;
+ traceback = traceback;
+ use = use;
+};
diff --git a/util/dependencies.lua b/util/dependencies.lua
index 4d50cf63..b3f07257 100644
--- a/util/dependencies.lua
+++ b/util/dependencies.lua
@@ -1,21 +1,19 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
-module("dependencies", package.seeall)
-
-function softreq(...) local ok, lib = pcall(require, ...); if ok then return lib; else return nil, lib; end end
+local function softreq(...) local ok, lib = pcall(require, ...); if ok then return lib; else return nil, lib; end end
-- Required to be able to find packages installed with luarocks
if not softreq "luarocks.loader" then -- LuaRocks 2.x
softreq "luarocks.require"; -- LuaRocks <1.x
end
-function missingdep(name, sources, msg)
+local function missingdep(name, sources, msg)
print("");
print("**************************");
print("Prosody was unable to find "..tostring(name));
@@ -35,7 +33,7 @@ function missingdep(name, sources, msg)
print("");
end
--- COMPAT w/pre-0.8 Debian: The Debian config file used to use
+-- COMPAT w/pre-0.8 Debian: The Debian config file used to use
-- util.ztact, which has been removed from Prosody in 0.8. This
-- is to log an error for people who still use it, so they can
-- update their configs.
@@ -48,19 +46,19 @@ package.preload["util.ztact"] = function ()
end
end;
-function check_dependencies()
- if _VERSION ~= "Lua 5.1" then
+local function check_dependencies()
+ if _VERSION < "Lua 5.1" then
print "***********************************"
print("Unsupported Lua version: ".._VERSION);
- print("Only Lua 5.1 is supported.");
+ print("At least Lua 5.1 is required.");
print "***********************************"
return false;
end
local fatal;
-
+
local lxp = softreq "lxp"
-
+
if not lxp then
missingdep("luaexpat", {
["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-expat0";
@@ -69,9 +67,9 @@ function check_dependencies()
});
fatal = true;
end
-
+
local socket = softreq "socket"
-
+
if not socket then
missingdep("luasocket", {
["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-socket2";
@@ -80,7 +78,7 @@ function check_dependencies()
});
fatal = true;
end
-
+
local lfs, err = softreq "lfs"
if not lfs then
missingdep("luafilesystem", {
@@ -90,9 +88,9 @@ function check_dependencies()
});
fatal = true;
end
-
+
local ssl = softreq "ssl"
-
+
if not ssl then
missingdep("LuaSec", {
["Debian/Ubuntu"] = "http://prosody.im/download/start#debian_and_ubuntu";
@@ -100,10 +98,10 @@ function check_dependencies()
["Source"] = "http://www.inf.puc-rio.br/~brunoos/luasec/";
}, "SSL/TLS support will not be available");
end
-
+
local encodings, err = softreq "util.encodings"
if not encodings then
- if err:match("not found") then
+ if err:match("module '[^']*' not found") then
missingdep("util.encodings", { ["Windows"] = "Make sure you have encodings.dll from the Prosody distribution in util/";
["GNU/Linux"] = "Run './configure' and 'make' in the Prosody source directory to build util/encodings.so";
});
@@ -120,7 +118,7 @@ function check_dependencies()
local hashes, err = softreq "util.hashes"
if not hashes then
- if err:match("not found") then
+ if err:match("module '[^']*' not found") then
missingdep("util.hashes", { ["Windows"] = "Make sure you have hashes.dll from the Prosody distribution in util/";
["GNU/Linux"] = "Run './configure' and 'make' in the Prosody source directory to build util/hashes.so";
});
@@ -137,22 +135,27 @@ function check_dependencies()
return not fatal;
end
-function log_warnings()
+local function log_warnings()
+ if _VERSION > "Lua 5.1" then
+ prosody.log("warn", "Support for %s is experimental, please report any issues", _VERSION);
+ end
+ local ssl = softreq"ssl";
if ssl then
local major, minor, veryminor, patched = ssl._VERSION:match("(%d+)%.(%d+)%.?(%d*)(M?)");
if not major or ((tonumber(major) == 0 and (tonumber(minor) or 0) <= 3 and (tonumber(veryminor) or 0) <= 2) and patched ~= "M") then
- log("error", "This version of LuaSec contains a known bug that causes disconnects, see http://prosody.im/doc/depends");
+ prosody.log("error", "This version of LuaSec contains a known bug that causes disconnects, see http://prosody.im/doc/depends");
end
end
+ local lxp = softreq"lxp";
if lxp then
if not pcall(lxp.new, { StartDoctypeDecl = false }) then
- log("error", "The version of LuaExpat on your system leaves Prosody "
+ prosody.log("error", "The version of LuaExpat on your system leaves Prosody "
.."vulnerable to denial-of-service attacks. You should upgrade to "
.."LuaExpat 1.3.0 or higher as soon as possible. See "
.."http://prosody.im/doc/depends#luaexpat for more information.");
end
if not lxp.new({}).getcurrentbytecount then
- log("error", "The version of LuaExpat on your system does not support "
+ prosody.log("error", "The version of LuaExpat on your system does not support "
.."stanza size limits, which may leave servers on untrusted "
.."networks (e.g. the internet) vulnerable to denial-of-service "
.."attacks. You should upgrade to LuaExpat 1.3.0 or higher as "
@@ -162,4 +165,9 @@ function log_warnings()
end
end
-return _M;
+return {
+ softreq = softreq;
+ missingdep = missingdep;
+ check_dependencies = check_dependencies;
+ log_warnings = log_warnings;
+};
diff --git a/util/events.lua b/util/events.lua
index 412acccd..e2943e44 100644
--- a/util/events.lua
+++ b/util/events.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -9,15 +9,23 @@
local pairs = pairs;
local t_insert = table.insert;
+local t_remove = table.remove;
local t_sort = table.sort;
local setmetatable = setmetatable;
local next = next;
-module "events"
+local _ENV = nil;
-function new()
+local function new()
+ -- Map event name to ordered list of handlers (lazily built): handlers[event_name] = array_of_handler_functions
local handlers = {};
+ -- Array of wrapper functions that wrap all events (nil if empty)
+ local global_wrappers;
+ -- Per-event wrappers: wrappers[event_name] = wrapper_function
+ local wrappers = {};
+ -- Event map: event_map[handler_function] = priority_number
local event_map = {};
+ -- Called on-demand to build handlers entries
local function _rebuild_index(handlers, event)
local _handlers = event_map[event];
if not _handlers or next(_handlers) == nil then return; end
@@ -50,6 +58,9 @@ function new()
end
end
end;
+ local function get_handlers(event)
+ return handlers[event];
+ end;
local function add_handlers(handlers)
for event, handler in pairs(handlers) do
add_handler(event, handler);
@@ -60,24 +71,91 @@ function new()
remove_handler(event, handler);
end
end;
- local function fire_event(event, ...)
- local h = handlers[event];
+ local function _fire_event(event_name, event_data)
+ local h = handlers[event_name];
if h then
for i=1,#h do
- local ret = h[i](...);
+ local ret = h[i](event_data);
if ret ~= nil then return ret; end
end
end
end;
+ local function fire_event(event_name, event_data)
+ local w = wrappers[event_name] or global_wrappers;
+ if w then
+ local curr_wrapper = #w;
+ local function c(event_name, event_data)
+ curr_wrapper = curr_wrapper - 1;
+ if curr_wrapper == 0 then
+ if global_wrappers == nil or w == global_wrappers then
+ return _fire_event(event_name, event_data);
+ end
+ w, curr_wrapper = global_wrappers, #global_wrappers;
+ return w[curr_wrapper](c, event_name, event_data);
+ else
+ return w[curr_wrapper](c, event_name, event_data);
+ end
+ end
+ return w[curr_wrapper](c, event_name, event_data);
+ end
+ return _fire_event(event_name, event_data);
+ end
+ local function add_wrapper(event_name, wrapper)
+ local w;
+ if event_name == false then
+ w = global_wrappers;
+ if not w then
+ w = {};
+ global_wrappers = w;
+ end
+ else
+ w = wrappers[event_name];
+ if not w then
+ w = {};
+ wrappers[event_name] = w;
+ end
+ end
+ w[#w+1] = wrapper;
+ end
+ local function remove_wrapper(event_name, wrapper)
+ local w;
+ if event_name == false then
+ w = global_wrappers;
+ else
+ w = wrappers[event_name];
+ end
+ if not w then return; end
+ for i = #w, 1 do
+ if w[i] == wrapper then
+ t_remove(w, i);
+ end
+ end
+ if #w == 0 then
+ if event_name == false then
+ global_wrappers = nil;
+ else
+ wrappers[event_name] = nil;
+ end
+ end
+ end
return {
add_handler = add_handler;
remove_handler = remove_handler;
add_handlers = add_handlers;
remove_handlers = remove_handlers;
+ get_handlers = get_handlers;
+ wrappers = {
+ add_handler = add_wrapper;
+ remove_handler = remove_wrapper;
+ };
+ add_wrapper = add_wrapper;
+ remove_wrapper = remove_wrapper;
fire_event = fire_event;
_handlers = handlers;
_event_map = event_map;
};
end
-return _M;
+return {
+ new = new;
+};
diff --git a/util/filters.lua b/util/filters.lua
index 6290e53b..f405c0bd 100644
--- a/util/filters.lua
+++ b/util/filters.lua
@@ -1,22 +1,22 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
local t_insert, t_remove = table.insert, table.remove;
-module "filters"
+local _ENV = nil;
local new_filter_hooks = {};
-function initialize(session)
+local function initialize(session)
if not session.filters then
local filters = {};
session.filters = filters;
-
+
function session.filter(type, data)
local filter_list = filters[type];
if filter_list then
@@ -28,19 +28,19 @@ function initialize(session)
return data;
end
end
-
+
for i=1,#new_filter_hooks do
new_filter_hooks[i](session);
end
-
+
return session.filter;
end
-function add_filter(session, type, callback, priority)
+local function add_filter(session, type, callback, priority)
if not session.filters then
initialize(session);
end
-
+
local filter_list = session.filters[type];
if not filter_list then
filter_list = {};
@@ -48,19 +48,19 @@ function add_filter(session, type, callback, priority)
elseif filter_list[callback] then
return; -- Filter already added
end
-
+
priority = priority or 0;
-
+
local i = 0;
repeat
i = i + 1;
until not filter_list[i] or filter_list[filter_list[i]] < priority;
-
+
t_insert(filter_list, i, callback);
filter_list[callback] = priority;
end
-function remove_filter(session, type, callback)
+local function remove_filter(session, type, callback)
if not session.filters then return; end
local filter_list = session.filters[type];
if filter_list and filter_list[callback] then
@@ -74,11 +74,11 @@ function remove_filter(session, type, callback)
end
end
-function add_filter_hook(callback)
+local function add_filter_hook(callback)
t_insert(new_filter_hooks, callback);
end
-function remove_filter_hook(callback)
+local function remove_filter_hook(callback)
for i=1,#new_filter_hooks do
if new_filter_hooks[i] == callback then
t_remove(new_filter_hooks, i);
@@ -86,4 +86,10 @@ function remove_filter_hook(callback)
end
end
-return _M;
+return {
+ initialize = initialize;
+ add_filter = add_filter;
+ remove_filter = remove_filter;
+ add_filter_hook = add_filter_hook;
+ remove_filter_hook = remove_filter_hook;
+};
diff --git a/util/helpers.lua b/util/helpers.lua
index 08b86a7c..bf76d258 100644
--- a/util/helpers.lua
+++ b/util/helpers.lua
@@ -1,28 +1,18 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
local debug = require "util.debug";
-module("helpers", package.seeall);
-
-- Helper functions for debugging
local log = require "util.logger".init("util.debug");
-function log_host_events(host)
- return log_events(prosody.hosts[host].events, host);
-end
-
-function revert_log_host_events(host)
- return revert_log_events(prosody.hosts[host].events);
-end
-
-function log_events(events, name, logger)
+local function log_events(events, name, logger)
local f = events.fire_event;
if not f then
error("Object does not appear to be a util.events object");
@@ -37,11 +27,19 @@ function log_events(events, name, logger)
return events;
end
-function revert_log_events(events)
+local function revert_log_events(events)
events.fire_event, events[events.fire_event] = events[events.fire_event], nil; -- :))
end
-function show_events(events, specific_event)
+local function log_host_events(host)
+ return log_events(prosody.hosts[host].events, host);
+end
+
+local function revert_log_host_events(host)
+ return revert_log_events(prosody.hosts[host].events);
+end
+
+local function show_events(events, specific_event)
local event_handlers = events._handlers;
local events_array = {};
local event_handler_arrays = {};
@@ -70,7 +68,7 @@ function show_events(events, specific_event)
return table.concat(events_array, "\n");
end
-function get_upvalue(f, get_name)
+local function get_upvalue(f, get_name)
local i, name, value = 0;
repeat
i = i + 1;
@@ -79,4 +77,11 @@ function get_upvalue(f, get_name)
return value;
end
-return _M;
+return {
+ log_host_events = log_host_events;
+ revert_log_host_events = revert_log_host_events;
+ log_events = log_events;
+ revert_log_events = revert_log_events;
+ show_events = show_events;
+ get_upvalue = get_upvalue;
+};
diff --git a/util/hex.lua b/util/hex.lua
new file mode 100644
index 00000000..4cc28d33
--- /dev/null
+++ b/util/hex.lua
@@ -0,0 +1,26 @@
+local s_char = string.char;
+local s_format = string.format;
+local s_gsub = string.gsub;
+local s_lower = string.lower;
+
+local char_to_hex = {};
+local hex_to_char = {};
+
+do
+ local char, hex;
+ for i = 0,255 do
+ char, hex = s_char(i), s_format("%02x", i);
+ char_to_hex[char] = hex;
+ hex_to_char[hex] = char;
+ end
+end
+
+local function to(s)
+ return (s_gsub(s, ".", char_to_hex));
+end
+
+local function from(s)
+ return (s_gsub(s_lower(s), "%X*(%x%x)%X*", hex_to_char));
+end
+
+return { to = to, from = from }
diff --git a/util/hmac.lua b/util/hmac.lua
index 51211c7a..2c4cc6ef 100644
--- a/util/hmac.lua
+++ b/util/hmac.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
diff --git a/util/import.lua b/util/import.lua
index 81401e8b..c2b9dce1 100644
--- a/util/import.lua
+++ b/util/import.lua
@@ -1,13 +1,14 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
+local unpack = table.unpack or unpack; --luacheck: ignore 113
local t_insert = table.insert;
function import(module, ...)
local m = package.loaded[module] or require(module);
diff --git a/util/interpolation.lua b/util/interpolation.lua
new file mode 100644
index 00000000..315cc203
--- /dev/null
+++ b/util/interpolation.lua
@@ -0,0 +1,85 @@
+-- Simple template language
+--
+-- The new() function takes a pattern and an escape function and returns
+-- a render() function. Both are required.
+--
+-- The function render() takes a string template and a table of values.
+-- Sequences like {name} in the template string are substituted
+-- with values from the table, optionally depending on a modifier
+-- symbol.
+--
+-- Variants are:
+-- {name} is substituted for values["name"] and is escaped using the
+-- second argument to new_render(). To disable the escaping, use {name!}.
+-- {name.item} can be used to access table items.
+-- To renter lists of items: {name# item number {idx} is {item} }
+-- Or key-value pairs: {name% t[ {idx} ] = {item} }
+-- To show a defaults for missing values {name? sub-template } can be used,
+-- which renders a sub-template if values["name"] is false-ish.
+-- {name& sub-template } does the opposite, the sub-template is rendered
+-- if the selected value is anything but false or nil.
+
+local type, tostring = type, tostring;
+local pairs, ipairs = pairs, ipairs;
+local s_sub, s_gsub, s_match = string.sub, string.gsub, string.match;
+local t_concat = table.concat;
+
+local function new_render(pat, escape, funcs)
+ -- assert(type(pat) == "string", "bad argument #1 to 'new_render' (string expected)");
+ -- assert(type(escape) == "function", "bad argument #2 to 'new_render' (function expected)");
+ local function render(template, values)
+ -- assert(type(template) == "string", "bad argument #1 to 'render' (string expected)");
+ -- assert(type(values) == "table", "bad argument #2 to 'render' (table expected)");
+ return (s_gsub(template, pat, function (block)
+ block = s_sub(block, 2, -2);
+ local name, opt, e = s_match(block, "^([%a_][%w_.]*)(%p?)()");
+ if not name then return end
+ local value = values[name];
+ if not value and name:find(".", 2, true) then
+ value = values;
+ for word in name:gmatch"[^.]+" do
+ value = value[word];
+ if not value then break; end
+ end
+ end
+ if funcs then
+ while value ~= nil and opt == '|' do
+ local f;
+ f, opt, e = s_match(block, "^([%a_][%w_.]*)(%p?)()", e);
+ f = funcs[f];
+ if f then value = f(value); end
+ end
+ end
+ if opt == '#' or opt == '%' then
+ if type(value) ~= "table" then return ""; end
+ local iter = opt == '#' and ipairs or pairs;
+ local out, i, subtpl = {}, 1, s_sub(block, e);
+ local subvalues = setmetatable({}, { __index = values });
+ for idx, item in iter(value) do
+ subvalues.idx = idx;
+ subvalues.item = item;
+ out[i], i = render(subtpl, subvalues), i+1;
+ end
+ return t_concat(out);
+ elseif opt == '&' then
+ if not value then return ""; end
+ return render(s_sub(block, e), values);
+ elseif opt == '?' and not value then
+ return render(s_sub(block, e), values);
+ elseif value ~= nil then
+ if type(value) ~= "string" then
+ value = tostring(value);
+ end
+ if opt ~= '!' then
+ return escape(value);
+ end
+ return value;
+ end
+ end));
+ end
+ return render;
+end
+
+return {
+ new = new_render;
+};
diff --git a/util/ip.lua b/util/ip.lua
index acfd7f24..81a98ef7 100644
--- a/util/ip.lua
+++ b/util/ip.lua
@@ -51,15 +51,15 @@ local function toBits(ip)
if not ip:match(":$") then fields[#fields] = nil; end
for i, field in ipairs(fields) do
if field:len() == 0 and i ~= 1 and i ~= #fields then
- for i = 1, 16 * (9 - #fields) do
+ for _ = 1, 16 * (9 - #fields) do
result = result .. "0";
end
else
- for i = 1, 4 - field:len() do
+ for _ = 1, 4 - field:len() do
result = result .. "0000";
end
- for i = 1, field:len() do
- result = result .. hex2bits[field:sub(i,i)];
+ for j = 1, field:len() do
+ result = result .. hex2bits[field:sub(j, j)];
end
end
end
@@ -96,7 +96,7 @@ local function v6scope(ip)
if ip:match("^[0:]*1$") then
return 0x2;
-- Link-local unicast:
- elseif ip:match("^[Ff][Ee][89ABab]") then
+ elseif ip:match("^[Ff][Ee][89ABab]") then
return 0x2;
-- Site-local unicast:
elseif ip:match("^[Ff][Ee][CcDdEeFf]") then
@@ -206,5 +206,40 @@ function ip_methods:scope()
return value;
end
+function ip_methods:private()
+ local private = self.scope ~= 0xE;
+ if not private and self.proto == "IPv4" then
+ local ip = self.addr;
+ local fields = {};
+ ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end);
+ if fields[1] == 127 or fields[1] == 10 or (fields[1] == 192 and fields[2] == 168)
+ or (fields[1] == 172 and (fields[2] >= 16 or fields[2] <= 32)) then
+ private = true;
+ end
+ end
+ self.private = private;
+ return private;
+end
+
+local function parse_cidr(cidr)
+ local bits;
+ local ip_len = cidr:find("/", 1, true);
+ if ip_len then
+ bits = tonumber(cidr:sub(ip_len+1, -1));
+ cidr = cidr:sub(1, ip_len-1);
+ end
+ return new_ip(cidr), bits;
+end
+
+local function match(ipA, ipB, bits)
+ local common_bits = commonPrefixLength(ipA, ipB);
+ if bits and ipB.proto == "IPv4" then
+ common_bits = common_bits - 96; -- v6 mapped addresses always share these bits
+ end
+ return common_bits >= (bits or 128);
+end
+
return {new_ip = new_ip,
- commonPrefixLength = commonPrefixLength};
+ commonPrefixLength = commonPrefixLength,
+ parse_cidr = parse_cidr,
+ match=match};
diff --git a/util/iterators.lua b/util/iterators.lua
index 1f6aacb8..bd150ff2 100644
--- a/util/iterators.lua
+++ b/util/iterators.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -10,6 +10,11 @@
local it = {};
+local t_insert = table.insert;
+local select, next = select, next;
+local unpack = table.unpack or unpack; --luacheck: ignore 113
+local pack = table.pack or function (...) return { n = select("#", ...), ... }; end
+
-- Reverse an iterator
function it.reverse(f, s, var)
local results = {};
@@ -18,18 +23,18 @@ function it.reverse(f, s, var)
while true do
local ret = { f(s, var) };
var = ret[1];
- if var == nil then break; end
- table.insert(results, 1, ret);
+ if var == nil then break; end
+ t_insert(results, 1, ret);
end
-
+
-- Then return our reverse one
local i,max = 0, #results;
- return function (results)
- if i<max then
- i = i + 1;
- return unpack(results[i]);
- end
- end, results;
+ return function (_results)
+ if i<max then
+ i = i + 1;
+ return unpack(_results[i]);
+ end
+ end, results;
end
-- Iterate only over keys in a table
@@ -43,24 +48,33 @@ end
-- Iterate only over values in a table
function it.values(t)
local key, val;
- return function (t)
- key, val = next(t, key);
+ return function (_t)
+ key, val = next(_t, key);
return val;
end, t;
end
+-- Iterate over the n:th return value
+function it.select(n, f, s, var)
+ return function (_s)
+ local ret = pack(f(_s, var));
+ var = ret[1];
+ return ret[n];
+ end, s, var;
+end
+
-- Given an iterator, iterate only over unique items
function it.unique(f, s, var)
local set = {};
-
+
return function ()
while true do
- local ret = { f(s, var) };
+ local ret = pack(f(s, var));
var = ret[1];
- if var == nil then break; end
- if not set[var] then
+ if var == nil then break; end
+ if not set[var] then
set[var] = true;
- return var;
+ return unpack(ret, 1, ret.n);
end
end
end;
@@ -69,32 +83,31 @@ end
--[[ Return the number of items an iterator returns ]]--
function it.count(f, s, var)
local x = 0;
-
+
while true do
- local ret = { f(s, var) };
- var = ret[1];
- if var == nil then break; end
+ var = f(s, var);
+ if var == nil then break; end
x = x + 1;
end
-
+
return x;
end
-- Return the first n items an iterator returns
function it.head(n, f, s, var)
local c = 0;
- return function (s, var)
+ return function (_s, _var)
if c >= n then
return nil;
end
c = c + 1;
- return f(s, var);
- end, s;
+ return f(_s, _var);
+ end, s, var;
end
-- Skip the first n items an iterator returns
function it.skip(n, f, s, var)
- for i=1,n do
+ for _ = 1, n do
var = f(s, var);
end
return f, s, var;
@@ -104,9 +117,9 @@ end
function it.tail(n, f, s, var)
local results, count = {}, 0;
while true do
- local ret = { f(s, var) };
+ local ret = pack(f(s, var));
var = ret[1];
- if var == nil then break; end
+ if var == nil then break; end
results[(count%n)+1] = ret;
count = count + 1;
end
@@ -117,9 +130,24 @@ function it.tail(n, f, s, var)
return function ()
pos = pos + 1;
if pos > n then return nil; end
- return unpack(results[((count-1+pos)%n)+1]);
+ local ret = results[((count-1+pos)%n)+1];
+ return unpack(ret, 1, ret.n);
+ end
+ --return reverse(head(n, reverse(f, s, var))); -- !
+end
+
+function it.filter(filter, f, s, var)
+ if type(filter) ~= "function" then
+ local filter_value = filter;
+ function filter(x) return x ~= filter_value; end
end
- --return reverse(head(n, reverse(f, s, var)));
+ return function (_s, _var)
+ local ret;
+ repeat ret = pack(f(_s, _var));
+ _var = ret[1];
+ until _var == nil or filter(unpack(ret, 1, ret.n));
+ return unpack(ret, 1, ret.n);
+ end, s, var;
end
local function _ripairs_iter(t, key) if key > 1 then return key-1, t[key-1]; end end
@@ -135,11 +163,11 @@ end
-- Convert the values returned by an iterator to an array
function it.to_array(f, s, var)
- local t, var = {};
+ local t = {};
while true do
var = f(s, var);
- if var == nil then break; end
- table.insert(t, var);
+ if var == nil then break; end
+ t_insert(t, var);
end
return t;
end
@@ -150,7 +178,7 @@ function it.to_table(f, s, var)
local t, var2 = {};
while true do
var, var2 = f(s, var);
- if var == nil then break; end
+ if var == nil then break; end
t[var] = var2;
end
return t;
diff --git a/util/jid.lua b/util/jid.lua
index 8e0a784c..522fb126 100644
--- a/util/jid.lua
+++ b/util/jid.lua
@@ -1,13 +1,14 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
+local select = select;
local match, sub = string.match, string.sub;
local nodeprep = require "util.encodings".stringprep.nodeprep;
local nameprep = require "util.encodings".stringprep.nameprep;
@@ -23,9 +24,9 @@ local escapes = {
local unescapes = {};
for k,v in pairs(escapes) do unescapes[v] = k; end
-module "jid"
+local _ENV = nil;
-local function _split(jid)
+local function split(jid)
if not jid then return; end
local node, nodepos = match(jid, "^([^@/]+)@()");
local host, hostpos = match(jid, "^([^@/]+)()", nodepos)
@@ -34,18 +35,17 @@ local function _split(jid)
if (not host) or ((not resource) and #jid >= hostpos) then return nil, nil, nil; end
return node, host, resource;
end
-split = _split;
-function bare(jid)
- local node, host = _split(jid);
+local function bare(jid)
+ local node, host = split(jid);
if node and host then
return node.."@"..host;
end
return host;
end
-local function _prepped_split(jid)
- local node, host, resource = _split(jid);
+local function prepped_split(jid)
+ local node, host, resource = split(jid);
if host then
if sub(host, -1, -1) == "." then -- Strip empty root label
host = sub(host, 1, -2);
@@ -63,39 +63,29 @@ local function _prepped_split(jid)
return node, host, resource;
end
end
-prepped_split = _prepped_split;
-
-function prep(jid)
- local node, host, resource = _prepped_split(jid);
- if host then
- if node then
- host = node .. "@" .. host;
- end
- if resource then
- host = host .. "/" .. resource;
- end
- end
- return host;
-end
-function join(node, host, resource)
- if node and host and resource then
+local function join(node, host, resource)
+ if not host then return end
+ if node and resource then
return node.."@"..host.."/"..resource;
- elseif node and host then
+ elseif node then
return node.."@"..host;
- elseif host and resource then
+ elseif resource then
return host.."/"..resource;
- elseif host then
- return host;
end
- return nil; -- Invalid JID
+ return host;
+end
+
+local function prep(jid)
+ local node, host, resource = prepped_split(jid);
+ return join(node, host, resource);
end
-function compare(jid, acl)
+local function compare(jid, acl)
-- compare jid to single acl rule
-- TODO compare to table of rules?
- local jid_node, jid_host, jid_resource = _split(jid);
- local acl_node, acl_host, acl_resource = _split(acl);
+ local jid_node, jid_host, jid_resource = split(jid);
+ local acl_node, acl_host, acl_resource = split(acl);
if ((acl_node ~= nil and acl_node == jid_node) or acl_node == nil) and
((acl_host ~= nil and acl_host == jid_host) or acl_host == nil) and
((acl_resource ~= nil and acl_resource == jid_resource) or acl_resource == nil) then
@@ -104,7 +94,31 @@ function compare(jid, acl)
return false
end
-function escape(s) return s and (s:gsub(".", escapes)); end
-function unescape(s) return s and (s:gsub("\\%x%x", unescapes)); end
+local function node(jid)
+ return (select(1, split(jid)));
+end
+
+local function host(jid)
+ return (select(2, split(jid)));
+end
+
+local function resource(jid)
+ return (select(3, split(jid)));
+end
-return _M;
+local function escape(s) return s and (s:gsub(".", escapes)); end
+local function unescape(s) return s and (s:gsub("\\%x%x", unescapes)); end
+
+return {
+ split = split;
+ bare = bare;
+ prepped_split = prepped_split;
+ join = join;
+ prep = prep;
+ compare = compare;
+ node = node;
+ host = host;
+ resource = resource;
+ escape = escape;
+ unescape = unescape;
+};
diff --git a/util/json.lua b/util/json.lua
index 82ebcc43..cba54e8e 100644
--- a/util/json.lua
+++ b/util/json.lua
@@ -12,21 +12,17 @@ local s_char = string.char;
local tostring, tonumber = tostring, tonumber;
local pairs, ipairs = pairs, ipairs;
local next = next;
-local error = error;
-local newproxy, getmetatable, setmetatable = newproxy, getmetatable, setmetatable;
+local getmetatable, setmetatable = getmetatable, setmetatable;
local print = print;
local has_array, array = pcall(require, "util.array");
local array_mt = has_array and getmetatable(array()) or {};
--module("json")
-local json = {};
+local module = {};
-local null = newproxy and newproxy(true) or {};
-if getmetatable and getmetatable(null) then
- getmetatable(null).__tostring = function() return "null"; end;
-end
-json.null = null;
+local null = setmetatable({}, { __tostring = function() return "null"; end; });
+module.null = null;
local escapes = {
["\""] = "\\\"", ["\\"] = "\\\\", ["\b"] = "\\b",
@@ -73,7 +69,7 @@ end
function arraysave(o, buffer)
t_insert(buffer, "[");
if next(o) then
- for i,v in ipairs(o) do
+ for _, v in ipairs(o) do
simplesave(v, buffer);
t_insert(buffer, ",");
end
@@ -148,7 +144,9 @@ end
function simplesave(o, buffer)
local t = type(o);
- if t == "number" then
+ if o == null then
+ t_insert(buffer, "null");
+ elseif t == "number" then
t_insert(buffer, tostring(o));
elseif t == "string" then
stringsave(o, buffer);
@@ -166,17 +164,17 @@ function simplesave(o, buffer)
end
end
-function json.encode(obj)
+function module.encode(obj)
local t = {};
simplesave(obj, t);
return t_concat(t);
end
-function json.encode_ordered(obj)
+function module.encode_ordered(obj)
local t = { ordered = true };
simplesave(obj, t);
return t_concat(t);
end
-function json.encode_array(obj)
+function module.encode_array(obj)
local t = {};
arraysave(obj, t);
return t_concat(t);
@@ -192,7 +190,7 @@ local function _fixobject(obj)
local __array = obj.__array;
if __array then
obj.__array = nil;
- for i,v in ipairs(__array) do
+ for _, v in ipairs(__array) do
t_insert(obj, v);
end
end
@@ -200,7 +198,7 @@ local function _fixobject(obj)
if __hash then
obj.__hash = nil;
local k;
- for i,v in ipairs(__hash) do
+ for _, v in ipairs(__hash) do
if k ~= nil then
obj[k] = v; k = nil;
else
@@ -345,12 +343,12 @@ local first_escape = {
["\\u" ] = "\\u";
};
-function json.decode(json)
+function module.decode(json)
json = json:gsub("\\.", first_escape) -- get rid of all escapes except \uXXXX, making string parsing much simpler
--:gsub("[\r\n]", "\t"); -- \r\n\t are equivalent, we care about none of them, and none of them can be in strings
-
+
-- TODO do encoding verification
-
+
local val, index = _readvalue(json, 1);
if val == nil then return val, index; end
if json:find("[^ \t\r\n]", index) then return nil, "garbage at eof"; end
@@ -358,10 +356,10 @@ function json.decode(json)
return val;
end
-function json.test(object)
- local encoded = json.encode(object);
- local decoded = json.decode(encoded);
- local recoded = json.encode(decoded);
+function module.test(object)
+ local encoded = module.encode(object);
+ local decoded = module.decode(encoded);
+ local recoded = module.encode(decoded);
if encoded ~= recoded then
print("FAILED");
print("encoded:", encoded);
@@ -372,4 +370,4 @@ function json.test(object)
return encoded == recoded;
end
-return json;
+return module;
diff --git a/util/logger.lua b/util/logger.lua
index 26206d4d..e72b29bc 100644
--- a/util/logger.lua
+++ b/util/logger.lua
@@ -1,23 +1,21 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
+-- luacheck: ignore 213/level
-local pcall = pcall;
-
-local find = string.find;
-local ipairs, pairs, setmetatable = ipairs, pairs, setmetatable;
+local pairs = pairs;
-module "logger"
+local _ENV = nil;
local level_sinks = {};
local make_logger;
-function init(name)
+local function init(name)
local log_debug = make_logger(name, "debug");
local log_info = make_logger(name, "info");
local log_warn = make_logger(name, "warn");
@@ -52,7 +50,7 @@ function make_logger(source_name, level)
return logger;
end
-function reset()
+local function reset()
for level, handler_list in pairs(level_sinks) do
-- Clear all handlers for this level
for i = 1, #handler_list do
@@ -61,7 +59,7 @@ function reset()
end
end
-function add_level_sink(level, sink_function)
+local function add_level_sink(level, sink_function)
if not level_sinks[level] then
level_sinks[level] = { sink_function };
else
@@ -69,6 +67,10 @@ function add_level_sink(level, sink_function)
end
end
-_M.new = make_logger;
-
-return _M;
+return {
+ init = init;
+ make_logger = make_logger;
+ reset = reset;
+ add_level_sink = add_level_sink;
+ new = make_logger;
+};
diff --git a/util/mercurial.lua b/util/mercurial.lua
new file mode 100644
index 00000000..3f75c4c1
--- /dev/null
+++ b/util/mercurial.lua
@@ -0,0 +1,34 @@
+
+local lfs = require"lfs";
+
+local hg = { };
+
+function hg.check_id(path)
+ if lfs.attributes(path, 'mode') ~= "directory" then
+ return nil, "not a directory";
+ end
+ local hg_dirstate = io.open(path.."/.hg/dirstate");
+ local hgid, hgrepo
+ if hg_dirstate then
+ hgid = ("%02x%02x%02x%02x%02x%02x"):format(hg_dirstate:read(6):byte(1, 6));
+ hg_dirstate:close();
+ local hg_changelog = io.open(path.."/.hg/store/00changelog.i");
+ if hg_changelog then
+ hg_changelog:seek("set", 0x20);
+ hgrepo = ("%02x%02x%02x%02x%02x%02x"):format(hg_changelog:read(6):byte(1, 6));
+ hg_changelog:close();
+ end
+ else
+ local hg_archival,e = io.open(path.."/.hg_archival.txt");
+ if hg_archival then
+ local repo = hg_archival:read("*l");
+ local node = hg_archival:read("*l");
+ hg_archival:close()
+ hgid = node and node:match("^node: (%x%x%x%x%x%x%x%x%x%x%x%x)")
+ hgrepo = repo and repo:match("^repo: (%x%x%x%x%x%x%x%x%x%x%x%x)")
+ end
+ end
+ return hgid, hgrepo;
+end
+
+return hg;
diff --git a/util/multitable.lua b/util/multitable.lua
index dbf34d28..e4321d3d 100644
--- a/util/multitable.lua
+++ b/util/multitable.lua
@@ -1,16 +1,17 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
local select = select;
local t_insert = table.insert;
-local unpack, pairs, next, type = unpack, pairs, next, type;
+local pairs, next, type = pairs, next, type;
+local unpack = table.unpack or unpack; --luacheck: ignore 113
-module "multitable"
+local _ENV = nil;
local function get(self, ...)
local t = self.data;
@@ -126,7 +127,7 @@ local function search_add(self, results, ...)
return results;
end
-function iter(self, ...)
+local function iter(self, ...)
local query = { ... };
local maxdepth = select("#", ...);
local stack = { self.data };
@@ -161,7 +162,7 @@ function iter(self, ...)
return it, self;
end
-function new()
+local function new()
return {
data = {};
get = get;
@@ -174,4 +175,7 @@ function new()
};
end
-return _M;
+return {
+ iter = iter;
+ new = new;
+};
diff --git a/util/openssl.lua b/util/openssl.lua
index 39fe99d6..703c6d15 100644
--- a/util/openssl.lua
+++ b/util/openssl.lua
@@ -12,7 +12,7 @@ local config = {};
_M.config = config;
local ssl_config = {};
-local ssl_config_mt = {__index=ssl_config};
+local ssl_config_mt = { __index = ssl_config };
function config.new()
return setmetatable({
@@ -61,17 +61,16 @@ local DN_order = {
_M._DN_order = DN_order;
function ssl_config:serialize()
local s = "";
- for k, t in pairs(self) do
- s = s .. ("[%s]\n"):format(k);
- if k == "subject_alternative_name" then
+ for section, t in pairs(self) do
+ s = s .. ("[%s]\n"):format(section);
+ if section == "subject_alternative_name" then
for san, n in pairs(t) do
- for i = 1,#n do
+ for i = 1, #n do
s = s .. s_format("%s.%d = %s\n", san, i -1, n[i]);
end
end
- elseif k == "distinguished_name" then
- for i=1,#DN_order do
- local k = DN_order[i]
+ elseif section == "distinguished_name" then
+ for _, k in ipairs(t[1] and t or DN_order) do
local v = t[k];
if v then
s = s .. ("%s = %s\n"):format(k, v);
@@ -107,7 +106,7 @@ end
function ssl_config:add_sRVName(host, service)
t_insert(self.subject_alternative_name.otherName,
- s_format("%s;%s", oid_dnssrv, ia5string("_" .. service .."." .. idna_to_ascii(host))));
+ s_format("%s;%s", oid_dnssrv, ia5string("_" .. service .. "." .. idna_to_ascii(host))));
end
function ssl_config:add_xmppAddr(host)
@@ -118,10 +117,10 @@ end
function ssl_config:from_prosody(hosts, config, certhosts)
-- TODO Decide if this should go elsewhere
local found_matching_hosts = false;
- for i = 1,#certhosts do
+ for i = 1, #certhosts do
local certhost = certhosts[i];
for name in pairs(hosts) do
- if name == certhost or name:sub(-1-#certhost) == "."..certhost then
+ if name == certhost or name:sub(-1-#certhost) == "." .. certhost then
found_matching_hosts = true;
self:add_dNSName(name);
--print(name .. "#component_module: " .. (config.get(name, "component_module") or "nil"));
@@ -144,30 +143,31 @@ end
do -- Lua to shell calls.
local function shell_escape(s)
- return s:gsub("'",[['\'']]);
+ return "'" .. tostring(s):gsub("'",[['\'']]) .. "'";
end
- local function serialize(f,o)
- local r = {"openssl", f};
- for k,v in pairs(o) do
+ local function serialize(command, args)
+ local commandline = { "openssl", command };
+ for k, v in pairs(args) do
if type(k) == "string" then
- t_insert(r, ("-%s"):format(k));
+ t_insert(commandline, ("-%s"):format(k));
if v ~= true then
- t_insert(r, ("'%s'"):format(shell_escape(tostring(v))));
+ t_insert(commandline, shell_escape(v));
end
end
end
- for _,v in ipairs(o) do
- t_insert(r, ("'%s'"):format(shell_escape(tostring(v))));
+ for _, v in ipairs(args) do
+ t_insert(commandline, shell_escape(v));
end
- return t_concat(r, " ");
+ return t_concat(commandline, " ");
end
local os_execute = os.execute;
setmetatable(_M, {
- __index=function(_,f)
+ __index = function(_, command)
return function(opts)
- return 0 == os_execute(serialize(f, type(opts) == "table" and opts or {}));
+ local ret = os_execute(serialize(command, type(opts) == "table" and opts or {}));
+ return ret == true or ret == 0;
end;
end;
});
diff --git a/util/paths.lua b/util/paths.lua
new file mode 100644
index 00000000..89f4cad9
--- /dev/null
+++ b/util/paths.lua
@@ -0,0 +1,44 @@
+local t_concat = table.concat;
+
+local path_sep = package.config:sub(1,1);
+
+local path_util = {}
+
+-- Helper function to resolve relative paths (needed by config)
+function path_util.resolve_relative_path(parent_path, path)
+ if path then
+ -- Some normalization
+ parent_path = parent_path:gsub("%"..path_sep.."+$", "");
+ path = path:gsub("^%.%"..path_sep.."+", "");
+
+ local is_relative;
+ if path_sep == "/" and path:sub(1,1) ~= "/" then
+ is_relative = true;
+ elseif path_sep == "\\" and (path:sub(1,1) ~= "/" and (path:sub(2,3) ~= ":\\" and path:sub(2,3) ~= ":/")) then
+ is_relative = true;
+ end
+ if is_relative then
+ return parent_path..path_sep..path;
+ end
+ end
+ return path;
+end
+
+-- Helper function to convert a glob to a Lua pattern
+function path_util.glob_to_pattern(glob)
+ return "^"..glob:gsub("[%p*?]", function (c)
+ if c == "*" then
+ return ".*";
+ elseif c == "?" then
+ return ".";
+ else
+ return "%"..c;
+ end
+ end).."$";
+end
+
+function path_util.join(...)
+ return t_concat({...}, path_sep);
+end
+
+return path_util;
diff --git a/util/pluginloader.lua b/util/pluginloader.lua
index 112c0d52..004855f0 100644
--- a/util/pluginloader.lua
+++ b/util/pluginloader.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -17,9 +17,7 @@ end
local io_open = io.open;
local envload = require "util.envload".envload;
-module "pluginloader"
-
-function load_file(names)
+local function load_file(names)
local file, err, path;
for i=1,#plugin_dir do
for j=1,#names do
@@ -35,7 +33,7 @@ function load_file(names)
return file, err;
end
-function load_resource(plugin, resource)
+local function load_resource(plugin, resource)
resource = resource or "mod_"..plugin..".lua";
local names = {
@@ -48,7 +46,7 @@ function load_resource(plugin, resource)
return load_file(names);
end
-function load_code(plugin, resource, env)
+local function load_code(plugin, resource, env)
local content, err = load_resource(plugin, resource);
if not content then return content, err; end
local path = err;
@@ -57,4 +55,23 @@ function load_code(plugin, resource, env)
return f, path;
end
-return _M;
+local function load_code_ext(plugin, resource, extension, env)
+ local content, err = load_resource(plugin, resource.."."..extension);
+ if not content then
+ content, err = load_resource(resource, resource.."."..extension);
+ if not content then
+ return content, err;
+ end
+ end
+ local path = err;
+ local f, err = envload(content, "@"..path, env);
+ if not f then return f, err; end
+ return f, path;
+end
+
+return {
+ load_file = load_file;
+ load_resource = load_resource;
+ load_code = load_code;
+ load_code_ext = load_code_ext;
+};
diff --git a/util/presence.lua b/util/presence.lua
new file mode 100644
index 00000000..f6370354
--- /dev/null
+++ b/util/presence.lua
@@ -0,0 +1,38 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local t_insert = table.insert;
+
+local function select_top_resources(user)
+ local priority = 0;
+ local recipients = {};
+ for _, session in pairs(user.sessions) do -- find resource with greatest priority
+ if session.presence then
+ -- TODO check active privacy list for session
+ local p = session.priority;
+ if p > priority then
+ priority = p;
+ recipients = {session};
+ elseif p == priority then
+ t_insert(recipients, session);
+ end
+ end
+ end
+ return recipients;
+end
+local function recalc_resource_map(user)
+ if user then
+ user.top_resources = select_top_resources(user);
+ if #user.top_resources == 0 then user.top_resources = nil; end
+ end
+end
+
+return {
+ select_top_resources = select_top_resources;
+ recalc_resource_map = recalc_resource_map;
+}
diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua
index c6fe1986..7c9a3c19 100644
--- a/util/prosodyctl.lua
+++ b/util/prosodyctl.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -22,35 +22,29 @@ local nodeprep, nameprep = stringprep.nodeprep, stringprep.nameprep;
local io, os = io, os;
local print = print;
-local tostring, tonumber = tostring, tonumber;
+local tonumber = tonumber;
local CFG_SOURCEDIR = _G.CFG_SOURCEDIR;
local _G = _G;
local prosody = prosody;
-module "prosodyctl"
-
-- UI helpers
-function show_message(msg, ...)
- print(msg:format(...));
-end
-
-function show_warning(msg, ...)
+local function show_message(msg, ...)
print(msg:format(...));
end
-function show_usage(usage, desc)
+local function show_usage(usage, desc)
print("Usage: ".._G.arg[0].." "..usage);
if desc then
print(" "..desc);
end
end
-function getchar(n)
+local function getchar(n)
local stty_ret = os.execute("stty raw -echo 2>/dev/null");
local ok, char;
- if stty_ret == 0 then
+ if stty_ret == true or stty_ret == 0 then
ok, char = pcall(io.read, n or 1);
os.execute("stty sane");
else
@@ -64,14 +58,14 @@ function getchar(n)
end
end
-function getline()
+local function getline()
local ok, line = pcall(io.read, "*l");
if ok then
return line;
end
end
-function getpass()
+local function getpass()
local stty_ret = os.execute("stty -echo 2>/dev/null");
if stty_ret ~= 0 then
io.write("\027[08m"); -- ANSI 'hidden' text attribute
@@ -88,7 +82,7 @@ function getpass()
end
end
-function show_yesno(prompt)
+local function show_yesno(prompt)
io.write(prompt, " ");
local choice = getchar():lower();
io.write("\n");
@@ -99,7 +93,7 @@ function show_yesno(prompt)
return (choice == "y");
end
-function read_password()
+local function read_password()
local password;
while true do
io.write("Enter new password: ");
@@ -120,7 +114,7 @@ function read_password()
return password;
end
-function show_prompt(prompt)
+local function show_prompt(prompt)
io.write(prompt, " ");
local line = getline();
line = line and line:gsub("\n$","");
@@ -128,7 +122,7 @@ function show_prompt(prompt)
end
-- Server control
-function adduser(params)
+local function adduser(params)
local user, host, password = nodeprep(params.user), nameprep(params.host), params.password;
if not user then
return false, "invalid-username";
@@ -146,44 +140,44 @@ function adduser(params)
if not(provider) or provider.name == "null" then
usermanager.initialize_host(host);
end
-
+
local ok, errmsg = usermanager.create_user(user, password, host);
if not ok then
- return false, errmsg;
+ return false, errmsg or "creating-user-failed";
end
return true;
end
-function user_exists(params)
- local user, host, password = nodeprep(params.user), nameprep(params.host), params.password;
+local function user_exists(params)
+ local user, host = nodeprep(params.user), nameprep(params.host);
storagemanager.initialize_host(host);
local provider = prosody.hosts[host].users;
if not(provider) or provider.name == "null" then
usermanager.initialize_host(host);
end
-
+
return usermanager.user_exists(user, host);
end
-function passwd(params)
- if not _M.user_exists(params) then
+local function passwd(params)
+ if not user_exists(params) then
return false, "no-such-user";
end
-
- return _M.adduser(params);
+
+ return adduser(params);
end
-function deluser(params)
- if not _M.user_exists(params) then
+local function deluser(params)
+ if not user_exists(params) then
return false, "no-such-user";
end
local user, host = nodeprep(params.user), nameprep(params.host);
-
+
return usermanager.delete_user(user, host);
end
-function getpid()
+local function getpid()
local pidfile = config.get("*", "pidfile");
if not pidfile then
return false, "no-pidfile";
@@ -192,35 +186,35 @@ function getpid()
if type(pidfile) ~= "string" then
return false, "invalid-pidfile";
end
-
- local modules_enabled = set.new(config.get("*", "modules_enabled"));
- if not modules_enabled:contains("posix") then
+
+ local modules_enabled = set.new(config.get("*", "modules_disabled"));
+ if prosody.platform ~= "posix" or modules_enabled:contains("posix") then
return false, "no-posix";
end
-
+
local file, err = io.open(pidfile, "r+");
if not file then
return false, "pidfile-read-failed", err;
end
-
+
local locked, err = lfs.lock(file, "w");
if locked then
file:close();
return false, "pidfile-not-locked";
end
-
+
local pid = tonumber(file:read("*a"));
file:close();
-
+
if not pid then
return false, "invalid-pid";
end
-
+
return true, pid;
end
-function isrunning()
- local ok, pid, err = _M.getpid();
+local function isrunning()
+ local ok, pid, err = getpid();
if not ok then
if pid == "pidfile-read-failed" or pid == "pidfile-not-locked" then
-- Report as not running, since we can't open the pidfile
@@ -232,8 +226,8 @@ function isrunning()
return true, signal.kill(pid, 0) == 0;
end
-function start()
- local ok, ret = _M.isrunning();
+local function start()
+ local ok, ret = isrunning();
if not ok then
return ok, ret;
end
@@ -248,36 +242,55 @@ function start()
return true;
end
-function stop()
- local ok, ret = _M.isrunning();
+local function stop()
+ local ok, ret = isrunning();
if not ok then
return ok, ret;
end
if not ret then
return false, "not-running";
end
-
- local ok, pid = _M.getpid()
+
+ local ok, pid = getpid()
if not ok then return false, pid; end
-
+
signal.kill(pid, signal.SIGTERM);
return true;
end
-function reload()
- local ok, ret = _M.isrunning();
+local function reload()
+ local ok, ret = isrunning();
if not ok then
return ok, ret;
end
if not ret then
return false, "not-running";
end
-
- local ok, pid = _M.getpid()
+
+ local ok, pid = getpid()
if not ok then return false, pid; end
-
+
signal.kill(pid, signal.SIGHUP);
return true;
end
-return _M;
+return {
+ show_message = show_message;
+ show_warning = show_message;
+ show_usage = show_usage;
+ getchar = getchar;
+ getline = getline;
+ getpass = getpass;
+ show_yesno = show_yesno;
+ read_password = read_password;
+ show_prompt = show_prompt;
+ adduser = adduser;
+ user_exists = user_exists;
+ passwd = passwd;
+ deluser = deluser;
+ getpid = getpid;
+ isrunning = isrunning;
+ start = start;
+ stop = stop;
+ reload = reload;
+};
diff --git a/util/pubsub.lua b/util/pubsub.lua
index e1418c62..6d12690a 100644
--- a/util/pubsub.lua
+++ b/util/pubsub.lua
@@ -1,23 +1,27 @@
local events = require "util.events";
-
-module("pubsub", package.seeall);
+local t_remove = table.remove;
local service = {};
local service_mt = { __index = service };
-local default_config = {
+local default_config = { __index = {
broadcaster = function () end;
get_affiliation = function () end;
capabilities = {};
-};
+} };
+local default_node_config = { __index = {
+ ["pubsub#max_items"] = "20";
+} };
-function new(config)
+local function new(config)
config = config or {};
return setmetatable({
- config = setmetatable(config, { __index = default_config });
+ config = setmetatable(config, default_config);
+ node_defaults = setmetatable(config.node_defaults or {}, default_node_config);
affiliations = {};
subscriptions = {};
nodes = {};
+ data = {};
events = events.new();
}, service_mt);
end
@@ -29,13 +33,13 @@ end
function service:may(node, actor, action)
if actor == true then return true; end
-
+
local node_obj = self.nodes[node];
local node_aff = node_obj and node_obj.affiliations[actor];
local service_aff = self.affiliations[actor]
or self.config.get_affiliation(actor, node, action)
or "none";
-
+
-- Check if node allows/forbids it
local node_capabilities = node_obj and node_obj.capabilities;
if node_capabilities then
@@ -47,7 +51,7 @@ function service:may(node, actor, action)
end
end
end
-
+
-- Check service-wide capabilities instead
local service_capabilities = self.config.capabilities;
local caps = service_capabilities[node_aff or service_aff];
@@ -57,7 +61,7 @@ function service:may(node, actor, action)
return can;
end
end
-
+
return false;
end
@@ -202,7 +206,7 @@ function service:get_subscription(node, actor, jid)
return true, node_obj.subscribers[jid];
end
-function service:create(node, actor)
+function service:create(node, actor, options)
-- Access checking
if not self:may(node, actor, "create") then
return false, "forbidden";
@@ -211,17 +215,20 @@ function service:create(node, actor)
if self.nodes[node] then
return false, "conflict";
end
-
+
+ self.data[node] = {};
self.nodes[node] = {
name = node;
subscribers = {};
- config = {};
- data = {};
+ config = setmetatable(options or {}, {__index=self.node_defaults});
affiliations = {};
};
+ setmetatable(self.nodes[node], { __index = { data = self.data[node] } }); -- COMPAT
+ self.events.fire_event("node-created", { node = node, actor = actor });
local ok, err = self:set_affiliation(node, true, actor, "owner");
if not ok then
self.nodes[node] = nil;
+ self.data[node] = nil;
end
return ok, err;
end
@@ -237,10 +244,31 @@ function service:delete(node, actor)
return false, "item-not-found";
end
self.nodes[node] = nil;
+ self.data[node] = nil;
+ self.events.fire_event("node-deleted", { node = node, actor = actor });
self.config.broadcaster("delete", node, node_obj.subscribers);
return true;
end
+local function remove_item_by_id(data, id)
+ if not data[id] then return end
+ data[id] = nil;
+ for i, _id in ipairs(data) do
+ if id == _id then
+ t_remove(data, i);
+ return i;
+ end
+ end
+end
+
+local function trim_items(data, max)
+ max = tonumber(max);
+ if not max or #data <= max then return end
+ repeat
+ data[t_remove(data, 1)] = nil;
+ until #data <= max
+end
+
function service:publish(node, actor, id, item)
-- Access checking
if not self:may(node, actor, "publish") then
@@ -258,9 +286,13 @@ function service:publish(node, actor, id, item)
end
node_obj = self.nodes[node];
end
- node_obj.data[id] = item;
+ local node_data = self.data[node];
+ remove_item_by_id(node_data, id);
+ node_data[#node_data + 1] = id;
+ node_data[id] = item;
+ trim_items(node_data, node_obj.config["pubsub#max_items"]);
self.events.fire_event("item-published", { node = node, actor = actor, id = id, item = item });
- self.config.broadcaster("items", node, node_obj.subscribers, item);
+ self.config.broadcaster("items", node, node_obj.subscribers, item, actor);
return true;
end
@@ -271,10 +303,11 @@ function service:retract(node, actor, id, retract)
end
--
local node_obj = self.nodes[node];
- if (not node_obj) or (not node_obj.data[id]) then
+ if (not node_obj) or (not self.data[node][id]) then
return false, "item-not-found";
end
- node_obj.data[id] = nil;
+ self.events.fire_event("item-retracted", { node = node, actor = actor, id = id });
+ remove_item_by_id(self.data[node], id);
if retract then
self.config.broadcaster("items", node, node_obj.subscribers, retract);
end
@@ -291,7 +324,8 @@ function service:purge(node, actor, notify)
if not node_obj then
return false, "item-not-found";
end
- node_obj.data = {}; -- Purge
+ self.data[node] = {}; -- Purge
+ self.events.fire_event("node-purged", { node = node, actor = actor });
if notify then
self.config.broadcaster("purge", node, node_obj.subscribers);
end
@@ -309,9 +343,9 @@ function service:get_items(node, actor, id)
return false, "item-not-found";
end
if id then -- Restrict results to a single specific item
- return true, { [id] = node_obj.data[id] };
+ return true, { id, [id] = self.data[node][id] };
else
- return true, node_obj.data;
+ return true, self.data[node];
end
end
@@ -388,4 +422,24 @@ function service:set_node_capabilities(node, actor, capabilities)
return true;
end
-return _M;
+function service:set_node_config(node, actor, new_config)
+ if not self:may(node, actor, "configure") then
+ return false, "forbidden";
+ end
+
+ local node_obj = self.nodes[node];
+ if not node_obj then
+ return false, "item-not-found";
+ end
+
+ for k,v in pairs(new_config) do
+ node_obj.config[k] = v;
+ end
+ trim_items(self.data[node], node_obj.config["pubsub#max_items"]);
+
+ return true;
+end
+
+return {
+ new = new;
+};
diff --git a/util/queue.lua b/util/queue.lua
new file mode 100644
index 00000000..728e905f
--- /dev/null
+++ b/util/queue.lua
@@ -0,0 +1,73 @@
+-- Prosody IM
+-- Copyright (C) 2008-2015 Matthew Wild
+-- Copyright (C) 2008-2015 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+-- Small ringbuffer library (i.e. an efficient FIFO queue with a size limit)
+-- (because unbounded dynamically-growing queues are a bad thing...)
+
+local have_utable, utable = pcall(require, "util.table"); -- For pre-allocation of table
+
+local function new(size, allow_wrapping)
+ -- Head is next insert, tail is next read
+ local head, tail = 1, 1;
+ local items = 0; -- Number of stored items
+ local t = have_utable and utable.create(size, 0) or {}; -- Table to hold items
+ --luacheck: ignore 212/self
+ return {
+ _items = t;
+ size = size;
+ count = function (self) return items; end;
+ push = function (self, item)
+ if items >= size then
+ if allow_wrapping then
+ tail = (tail%size)+1; -- Advance to next oldest item
+ items = items - 1;
+ else
+ return nil, "queue full";
+ end
+ end
+ t[head] = item;
+ items = items + 1;
+ head = (head%size)+1;
+ return true;
+ end;
+ pop = function (self)
+ if items == 0 then
+ return nil;
+ end
+ local item;
+ item, t[tail] = t[tail], 0;
+ tail = (tail%size)+1;
+ items = items - 1;
+ return item;
+ end;
+ peek = function (self)
+ if items == 0 then
+ return nil;
+ end
+ return t[tail];
+ end;
+ items = function (self)
+ --luacheck: ignore 431/t
+ return function (t, pos)
+ if pos >= t:count() then
+ return nil;
+ end
+ local read_pos = tail + pos;
+ if read_pos > t.size then
+ read_pos = (read_pos%size);
+ end
+ return pos+1, t._items[read_pos];
+ end, self, 0;
+ end;
+ };
+end
+
+return {
+ new = new;
+};
+
diff --git a/util/random.lua b/util/random.lua
new file mode 100644
index 00000000..574e2e1c
--- /dev/null
+++ b/util/random.lua
@@ -0,0 +1,30 @@
+-- Prosody IM
+-- Copyright (C) 2008-2014 Matthew Wild
+-- Copyright (C) 2008-2014 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local ok, crand = pcall(require, "util.crand");
+if ok then return crand; end
+
+local urandom, urandom_err = io.open("/dev/urandom", "r");
+
+local function seed()
+end
+
+local function bytes(n)
+ return urandom:read(n);
+end
+
+if not urandom then
+ function bytes()
+ error("Unable to obtain a secure random number generator, please see https://prosody.im/doc/random ("..urandom_err..")");
+ end
+end
+
+return {
+ seed = seed;
+ bytes = bytes;
+};
diff --git a/util/rfc6724.lua b/util/rfc6724.lua
index c8aec631..81f78d55 100644
--- a/util/rfc6724.lua
+++ b/util/rfc6724.lua
@@ -10,7 +10,6 @@
-- We can't hand this off to getaddrinfo, since it blocks
local ip_commonPrefixLength = require"util.ip".commonPrefixLength
-local new_ip = require"util.ip".new_ip;
local function commonPrefixLength(ipA, ipB)
local len = ip_commonPrefixLength(ipA, ipB);
diff --git a/util/sasl.lua b/util/sasl.lua
index afb3861b..5845f34a 100644
--- a/util/sasl.lua
+++ b/util/sasl.lua
@@ -19,7 +19,7 @@ local setmetatable = setmetatable;
local assert = assert;
local require = require;
-module "sasl"
+local _ENV = nil;
--[[
Authentication Backend Prototypes:
@@ -27,19 +27,38 @@ Authentication Backend Prototypes:
state = false : disabled
state = true : enabled
state = nil : non-existant
+
+Channel Binding:
+
+To enable support of channel binding in some mechanisms you need to provide appropriate callbacks in a table
+at profile.cb.
+
+Example:
+ profile.cb["tls-unique"] = function(self)
+ return self.user
+ end
+
]]
local method = {};
method.__index = method;
local mechanisms = {};
local backend_mechanism = {};
+local mechanism_channelbindings = {};
-- register a new SASL mechanims
-function registerMechanism(name, backends, f)
+local function registerMechanism(name, backends, f, cb_backends)
assert(type(name) == "string", "Parameter name MUST be a string.");
assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table.");
assert(type(f) == "function", "Parameter f MUST be a function.");
+ if cb_backends then assert(type(cb_backends) == "table"); end
mechanisms[name] = f
+ if cb_backends then
+ mechanism_channelbindings[name] = {};
+ for _, cb_name in ipairs(cb_backends) do
+ mechanism_channelbindings[name][cb_name] = true;
+ end
+ end
for _, backend_name in ipairs(backends) do
if backend_mechanism[backend_name] == nil then backend_mechanism[backend_name] = {}; end
t_insert(backend_mechanism[backend_name], name);
@@ -47,7 +66,7 @@ function registerMechanism(name, backends, f)
end
-- create a new SASL object which can be used to authenticate clients
-function new(realm, profile)
+local function new(realm, profile)
local mechanisms = profile.mechanisms;
if not mechanisms then
mechanisms = {};
@@ -63,6 +82,15 @@ function new(realm, profile)
return setmetatable({ profile = profile, realm = realm, mechs = mechanisms }, method);
end
+-- add a channel binding handler
+function method:add_cb_handler(name, f)
+ if type(self.profile.cb) ~= "table" then
+ self.profile.cb = {};
+ end
+ self.profile.cb[name] = f;
+ return self;
+end
+
-- get a fresh clone with the same realm and profile
function method:clean_clone()
return new(self.realm, self.profile)
@@ -70,7 +98,23 @@ end
-- get a list of possible SASL mechanims to use
function method:mechanisms()
- return self.mechs;
+ local current_mechs = {};
+ for mech, _ in pairs(self.mechs) do
+ if mechanism_channelbindings[mech] then
+ if self.profile.cb then
+ local ok = false;
+ for cb_name, _ in pairs(self.profile.cb) do
+ if mechanism_channelbindings[mech][cb_name] then
+ ok = true;
+ end
+ end
+ if ok == true then current_mechs[mech] = true; end
+ end
+ else
+ current_mechs[mech] = true;
+ end
+ end
+ return current_mechs;
end
-- select a mechanism to use
@@ -92,5 +136,9 @@ require "util.sasl.plain" .init(registerMechanism);
require "util.sasl.digest-md5".init(registerMechanism);
require "util.sasl.anonymous" .init(registerMechanism);
require "util.sasl.scram" .init(registerMechanism);
+require "util.sasl.external" .init(registerMechanism);
-return _M;
+return {
+ registerMechanism = registerMechanism;
+ new = new;
+};
diff --git a/util/sasl/anonymous.lua b/util/sasl/anonymous.lua
index ca5fe404..6201db32 100644
--- a/util/sasl/anonymous.lua
+++ b/util/sasl/anonymous.lua
@@ -11,12 +11,10 @@
--
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-local s_match = string.match;
-local log = require "util.logger".init("sasl");
local generate_uuid = require "util.uuid".generate;
-module "sasl.anonymous"
+local _ENV = nil;
--=========================
--SASL ANONYMOUS according to RFC 4505
@@ -39,8 +37,10 @@ local function anonymous(self, message)
return "success"
end
-function init(registerMechanism)
+local function init(registerMechanism)
registerMechanism("ANONYMOUS", {"anonymous"}, anonymous);
end
-return _M;
+return {
+ init = init;
+}
diff --git a/util/sasl/digest-md5.lua b/util/sasl/digest-md5.lua
index 591d8537..695dd2a3 100644
--- a/util/sasl/digest-md5.lua
+++ b/util/sasl/digest-md5.lua
@@ -25,7 +25,7 @@ local log = require "util.logger".init("sasl");
local generate_uuid = require "util.uuid".generate;
local nodeprep = require "util.encodings".stringprep.nodeprep;
-module "sasl.digest-md5"
+local _ENV = nil;
--=========================
--SASL DIGEST-MD5 according to RFC 2831
@@ -241,8 +241,10 @@ local function digest(self, message)
end
end
-function init(registerMechanism)
+local function init(registerMechanism)
registerMechanism("DIGEST-MD5", {"plain"}, digest);
end
-return _M;
+return {
+ init = init;
+}
diff --git a/util/sasl/external.lua b/util/sasl/external.lua
new file mode 100644
index 00000000..5ba90190
--- /dev/null
+++ b/util/sasl/external.lua
@@ -0,0 +1,27 @@
+local saslprep = require "util.encodings".stringprep.saslprep;
+
+local _ENV = nil;
+
+local function external(self, message)
+ message = saslprep(message);
+ local state
+ self.username, state = self.profile.external(message);
+
+ if state == false then
+ return "failure", "account-disabled";
+ elseif state == nil then
+ return "failure", "not-authorized";
+ elseif state == "expired" then
+ return "false", "credentials-expired";
+ end
+
+ return "success";
+end
+
+local function init(registerMechanism)
+ registerMechanism("EXTERNAL", {"external"}, external);
+end
+
+return {
+ init = init;
+}
diff --git a/util/sasl/plain.lua b/util/sasl/plain.lua
index c9ec2911..26e65335 100644
--- a/util/sasl/plain.lua
+++ b/util/sasl/plain.lua
@@ -16,7 +16,7 @@ local saslprep = require "util.encodings".stringprep.saslprep;
local nodeprep = require "util.encodings".stringprep.nodeprep;
local log = require "util.logger".init("sasl");
-module "sasl.plain"
+local _ENV = nil;
-- ================================
-- SASL PLAIN according to RFC 4616
@@ -82,8 +82,10 @@ local function plain(self, message)
return "success";
end
-function init(registerMechanism)
+local function init(registerMechanism)
registerMechanism("PLAIN", {"plain", "plain_test"}, plain);
end
-return _M;
+return {
+ init = init;
+}
diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua
index cf2f0ede..d2b2abde 100644
--- a/util/sasl/scram.lua
+++ b/util/sasl/scram.lua
@@ -13,7 +13,6 @@
local s_match = string.match;
local type = type
-local string = string
local base64 = require "util.encodings".base64;
local hmac_sha1 = require "util.hashes".hmac_sha1;
local sha1 = require "util.hashes".sha1;
@@ -26,7 +25,7 @@ local t_concat = table.concat;
local char = string.char;
local byte = string.byte;
-module "sasl.scram"
+local _ENV = nil;
--=========================
--SASL SCRAM-SHA-1 according to RFC 5802
@@ -39,18 +38,14 @@ scram_{MECH}:
function(username, realm)
return stored_key, server_key, iteration_count, salt, state;
end
+
+Supported Channel Binding Backends
+
+'tls-unique' according to RFC 5929
]]
local default_i = 4096
-local function bp( b )
- local result = ""
- for i=1, b:len() do
- result = result.."\\"..b:byte(i)
- end
- return result
-end
-
local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;};
local result = {};
@@ -73,11 +68,11 @@ local function validate_username(username, _nodeprep)
return false
end
end
-
+
-- replace =2C with , and =3D with =
username = username:gsub("=2C", ",");
username = username:gsub("=3D", "=");
-
+
-- apply SASLprep
username = saslprep(username);
@@ -92,7 +87,7 @@ local function hashprep(hashname)
return hashname:lower():gsub("-", "_");
end
-function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
+local function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then
return false, "inappropriate argument types"
end
@@ -106,96 +101,131 @@ function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
end
local function scram_gen(hash_name, H_f, HMAC_f)
+ local profile_name = "scram_" .. hashprep(hash_name);
local function scram_hash(self, message)
- if not self.state then self["state"] = {} end
-
+ local support_channel_binding = false;
+ if self.profile.cb then support_channel_binding = true; end
+
if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end
- if not self.state.name then
+ local state = self.state;
+ if not state then
-- we are processing client_first_message
local client_first_message = message;
-
+
-- TODO: fail if authzid is provided, since we don't support them yet
- self.state["client_first_message"] = client_first_message;
- self.state["gs2_cbind_flag"], self.state["authzid"], self.state["name"], self.state["clientnonce"]
- = client_first_message:match("^(%a),(.*),n=(.*),r=([^,]*).*");
+ local gs2_header, gs2_cbind_flag, gs2_cbind_name, authzid, client_first_message_bare, username, clientnonce
+ = s_match(client_first_message, "^(([pny])=?([^,]*),([^,]*),)(m?=?[^,]*,?n=([^,]*),r=([^,]*),?.*)$");
- -- we don't do any channel binding yet
- if self.state.gs2_cbind_flag ~= "n" and self.state.gs2_cbind_flag ~= "y" then
+ if not gs2_cbind_flag then
return "failure", "malformed-request";
end
- if not self.state.name or not self.state.clientnonce then
- return "failure", "malformed-request", "Channel binding isn't support at this time.";
+ if support_channel_binding and gs2_cbind_flag == "y" then
+ -- "y" -> client does support channel binding
+ -- but thinks the server does not.
+ return "failure", "malformed-request";
+ end
+
+ if gs2_cbind_flag == "n" then
+ -- "n" -> client doesn't support channel binding.
+ support_channel_binding = false;
end
-
- self.state.name = validate_username(self.state.name, self.profile.nodeprep);
- if not self.state.name then
+
+ if support_channel_binding and gs2_cbind_flag == "p" then
+ -- check whether we support the proposed channel binding type
+ if not self.profile.cb[gs2_cbind_name] then
+ return "failure", "malformed-request", "Proposed channel binding type isn't supported.";
+ end
+ else
+ -- no channel binding,
+ gs2_cbind_name = nil;
+ end
+
+ username = validate_username(username, self.profile.nodeprep);
+ if not username then
log("debug", "Username violates either SASLprep or contains forbidden character sequences.")
return "failure", "malformed-request", "Invalid username.";
end
-
- self.state["servernonce"] = generate_uuid();
-
+
-- retreive credentials
+ local stored_key, server_key, salt, iteration_count;
if self.profile.plain then
- local password, state = self.profile.plain(self, self.state.name, self.realm)
- if state == nil then return "failure", "not-authorized"
- elseif state == false then return "failure", "account-disabled" end
-
+ local password, status = self.profile.plain(self, username, self.realm)
+ if status == nil then return "failure", "not-authorized"
+ elseif status == false then return "failure", "account-disabled" end
+
password = saslprep(password);
if not password then
log("debug", "Password violates SASLprep.");
return "failure", "not-authorized", "Invalid password."
end
- self.state.salt = generate_uuid();
- self.state.iteration_count = default_i;
+ salt = generate_uuid();
+ iteration_count = default_i;
- local succ = false;
- succ, self.state.stored_key, self.state.server_key = getAuthenticationDatabaseSHA1(password, self.state.salt, default_i, self.state.iteration_count);
+ local succ;
+ succ, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count);
if not succ then
- log("error", "Generating authentication database failed. Reason: %s", self.state.stored_key);
+ log("error", "Generating authentication database failed. Reason: %s", stored_key);
return "failure", "temporary-auth-failure";
end
- elseif self.profile["scram_"..hashprep(hash_name)] then
- local stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, self.state.name, self.realm);
- if state == nil then return "failure", "not-authorized"
- elseif state == false then return "failure", "account-disabled" end
-
- self.state.stored_key = stored_key;
- self.state.server_key = server_key;
- self.state.iteration_count = iteration_count;
- self.state.salt = salt
+ elseif self.profile[profile_name] then
+ local status;
+ stored_key, server_key, iteration_count, salt, status = self.profile[profile_name](self, username, self.realm);
+ if status == nil then return "failure", "not-authorized"
+ elseif status == false then return "failure", "account-disabled" end
end
-
- local server_first_message = "r="..self.state.clientnonce..self.state.servernonce..",s="..base64.encode(self.state.salt)..",i="..self.state.iteration_count;
- self.state["server_first_message"] = server_first_message;
+
+ local nonce = clientnonce .. generate_uuid();
+ local server_first_message = "r="..nonce..",s="..base64.encode(salt)..",i="..iteration_count;
+ self.state = {
+ gs2_header = gs2_header;
+ gs2_cbind_name = gs2_cbind_name;
+ username = username;
+ nonce = nonce;
+
+ server_key = server_key;
+ stored_key = stored_key;
+ client_first_message_bare = client_first_message_bare;
+ server_first_message = server_first_message;
+ }
return "challenge", server_first_message
else
-- we are processing client_final_message
local client_final_message = message;
-
- self.state["channelbinding"], self.state["nonce"], self.state["proof"] = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)");
-
- if not self.state.proof or not self.state.nonce or not self.state.channelbinding then
+
+ local client_final_message_without_proof, channelbinding, nonce, proof
+ = s_match(client_final_message, "(c=([^,]*),r=([^,]*),?.-),p=(.*)$");
+
+ if not proof or not nonce or not channelbinding then
return "failure", "malformed-request", "Missing an attribute(p, r or c) in SASL message.";
end
- if self.state.nonce ~= self.state.clientnonce..self.state.servernonce then
+ local client_gs2_header = base64.decode(channelbinding)
+ local our_client_gs2_header = state["gs2_header"]
+ if state.gs2_cbind_name then
+ -- we support channelbinding, so check if the value is valid
+ our_client_gs2_header = our_client_gs2_header .. self.profile.cb[state.gs2_cbind_name](self);
+ end
+ if client_gs2_header ~= our_client_gs2_header then
+ return "failure", "malformed-request", "Invalid channel binding value.";
+ end
+
+ if nonce ~= state.nonce then
return "failure", "malformed-request", "Wrong nonce in client-final-message.";
end
-
- local ServerKey = self.state.server_key;
- local StoredKey = self.state.stored_key;
-
- local AuthMessage = "n=" .. s_match(self.state.client_first_message,"n=(.+)") .. "," .. self.state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+")
+
+ local ServerKey = state.server_key;
+ local StoredKey = state.stored_key;
+
+ local AuthMessage = state.client_first_message_bare .. "," .. state.server_first_message .. "," .. client_final_message_without_proof
local ClientSignature = HMAC_f(StoredKey, AuthMessage)
- local ClientKey = binaryXOR(ClientSignature, base64.decode(self.state.proof))
+ local ClientKey = binaryXOR(ClientSignature, base64.decode(proof))
local ServerSignature = HMAC_f(ServerKey, AuthMessage)
if StoredKey == H_f(ClientKey) then
local server_final_message = "v="..base64.encode(ServerSignature);
- self["username"] = self.state.name;
+ self["username"] = state.username;
return "success", server_final_message;
else
return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated.";
@@ -205,12 +235,18 @@ local function scram_gen(hash_name, H_f, HMAC_f)
return scram_hash;
end
-function init(registerMechanism)
+local function init(registerMechanism)
local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash));
+
+ -- register channel binding equivalent
+ registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
end
registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
end
-return _M;
+return {
+ getAuthenticationDatabaseSHA1 = getAuthenticationDatabaseSHA1;
+ init = init;
+}
diff --git a/util/sasl_cyrus.lua b/util/sasl_cyrus.lua
index 19684587..4e9a4af5 100644
--- a/util/sasl_cyrus.lua
+++ b/util/sasl_cyrus.lua
@@ -60,7 +60,7 @@ local sasl_errstring = {
};
setmetatable(sasl_errstring, { __index = function() return "undefined error!" end });
-module "sasl_cyrus"
+local _ENV = nil;
local method = {};
method.__index = method;
@@ -78,11 +78,11 @@ local function init(service_name)
end
-- create a new SASL object which can be used to authenticate clients
--- host_fqdn may be nil in which case gethostname() gives the value.
+-- host_fqdn may be nil in which case gethostname() gives the value.
-- For GSSAPI, this determines the hostname in the service ticket (after
-- reverse DNS canonicalization, only if [libdefaults] rdns = true which
--- is the default).
-function new(realm, service_name, app_name, host_fqdn)
+-- is the default).
+local function new(realm, service_name, app_name, host_fqdn)
init(app_name or service_name);
@@ -163,4 +163,6 @@ function method:process(message)
end
end
-return _M;
+return {
+ new = new;
+};
diff --git a/util/serialization.lua b/util/serialization.lua
index 8a259184..206f5fbb 100644
--- a/util/serialization.lua
+++ b/util/serialization.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -11,18 +11,16 @@ local type = type;
local tostring = tostring;
local t_insert = table.insert;
local t_concat = table.concat;
-local error = error;
local pairs = pairs;
local next = next;
-local loadstring = loadstring;
local pcall = pcall;
local debug_traceback = debug.traceback;
local log = require "util.logger".init("serialization");
local envload = require"util.envload".envload;
-module "serialization"
+local _ENV = nil;
local indent = function(i)
return string_rep("\t", i);
@@ -73,16 +71,16 @@ local function _simplesave(o, ind, t, func)
end
end
-function append(t, o)
+local function append(t, o)
_simplesave(o, 1, t, t.write or t_insert);
return t;
end
-function serialize(o)
+local function serialize(o)
return t_concat(append({}, o));
end
-function deserialize(str)
+local function deserialize(str)
if type(str) ~= "string" then return nil; end
str = "return "..str;
local f, err = envload(str, "@data", {});
@@ -92,4 +90,8 @@ function deserialize(str)
return ret;
end
-return _M;
+return {
+ append = append;
+ serialize = serialize;
+ deserialize = deserialize;
+};
diff --git a/util/session.lua b/util/session.lua
new file mode 100644
index 00000000..b2a726ce
--- /dev/null
+++ b/util/session.lua
@@ -0,0 +1,65 @@
+local initialize_filters = require "util.filters".initialize;
+local logger = require "util.logger";
+
+local function new_session(typ)
+ local session = {
+ type = typ .. "_unauthed";
+ };
+ return session;
+end
+
+local function set_id(session)
+ local id = session.type .. tostring(session):match("%x+$"):lower();
+ session.id = id;
+ return session;
+end
+
+local function set_logger(session)
+ local log = logger.init(session.id);
+ session.log = log;
+ return session;
+end
+
+local function set_conn(session, conn)
+ session.conn = conn;
+ session.ip = conn:ip();
+ return session;
+end
+
+local function set_send(session)
+ local conn = session.conn;
+ if not conn then
+ function session.send(data)
+ session.log("debug", "Discarding data sent to unconnected session: %s", tostring(data));
+ return false;
+ end
+ return session;
+ end
+ local filter = initialize_filters(session);
+ local w = conn.write;
+ session.send = function (t)
+ if t.name then
+ t = filter("stanzas/out", t);
+ end
+ if t then
+ t = filter("bytes/out", tostring(t));
+ if t then
+ local ret, err = w(conn, t);
+ if not ret then
+ session.log("debug", "Error writing to connection: %s", tostring(err));
+ return false, err;
+ end
+ end
+ end
+ return true;
+ end
+ return session;
+end
+
+return {
+ new = new_session;
+ set_id = set_id;
+ set_logger = set_logger;
+ set_conn = set_conn;
+ set_send = set_send;
+}
diff --git a/util/set.lua b/util/set.lua
index fa065a9c..c136a522 100644
--- a/util/set.lua
+++ b/util/set.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -10,86 +10,49 @@ local ipairs, pairs, setmetatable, next, tostring =
ipairs, pairs, setmetatable, next, tostring;
local t_concat = table.concat;
-module "set"
+local _ENV = nil;
local set_mt = {};
function set_mt.__call(set, _, k)
return next(set._items, k);
end
-function set_mt.__add(set1, set2)
- return _M.union(set1, set2);
-end
-function set_mt.__sub(set1, set2)
- return _M.difference(set1, set2);
-end
-function set_mt.__div(set, func)
- local new_set = _M.new();
- local items, new_items = set._items, new_set._items;
- for item in pairs(items) do
- local new_item = func(item);
- if new_item ~= nil then
- new_items[new_item] = true;
- end
- end
- return new_set;
-end
-function set_mt.__eq(set1, set2)
- local set1, set2 = set1._items, set2._items;
- for item in pairs(set1) do
- if not set2[item] then
- return false;
- end
- end
-
- for item in pairs(set2) do
- if not set1[item] then
- return false;
- end
- end
-
- return true;
-end
-function set_mt.__tostring(set)
- local s, items = { }, set._items;
- for item in pairs(items) do
- s[#s+1] = tostring(item);
- end
- return t_concat(s, ", ");
-end
local items_mt = {};
function items_mt.__call(items, _, k)
return next(items, k);
end
-function new(list)
+local function new(list)
local items = setmetatable({}, items_mt);
local set = { _items = items };
-
+
+ -- We access the set through an upvalue in these methods, so ignore 'self' being unused
+ --luacheck: ignore 212/self
+
function set:add(item)
items[item] = true;
end
-
+
function set:contains(item)
return items[item];
end
-
+
function set:items()
- return items;
+ return next, items;
end
-
+
function set:remove(item)
items[item] = nil;
end
-
- function set:add_list(list)
- if list then
- for _, item in ipairs(list) do
+
+ function set:add_list(item_list)
+ if item_list then
+ for _, item in ipairs(item_list) do
items[item] = true;
end
end
end
-
+
function set:include(otherset)
for item in otherset do
items[item] = true;
@@ -101,22 +64,22 @@ function new(list)
items[item] = nil;
end
end
-
+
function set:empty()
return not next(items);
end
-
+
if list then
set:add_list(list);
end
-
+
return setmetatable(set, set_mt);
end
-function union(set1, set2)
+local function union(set1, set2)
local set = new();
local items = set._items;
-
+
for item in pairs(set1._items) do
items[item] = true;
end
@@ -124,14 +87,14 @@ function union(set1, set2)
for item in pairs(set2._items) do
items[item] = true;
end
-
+
return set;
end
-function difference(set1, set2)
+local function difference(set1, set2)
local set = new();
local items = set._items;
-
+
for item in pairs(set1._items) do
items[item] = (not set2._items[item]) or nil;
end
@@ -139,21 +102,68 @@ function difference(set1, set2)
return set;
end
-function intersection(set1, set2)
+local function intersection(set1, set2)
local set = new();
local items = set._items;
-
+
set1, set2 = set1._items, set2._items;
-
+
for item in pairs(set1) do
items[item] = (not not set2[item]) or nil;
end
-
+
return set;
end
-function xor(set1, set2)
+local function xor(set1, set2)
return union(set1, set2) - intersection(set1, set2);
end
-return _M;
+function set_mt.__add(set1, set2)
+ return union(set1, set2);
+end
+function set_mt.__sub(set1, set2)
+ return difference(set1, set2);
+end
+function set_mt.__div(set, func)
+ local new_set = new();
+ local items, new_items = set._items, new_set._items;
+ for item in pairs(items) do
+ local new_item = func(item);
+ if new_item ~= nil then
+ new_items[new_item] = true;
+ end
+ end
+ return new_set;
+end
+function set_mt.__eq(set1, set2)
+ set1, set2 = set1._items, set2._items;
+ for item in pairs(set1) do
+ if not set2[item] then
+ return false;
+ end
+ end
+
+ for item in pairs(set2) do
+ if not set1[item] then
+ return false;
+ end
+ end
+
+ return true;
+end
+function set_mt.__tostring(set)
+ local s, items = { }, set._items;
+ for item in pairs(items) do
+ s[#s+1] = tostring(item);
+ end
+ return t_concat(s, ", ");
+end
+
+return {
+ new = new;
+ union = union;
+ difference = difference;
+ intersection = intersection;
+ xor = xor;
+};
diff --git a/util/sql.lua b/util/sql.lua
index f360d6d0..eb562eb2 100644
--- a/util/sql.lua
+++ b/util/sql.lua
@@ -1,8 +1,9 @@
local setmetatable, getmetatable = setmetatable, getmetatable;
-local ipairs, unpack, select = ipairs, unpack, select;
+local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113
local tonumber, tostring = tonumber, tostring;
-local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback;
+local type = type;
+local assert, pcall, xpcall, debug_traceback = assert, pcall, xpcall, debug.traceback;
local t_concat = table.concat;
local s_char = string.char;
local log = require "util.logger".init("sql");
@@ -13,7 +14,7 @@ local DBI = require "DBI";
DBI.Drivers();
local build_url = require "socket.url".build;
-module("sql")
+local _ENV = nil;
local column_mt = {};
local table_mt = {};
@@ -21,42 +22,17 @@ local query_mt = {};
--local op_mt = {};
local index_mt = {};
-function is_column(x) return getmetatable(x)==column_mt; end
-function is_index(x) return getmetatable(x)==index_mt; end
-function is_table(x) return getmetatable(x)==table_mt; end
-function is_query(x) return getmetatable(x)==query_mt; end
---function is_op(x) return getmetatable(x)==op_mt; end
---function expr(...) return setmetatable({...}, op_mt); end
-function Integer(n) return "Integer()" end
-function String(n) return "String()" end
+local function is_column(x) return getmetatable(x)==column_mt; end
+local function is_index(x) return getmetatable(x)==index_mt; end
+local function is_table(x) return getmetatable(x)==table_mt; end
+local function is_query(x) return getmetatable(x)==query_mt; end
+local function Integer() return "Integer()" end
+local function String() return "String()" end
---[[local ops = {
- __add = function(a, b) return "("..a.."+"..b..")" end;
- __sub = function(a, b) return "("..a.."-"..b..")" end;
- __mul = function(a, b) return "("..a.."*"..b..")" end;
- __div = function(a, b) return "("..a.."/"..b..")" end;
- __mod = function(a, b) return "("..a.."%"..b..")" end;
- __pow = function(a, b) return "POW("..a..","..b..")" end;
- __unm = function(a) return "NOT("..a..")" end;
- __len = function(a) return "COUNT("..a..")" end;
- __eq = function(a, b) return "("..a.."=="..b..")" end;
- __lt = function(a, b) return "("..a.."<"..b..")" end;
- __le = function(a, b) return "("..a.."<="..b..")" end;
-};
-
-local functions = {
-
-};
-
-local cmap = {
- [Integer] = Integer();
- [String] = String();
-};]]
-
-function Column(definition)
+local function Column(definition)
return setmetatable(definition, column_mt);
end
-function Table(definition)
+local function Table(definition)
local c = {}
for i,col in ipairs(definition) do
if is_column(col) then
@@ -67,13 +43,13 @@ function Table(definition)
end
return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
end
-function Index(definition)
+local function Index(definition)
return setmetatable(definition, index_mt);
end
function table_mt:__tostring()
local s = { 'name="'..self.__table__.name..'"' }
- for i,col in ipairs(self.__table__) do
+ for _, col in ipairs(self.__table__) do
s[#s+1] = tostring(col);
end
return 'Table{ '..t_concat(s, ", ")..' }'
@@ -94,7 +70,6 @@ function index_mt:__tostring()
return s..' }';
-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
end
---
local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
local function parse_url(url)
@@ -121,43 +96,44 @@ local function parse_url(url)
};
end
---[[local session = {};
-
-function session.query(...)
- local rets = {...};
- local query = setmetatable({ __rets = rets, __filters }, query_mt);
- return query;
-end
---
-
-local function db2uri(params)
- return build_url{
- scheme = params.driver,
- user = params.username,
- password = params.password,
- host = params.host,
- port = params.port,
- path = params.database,
- };
-end]]
-
local engine = {};
function engine:connect()
if self.conn then return true; end
local params = self.params;
assert(params.driver, "no driver")
- local dbh, err = DBI.Connect(
+ log("debug", "Connecting to [%s] %s...", params.driver, params.database);
+ local ok, dbh, err = pcall(DBI.Connect,
params.driver, params.database,
params.username, params.password,
params.host, params.port
);
+ if not ok then return ok, dbh; end
if not dbh then return nil, err; end
dbh:autocommit(false); -- don't commit automatically
self.conn = dbh;
self.prepared = {};
+ local ok, err = self:set_encoding();
+ if not ok then
+ return ok, err;
+ end
+ local ok, err = self:onconnect();
+ if ok == false then
+ return ok, err;
+ end
return true;
end
+function engine:onconnect()
+ -- Override from create_engine()
+end
+
+function engine:prepquery(sql)
+ if self.params.driver == "PostgreSQL" then
+ sql = sql:gsub("`", "\"");
+ end
+ return sql;
+end
+
function engine:execute(sql, ...)
local success, err = self:connect();
if not success then return success, err; end
@@ -177,22 +153,23 @@ function engine:execute(sql, ...)
end
local result_mt = { __index = {
- affected = function(self) return self.__affected; end;
- rowcount = function(self) return self.__rowcount; end;
+ affected = function(self) return self.__stmt:affected(); end;
+ rowcount = function(self) return self.__stmt:rowcount(); end;
} };
+local function debugquery(where, sql, ...)
+ local i = 0; local a = {...}
+ log("debug", "[%s] %s", where, sql:gsub("%?", function () i = i + 1; local v = a[i]; if type(v) == "string" then v = ("%q"):format(v); end return tostring(v); end));
+end
+
function engine:execute_query(sql, ...)
- if self.params.driver == "PostgreSQL" then
- sql = sql:gsub("`", "\"");
- end
+ sql = self:prepquery(sql);
local stmt = assert(self.conn:prepare(sql));
assert(stmt:execute(...));
return stmt:rows();
end
function engine:execute_update(sql, ...)
- if self.params.driver == "PostgreSQL" then
- sql = sql:gsub("`", "\"");
- end
+ sql = self:prepquery(sql);
local prepared = self.prepared;
local stmt = prepared[sql];
if not stmt then
@@ -200,22 +177,47 @@ function engine:execute_update(sql, ...)
prepared[sql] = stmt;
end
assert(stmt:execute(...));
- return setmetatable({ __affected = stmt:affected(), __rowcount = stmt:rowcount() }, result_mt);
+ return setmetatable({ __stmt = stmt }, result_mt);
end
engine.insert = engine.execute_update;
engine.select = engine.execute_query;
engine.delete = engine.execute_update;
engine.update = engine.execute_update;
+local function debugwrap(name, f)
+ return function (self, sql, ...)
+ debugquery(name, sql, ...)
+ return f(self, sql, ...)
+ end
+end
+function engine:debug(enable)
+ self._debug = enable;
+ if enable then
+ engine.insert = debugwrap("insert", engine.execute_update);
+ engine.select = debugwrap("select", engine.execute_query);
+ engine.delete = debugwrap("delete", engine.execute_update);
+ engine.update = debugwrap("update", engine.execute_update);
+ else
+ engine.insert = engine.execute_update;
+ engine.select = engine.execute_query;
+ engine.delete = engine.execute_update;
+ engine.update = engine.execute_update;
+ end
+end
+local function handleerr(err)
+ log("error", "Error in SQL transaction: %s", debug_traceback(err, 3));
+ return err;
+end
function engine:_transaction(func, ...)
if not self.conn then
- local a,b = self:connect();
- if not a then return a,b; end
+ local ok, err = self:connect();
+ if not ok then return ok, err; end
end
--assert(not self.__transaction, "Recursive transactions not allowed");
local args, n_args = {...}, select("#", ...);
local function f() return func(unpack(args, 1, n_args)); end
+ log("debug", "SQL transaction begin [%s]", tostring(func));
self.__transaction = true;
- local success, a, b, c = xpcall(f, debug_traceback);
+ local success, a, b, c = xpcall(f, handleerr);
self.__transaction = nil;
if success then
log("debug", "SQL transaction success [%s]", tostring(func));
@@ -229,15 +231,15 @@ function engine:_transaction(func, ...)
end
end
function engine:transaction(...)
- local a,b = self:_transaction(...);
- if not a then
+ local ok, ret = self:_transaction(...);
+ if not ok then
local conn = self.conn;
if not conn or not conn:ping() then
self.conn = nil;
- a,b = self:_transaction(...);
+ ok, ret = self:_transaction(...);
end
end
- return a,b;
+ return ok, ret;
end
function engine:_create_index(index)
local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` (";
@@ -251,29 +253,100 @@ function engine:_create_index(index)
elseif self.params.driver == "MySQL" then
sql = sql:gsub("`([,)])", "`(20)%1");
end
- --print(sql);
+ if index.unique then
+ sql = sql:gsub("^CREATE", "CREATE UNIQUE");
+ end
+ if self._debug then
+ debugquery("create", sql);
+ end
return self:execute(sql);
end
function engine:_create_table(table)
local sql = "CREATE TABLE `"..table.name.."` (";
for i,col in ipairs(table.c) do
- sql = sql.."`"..col.name.."` "..col.type;
+ local col_type = col.type;
+ if col_type == "MEDIUMTEXT" and self.params.driver ~= "MySQL" then
+ col_type = "TEXT"; -- MEDIUMTEXT is MySQL-specific
+ end
+ if col.auto_increment == true and self.params.driver == "PostgreSQL" then
+ col_type = "BIGSERIAL";
+ end
+ sql = sql.."`"..col.name.."` "..col_type;
if col.nullable == false then sql = sql.." NOT NULL"; end
+ if col.primary_key == true then sql = sql.." PRIMARY KEY"; end
+ if col.auto_increment == true then
+ if self.params.driver == "MySQL" then
+ sql = sql.." AUTO_INCREMENT";
+ elseif self.params.driver == "SQLite3" then
+ sql = sql.." AUTOINCREMENT";
+ end
+ end
if i ~= #table.c then sql = sql..", "; end
end
sql = sql.. ");"
if self.params.driver == "PostgreSQL" then
sql = sql:gsub("`", "\"");
+ elseif self.params.driver == "MySQL" then
+ sql = sql:gsub(";$", (" CHARACTER SET '%s' COLLATE '%s_bin';"):format(self.charset, self.charset));
+ end
+ if self._debug then
+ debugquery("create", sql);
end
local success,err = self:execute(sql);
if not success then return success,err; end
- for i,v in ipairs(table.__table__) do
+ for _, v in ipairs(table.__table__) do
if is_index(v) then
self:_create_index(v);
end
end
return success;
end
+function engine:set_encoding() -- to UTF-8
+ local driver = self.params.driver;
+ if driver == "SQLite3" then
+ return self:transaction(function()
+ for encoding in self:select"PRAGMA encoding;" do
+ if encoding[1] == "UTF-8" then
+ self.charset = "utf8";
+ end
+ end
+ end);
+ end
+ local set_names_query = "SET NAMES '%s';"
+ local charset = "utf8";
+ if driver == "MySQL" then
+ self:transaction(function()
+ for row in self:select"SELECT `CHARACTER_SET_NAME` FROM `information_schema`.`CHARACTER_SETS` WHERE `CHARACTER_SET_NAME` LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;" do
+ charset = row and row[1] or charset;
+ end
+ end);
+ set_names_query = set_names_query:gsub(";$", (" COLLATE '%s';"):format(charset.."_bin"));
+ end
+ self.charset = charset;
+ log("debug", "Using encoding '%s' for database connection", charset);
+ local ok, err = self:transaction(function() return self:execute(set_names_query:format(charset)); end);
+ if not ok then
+ return ok, err;
+ end
+
+ if driver == "MySQL" then
+ local ok, actual_charset = self:transaction(function ()
+ return self:select"SHOW SESSION VARIABLES LIKE 'character_set_client'";
+ end);
+ local charset_ok = true;
+ for row in actual_charset do
+ if row[2] ~= charset then
+ log("error", "MySQL %s is actually %q (expected %q)", row[1], row[2], charset);
+ charset_ok = false;
+ end
+ end
+ if not charset_ok then
+ return false, "Failed to set connection encoding";
+ end
+ end
+
+ return true;
+end
local engine_mt = { __index = engine };
local function db2uri(params)
@@ -286,55 +359,21 @@ local function db2uri(params)
path = params.database,
};
end
-local engine_cache = {}; -- TODO make weak valued
-function create_engine(self, params)
- local url = db2uri(params);
- if not engine_cache[url] then
- local engine = setmetatable({ url = url, params = params }, engine_mt);
- engine_cache[url] = engine;
- end
- return engine_cache[url];
-end
-
---[[Users = Table {
- name="users";
- Column { name="user_id", type=String(), primary_key=true };
-};
-print(Users)
-print(Users.c.user_id)]]
-
---local engine = create_engine('postgresql://scott:tiger@localhost:5432/mydatabase');
---[[local engine = create_engine{ driver = "SQLite3", database = "./alchemy.sqlite" };
-
-local i = 0;
-for row in assert(engine:execute("select * from sqlite_master")):rows(true) do
- i = i+1;
- print(i);
- for k,v in pairs(row) do
- print("",k,v);
- end
+local function create_engine(self, params, onconnect)
+ return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt);
end
-print("---")
-Prosody = Table {
- name="prosody";
- Column { name="host", type="TEXT", nullable=false };
- Column { name="user", type="TEXT", nullable=false };
- Column { name="store", type="TEXT", nullable=false };
- Column { name="key", type="TEXT", nullable=false };
- Column { name="type", type="TEXT", nullable=false };
- Column { name="value", type="TEXT", nullable=false };
- Index { name="prosody_index", "host", "user", "store", "key" };
+return {
+ is_column = is_column;
+ is_index = is_index;
+ is_table = is_table;
+ is_query = is_query;
+ Integer = Integer;
+ String = String;
+ Column = Column;
+ Table = Table;
+ Index = Index;
+ create_engine = create_engine;
+ db2uri = db2uri;
};
---print(Prosody);
-assert(engine:transaction(function()
- assert(Prosody:create(engine));
-end));
-
-for row in assert(engine:execute("select user from prosody")):rows(true) do
- print("username:", row['username'])
-end
---result.close();]]
-
-return _M;
diff --git a/util/sslconfig.lua b/util/sslconfig.lua
new file mode 100644
index 00000000..c849aa28
--- /dev/null
+++ b/util/sslconfig.lua
@@ -0,0 +1,120 @@
+-- util to easily merge multiple sets of LuaSec context options
+
+local type = type;
+local pairs = pairs;
+local rawset = rawset;
+local t_concat = table.concat;
+local t_insert = table.insert;
+local setmetatable = setmetatable;
+
+local _ENV = nil;
+
+local handlers = { };
+local finalisers = { };
+local id = function (v) return v end
+
+-- All "handlers" behave like extended rawset(table, key, value) with extra
+-- processing usually merging the new value with the old in some reasonable
+-- way
+-- If a field does not have a defined handler then a new value simply
+-- replaces the old.
+
+
+-- Convert either a list or a set into a special type of set where each
+-- item is either positive or negative in order for a later set of options
+-- to be able to remove options from this set by filtering out the negative ones
+function handlers.options(config, field, new)
+ local options = config[field] or { };
+ if type(new) ~= "table" then new = { new } end
+ for key, value in pairs(new) do
+ if value == true or value == false then
+ options[key] = value;
+ else -- list item
+ options[value] = true;
+ end
+ end
+ config[field] = options;
+end
+
+handlers.verify = handlers.options;
+handlers.verifyext = handlers.options;
+
+-- finalisers take something produced by handlers and return what luasec
+-- expects it to be
+
+-- Produce a list of "positive" options from the set
+function finalisers.options(options)
+ local output = {};
+ for opt, enable in pairs(options) do
+ if enable then
+ output[#output+1] = opt;
+ end
+ end
+ return output;
+end
+
+finalisers.verify = finalisers.options;
+finalisers.verifyext = finalisers.options;
+
+-- We allow ciphers to be a list
+
+function finalisers.ciphers(cipherlist)
+ if type(cipherlist) == "table" then
+ return t_concat(cipherlist, ":");
+ end
+ return cipherlist;
+end
+
+-- protocol = "x" should enable only that protocol
+-- protocol = "x+" should enable x and later versions
+
+local protocols = { "sslv2", "sslv3", "tlsv1", "tlsv1_1", "tlsv1_2" };
+for i = 1, #protocols do protocols[protocols[i] .. "+"] = i - 1; end
+
+-- this interacts with ssl.options as well to add no_x
+local function protocol(config)
+ local min_protocol = protocols[config.protocol];
+ if min_protocol then
+ config.protocol = "sslv23";
+ for i = 1, min_protocol do
+ t_insert(config.options, "no_"..protocols[i]);
+ end
+ end
+end
+
+-- Merge options from 'new' config into 'config'
+local function apply(config, new)
+ if type(new) == "table" then
+ for field, value in pairs(new) do
+ (handlers[field] or rawset)(config, field, value);
+ end
+ end
+end
+
+-- Finalize the config into the form LuaSec expects
+local function final(config)
+ local output = { };
+ for field, value in pairs(config) do
+ output[field] = (finalisers[field] or id)(value);
+ end
+ -- Need to handle protocols last because it adds to the options list
+ protocol(output);
+ return output;
+end
+
+local sslopts_mt = {
+ __index = {
+ apply = apply;
+ final = final;
+ };
+};
+
+local function new()
+ return setmetatable({options={}}, sslopts_mt);
+end
+
+return {
+ apply = apply;
+ final = final;
+ new = new;
+};
diff --git a/util/stanza.lua b/util/stanza.lua
index 7c214210..8bb9ba05 100644
--- a/util/stanza.lua
+++ b/util/stanza.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -35,17 +35,15 @@ end
local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas";
-module "stanza"
+local _ENV = nil;
-stanza_mt = { __type = "stanza" };
+local stanza_mt = { __type = "stanza" };
stanza_mt.__index = stanza_mt;
-local stanza_mt = stanza_mt;
-function stanza(name, attr)
+local function new_stanza(name, attr)
local stanza = { name = name, attr = attr or {}, tags = {} };
return setmetatable(stanza, stanza_mt);
end
-local stanza = stanza;
function stanza_mt:query(xmlns)
return self:tag("query", { xmlns = xmlns });
@@ -56,7 +54,7 @@ function stanza_mt:body(text, attr)
end
function stanza_mt:tag(name, attrs)
- local s = stanza(name, attrs);
+ local s = new_stanza(name, attrs);
local last_add = self.last_add;
if not last_add then last_add = {}; self.last_add = last_add; end
(last_add[#last_add] or self):add_direct_child(s);
@@ -99,7 +97,7 @@ function stanza_mt:get_child(name, xmlns)
if (not name or child.name == name)
and ((not xmlns and self.attr.xmlns == child.attr.xmlns)
or child.attr.xmlns == xmlns) then
-
+
return child;
end
end
@@ -152,7 +150,7 @@ end
function stanza_mt:maptags(callback)
local tags, curr_tag = self.tags, 1;
local n_children, n_tags = #self, #tags;
-
+
local i = 1;
while curr_tag <= n_tags and n_tags > 0 do
if self[i] == tags[curr_tag] then
@@ -200,14 +198,10 @@ function stanza_mt:find(path)
end
-local xml_escape
-do
- local escape_table = { ["'"] = "&apos;", ["\""] = "&quot;", ["<"] = "&lt;", [">"] = "&gt;", ["&"] = "&amp;" };
- function xml_escape(str) return (s_gsub(str, "['&<>\"]", escape_table)); end
- _M.xml_escape = xml_escape;
-end
+local escape_table = { ["'"] = "&apos;", ["\""] = "&quot;", ["<"] = "&lt;", [">"] = "&gt;", ["&"] = "&amp;" };
+local function xml_escape(str) return (s_gsub(str, "['&<>\"]", escape_table)); end
-local function _dostring(t, buf, self, xml_escape, parentns)
+local function _dostring(t, buf, self, _xml_escape, parentns)
local nsid = 0;
local name = t.name
t_insert(buf, "<"..name);
@@ -215,9 +209,9 @@ local function _dostring(t, buf, self, xml_escape, parentns)
if s_find(k, "\1", 1, true) then
local ns, attrk = s_match(k, "^([^\1]*)\1?(.*)$");
nsid = nsid + 1;
- t_insert(buf, " xmlns:ns"..nsid.."='"..xml_escape(ns).."' ".."ns"..nsid..":"..attrk.."='"..xml_escape(v).."'");
+ t_insert(buf, " xmlns:ns"..nsid.."='".._xml_escape(ns).."' ".."ns"..nsid..":"..attrk.."='".._xml_escape(v).."'");
elseif not(k == "xmlns" and v == parentns) then
- t_insert(buf, " "..k.."='"..xml_escape(v).."'");
+ t_insert(buf, " "..k.."='".._xml_escape(v).."'");
end
end
local len = #t;
@@ -228,9 +222,9 @@ local function _dostring(t, buf, self, xml_escape, parentns)
for n=1,len do
local child = t[n];
if child.name then
- self(child, buf, self, xml_escape, t.attr.xmlns);
+ self(child, buf, self, _xml_escape, t.attr.xmlns);
else
- t_insert(buf, xml_escape(child));
+ t_insert(buf, _xml_escape(child));
end
end
t_insert(buf, "</"..name..">");
@@ -257,14 +251,14 @@ function stanza_mt.get_text(t)
end
function stanza_mt.get_error(stanza)
- local type, condition, text;
-
+ local error_type, condition, text;
+
local error_tag = stanza:get_child("error");
if not error_tag then
return nil, nil, nil;
end
- type = error_tag.attr.type;
-
+ error_type = error_tag.attr.type;
+
for _, child in ipairs(error_tag.tags) do
if child.attr.xmlns == xmlns_stanzas then
if not text and child.name == "text" then
@@ -277,18 +271,16 @@ function stanza_mt.get_error(stanza)
end
end
end
- return type, condition or "undefined-condition", text;
+ return error_type, condition or "undefined-condition", text;
end
-do
- local id = 0;
- function new_id()
- id = id + 1;
- return "lx"..id;
- end
+local id = 0;
+local function new_id()
+ id = id + 1;
+ return "lx"..id;
end
-function preserialize(stanza)
+local function preserialize(stanza)
local s = { name = stanza.name, attr = stanza.attr };
for _, child in ipairs(stanza) do
if type(child) == "table" then
@@ -300,7 +292,7 @@ function preserialize(stanza)
return s;
end
-function deserialize(stanza)
+local function deserialize(stanza)
-- Set metatable
if stanza then
local attr = stanza.attr;
@@ -333,56 +325,53 @@ function deserialize(stanza)
stanza.tags = tags;
end
end
-
+
return stanza;
end
-local function _clone(stanza)
+local function clone(stanza)
local attr, tags = {}, {};
for k,v in pairs(stanza.attr) do attr[k] = v; end
local new = { name = stanza.name, attr = attr, tags = tags };
for i=1,#stanza do
local child = stanza[i];
if child.name then
- child = _clone(child);
+ child = clone(child);
t_insert(tags, child);
end
t_insert(new, child);
end
return setmetatable(new, stanza_mt);
end
-clone = _clone;
-function message(attr, body)
+local function message(attr, body)
if not body then
- return stanza("message", attr);
+ return new_stanza("message", attr);
else
- return stanza("message", attr):tag("body"):text(body):up();
+ return new_stanza("message", attr):tag("body"):text(body):up();
end
end
-function iq(attr)
+local function iq(attr)
if attr and not attr.id then attr.id = new_id(); end
- return stanza("iq", attr or { id = new_id() });
+ return new_stanza("iq", attr or { id = new_id() });
end
-function reply(orig)
- return stanza(orig.name, orig.attr and { to = orig.attr.from, from = orig.attr.to, id = orig.attr.id, type = ((orig.name == "iq" and "result") or orig.attr.type) });
+local function reply(orig)
+ return new_stanza(orig.name, orig.attr and { to = orig.attr.from, from = orig.attr.to, id = orig.attr.id, type = ((orig.name == "iq" and "result") or orig.attr.type) });
end
-do
- local xmpp_stanzas_attr = { xmlns = xmlns_stanzas };
- function error_reply(orig, type, condition, message)
- local t = reply(orig);
- t.attr.type = "error";
- t:tag("error", {type = type}) --COMPAT: Some day xmlns:stanzas goes here
- :tag(condition, xmpp_stanzas_attr):up();
- if (message) then t:tag("text", xmpp_stanzas_attr):text(message):up(); end
- return t; -- stanza ready for adding app-specific errors
- end
+local xmpp_stanzas_attr = { xmlns = xmlns_stanzas };
+local function error_reply(orig, error_type, condition, error_message)
+ local t = reply(orig);
+ t.attr.type = "error";
+ t:tag("error", {type = error_type}) --COMPAT: Some day xmlns:stanzas goes here
+ :tag(condition, xmpp_stanzas_attr):up();
+ if error_message then t:tag("text", xmpp_stanzas_attr):text(error_message):up(); end
+ return t; -- stanza ready for adding app-specific errors
end
-function presence(attr)
- return stanza("presence", attr);
+local function presence(attr)
+ return new_stanza("presence", attr);
end
if do_pretty_printing then
@@ -390,14 +379,14 @@ if do_pretty_printing then
local style_attrv = getstyle("red");
local style_tagname = getstyle("red");
local style_punc = getstyle("magenta");
-
+
local attr_format = " "..getstring(style_attrk, "%s")..getstring(style_punc, "=")..getstring(style_attrv, "'%s'");
local top_tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">");
--local tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">").."%s"..getstring(style_punc, "</")..getstring(style_tagname, "%s")..getstring(style_punc, ">");
local tag_format = top_tag_format.."%s"..getstring(style_punc, "</")..getstring(style_tagname, "%s")..getstring(style_punc, ">");
function stanza_mt.pretty_print(t)
local children_text = "";
- for n, child in ipairs(t) do
+ for _, child in ipairs(t) do
if type(child) == "string" then
children_text = children_text .. xml_escape(child);
else
@@ -411,7 +400,7 @@ if do_pretty_printing then
end
return s_format(tag_format, t.name, attr_string, children_text, t.name);
end
-
+
function stanza_mt.pretty_top_tag(t)
local attr_string = "";
if t.attr then
@@ -425,4 +414,17 @@ else
stanza_mt.pretty_top_tag = stanza_mt.top_tag;
end
-return _M;
+return {
+ stanza_mt = stanza_mt;
+ stanza = new_stanza;
+ new_id = new_id;
+ preserialize = preserialize;
+ deserialize = deserialize;
+ clone = clone;
+ message = message;
+ iq = iq;
+ reply = reply;
+ error_reply = error_reply;
+ presence = presence;
+ xml_escape = xml_escape;
+};
diff --git a/util/statistics.lua b/util/statistics.lua
new file mode 100644
index 00000000..26355026
--- /dev/null
+++ b/util/statistics.lua
@@ -0,0 +1,160 @@
+local t_sort = table.sort
+local m_floor = math.floor;
+local time = require "socket".gettime;
+
+local function nop_function() end
+
+local function percentile(arr, length, pc)
+ local n = pc/100 * (length + 1);
+ local k, d = m_floor(n), n%1;
+ if k == 0 then
+ return arr[1] or 0;
+ elseif k >= length then
+ return arr[length];
+ end
+ return arr[k] + d*(arr[k+1] - arr[k]);
+end
+
+local function new_registry(config)
+ config = config or {};
+ local duration_sample_interval = config.duration_sample_interval or 5;
+ local duration_max_samples = config.duration_max_stored_samples or 5000;
+
+ local function get_distribution_stats(events, n_actual_events, since, new_time, units)
+ local n_stored_events = #events;
+ t_sort(events);
+ local sum = 0;
+ for i = 1, n_stored_events do
+ sum = sum + events[i];
+ end
+
+ return {
+ samples = events;
+ sample_count = n_stored_events;
+ count = n_actual_events,
+ rate = n_actual_events/(new_time-since);
+ average = n_stored_events > 0 and sum/n_stored_events or 0,
+ min = events[1] or 0,
+ max = events[n_stored_events] or 0,
+ units = units,
+ };
+ end
+
+
+ local registry = {};
+ local methods;
+ methods = {
+ amount = function (name, initial)
+ local v = initial or 0;
+ registry[name..":amount"] = function () return "amount", v; end
+ return function (new_v) v = new_v; end
+ end;
+ counter = function (name, initial)
+ local v = initial or 0;
+ registry[name..":amount"] = function () return "amount", v; end
+ return function (delta)
+ v = v + delta;
+ end;
+ end;
+ rate = function (name)
+ local since, n = time(), 0;
+ registry[name..":rate"] = function ()
+ local t = time();
+ local stats = {
+ rate = n/(t-since);
+ count = n;
+ };
+ since, n = t, 0;
+ return "rate", stats.rate, stats;
+ end;
+ return function ()
+ n = n + 1;
+ end;
+ end;
+ distribution = function (name, unit, type)
+ type = type or "distribution";
+ local events, last_event = {}, 0;
+ local n_actual_events = 0;
+ local since = time();
+
+ registry[name..":"..type] = function ()
+ local new_time = time();
+ local stats = get_distribution_stats(events, n_actual_events, since, new_time, unit);
+ events, last_event = {}, 0;
+ n_actual_events = 0;
+ since = new_time;
+ return type, stats.average, stats;
+ end;
+
+ return function (value)
+ n_actual_events = n_actual_events + 1;
+ if n_actual_events%duration_sample_interval == 1 then
+ last_event = (last_event%duration_max_samples) + 1;
+ events[last_event] = value;
+ end
+ end;
+ end;
+ sizes = function (name)
+ return methods.distribution(name, "bytes", "size");
+ end;
+ times = function (name)
+ local events, last_event = {}, 0;
+ local n_actual_events = 0;
+ local since = time();
+
+ registry[name..":duration"] = function ()
+ local new_time = time();
+ local stats = get_distribution_stats(events, n_actual_events, since, new_time, "seconds");
+ events, last_event = {}, 0;
+ n_actual_events = 0;
+ since = new_time;
+ return "duration", stats.average, stats;
+ end;
+
+ return function ()
+ n_actual_events = n_actual_events + 1;
+ if n_actual_events%duration_sample_interval ~= 1 then
+ return nop_function;
+ end
+
+ local start_time = time();
+ return function ()
+ local end_time = time();
+ local duration = end_time - start_time;
+ last_event = (last_event%duration_max_samples) + 1;
+ events[last_event] = duration;
+ end
+ end;
+ end;
+
+ get_stats = function ()
+ return registry;
+ end;
+ };
+ return methods;
+end
+
+return {
+ new = new_registry;
+ get_histogram = function (duration, n_buckets)
+ n_buckets = n_buckets or 100;
+ local events, n_events = duration.samples, duration.sample_count;
+ if not (events and n_events) then
+ return nil, "not a valid distribution stat";
+ end
+ local histogram = {};
+
+ for i = 1, 100, 100/n_buckets do
+ histogram[i] = percentile(events, n_events, i);
+ end
+ return histogram;
+ end;
+
+ get_percentile = function (duration, pc)
+ local events, n_events = duration.samples, duration.sample_count;
+ if not (events and n_events) then
+ return nil, "not a valid distribution stat";
+ end
+ return percentile(events, n_events, pc);
+ end;
+}
diff --git a/util/statsd.lua b/util/statsd.lua
new file mode 100644
index 00000000..2874e8a8
--- /dev/null
+++ b/util/statsd.lua
@@ -0,0 +1,84 @@
+local socket = require "socket";
+
+local time = require "socket".gettime;
+
+local function new(config)
+ if not config or not config.statsd_server then
+ return nil, "No statsd server specified in the config, please see https://prosody.im/doc/statistics";
+ end
+
+ local sock = socket.udp();
+ sock:setpeername(config.statsd_server, config.statsd_port or 8125);
+
+ local prefix = (config.prefix or "prosody")..".";
+
+ local function send_metric(s)
+ return sock:send(prefix..s);
+ end
+
+ local function send_gauge(name, amount, relative)
+ local s_amount = tostring(amount);
+ if relative and amount > 0 then
+ s_amount = "+"..s_amount;
+ end
+ return send_metric(name..":"..s_amount.."|g");
+ end
+
+ local function send_counter(name, amount)
+ return send_metric(name..":"..tostring(amount).."|c");
+ end
+
+ local function send_duration(name, duration)
+ return send_metric(name..":"..tostring(duration).."|ms");
+ end
+
+ local function send_histogram_sample(name, sample)
+ return send_metric(name..":"..tostring(sample).."|h");
+ end
+
+ local methods;
+ methods = {
+ amount = function (name, initial)
+ if initial then
+ send_gauge(name, initial);
+ end
+ return function (new_v) send_gauge(name, new_v); end
+ end;
+ counter = function (name, initial)
+ return function (delta)
+ send_gauge(name, delta, true);
+ end;
+ end;
+ rate = function (name)
+ return function ()
+ send_counter(name, 1);
+ end;
+ end;
+ distribution = function (name, unit, type) --luacheck: ignore 212/unit 212/type
+ return function (value)
+ send_histogram_sample(name, value);
+ end;
+ end;
+ sizes = function (name)
+ name = name.."_size";
+ return function (value)
+ send_histogram_sample(name, value);
+ end;
+ end;
+ times = function (name)
+ return function ()
+ local start_time = time();
+ return function ()
+ local end_time = time();
+ local duration = end_time - start_time;
+ send_duration(name, duration*1000);
+ end
+ end;
+ end;
+ };
+ return methods;
+end
+
+return {
+ new = new;
+}
diff --git a/util/template.lua b/util/template.lua
index 66d4fca7..04ebb93d 100644
--- a/util/template.lua
+++ b/util/template.lua
@@ -1,4 +1,4 @@
-
+-- luacheck: ignore 213/i
local stanza_mt = require "util.stanza".stanza_mt;
local setmetatable = setmetatable;
local pairs = pairs;
@@ -9,7 +9,7 @@ local debug = debug;
local t_remove = table.remove;
local parse_xml = require "util.xml".parse;
-module("template")
+local _ENV = nil;
local function trim_xml(stanza)
for i=#stanza,1,-1 do
@@ -67,12 +67,12 @@ end
local function create_cloner(stanza, chunkname)
local lookup = {};
local name = create_clone_string(stanza, lookup, "");
- local f = "local setmetatable,stanza_mt=...;return function(data)";
+ local src = "local setmetatable,stanza_mt=...;return function(data)";
for i=1,#lookup do
- f = f.."local _"..i.."="..lookup[i]..";";
+ src = src.."local _"..i.."="..lookup[i]..";";
end
- f = f.."return "..name..";end";
- local f,err = loadstring(f, chunkname);
+ src = src.."return "..name..";end";
+ local f,err = loadstring(src, chunkname);
if not f then error(err); end
return f(setmetatable, stanza_mt);
end
diff --git a/util/termcolours.lua b/util/termcolours.lua
index 6ef3b689..53633b45 100644
--- a/util/termcolours.lua
+++ b/util/termcolours.lua
@@ -1,10 +1,12 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
+--
+-- luacheck: ignore 213/i
local t_concat, t_insert = table.concat, table.insert;
@@ -12,6 +14,10 @@ local char, format = string.char, string.format;
local tonumber = tonumber;
local ipairs = ipairs;
local io_write = io.write;
+local m_floor = math.floor;
+local type = type;
+local setmetatable = setmetatable;
+local pairs = pairs;
local windows;
if os.getenv("WINDIR") then
@@ -19,7 +25,7 @@ if os.getenv("WINDIR") then
end
local orig_color = windows and windows.get_consolecolor and windows.get_consolecolor();
-module "termcolours"
+local _ENV = nil;
local stylemap = {
reset = 0; bright = 1, dim = 2, underscore = 4, blink = 5, reverse = 7, hidden = 8;
@@ -45,7 +51,7 @@ local cssmap = {
};
local fmt_string = char(0x1B).."[%sm%s"..char(0x1B).."[0m";
-function getstring(style, text)
+local function getstring(style, text)
if style then
return format(fmt_string, style, text);
else
@@ -53,7 +59,45 @@ function getstring(style, text)
end
end
-function getstyle(...)
+local function gray(n)
+ return m_floor(n*3/32)+0xe8;
+end
+local function color(r,g,b)
+ if r == g and g == b then
+ return gray(r);
+ end
+ r = m_floor(r*3/128);
+ g = m_floor(g*3/128);
+ b = m_floor(b*3/128);
+ return 0x10 + ( r * 36 ) + ( g * 6 ) + ( b );
+end
+local function hex2rgb(hex)
+ local r = tonumber(hex:sub(1,2),16);
+ local g = tonumber(hex:sub(3,4),16);
+ local b = tonumber(hex:sub(5,6),16);
+ return r,g,b;
+end
+
+setmetatable(stylemap, { __index = function(_, style)
+ if type(style) == "string" and style:find("%x%x%x%x%x%x") == 1 then
+ local g = style:sub(7) == " background" and "48;5;" or "38;5;";
+ return g .. color(hex2rgb(style));
+ end
+end } );
+
+local csscolors = {
+ red = "ff0000"; fuchsia = "ff00ff"; green = "008000"; white = "ffffff";
+ lime = "00ff00"; yellow = "ffff00"; purple = "800080"; blue = "0000ff";
+ aqua = "00ffff"; olive = "808000"; black = "000000"; navy = "000080";
+ teal = "008080"; silver = "c0c0c0"; maroon = "800000"; gray = "808080";
+}
+for colorname, rgb in pairs(csscolors) do
+ stylemap[colorname] = stylemap[colorname] or stylemap[rgb];
+ colorname, rgb = colorname .. " background", rgb .. " background"
+ stylemap[colorname] = stylemap[colorname] or stylemap[rgb];
+end
+
+local function getstyle(...)
local styles, result = { ... }, {};
for i, style in ipairs(styles) do
style = stylemap[style];
@@ -65,7 +109,7 @@ function getstyle(...)
end
local last = "0";
-function setstyle(style)
+local function setstyle(style)
style = style or "0";
if style ~= last then
io_write("\27["..style.."m");
@@ -82,7 +126,7 @@ if windows then
end
end
if not orig_color then
- function setstyle(style) end
+ function setstyle() end
end
end
@@ -95,8 +139,13 @@ local function ansi2css(ansi_codes)
return "</span><span style='"..t_concat(css, ";").."'>";
end
-function tohtml(input)
+local function tohtml(input)
return input:gsub("\027%[(.-)m", ansi2css);
end
-return _M;
+return {
+ getstring = getstring;
+ getstyle = getstyle;
+ setstyle = setstyle;
+ tohtml = tohtml;
+};
diff --git a/util/throttle.lua b/util/throttle.lua
index 55e1d07b..3d3f5d2d 100644
--- a/util/throttle.lua
+++ b/util/throttle.lua
@@ -3,7 +3,7 @@ local gettime = require "socket".gettime;
local setmetatable = setmetatable;
local floor = math.floor;
-module "throttle"
+local _ENV = nil;
local throttle = {};
local throttle_mt = { __index = throttle };
@@ -39,8 +39,10 @@ function throttle:poll(cost, split)
end
end
-function create(max, period)
+local function create(max, period)
return setmetatable({ rate = max / period, max = max, t = 0, balance = max }, throttle_mt);
end
-return _M;
+return {
+ create = create;
+};
diff --git a/util/time.lua b/util/time.lua
new file mode 100644
index 00000000..84cff877
--- /dev/null
+++ b/util/time.lua
@@ -0,0 +1,8 @@
+-- Import gettime() from LuaSocket, as a way to access high-resolution time
+-- in a platform-independent way
+
+local socket_gettime = require "socket".gettime;
+
+return {
+ now = socket_gettime;
+}
diff --git a/util/timer.lua b/util/timer.lua
index af1e57b6..3713625d 100644
--- a/util/timer.lua
+++ b/util/timer.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -17,7 +17,7 @@ local type = type;
local data = {};
local new_data = {};
-module "timer"
+local _ENV = nil;
local _add_task;
if not server.event then
@@ -42,7 +42,7 @@ if not server.event then
end
new_data = {};
end
-
+
local next_time = math_huge;
for i, d in pairs(data) do
local t, callback = d[1], d[2];
@@ -78,6 +78,6 @@ else
end
end
-add_task = _add_task;
-
-return _M;
+return {
+ add_task = _add_task;
+};
diff --git a/util/uuid.lua b/util/uuid.lua
index 3576be8f..f4fd21f6 100644
--- a/util/uuid.lua
+++ b/util/uuid.lua
@@ -1,36 +1,32 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
-local error = error;
-local round_up = math.ceil;
-local urandom, urandom_err = io.open("/dev/urandom", "r");
-
-module "uuid"
+local random = require "util.random";
+local random_bytes = random.bytes;
+local hex = require "util.hex".to;
+local m_ceil = math.ceil;
local function get_nibbles(n)
- local binary_random = urandom:read(round_up(n/2));
- local hex_random = binary_random:gsub(".",
- function (x) return ("%02x"):format(x:byte()) end);
- return hex_random:sub(1, n);
+ return hex(random_bytes(m_ceil(n/2))):sub(1, n);
end
+
local function get_twobits()
- return ("%x"):format(urandom:read(1):byte() % 4 + 8);
+ return ("%x"):format(random_bytes(1):byte() % 4 + 8);
end
-function generate()
- if not urandom then
- error("Unable to obtain a secure random number generator, please see https://prosody.im/doc/random ("..urandom_err..")");
- end
+local function generate()
-- generate RFC 4122 complaint UUIDs (version 4 - random)
return get_nibbles(8).."-"..get_nibbles(4).."-4"..get_nibbles(3).."-"..(get_twobits())..get_nibbles(3).."-"..get_nibbles(12);
end
-function seed()
-end
-
-return _M;
+return {
+ get_nibbles=get_nibbles;
+ generate = generate ;
+ -- COMPAT
+ seed = random.seed;
+};
diff --git a/util/watchdog.lua b/util/watchdog.lua
index bcb2e274..aa8c6486 100644
--- a/util/watchdog.lua
+++ b/util/watchdog.lua
@@ -2,12 +2,12 @@ local timer = require "util.timer";
local setmetatable = setmetatable;
local os_time = os.time;
-module "watchdog"
+local _ENV = nil;
local watchdog_methods = {};
local watchdog_mt = { __index = watchdog_methods };
-function new(timeout, callback)
+local function new(timeout, callback)
local watchdog = setmetatable({ timeout = timeout, last_reset = os_time(), callback = callback }, watchdog_mt);
timer.add_task(timeout+1, function (current_time)
local last_reset = watchdog.last_reset;
@@ -31,4 +31,6 @@ function watchdog_methods:cancel()
self.last_reset = nil;
end
-return _M;
+return {
+ new = new;
+};
diff --git a/util/x509.lua b/util/x509.lua
index 19d4ec6d..f228b201 100644
--- a/util/x509.lua
+++ b/util/x509.lua
@@ -20,13 +20,11 @@
local nameprep = require "util.encodings".stringprep.nameprep;
local idna_to_ascii = require "util.encodings".idna.to_ascii;
+local base64 = require "util.encodings".base64;
local log = require "util.logger".init("x509");
-local pairs, ipairs = pairs, ipairs;
local s_format = string.format;
-local t_insert = table.insert;
-local t_concat = table.concat;
-module "x509"
+local _ENV = nil;
local oid_commonname = "2.5.4.3"; -- [LDAP] 2.3
local oid_subjectaltname = "2.5.29.17"; -- [PKIX] 4.2.1.6
@@ -149,7 +147,10 @@ local function compare_srvname(host, service, asserted_names)
return false
end
-function verify_identity(host, service, cert)
+local function verify_identity(host, service, cert)
+ if cert.setencode then
+ cert:setencode("utf8");
+ end
local ext = cert:extensions()
if ext[oid_subjectaltname] then
local sans = ext[oid_subjectaltname];
@@ -161,7 +162,9 @@ function verify_identity(host, service, cert)
if sans[oid_xmppaddr] then
had_supported_altnames = true
- if compare_xmppaddr(host, sans[oid_xmppaddr]) then return true end
+ if service == "_xmpp-client" or service == "_xmpp-server" then
+ if compare_xmppaddr(host, sans[oid_xmppaddr]) then return true end
+ end
end
if sans[oid_dnssrv] then
@@ -212,4 +215,27 @@ function verify_identity(host, service, cert)
return false
end
-return _M;
+local pat = "%-%-%-%-%-BEGIN ([A-Z ]+)%-%-%-%-%-\r?\n"..
+"([0-9A-Za-z+/=\r\n]*)\r?\n%-%-%-%-%-END %1%-%-%-%-%-";
+
+local function pem2der(pem)
+ local typ, data = pem:match(pat);
+ if typ and data then
+ return base64.decode(data), typ;
+ end
+end
+
+local wrap = ('.'):rep(64);
+local envelope = "-----BEGIN %s-----\n%s\n-----END %s-----\n"
+
+local function der2pem(data, typ)
+ typ = typ and typ:upper() or "CERTIFICATE";
+ data = base64.encode(data);
+ return s_format(envelope, typ, data:gsub(wrap, '%0\n', (#data-1)/64), typ);
+end
+
+return {
+ verify_identity = verify_identity;
+ pem2der = pem2der;
+ der2pem = der2pem;
+};
diff --git a/util/xml.lua b/util/xml.lua
index 076490fa..733d821a 100644
--- a/util/xml.lua
+++ b/util/xml.lua
@@ -2,7 +2,7 @@
local st = require "util.stanza";
local lxp = require "lxp";
-module("xml")
+local _ENV = nil;
local parse_xml = (function()
local ns_prefixes = {
@@ -11,6 +11,7 @@ local parse_xml = (function()
local ns_separator = "\1";
local ns_pattern = "^([^"..ns_separator.."]*)"..ns_separator.."?(.*)$";
return function(xml)
+ --luacheck: ignore 212/self
local handler = {};
local stanza = st.stanza("root");
function handler:StartElement(tagname, attr)
@@ -26,8 +27,8 @@ local parse_xml = (function()
attr[i] = nil;
local ns, nm = k:match(ns_pattern);
if nm ~= "" then
- ns = ns_prefixes[ns];
- if ns then
+ ns = ns_prefixes[ns];
+ if ns then
attr[ns..":"..nm] = attr[k];
attr[k] = nil;
end
@@ -38,7 +39,7 @@ local parse_xml = (function()
function handler:CharacterData(data)
stanza:text(data);
end
- function handler:EndElement(tagname)
+ function handler:EndElement()
stanza:up();
end
local parser = lxp.new(handler, "\1");
@@ -53,5 +54,6 @@ local parse_xml = (function()
end;
end)();
-parse = parse_xml;
-return _M;
+return {
+ parse = parse_xml;
+};
diff --git a/util/xmppstream.lua b/util/xmppstream.lua
index 138c86b7..7be63285 100644
--- a/util/xmppstream.lua
+++ b/util/xmppstream.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -24,7 +24,7 @@ local lxp_supports_bytecount = not not lxp.new({}).getcurrentbytecount;
local default_stanza_size_limit = 1024*1024*10; -- 10MB
-module "xmppstream"
+local _ENV = nil;
local new_parser = lxp.new;
@@ -40,29 +40,26 @@ local xmlns_streams = "http://etherx.jabber.org/streams";
local ns_separator = "\1";
local ns_pattern = "^([^"..ns_separator.."]*)"..ns_separator.."?(.*)$";
-_M.ns_separator = ns_separator;
-_M.ns_pattern = ns_pattern;
-
local function dummy_cb() end
-function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
+local function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
local xml_handlers = {};
-
+
local cb_streamopened = stream_callbacks.streamopened;
local cb_streamclosed = stream_callbacks.streamclosed;
local cb_error = stream_callbacks.error or function(session, e, stanza) error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); end;
local cb_handlestanza = stream_callbacks.handlestanza;
cb_handleprogress = cb_handleprogress or dummy_cb;
-
+
local stream_ns = stream_callbacks.stream_ns or xmlns_streams;
local stream_tag = stream_callbacks.stream_tag or "stream";
if stream_ns ~= "" then
stream_tag = stream_ns..ns_separator..stream_tag;
end
local stream_error_tag = stream_ns..ns_separator..(stream_callbacks.error_tag or "error");
-
+
local stream_default_ns = stream_callbacks.default_ns;
-
+
local stack = {};
local chardata, stanza = {};
local stanza_size = 0;
@@ -82,7 +79,7 @@ function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
attr.xmlns = curr_ns;
non_streamns_depth = non_streamns_depth + 1;
end
-
+
for i=1,#attr do
local k = attr[i];
attr[i] = nil;
@@ -92,7 +89,7 @@ function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
attr[k] = nil;
end
end
-
+
if not stanza then --if we are not currently inside a stanza
if lxp_supports_bytecount then
stanza_size = self:getcurrentbytecount();
@@ -116,7 +113,7 @@ function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
if curr_ns == "jabber:client" and name ~= "iq" and name ~= "presence" and name ~= "message" then
cb_error(session, "invalid-top-level-element");
end
-
+
stanza = setmetatable({ name = name, attr = attr, tags = {} }, stanza_mt);
else -- we are inside a stanza, so add a tag
if lxp_supports_bytecount then
@@ -205,26 +202,26 @@ function new_sax_handlers(session, stream_callbacks, cb_handleprogress)
error("Failed to abort parsing");
end
end
-
+
if lxp_supports_doctype then
xml_handlers.StartDoctypeDecl = restricted_handler;
end
xml_handlers.Comment = restricted_handler;
xml_handlers.ProcessingInstruction = restricted_handler;
-
+
local function reset()
stanza, chardata, stanza_size = nil, {}, 0;
stack = {};
end
-
+
local function set_session(stream, new_session)
session = new_session;
end
-
+
return xml_handlers, { reset = reset, set_session = set_session };
end
-function new(session, stream_callbacks, stanza_size_limit)
+local function new(session, stream_callbacks, stanza_size_limit)
-- Used to track parser progress (e.g. to enforce size limits)
local n_outstanding_bytes = 0;
local handle_progress;
@@ -241,6 +238,25 @@ function new(session, stream_callbacks, stanza_size_limit)
local parser = new_parser(handlers, ns_separator, false);
local parse = parser.parse;
+ function session.open_stream(session, from, to)
+ local send = session.sends2s or session.send;
+
+ local attr = {
+ ["xmlns:stream"] = "http://etherx.jabber.org/streams",
+ ["xml:lang"] = "en",
+ xmlns = stream_callbacks.default_ns,
+ version = session.version and (session.version > 0 and "1.0" or nil),
+ id = session.streamid,
+ from = from or session.host, to = to,
+ };
+ if session.stream_attrs then
+ session:stream_attrs(from, to, attr)
+ end
+ send("<?xml version='1.0'?>");
+ send(st.stanza("stream:stream", attr):top_tag());
+ return true;
+ end
+
return {
reset = function ()
parser = new_parser(handlers, ns_separator, false);
@@ -262,4 +278,9 @@ function new(session, stream_callbacks, stanza_size_limit)
};
end
-return _M;
+return {
+ ns_separator = ns_separator;
+ ns_pattern = ns_pattern;
+ new_sax_handlers = new_sax_handlers;
+ new = new;
+};