aboutsummaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/array.lua35
-rw-r--r--util/async.lua158
-rw-r--r--util/caps.lua2
-rw-r--r--util/dataforms.lua8
-rw-r--r--util/datetime.lua2
-rw-r--r--util/debug.lua38
-rw-r--r--util/dependencies.lua20
-rw-r--r--util/events.lua8
-rw-r--r--util/filters.lua16
-rw-r--r--util/helpers.lua2
-rw-r--r--util/hmac.lua2
-rw-r--r--util/import.lua2
-rw-r--r--util/ip.lua54
-rw-r--r--util/iterators.lua46
-rw-r--r--util/jid.lua2
-rw-r--r--util/json.lua4
-rw-r--r--util/logger.lua2
-rw-r--r--util/multitable.lua2
-rw-r--r--util/pluginloader.lua2
-rw-r--r--util/prosodyctl.lua30
-rw-r--r--util/pubsub.lua47
-rw-r--r--util/sasl.lua47
-rw-r--r--util/sasl/external.lua25
-rw-r--r--util/sasl/scram.lua145
-rw-r--r--util/sasl_cyrus.lua4
-rw-r--r--util/serialization.lua2
-rw-r--r--util/set.lua38
-rw-r--r--util/sql.lua52
-rw-r--r--util/stanza.lua16
-rw-r--r--util/termcolours.lua2
-rw-r--r--util/timer.lua4
-rw-r--r--util/uuid.lua2
-rw-r--r--util/x509.lua4
-rw-r--r--util/xml.lua4
-rw-r--r--util/xmppstream.lua24
35 files changed, 630 insertions, 221 deletions
diff --git a/util/array.lua b/util/array.lua
index 2d58e7fb..6f2abe04 100644
--- a/util/array.lua
+++ b/util/array.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -11,6 +11,7 @@ local t_insert, t_sort, t_remove, t_concat
local setmetatable = setmetatable;
local math_random = math.random;
+local math_floor = math.floor;
local pairs, ipairs = pairs, ipairs;
local tostring = tostring;
@@ -59,13 +60,13 @@ function array_base.filter(outa, ina, func)
write = write + 1;
end
end
-
+
if inplace and write <= start_length then
for i=write,start_length do
outa[i] = nil;
end
end
-
+
return outa;
end
@@ -84,6 +85,25 @@ function array_base.pluck(outa, ina, key)
return outa;
end
+function array_base.reverse(outa, ina)
+ local len = #ina;
+ if ina == outa then
+ local middle = math_floor(len/2);
+ len = len + 1;
+ local o; -- opposite
+ for i = 1, middle do
+ o = len - i;
+ outa[i], outa[o] = outa[o], outa[i];
+ end
+ else
+ local off = len + 1;
+ for i = 1, len do
+ outa[i] = ina[off - i];
+ end
+ end
+ return outa;
+end
+
--- These methods only mutate the array
function array_methods:shuffle(outa, ina)
local len = #self;
@@ -94,15 +114,6 @@ function array_methods:shuffle(outa, ina)
return self;
end
-function array_methods:reverse()
- local len = #self-1;
- for i=len,1,-1 do
- self:push(self[i]);
- self:pop(i);
- end
- return self;
-end
-
function array_methods:append(array)
local len,len2 = #self, #array;
for i=1,len2 do
diff --git a/util/async.lua b/util/async.lua
new file mode 100644
index 00000000..968ec804
--- /dev/null
+++ b/util/async.lua
@@ -0,0 +1,158 @@
+local log = require "util.logger".init("util.async");
+
+local function runner_continue(thread)
+ -- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure)
+ if coroutine.status(thread) ~= "suspended" then -- This should suffice
+ return false;
+ end
+ local ok, state, runner = coroutine.resume(thread);
+ if not ok then
+ local level = 0;
+ while debug.getinfo(thread, level, "") do level = level + 1; end
+ ok, runner = debug.getlocal(thread, level-1, 1);
+ local error_handler = runner.watchers.error;
+ if error_handler then error_handler(runner, debug.traceback(thread, state)); end
+ elseif state == "ready" then
+ -- If state is 'ready', it is our responsibility to update runner.state from 'waiting'.
+ -- We also have to :run(), because the queue might have further items that will not be
+ -- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer).
+ runner.state = "ready";
+ runner:run();
+ end
+ return true;
+end
+
+local function waiter(num)
+ local thread = coroutine.running();
+ if not thread then
+ error("Not running in an async context, see http://prosody.im/doc/developers/async");
+ end
+ num = num or 1;
+ local waiting;
+ return function ()
+ if num == 0 then return; end -- already done
+ waiting = true;
+ coroutine.yield("wait");
+ end, function ()
+ num = num - 1;
+ if num == 0 and waiting then
+ runner_continue(thread);
+ elseif num < 0 then
+ error("done() called too many times");
+ end
+ end;
+end
+
+local function guarder()
+ local guards = {};
+ return function (id, func)
+ local thread = coroutine.running();
+ if not thread then
+ error("Not running in an async context, see http://prosody.im/doc/developers/async");
+ end
+ local guard = guards[id];
+ if not guard then
+ guard = {};
+ guards[id] = guard;
+ log("debug", "New guard!");
+ else
+ table.insert(guard, thread);
+ log("debug", "Guarded. %d threads waiting.", #guard)
+ coroutine.yield("wait");
+ end
+ local function exit()
+ local next_waiting = table.remove(guard, 1);
+ if next_waiting then
+ log("debug", "guard: Executing next waiting thread (%d left)", #guard)
+ runner_continue(next_waiting);
+ else
+ log("debug", "Guard off duty.")
+ guards[id] = nil;
+ end
+ end
+ if func then
+ func();
+ exit();
+ return;
+ end
+ return exit;
+ end;
+end
+
+local runner_mt = {};
+runner_mt.__index = runner_mt;
+
+local function runner_create_thread(func, self)
+ local thread = coroutine.create(function (self)
+ while true do
+ func(coroutine.yield("ready", self));
+ end
+ end);
+ assert(coroutine.resume(thread, self)); -- Start it up, it will return instantly to wait for the first input
+ return thread;
+end
+
+local empty_watchers = {};
+local function runner(func, watchers, data)
+ return setmetatable({ func = func, thread = false, state = "ready", notified_state = "ready",
+ queue = {}, watchers = watchers or empty_watchers, data = data }
+ , runner_mt);
+end
+
+function runner_mt:run(input)
+ if input ~= nil then
+ table.insert(self.queue, input);
+ end
+ if self.state ~= "ready" then
+ return true, self.state, #self.queue;
+ end
+
+ local q, thread = self.queue, self.thread;
+ if not thread or coroutine.status(thread) == "dead" then
+ thread = runner_create_thread(self.func, self);
+ self.thread = thread;
+ end
+
+ local n, state, err = #q, self.state, nil;
+ self.state = "running";
+ while n > 0 and state == "ready" do
+ local consumed;
+ for i = 1,n do
+ local input = q[i];
+ local ok, new_state = coroutine.resume(thread, input);
+ if not ok then
+ consumed, state, err = i, "ready", debug.traceback(thread, new_state);
+ self.thread = nil;
+ break;
+ elseif new_state == "wait" then
+ consumed, state = i, "waiting";
+ break;
+ end
+ end
+ if not consumed then consumed = n; end
+ if q[n+1] ~= nil then
+ n = #q;
+ end
+ for i = 1, n do
+ q[i] = q[consumed+i];
+ end
+ n = #q;
+ end
+ self.state = state;
+ if err or state ~= self.notified_state then
+ if err then
+ state = "error"
+ else
+ self.notified_state = state;
+ end
+ local handler = self.watchers[state];
+ if handler then handler(self, err); end
+ end
+ return true, state, n;
+end
+
+function runner_mt:enqueue(input)
+ table.insert(self.queue, input);
+end
+
+return { waiter = waiter, guarder = guarder, runner = runner };
diff --git a/util/caps.lua b/util/caps.lua
index a61e7403..4723b912 100644
--- a/util/caps.lua
+++ b/util/caps.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
diff --git a/util/dataforms.lua b/util/dataforms.lua
index 52924841..b38d0e27 100644
--- a/util/dataforms.lua
+++ b/util/dataforms.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -38,7 +38,7 @@ function form_t.form(layout, data, formtype)
form:tag("field", { type = field_type, var = field.name, label = field.label });
local value = (data and data[field.name]) or field.value;
-
+
if value then
-- Add value, depending on type
if field_type == "hidden" then
@@ -93,11 +93,11 @@ function form_t.form(layout, data, formtype)
end
end
end
-
+
if field.required then
form:tag("required"):up();
end
-
+
-- Jump back up to list of fields
form:up();
end
diff --git a/util/datetime.lua b/util/datetime.lua
index a1f62a48..dd596527 100644
--- a/util/datetime.lua
+++ b/util/datetime.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
diff --git a/util/debug.lua b/util/debug.lua
index bff0e347..91f691e1 100644
--- a/util/debug.lua
+++ b/util/debug.lua
@@ -24,11 +24,15 @@ do
end
module("debugx", package.seeall);
-function get_locals_table(level)
- level = level + 1; -- Skip this function itself
+function get_locals_table(thread, level)
local locals = {};
for local_num = 1, math.huge do
- local name, value = debug.getlocal(level, local_num);
+ local name, value;
+ if thread then
+ name, value = debug.getlocal(thread, level, local_num);
+ else
+ name, value = debug.getlocal(level+1, local_num);
+ end
if not name then break; end
table.insert(locals, { name = name, value = value });
end
@@ -88,19 +92,19 @@ function get_traceback_table(thread, start_level)
for level = start_level, math.huge do
local info;
if thread then
- info = debug.getinfo(thread, level+1);
+ info = debug.getinfo(thread, level);
else
info = debug.getinfo(level+1);
end
if not info then break; end
-
+
levels[(level-start_level)+1] = {
level = level;
info = info;
- locals = get_locals_table(level+1);
+ locals = get_locals_table(thread, level+(thread and 0 or 1));
upvalues = get_upvalues_table(info.func);
};
- end
+ end
return levels;
end
@@ -134,15 +138,15 @@ function _traceback(thread, message, level)
return nil; -- debug.traceback() does this
end
- level = level or 1;
+ level = level or 0;
message = message and (message.."\n") or "";
-
- -- +3 counts for this function, and the pcall() and wrapper above us
- local levels = get_traceback_table(thread, level+3);
-
+
+ -- +3 counts for this function, and the pcall() and wrapper above us, the +1... I don't know.
+ local levels = get_traceback_table(thread, level+(thread == nil and 4 or 0));
+
local last_source_desc;
-
+
local lines = {};
for nlevel, level in ipairs(levels) do
local info = level.info;
@@ -171,9 +175,11 @@ function _traceback(thread, message, level)
nlevel = nlevel-1;
table.insert(lines, "\t"..(nlevel==0 and ">" or " ")..getstring(styles.level_num, "("..nlevel..") ")..line);
local npadding = (" "):rep(#tostring(nlevel));
- local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding);
- if locals_str then
- table.insert(lines, "\t "..npadding.."Locals: "..locals_str);
+ if level.locals then
+ local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding);
+ if locals_str then
+ table.insert(lines, "\t "..npadding.."Locals: "..locals_str);
+ end
end
local upvalues_str = string_from_var_table(level.upvalues, optimal_line_length, "\t "..npadding);
if upvalues_str then
diff --git a/util/dependencies.lua b/util/dependencies.lua
index 53d2719d..109a3332 100644
--- a/util/dependencies.lua
+++ b/util/dependencies.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -35,7 +35,7 @@ function missingdep(name, sources, msg)
print("");
end
--- COMPAT w/pre-0.8 Debian: The Debian config file used to use
+-- COMPAT w/pre-0.8 Debian: The Debian config file used to use
-- util.ztact, which has been removed from Prosody in 0.8. This
-- is to log an error for people who still use it, so they can
-- update their configs.
@@ -50,9 +50,9 @@ end;
function check_dependencies()
local fatal;
-
+
local lxp = softreq "lxp"
-
+
if not lxp then
missingdep("luaexpat", {
["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-expat0";
@@ -61,9 +61,9 @@ function check_dependencies()
});
fatal = true;
end
-
+
local socket = softreq "socket"
-
+
if not socket then
missingdep("luasocket", {
["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-socket2";
@@ -72,7 +72,7 @@ function check_dependencies()
});
fatal = true;
end
-
+
local lfs, err = softreq "lfs"
if not lfs then
missingdep("luafilesystem", {
@@ -82,9 +82,9 @@ function check_dependencies()
});
fatal = true;
end
-
+
local ssl = softreq "ssl"
-
+
if not ssl then
missingdep("LuaSec", {
["Debian/Ubuntu"] = "http://prosody.im/download/start#debian_and_ubuntu";
@@ -92,7 +92,7 @@ function check_dependencies()
["Source"] = "http://www.inf.puc-rio.br/~brunoos/luasec/";
}, "SSL/TLS support will not be available");
end
-
+
local encodings, err = softreq "util.encodings"
if not encodings then
if err:match("not found") then
diff --git a/util/events.lua b/util/events.lua
index 412acccd..40ca3913 100644
--- a/util/events.lua
+++ b/util/events.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -60,11 +60,11 @@ function new()
remove_handler(event, handler);
end
end;
- local function fire_event(event, ...)
- local h = handlers[event];
+ local function fire_event(event_name, event_data)
+ local h = handlers[event_name];
if h then
for i=1,#h do
- local ret = h[i](...);
+ local ret = h[i](event_data);
if ret ~= nil then return ret; end
end
end
diff --git a/util/filters.lua b/util/filters.lua
index d24bd33e..c2bdca07 100644
--- a/util/filters.lua
+++ b/util/filters.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -16,7 +16,7 @@ function initialize(session)
if not session.filters then
local filters = {};
session.filters = filters;
-
+
function session.filter(type, data)
local filter_list = filters[type];
if filter_list then
@@ -28,11 +28,11 @@ function initialize(session)
return data;
end
end
-
+
for i=1,#new_filter_hooks do
new_filter_hooks[i](session);
end
-
+
return session.filter;
end
@@ -40,20 +40,20 @@ function add_filter(session, type, callback, priority)
if not session.filters then
initialize(session);
end
-
+
local filter_list = session.filters[type];
if not filter_list then
filter_list = {};
session.filters[type] = filter_list;
end
-
+
priority = priority or 0;
-
+
local i = 0;
repeat
i = i + 1;
until not filter_list[i] or filter_list[filter_list[i]] < priority;
-
+
t_insert(filter_list, i, callback);
filter_list[callback] = priority;
end
diff --git a/util/helpers.lua b/util/helpers.lua
index 08b86a7c..437a920c 100644
--- a/util/helpers.lua
+++ b/util/helpers.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
diff --git a/util/hmac.lua b/util/hmac.lua
index 51211c7a..2c4cc6ef 100644
--- a/util/hmac.lua
+++ b/util/hmac.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
diff --git a/util/import.lua b/util/import.lua
index 81401e8b..174da0ca 100644
--- a/util/import.lua
+++ b/util/import.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
diff --git a/util/ip.lua b/util/ip.lua
index 856bf034..d0ae07eb 100644
--- a/util/ip.lua
+++ b/util/ip.lua
@@ -12,7 +12,17 @@ local ip_mt = { __index = function (ip, key) return (ip_methods[key])(ip); end,
local hex2bits = { ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111", ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011", ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111" };
local function new_ip(ipStr, proto)
- if proto ~= "IPv4" and proto ~= "IPv6" then
+ if not proto then
+ local sep = ipStr:match("^%x+(.)");
+ if sep == ":" or (not(sep) and ipStr:sub(1,1) == ":") then
+ proto = "IPv6"
+ elseif sep == "." then
+ proto = "IPv4"
+ end
+ if not proto then
+ return nil, "invalid address";
+ end
+ elseif proto ~= "IPv4" and proto ~= "IPv6" then
return nil, "invalid protocol";
end
if proto == "IPv6" and ipStr:find('.', 1, true) then
@@ -82,7 +92,7 @@ local function v6scope(ip)
if ip:match("^[0:]*1$") then
return 0x2;
-- Link-local unicast:
- elseif ip:match("^[Ff][Ee][89ABab]") then
+ elseif ip:match("^[Ff][Ee][89ABab]") then
return 0x2;
-- Site-local unicast:
elseif ip:match("^[Ff][Ee][CcDdEeFf]") then
@@ -192,5 +202,43 @@ function ip_methods:scope()
return value;
end
+function ip_methods:private()
+ local private = self.scope ~= 0xE;
+ if not private and self.proto == "IPv4" then
+ local ip = self.addr;
+ local fields = {};
+ ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end);
+ if fields[1] == 127 or fields[1] == 10 or (fields[1] == 192 and fields[2] == 168)
+ or (fields[1] == 172 and (fields[2] >= 16 or fields[2] <= 32)) then
+ private = true;
+ end
+ end
+ self.private = private;
+ return private;
+end
+
+local function parse_cidr(cidr)
+ local bits;
+ local ip_len = cidr:find("/", 1, true);
+ if ip_len then
+ bits = tonumber(cidr:sub(ip_len+1, -1));
+ cidr = cidr:sub(1, ip_len-1);
+ end
+ return new_ip(cidr), bits;
+end
+
+local function match(ipA, ipB, bits)
+ local common_bits = commonPrefixLength(ipA, ipB);
+ if not bits then
+ return ipA == ipB;
+ end
+ if bits and ipB.proto == "IPv4" then
+ common_bits = common_bits - 96; -- v6 mapped addresses always share these bits
+ end
+ return common_bits >= bits;
+end
+
return {new_ip = new_ip,
- commonPrefixLength = commonPrefixLength};
+ commonPrefixLength = commonPrefixLength,
+ parse_cidr = parse_cidr,
+ match=match};
diff --git a/util/iterators.lua b/util/iterators.lua
index 1f6aacb8..aa9c3ec0 100644
--- a/util/iterators.lua
+++ b/util/iterators.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -10,6 +10,10 @@
local it = {};
+local t_insert = table.insert;
+local select, unpack, next = select, unpack, next;
+local function pack(...) return { n = select("#", ...), ... }; end
+
-- Reverse an iterator
function it.reverse(f, s, var)
local results = {};
@@ -19,9 +23,9 @@ function it.reverse(f, s, var)
local ret = { f(s, var) };
var = ret[1];
if var == nil then break; end
- table.insert(results, 1, ret);
+ t_insert(results, 1, ret);
end
-
+
-- Then return our reverse one
local i,max = 0, #results;
return function (results)
@@ -52,15 +56,15 @@ end
-- Given an iterator, iterate only over unique items
function it.unique(f, s, var)
local set = {};
-
+
return function ()
while true do
- local ret = { f(s, var) };
+ local ret = pack(f(s, var));
var = ret[1];
if var == nil then break; end
if not set[var] then
set[var] = true;
- return var;
+ return unpack(ret, 1, ret.n);
end
end
end;
@@ -69,14 +73,13 @@ end
--[[ Return the number of items an iterator returns ]]--
function it.count(f, s, var)
local x = 0;
-
+
while true do
- local ret = { f(s, var) };
- var = ret[1];
+ var = f(s, var);
if var == nil then break; end
x = x + 1;
end
-
+
return x;
end
@@ -104,7 +107,7 @@ end
function it.tail(n, f, s, var)
local results, count = {}, 0;
while true do
- local ret = { f(s, var) };
+ local ret = pack(f(s, var));
var = ret[1];
if var == nil then break; end
results[(count%n)+1] = ret;
@@ -117,9 +120,24 @@ function it.tail(n, f, s, var)
return function ()
pos = pos + 1;
if pos > n then return nil; end
- return unpack(results[((count-1+pos)%n)+1]);
+ local ret = results[((count-1+pos)%n)+1];
+ return unpack(ret, 1, ret.n);
end
- --return reverse(head(n, reverse(f, s, var)));
+ --return reverse(head(n, reverse(f, s, var))); -- !
+end
+
+function it.filter(filter, f, s, var)
+ if type(filter) ~= "function" then
+ local filter_value = filter;
+ function filter(x) return x ~= filter_value; end
+ end
+ return function (s, var)
+ local ret;
+ repeat ret = pack(f(s, var));
+ var = ret[1];
+ until var == nil or filter(unpack(ret, 1, ret.n));
+ return unpack(ret, 1, ret.n);
+ end, s, var;
end
local function _ripairs_iter(t, key) if key > 1 then return key-1, t[key-1]; end end
@@ -139,7 +157,7 @@ function it.to_array(f, s, var)
while true do
var = f(s, var);
if var == nil then break; end
- table.insert(t, var);
+ t_insert(t, var);
end
return t;
end
diff --git a/util/jid.lua b/util/jid.lua
index 8e0a784c..08e63335 100644
--- a/util/jid.lua
+++ b/util/jid.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
diff --git a/util/json.lua b/util/json.lua
index 82ebcc43..a8a58afc 100644
--- a/util/json.lua
+++ b/util/json.lua
@@ -348,9 +348,9 @@ local first_escape = {
function json.decode(json)
json = json:gsub("\\.", first_escape) -- get rid of all escapes except \uXXXX, making string parsing much simpler
--:gsub("[\r\n]", "\t"); -- \r\n\t are equivalent, we care about none of them, and none of them can be in strings
-
+
-- TODO do encoding verification
-
+
local val, index = _readvalue(json, 1);
if val == nil then return val, index; end
if json:find("[^ \t\r\n]", index) then return nil, "garbage at eof"; end
diff --git a/util/logger.lua b/util/logger.lua
index 26206d4d..cd0769f9 100644
--- a/util/logger.lua
+++ b/util/logger.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
diff --git a/util/multitable.lua b/util/multitable.lua
index dbf34d28..caf25118 100644
--- a/util/multitable.lua
+++ b/util/multitable.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
diff --git a/util/pluginloader.lua b/util/pluginloader.lua
index 112c0d52..b9b3e207 100644
--- a/util/pluginloader.lua
+++ b/util/pluginloader.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua
index b80a69f2..fe862114 100644
--- a/util/prosodyctl.lua
+++ b/util/prosodyctl.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -146,7 +146,7 @@ function adduser(params)
if not(provider) or provider.name == "null" then
usermanager.initialize_host(host);
end
-
+
local ok, errmsg = usermanager.create_user(user, password, host);
if not ok then
return false, errmsg;
@@ -162,7 +162,7 @@ function user_exists(params)
if not(provider) or provider.name == "null" then
usermanager.initialize_host(host);
end
-
+
return usermanager.user_exists(user, host);
end
@@ -170,7 +170,7 @@ function passwd(params)
if not _M.user_exists(params) then
return false, "no-such-user";
end
-
+
return _M.adduser(params);
end
@@ -179,7 +179,7 @@ function deluser(params)
return false, "no-such-user";
end
local user, host = nodeprep(params.user), nameprep(params.host);
-
+
return usermanager.delete_user(user, host);
end
@@ -188,30 +188,30 @@ function getpid()
if not pidfile then
return false, "no-pidfile";
end
-
+
local modules_enabled = set.new(config.get("*", "modules_enabled"));
if not modules_enabled:contains("posix") then
return false, "no-posix";
end
-
+
local file, err = io.open(pidfile, "r+");
if not file then
return false, "pidfile-read-failed", err;
end
-
+
local locked, err = lfs.lock(file, "w");
if locked then
file:close();
return false, "pidfile-not-locked";
end
-
+
local pid = tonumber(file:read("*a"));
file:close();
-
+
if not pid then
return false, "invalid-pid";
end
-
+
return true, pid;
end
@@ -252,10 +252,10 @@ function stop()
if not ret then
return false, "not-running";
end
-
+
local ok, pid = _M.getpid()
if not ok then return false, pid; end
-
+
signal.kill(pid, signal.SIGTERM);
return true;
end
@@ -268,10 +268,10 @@ function reload()
if not ret then
return false, "not-running";
end
-
+
local ok, pid = _M.getpid()
if not ok then return false, pid; end
-
+
signal.kill(pid, signal.SIGHUP);
return true;
end
diff --git a/util/pubsub.lua b/util/pubsub.lua
index e1418c62..fc67cb1f 100644
--- a/util/pubsub.lua
+++ b/util/pubsub.lua
@@ -1,4 +1,5 @@
local events = require "util.events";
+local t_remove = table.remove;
module("pubsub", package.seeall);
@@ -18,6 +19,7 @@ function new(config)
affiliations = {};
subscriptions = {};
nodes = {};
+ data = {};
events = events.new();
}, service_mt);
end
@@ -29,13 +31,13 @@ end
function service:may(node, actor, action)
if actor == true then return true; end
-
+
local node_obj = self.nodes[node];
local node_aff = node_obj and node_obj.affiliations[actor];
local service_aff = self.affiliations[actor]
or self.config.get_affiliation(actor, node, action)
or "none";
-
+
-- Check if node allows/forbids it
local node_capabilities = node_obj and node_obj.capabilities;
if node_capabilities then
@@ -47,7 +49,7 @@ function service:may(node, actor, action)
end
end
end
-
+
-- Check service-wide capabilities instead
local service_capabilities = self.config.capabilities;
local caps = service_capabilities[node_aff or service_aff];
@@ -57,7 +59,7 @@ function service:may(node, actor, action)
return can;
end
end
-
+
return false;
end
@@ -211,17 +213,20 @@ function service:create(node, actor)
if self.nodes[node] then
return false, "conflict";
end
-
+
+ self.data[node] = {};
self.nodes[node] = {
name = node;
subscribers = {};
config = {};
- data = {};
affiliations = {};
};
+ setmetatable(self.nodes[node], { __index = { data = self.data[node] } }); -- COMPAT
+ self.events.fire_event("node-created", { node = node, actor = actor });
local ok, err = self:set_affiliation(node, true, actor, "owner");
if not ok then
self.nodes[node] = nil;
+ self.data[node] = nil;
end
return ok, err;
end
@@ -237,10 +242,23 @@ function service:delete(node, actor)
return false, "item-not-found";
end
self.nodes[node] = nil;
+ self.data[node] = nil;
+ self.events.fire_event("node-deleted", { node = node, actor = actor });
self.config.broadcaster("delete", node, node_obj.subscribers);
return true;
end
+local function remove_item_by_id(data, id)
+ if not data[id] then return end
+ data[id] = nil;
+ for i, _id in ipairs(data) do
+ if id == _id then
+ t_remove(data, i);
+ return i;
+ end
+ end
+end
+
function service:publish(node, actor, id, item)
-- Access checking
if not self:may(node, actor, "publish") then
@@ -258,7 +276,10 @@ function service:publish(node, actor, id, item)
end
node_obj = self.nodes[node];
end
- node_obj.data[id] = item;
+ local node_data = self.data[node];
+ remove_item_by_id(node_data, id);
+ node_data[#self.data[node] + 1] = id;
+ node_data[id] = item;
self.events.fire_event("item-published", { node = node, actor = actor, id = id, item = item });
self.config.broadcaster("items", node, node_obj.subscribers, item);
return true;
@@ -271,10 +292,11 @@ function service:retract(node, actor, id, retract)
end
--
local node_obj = self.nodes[node];
- if (not node_obj) or (not node_obj.data[id]) then
+ if (not node_obj) or (not self.data[node][id]) then
return false, "item-not-found";
end
- node_obj.data[id] = nil;
+ self.events.fire_event("item-retracted", { node = node, actor = actor, id = id });
+ remove_item_by_id(self.data[node], id);
if retract then
self.config.broadcaster("items", node, node_obj.subscribers, retract);
end
@@ -291,7 +313,8 @@ function service:purge(node, actor, notify)
if not node_obj then
return false, "item-not-found";
end
- node_obj.data = {}; -- Purge
+ self.data[node] = {}; -- Purge
+ self.events.fire_event("node-purged", { node = node, actor = actor });
if notify then
self.config.broadcaster("purge", node, node_obj.subscribers);
end
@@ -309,9 +332,9 @@ function service:get_items(node, actor, id)
return false, "item-not-found";
end
if id then -- Restrict results to a single specific item
- return true, { [id] = node_obj.data[id] };
+ return true, { id, [id] = self.data[node][id] };
else
- return true, node_obj.data;
+ return true, self.data[node];
end
end
diff --git a/util/sasl.lua b/util/sasl.lua
index afb3861b..c8490842 100644
--- a/util/sasl.lua
+++ b/util/sasl.lua
@@ -27,19 +27,38 @@ Authentication Backend Prototypes:
state = false : disabled
state = true : enabled
state = nil : non-existant
+
+Channel Binding:
+
+To enable support of channel binding in some mechanisms you need to provide appropriate callbacks in a table
+at profile.cb.
+
+Example:
+ profile.cb["tls-unique"] = function(self)
+ return self.user
+ end
+
]]
local method = {};
method.__index = method;
local mechanisms = {};
local backend_mechanism = {};
+local mechanism_channelbindings = {};
-- register a new SASL mechanims
-function registerMechanism(name, backends, f)
+function registerMechanism(name, backends, f, cb_backends)
assert(type(name) == "string", "Parameter name MUST be a string.");
assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table.");
assert(type(f) == "function", "Parameter f MUST be a function.");
+ if cb_backends then assert(type(cb_backends) == "table"); end
mechanisms[name] = f
+ if cb_backends then
+ mechanism_channelbindings[name] = {};
+ for _, cb_name in ipairs(cb_backends) do
+ mechanism_channelbindings[name][cb_name] = true;
+ end
+ end
for _, backend_name in ipairs(backends) do
if backend_mechanism[backend_name] == nil then backend_mechanism[backend_name] = {}; end
t_insert(backend_mechanism[backend_name], name);
@@ -63,6 +82,15 @@ function new(realm, profile)
return setmetatable({ profile = profile, realm = realm, mechs = mechanisms }, method);
end
+-- add a channel binding handler
+function method:add_cb_handler(name, f)
+ if type(self.profile.cb) ~= "table" then
+ self.profile.cb = {};
+ end
+ self.profile.cb[name] = f;
+ return self;
+end
+
-- get a fresh clone with the same realm and profile
function method:clean_clone()
return new(self.realm, self.profile)
@@ -70,7 +98,21 @@ end
-- get a list of possible SASL mechanims to use
function method:mechanisms()
- return self.mechs;
+ local current_mechs = {};
+ for mech, _ in pairs(self.mechs) do
+ if mechanism_channelbindings[mech] and self.profile.cb then
+ local ok = false;
+ for cb_name, _ in pairs(self.profile.cb) do
+ if mechanism_channelbindings[mech][cb_name] then
+ ok = true;
+ end
+ end
+ if ok == true then current_mechs[mech] = true; end
+ else
+ current_mechs[mech] = true;
+ end
+ end
+ return current_mechs;
end
-- select a mechanism to use
@@ -92,5 +134,6 @@ require "util.sasl.plain" .init(registerMechanism);
require "util.sasl.digest-md5".init(registerMechanism);
require "util.sasl.anonymous" .init(registerMechanism);
require "util.sasl.scram" .init(registerMechanism);
+require "util.sasl.external" .init(registerMechanism);
return _M;
diff --git a/util/sasl/external.lua b/util/sasl/external.lua
new file mode 100644
index 00000000..4c5c4343
--- /dev/null
+++ b/util/sasl/external.lua
@@ -0,0 +1,25 @@
+local saslprep = require "util.encodings".stringprep.saslprep;
+
+module "sasl.external"
+
+local function external(self, message)
+ message = saslprep(message);
+ local state
+ self.username, state = self.profile.external(message);
+
+ if state == false then
+ return "failure", "account-disabled";
+ elseif state == nil then
+ return "failure", "not-authorized";
+ elseif state == "expired" then
+ return "false", "credentials-expired";
+ end
+
+ return "success";
+end
+
+function init(registerMechanism)
+ registerMechanism("EXTERNAL", {"external"}, external);
+end
+
+return _M;
diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua
index cf2f0ede..0d2852bf 100644
--- a/util/sasl/scram.lua
+++ b/util/sasl/scram.lua
@@ -13,7 +13,6 @@
local s_match = string.match;
local type = type
-local string = string
local base64 = require "util.encodings".base64;
local hmac_sha1 = require "util.hashes".hmac_sha1;
local sha1 = require "util.hashes".sha1;
@@ -39,18 +38,14 @@ scram_{MECH}:
function(username, realm)
return stored_key, server_key, iteration_count, salt, state;
end
+
+Supported Channel Binding Backends
+
+'tls-unique' according to RFC 5929
]]
local default_i = 4096
-local function bp( b )
- local result = ""
- for i=1, b:len() do
- result = result.."\\"..b:byte(i)
- end
- return result
-end
-
local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;};
local result = {};
@@ -73,11 +68,11 @@ local function validate_username(username, _nodeprep)
return false
end
end
-
+
-- replace =2C with , and =3D with =
username = username:gsub("=2C", ",");
username = username:gsub("=3D", "=");
-
+
-- apply SASLprep
username = saslprep(username);
@@ -106,96 +101,131 @@ function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
end
local function scram_gen(hash_name, H_f, HMAC_f)
+ local profile_name = "scram_" .. hashprep(hash_name);
local function scram_hash(self, message)
- if not self.state then self["state"] = {} end
-
+ local support_channel_binding = false;
+ if self.profile.cb then support_channel_binding = true; end
+
if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end
- if not self.state.name then
+ local state = self.state;
+ if not state then
-- we are processing client_first_message
local client_first_message = message;
-
+
-- TODO: fail if authzid is provided, since we don't support them yet
- self.state["client_first_message"] = client_first_message;
- self.state["gs2_cbind_flag"], self.state["authzid"], self.state["name"], self.state["clientnonce"]
- = client_first_message:match("^(%a),(.*),n=(.*),r=([^,]*).*");
+ local gs2_header, gs2_cbind_flag, gs2_cbind_name, authzid, client_first_message_bare, username, clientnonce
+ = s_match(client_first_message, "^(([pny])=?([^,]*),([^,]*),)(m?=?[^,]*,?n=([^,]*),r=([^,]*),?.*)$");
- -- we don't do any channel binding yet
- if self.state.gs2_cbind_flag ~= "n" and self.state.gs2_cbind_flag ~= "y" then
+ if not gs2_cbind_flag then
return "failure", "malformed-request";
end
- if not self.state.name or not self.state.clientnonce then
- return "failure", "malformed-request", "Channel binding isn't support at this time.";
+ if support_channel_binding and gs2_cbind_flag == "y" then
+ -- "y" -> client does support channel binding
+ -- but thinks the server does not.
+ return "failure", "malformed-request";
+ end
+
+ if gs2_cbind_flag == "n" then
+ -- "n" -> client doesn't support channel binding.
+ support_channel_binding = false;
end
-
- self.state.name = validate_username(self.state.name, self.profile.nodeprep);
- if not self.state.name then
+
+ if support_channel_binding and gs2_cbind_flag == "p" then
+ -- check whether we support the proposed channel binding type
+ if not self.profile.cb[gs2_cbind_name] then
+ return "failure", "malformed-request", "Proposed channel binding type isn't supported.";
+ end
+ else
+ -- no channel binding,
+ gs2_cbind_name = nil;
+ end
+
+ username = validate_username(username, self.profile.nodeprep);
+ if not username then
log("debug", "Username violates either SASLprep or contains forbidden character sequences.")
return "failure", "malformed-request", "Invalid username.";
end
-
- self.state["servernonce"] = generate_uuid();
-
+
-- retreive credentials
+ local stored_key, server_key, salt, iteration_count;
if self.profile.plain then
- local password, state = self.profile.plain(self, self.state.name, self.realm)
+ local password, state = self.profile.plain(self, username, self.realm)
if state == nil then return "failure", "not-authorized"
elseif state == false then return "failure", "account-disabled" end
-
+
password = saslprep(password);
if not password then
log("debug", "Password violates SASLprep.");
return "failure", "not-authorized", "Invalid password."
end
- self.state.salt = generate_uuid();
- self.state.iteration_count = default_i;
+ salt = generate_uuid();
+ iteration_count = default_i;
local succ = false;
- succ, self.state.stored_key, self.state.server_key = getAuthenticationDatabaseSHA1(password, self.state.salt, default_i, self.state.iteration_count);
+ succ, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count);
if not succ then
- log("error", "Generating authentication database failed. Reason: %s", self.state.stored_key);
+ log("error", "Generating authentication database failed. Reason: %s", stored_key);
return "failure", "temporary-auth-failure";
end
- elseif self.profile["scram_"..hashprep(hash_name)] then
- local stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, self.state.name, self.realm);
+ elseif self.profile[profile_name] then
+ local state;
+ stored_key, server_key, iteration_count, salt, state = self.profile[profile_name](self, username, self.realm);
if state == nil then return "failure", "not-authorized"
elseif state == false then return "failure", "account-disabled" end
-
- self.state.stored_key = stored_key;
- self.state.server_key = server_key;
- self.state.iteration_count = iteration_count;
- self.state.salt = salt
end
-
- local server_first_message = "r="..self.state.clientnonce..self.state.servernonce..",s="..base64.encode(self.state.salt)..",i="..self.state.iteration_count;
- self.state["server_first_message"] = server_first_message;
+
+ local nonce = clientnonce .. generate_uuid();
+ local server_first_message = "r="..nonce..",s="..base64.encode(salt)..",i="..iteration_count;
+ self.state = {
+ gs2_header = gs2_header;
+ gs2_cbind_name = gs2_cbind_name;
+ username = username;
+ nonce = nonce;
+
+ server_key = server_key;
+ stored_key = stored_key;
+ client_first_message_bare = client_first_message_bare;
+ server_first_message = server_first_message;
+ }
return "challenge", server_first_message
else
-- we are processing client_final_message
local client_final_message = message;
-
- self.state["channelbinding"], self.state["nonce"], self.state["proof"] = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)");
-
- if not self.state.proof or not self.state.nonce or not self.state.channelbinding then
+
+ local client_final_message_without_proof, channelbinding, nonce, proof
+ = s_match(client_final_message, "(c=([^,]*),r=([^,]*),?.-),p=(.*)$");
+
+ if not proof or not nonce or not channelbinding then
return "failure", "malformed-request", "Missing an attribute(p, r or c) in SASL message.";
end
- if self.state.nonce ~= self.state.clientnonce..self.state.servernonce then
+ local client_gs2_header = base64.decode(channelbinding)
+ local our_client_gs2_header = state["gs2_header"]
+ if state.gs2_cbind_name then
+ -- we support channelbinding, so check if the value is valid
+ our_client_gs2_header = our_client_gs2_header .. self.profile.cb[state.gs2_cbind_name](self);
+ end
+ if client_gs2_header ~= our_client_gs2_header then
+ return "failure", "malformed-request", "Invalid channel binding value.";
+ end
+
+ if nonce ~= state.nonce then
return "failure", "malformed-request", "Wrong nonce in client-final-message.";
end
-
- local ServerKey = self.state.server_key;
- local StoredKey = self.state.stored_key;
-
- local AuthMessage = "n=" .. s_match(self.state.client_first_message,"n=(.+)") .. "," .. self.state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+")
+
+ local ServerKey = state.server_key;
+ local StoredKey = state.stored_key;
+
+ local AuthMessage = state.client_first_message_bare .. "," .. state.server_first_message .. "," .. client_final_message_without_proof
local ClientSignature = HMAC_f(StoredKey, AuthMessage)
- local ClientKey = binaryXOR(ClientSignature, base64.decode(self.state.proof))
+ local ClientKey = binaryXOR(ClientSignature, base64.decode(proof))
local ServerSignature = HMAC_f(ServerKey, AuthMessage)
if StoredKey == H_f(ClientKey) then
local server_final_message = "v="..base64.encode(ServerSignature);
- self["username"] = self.state.name;
+ self["username"] = state.username;
return "success", server_final_message;
else
return "failure", "not-authorized", "The response provided by the client doesn't match the one we calculated.";
@@ -208,6 +238,9 @@ end
function init(registerMechanism)
local function registerSCRAMMechanism(hash_name, hash, hmac_hash)
registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash));
+
+ -- register channel binding equivalent
+ registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"});
end
registerSCRAMMechanism("SHA-1", sha1, hmac_sha1);
diff --git a/util/sasl_cyrus.lua b/util/sasl_cyrus.lua
index 19684587..a0e8bd69 100644
--- a/util/sasl_cyrus.lua
+++ b/util/sasl_cyrus.lua
@@ -78,10 +78,10 @@ local function init(service_name)
end
-- create a new SASL object which can be used to authenticate clients
--- host_fqdn may be nil in which case gethostname() gives the value.
+-- host_fqdn may be nil in which case gethostname() gives the value.
-- For GSSAPI, this determines the hostname in the service ticket (after
-- reverse DNS canonicalization, only if [libdefaults] rdns = true which
--- is the default).
+-- is the default).
function new(realm, service_name, app_name, host_fqdn)
init(app_name or service_name);
diff --git a/util/serialization.lua b/util/serialization.lua
index 8a259184..06e45054 100644
--- a/util/serialization.lua
+++ b/util/serialization.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
diff --git a/util/set.lua b/util/set.lua
index fa065a9c..04f5f0f4 100644
--- a/util/set.lua
+++ b/util/set.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -40,13 +40,13 @@ function set_mt.__eq(set1, set2)
return false;
end
end
-
+
for item in pairs(set2) do
if not set1[item] then
return false;
end
end
-
+
return true;
end
function set_mt.__tostring(set)
@@ -65,23 +65,23 @@ end
function new(list)
local items = setmetatable({}, items_mt);
local set = { _items = items };
-
+
function set:add(item)
items[item] = true;
end
-
+
function set:contains(item)
return items[item];
end
-
+
function set:items()
- return items;
+ return next, items;
end
-
+
function set:remove(item)
items[item] = nil;
end
-
+
function set:add_list(list)
if list then
for _, item in ipairs(list) do
@@ -89,7 +89,7 @@ function new(list)
end
end
end
-
+
function set:include(otherset)
for item in otherset do
items[item] = true;
@@ -101,22 +101,22 @@ function new(list)
items[item] = nil;
end
end
-
+
function set:empty()
return not next(items);
end
-
+
if list then
set:add_list(list);
end
-
+
return setmetatable(set, set_mt);
end
function union(set1, set2)
local set = new();
local items = set._items;
-
+
for item in pairs(set1._items) do
items[item] = true;
end
@@ -124,14 +124,14 @@ function union(set1, set2)
for item in pairs(set2._items) do
items[item] = true;
end
-
+
return set;
end
function difference(set1, set2)
local set = new();
local items = set._items;
-
+
for item in pairs(set1._items) do
items[item] = (not set2._items[item]) or nil;
end
@@ -142,13 +142,13 @@ end
function intersection(set1, set2)
local set = new();
local items = set._items;
-
+
set1, set2 = set1._items, set2._items;
-
+
for item in pairs(set1) do
items[item] = (not not set2[item]) or nil;
end
-
+
return set;
end
diff --git a/util/sql.lua b/util/sql.lua
index f360d6d0..5a1dda5d 100644
--- a/util/sql.lua
+++ b/util/sql.lua
@@ -45,7 +45,7 @@ function String(n) return "String()" end
};
local functions = {
-
+
};
local cmap = {
@@ -177,8 +177,8 @@ function engine:execute(sql, ...)
end
local result_mt = { __index = {
- affected = function(self) return self.__affected; end;
- rowcount = function(self) return self.__rowcount; end;
+ affected = function(self) return self.__stmt:affected(); end;
+ rowcount = function(self) return self.__stmt:rowcount(); end;
} };
function engine:execute_query(sql, ...)
@@ -200,7 +200,7 @@ function engine:execute_update(sql, ...)
prepared[sql] = stmt;
end
assert(stmt:execute(...));
- return setmetatable({ __affected = stmt:affected(), __rowcount = stmt:rowcount() }, result_mt);
+ return setmetatable({ __stmt = stmt }, result_mt);
end
engine.insert = engine.execute_update;
engine.select = engine.execute_query;
@@ -251,19 +251,39 @@ function engine:_create_index(index)
elseif self.params.driver == "MySQL" then
sql = sql:gsub("`([,)])", "`(20)%1");
end
+ if index.unique then
+ sql = sql:gsub("^CREATE", "CREATE UNIQUE");
+ end
--print(sql);
return self:execute(sql);
end
function engine:_create_table(table)
local sql = "CREATE TABLE `"..table.name.."` (";
for i,col in ipairs(table.c) do
- sql = sql.."`"..col.name.."` "..col.type;
+ local col_type = col.type;
+ if col_type == "MEDIUMTEXT" and self.params.driver ~= "MySQL" then
+ col_type = "TEXT"; -- MEDIUMTEXT is MySQL-specific
+ end
+ if col.auto_increment == true and self.params.driver == "PostgreSQL" then
+ col_type = "BIGSERIAL";
+ end
+ sql = sql.."`"..col.name.."` "..col_type;
if col.nullable == false then sql = sql.." NOT NULL"; end
+ if col.primary_key == true then sql = sql.." PRIMARY KEY"; end
+ if col.auto_increment == true then
+ if self.params.driver == "MySQL" then
+ sql = sql.." AUTO_INCREMENT";
+ elseif self.params.driver == "SQLite3" then
+ sql = sql.." AUTOINCREMENT";
+ end
+ end
if i ~= #table.c then sql = sql..", "; end
end
sql = sql.. ");"
if self.params.driver == "PostgreSQL" then
sql = sql:gsub("`", "\"");
+ elseif self.params.driver == "MySQL" then
+ sql = sql:gsub(";$", " CHARACTER SET 'utf8' COLLATE 'utf8_bin';");
end
local success,err = self:execute(sql);
if not success then return success,err; end
@@ -274,6 +294,28 @@ function engine:_create_table(table)
end
return success;
end
+function engine:set_encoding() -- to UTF-8
+ local driver = self.params.driver;
+ if driver == "SQLite3" then
+ return self:transaction(function()
+ if self:select"PRAGMA encoding;"()[1] == "UTF-8" then
+ self.charset = "utf8";
+ end
+ end);
+ end
+ local set_names_query = "SET NAMES '%s';"
+ local charset = "utf8";
+ if driver == "MySQL" then
+ set_names_query = set_names_query:gsub(";$", " COLLATE 'utf8_bin';");
+ local ok, charsets = self:transaction(function()
+ return self:select"SELECT `CHARACTER_SET_NAME` FROM `information_schema`.`CHARACTER_SETS` WHERE `CHARACTER_SET_NAME` LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;";
+ end);
+ local row = ok and charsets();
+ charset = row and row[1] or charset;
+ end
+ self.charset = charset;
+ return self:transaction(function() return self:execute(set_names_query:format(charset)); end);
+end
local engine_mt = { __index = engine };
local function db2uri(params)
diff --git a/util/stanza.lua b/util/stanza.lua
index 7c214210..82601e63 100644
--- a/util/stanza.lua
+++ b/util/stanza.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -99,7 +99,7 @@ function stanza_mt:get_child(name, xmlns)
if (not name or child.name == name)
and ((not xmlns and self.attr.xmlns == child.attr.xmlns)
or child.attr.xmlns == xmlns) then
-
+
return child;
end
end
@@ -152,7 +152,7 @@ end
function stanza_mt:maptags(callback)
local tags, curr_tag = self.tags, 1;
local n_children, n_tags = #self, #tags;
-
+
local i = 1;
while curr_tag <= n_tags and n_tags > 0 do
if self[i] == tags[curr_tag] then
@@ -258,13 +258,13 @@ end
function stanza_mt.get_error(stanza)
local type, condition, text;
-
+
local error_tag = stanza:get_child("error");
if not error_tag then
return nil, nil, nil;
end
type = error_tag.attr.type;
-
+
for _, child in ipairs(error_tag.tags) do
if child.attr.xmlns == xmlns_stanzas then
if not text and child.name == "text" then
@@ -333,7 +333,7 @@ function deserialize(stanza)
stanza.tags = tags;
end
end
-
+
return stanza;
end
@@ -390,7 +390,7 @@ if do_pretty_printing then
local style_attrv = getstyle("red");
local style_tagname = getstyle("red");
local style_punc = getstyle("magenta");
-
+
local attr_format = " "..getstring(style_attrk, "%s")..getstring(style_punc, "=")..getstring(style_attrv, "'%s'");
local top_tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">");
--local tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">").."%s"..getstring(style_punc, "</")..getstring(style_tagname, "%s")..getstring(style_punc, ">");
@@ -411,7 +411,7 @@ if do_pretty_printing then
end
return s_format(tag_format, t.name, attr_string, children_text, t.name);
end
-
+
function stanza_mt.pretty_top_tag(t)
local attr_string = "";
if t.attr then
diff --git a/util/termcolours.lua b/util/termcolours.lua
index 6ef3b689..ef978364 100644
--- a/util/termcolours.lua
+++ b/util/termcolours.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
diff --git a/util/timer.lua b/util/timer.lua
index af1e57b6..0e10e144 100644
--- a/util/timer.lua
+++ b/util/timer.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -42,7 +42,7 @@ if not server.event then
end
new_data = {};
end
-
+
local next_time = math_huge;
for i, d in pairs(data) do
local t, callback = d[1], d[2];
diff --git a/util/uuid.lua b/util/uuid.lua
index 796c8ee4..fc487c72 100644
--- a/util/uuid.lua
+++ b/util/uuid.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
diff --git a/util/x509.lua b/util/x509.lua
index 19d4ec6d..857f02a4 100644
--- a/util/x509.lua
+++ b/util/x509.lua
@@ -161,7 +161,9 @@ function verify_identity(host, service, cert)
if sans[oid_xmppaddr] then
had_supported_altnames = true
- if compare_xmppaddr(host, sans[oid_xmppaddr]) then return true end
+ if service == "_xmpp-client" or service == "_xmpp-server" then
+ if compare_xmppaddr(host, sans[oid_xmppaddr]) then return true end
+ end
end
if sans[oid_dnssrv] then
diff --git a/util/xml.lua b/util/xml.lua
index 076490fa..6dbed65d 100644
--- a/util/xml.lua
+++ b/util/xml.lua
@@ -26,8 +26,8 @@ local parse_xml = (function()
attr[i] = nil;
local ns, nm = k:match(ns_pattern);
if nm ~= "" then
- ns = ns_prefixes[ns];
- if ns then
+ ns = ns_prefixes[ns];
+ if ns then
attr[ns..":"..nm] = attr[k];
attr[k] = nil;
end
diff --git a/util/xmppstream.lua b/util/xmppstream.lua
index 4909678c..550170c9 100644
--- a/util/xmppstream.lua
+++ b/util/xmppstream.lua
@@ -1,7 +1,7 @@
-- Prosody IM
-- Copyright (C) 2008-2010 Matthew Wild
-- Copyright (C) 2008-2010 Waqas Hussain
---
+--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
@@ -42,21 +42,21 @@ _M.ns_pattern = ns_pattern;
function new_sax_handlers(session, stream_callbacks)
local xml_handlers = {};
-
+
local cb_streamopened = stream_callbacks.streamopened;
local cb_streamclosed = stream_callbacks.streamclosed;
local cb_error = stream_callbacks.error or function(session, e, stanza) error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); end;
local cb_handlestanza = stream_callbacks.handlestanza;
-
+
local stream_ns = stream_callbacks.stream_ns or xmlns_streams;
local stream_tag = stream_callbacks.stream_tag or "stream";
if stream_ns ~= "" then
stream_tag = stream_ns..ns_separator..stream_tag;
end
local stream_error_tag = stream_ns..ns_separator..(stream_callbacks.error_tag or "error");
-
+
local stream_default_ns = stream_callbacks.default_ns;
-
+
local stack = {};
local chardata, stanza = {};
local non_streamns_depth = 0;
@@ -75,7 +75,7 @@ function new_sax_handlers(session, stream_callbacks)
attr.xmlns = curr_ns;
non_streamns_depth = non_streamns_depth + 1;
end
-
+
for i=1,#attr do
local k = attr[i];
attr[i] = nil;
@@ -85,7 +85,7 @@ function new_sax_handlers(session, stream_callbacks)
attr[k] = nil;
end
end
-
+
if not stanza then --if we are not currently inside a stanza
if session.notopen then
if tagname == stream_tag then
@@ -102,7 +102,7 @@ function new_sax_handlers(session, stream_callbacks)
if curr_ns == "jabber:client" and name ~= "iq" and name ~= "presence" and name ~= "message" then
cb_error(session, "invalid-top-level-element");
end
-
+
stanza = setmetatable({ name = name, attr = attr, tags = {} }, stanza_mt);
else -- we are inside a stanza, so add a tag
t_insert(stack, stanza);
@@ -151,22 +151,22 @@ function new_sax_handlers(session, stream_callbacks)
error("Failed to abort parsing");
end
end
-
+
if lxp_supports_doctype then
xml_handlers.StartDoctypeDecl = restricted_handler;
end
xml_handlers.Comment = restricted_handler;
xml_handlers.ProcessingInstruction = restricted_handler;
-
+
local function reset()
stanza, chardata = nil, {};
stack = {};
end
-
+
local function set_session(stream, new_session)
session = new_session;
end
-
+
return xml_handlers, { reset = reset, set_session = set_session };
end