aboutsummaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/adhoc.lua31
-rw-r--r--util/datamanager.lua32
-rw-r--r--util/helpers.lua8
-rw-r--r--util/hmac.lua64
-rw-r--r--util/http.lua55
-rw-r--r--util/httpstream.lua134
-rw-r--r--util/ip.lua23
-rw-r--r--util/iterators.lua5
-rw-r--r--util/json.lua39
-rw-r--r--util/openssl.lua32
-rw-r--r--util/prosodyctl.lua10
-rw-r--r--util/rfc6724.lua (renamed from util/rfc3484.lua)15
-rw-r--r--util/sasl/scram.lua17
-rw-r--r--util/sql.lua340
-rw-r--r--util/stanza.lua34
15 files changed, 594 insertions, 245 deletions
diff --git a/util/adhoc.lua b/util/adhoc.lua
new file mode 100644
index 00000000..671e85cf
--- /dev/null
+++ b/util/adhoc.lua
@@ -0,0 +1,31 @@
+local function new_simple_form(form, result_handler)
+ return function(self, data, state)
+ if state then
+ if data.action == "cancel" then
+ return { status = "canceled" };
+ end
+ local fields, err = form:data(data.form);
+ return result_handler(fields, err, data);
+ else
+ return { status = "executing", actions = {"next", "complete", default = "complete"}, form = form }, "executing";
+ end
+ end
+end
+
+local function new_initial_data_form(form, initial_data, result_handler)
+ return function(self, data, state)
+ if state then
+ if data.action == "cancel" then
+ return { status = "canceled" };
+ end
+ local fields, err = form:data(data.form);
+ return result_handler(fields, err, data);
+ else
+ return { status = "executing", actions = {"next", "complete", default = "complete"},
+ form = { layout = form, values = initial_data() } }, "executing";
+ end
+ end
+end
+
+return { new_simple_form = new_simple_form,
+ new_initial_data_form = new_initial_data_form };
diff --git a/util/datamanager.lua b/util/datamanager.lua
index 383e738f..4a4d62b3 100644
--- a/util/datamanager.lua
+++ b/util/datamanager.lua
@@ -187,17 +187,25 @@ function store(username, host, datastore, data)
-- save the datastore
local d = "return " .. serialize(data) .. ";\n";
- local ok, msg = atomic_store(getpath(username, host, datastore, nil, true), d);
- if not ok then
- log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil");
- return nil, "Error saving to storage";
- end
- if next(data) == nil then -- try to delete empty datastore
- log("debug", "Removing empty %s datastore for user %s@%s", datastore, username or "nil", host or "nil");
- os_remove(getpath(username, host, datastore));
- end
- -- we write data even when we are deleting because lua doesn't have a
- -- platform independent way of checking for non-exisitng files
+ local mkdir_cache_cleared;
+ repeat
+ local ok, msg = atomic_store(getpath(username, host, datastore, nil, true), d);
+ if not ok then
+ if not mkdir_cache_cleared then -- We may need to recreate a removed directory
+ _mkdir = {};
+ mkdir_cache_cleared = true;
+ else
+ log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil");
+ return nil, "Error saving to storage";
+ end
+ end
+ if next(data) == nil then -- try to delete empty datastore
+ log("debug", "Removing empty %s datastore for user %s@%s", datastore, username or "nil", host or "nil");
+ os_remove(getpath(username, host, datastore));
+ end
+ -- we write data even when we are deleting because lua doesn't have a
+ -- platform independent way of checking for non-exisitng files
+ until ok;
return true;
end
@@ -354,4 +362,6 @@ function purge(username, host)
return #errs == 0, t_concat(errs, ", ");
end
+_M.path_decode = decode;
+_M.path_encode = encode;
return _M;
diff --git a/util/helpers.lua b/util/helpers.lua
index 6103a319..08b86a7c 100644
--- a/util/helpers.lua
+++ b/util/helpers.lua
@@ -14,6 +14,14 @@ module("helpers", package.seeall);
local log = require "util.logger".init("util.debug");
+function log_host_events(host)
+ return log_events(prosody.hosts[host].events, host);
+end
+
+function revert_log_host_events(host)
+ return revert_log_events(prosody.hosts[host].events);
+end
+
function log_events(events, name, logger)
local f = events.fire_event;
if not f then
diff --git a/util/hmac.lua b/util/hmac.lua
index 6df6986e..51211c7a 100644
--- a/util/hmac.lua
+++ b/util/hmac.lua
@@ -6,64 +6,10 @@
-- COPYING file in the source package for more information.
--
-local hashes = require "util.hashes"
-
-local s_char = string.char;
-local s_gsub = string.gsub;
-local s_rep = string.rep;
-
-module "hmac"
-
-local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;};
-local function xor(x, y)
- local lowx, lowy = x % 16, y % 16;
- local hix, hiy = (x - lowx) / 16, (y - lowy) / 16;
- local lowr, hir = xor_map[lowx * 16 + lowy + 1], xor_map[hix * 16 + hiy + 1];
- local r = hir * 16 + lowr;
- return r;
-end
-local opadc, ipadc = s_char(0x5c), s_char(0x36);
-local ipad_map = {};
-local opad_map = {};
-for i=0,255 do
- ipad_map[s_char(i)] = s_char(xor(0x36, i));
- opad_map[s_char(i)] = s_char(xor(0x5c, i));
-end
-
---[[
-key
- the key to use in the hash
-message
- the message to hash
-hash
- the hash function
-blocksize
- the blocksize for the hash function in bytes
-hex
- return raw hash or hexadecimal string
---]]
-function hmac(key, message, hash, blocksize, hex)
- if #key > blocksize then
- key = hash(key)
- end
+-- COMPAT: Only for external pre-0.9 modules
- local padding = blocksize - #key;
- local ipad = s_gsub(key, ".", ipad_map)..s_rep(ipadc, padding);
- local opad = s_gsub(key, ".", opad_map)..s_rep(opadc, padding);
-
- return hash(opad..hash(ipad..message), hex)
-end
-
-function md5(key, message, hex)
- return hmac(key, message, hashes.md5, 64, hex)
-end
-
-function sha1(key, message, hex)
- return hmac(key, message, hashes.sha1, 64, hex)
-end
-
-function sha256(key, message, hex)
- return hmac(key, message, hashes.sha256, 64, hex)
-end
+local hashes = require "util.hashes"
-return _M
+return { md5 = hashes.hmac_md5,
+ sha1 = hashes.hmac_sha1,
+ sha256 = hashes.hmac_sha256 };
diff --git a/util/http.lua b/util/http.lua
index 5b49d1d0..f7259920 100644
--- a/util/http.lua
+++ b/util/http.lua
@@ -5,11 +5,60 @@
-- COPYING file in the source package for more information.
--
-local http = {};
+local format, char = string.format, string.char;
+local pairs, ipairs, tonumber = pairs, ipairs, tonumber;
+local t_insert, t_concat = table.insert, table.concat;
-function http.contains_token(field, token)
+local function urlencode(s)
+ return s and (s:gsub("[^a-zA-Z0-9.~_-]", function (c) return format("%%%02x", c:byte()); end));
+end
+local function urldecode(s)
+ return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end));
+end
+
+local function _formencodepart(s)
+ return s and (s:gsub("%W", function (c)
+ if c ~= " " then
+ return format("%%%02x", c:byte());
+ else
+ return "+";
+ end
+ end));
+end
+
+local function formencode(form)
+ local result = {};
+ if form[1] then -- Array of ordered { name, value }
+ for _, field in ipairs(form) do
+ t_insert(result, _formencodepart(field.name).."=".._formencodepart(field.value));
+ end
+ else -- Unordered map of name -> value
+ for name, value in pairs(form) do
+ t_insert(result, _formencodepart(name).."=".._formencodepart(value));
+ end
+ end
+ return t_concat(result, "&");
+end
+
+local function formdecode(s)
+ if not s:match("=") then return urldecode(s); end
+ local r = {};
+ for k, v in s:gmatch("([^=&]*)=([^&]*)") do
+ k, v = k:gsub("%+", "%%20"), v:gsub("%+", "%%20");
+ k, v = urldecode(k), urldecode(v);
+ t_insert(r, { name = k, value = v });
+ r[k] = v;
+ end
+ return r;
+end
+
+local function contains_token(field, token)
field = ","..field:gsub("[ \t]", ""):lower()..",";
return field:find(","..token:lower()..",", 1, true) ~= nil;
end
-return http;
+return {
+ urlencode = urlencode, urldecode = urldecode;
+ formencode = formencode, formdecode = formdecode;
+ contains_token = contains_token;
+};
diff --git a/util/httpstream.lua b/util/httpstream.lua
deleted file mode 100644
index 190b3ed6..00000000
--- a/util/httpstream.lua
+++ /dev/null
@@ -1,134 +0,0 @@
-
-local coroutine = coroutine;
-local tonumber = tonumber;
-
-local deadroutine = coroutine.create(function() end);
-coroutine.resume(deadroutine);
-
-module("httpstream")
-
-local function parser(success_cb, parser_type, options_cb)
- local data = coroutine.yield();
- local function readline()
- local pos = data:find("\r\n", nil, true);
- while not pos do
- data = data..coroutine.yield();
- pos = data:find("\r\n", nil, true);
- end
- local r = data:sub(1, pos-1);
- data = data:sub(pos+2);
- return r;
- end
- local function readlength(n)
- while #data < n do
- data = data..coroutine.yield();
- end
- local r = data:sub(1, n);
- data = data:sub(n + 1);
- return r;
- end
- local function readheaders()
- local headers = {}; -- read headers
- while true do
- local line = readline();
- if line == "" then break; end -- headers done
- local key, val = line:match("^([^%s:]+): *(.*)$");
- if not key then coroutine.yield("invalid-header-line"); end -- TODO handle multi-line and invalid headers
- key = key:lower();
- headers[key] = headers[key] and headers[key]..","..val or val;
- end
- return headers;
- end
-
- if not parser_type or parser_type == "server" then
- while true do
- -- read status line
- local status_line = readline();
- local method, path, httpversion = status_line:match("^(%S+)%s+(%S+)%s+HTTP/(%S+)$");
- if not method then coroutine.yield("invalid-status-line"); end
- path = path:gsub("^//+", "/"); -- TODO parse url more
- local headers = readheaders();
-
- -- read body
- local len = tonumber(headers["content-length"]);
- len = len or 0; -- TODO check for invalid len
- local body = readlength(len);
-
- success_cb({
- method = method;
- path = path;
- httpversion = httpversion;
- headers = headers;
- body = body;
- });
- end
- elseif parser_type == "client" then
- while true do
- -- read status line
- local status_line = readline();
- local httpversion, status_code, reason_phrase = status_line:match("^HTTP/(%S+)%s+(%d%d%d)%s+(.*)$");
- status_code = tonumber(status_code);
- if not status_code then coroutine.yield("invalid-status-line"); end
- local headers = readheaders();
-
- -- read body
- local have_body = not
- ( (options_cb and options_cb().method == "HEAD")
- or (status_code == 204 or status_code == 304 or status_code == 301)
- or (status_code >= 100 and status_code < 200) );
-
- local body;
- if have_body then
- local len = tonumber(headers["content-length"]);
- if headers["transfer-encoding"] == "chunked" then
- body = "";
- while true do
- local chunk_size = readline():match("^%x+");
- if not chunk_size then coroutine.yield("invalid-chunk-size"); end
- chunk_size = tonumber(chunk_size, 16)
- if chunk_size == 0 then break; end
- body = body..readlength(chunk_size);
- if readline() ~= "" then coroutine.yield("invalid-chunk-ending"); end
- end
- local trailers = readheaders();
- elseif len then -- TODO check for invalid len
- body = readlength(len);
- else -- read to end
- repeat
- local newdata = coroutine.yield();
- data = data..newdata;
- until newdata == "";
- body, data = data, "";
- end
- end
-
- success_cb({
- code = status_code;
- httpversion = httpversion;
- headers = headers;
- body = body;
- });
- end
- else coroutine.yield("unknown-parser-type"); end
-end
-
-function new(success_cb, error_cb, parser_type, options_cb)
- local co = coroutine.create(parser);
- coroutine.resume(co, success_cb, parser_type, options_cb)
- return {
- feed = function(self, data)
- if not data then
- if parser_type == "client" then coroutine.resume(co, ""); end
- co = deadroutine;
- return error_cb();
- end
- local success, result = coroutine.resume(co, data);
- if result then
- co = deadroutine;
- return error_cb(result);
- end
- end;
- };
-end
-
-return _M;
diff --git a/util/ip.lua b/util/ip.lua
index 2f09c034..de287b16 100644
--- a/util/ip.lua
+++ b/util/ip.lua
@@ -64,9 +64,6 @@ local function v4scope(ip)
-- Link-local unicast:
elseif fields[1] == 169 and fields[2] == 254 then
return 0x2;
- -- Site-local unicast:
- elseif (fields[1] == 10) or (fields[1] == 192 and fields[2] == 168) or (fields[1] == 172 and (fields[2] >= 16 and fields[2] < 32)) then
- return 0x5;
-- Global unicast:
else
return 0xE;
@@ -97,6 +94,14 @@ local function label(ip)
return 0;
elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then
return 2;
+ elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then
+ return 5;
+ elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then
+ return 13;
+ elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then
+ return 11;
+ elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then
+ return 12;
elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then
return 3;
elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then
@@ -111,10 +116,18 @@ local function precedence(ip)
return 50;
elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then
return 30;
+ elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then
+ return 5;
+ elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then
+ return 3;
+ elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then
+ return 1;
+ elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then
+ return 1;
elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then
- return 20;
+ return 1;
elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then
- return 10;
+ return 35;
else
return 40;
end
diff --git a/util/iterators.lua b/util/iterators.lua
index fb89f4a5..1f6aacb8 100644
--- a/util/iterators.lua
+++ b/util/iterators.lua
@@ -122,6 +122,11 @@ function it.tail(n, f, s, var)
--return reverse(head(n, reverse(f, s, var)));
end
+local function _ripairs_iter(t, key) if key > 1 then return key-1, t[key-1]; end end
+function it.ripairs(t)
+ return _ripairs_iter, t, #t+1;
+end
+
local function _range_iter(max, curr) if curr < max then return curr + 1; end end
function it.range(x, y)
if not y then x, y = 1, x; end -- Default to 1..x if y not given
diff --git a/util/json.lua b/util/json.lua
index efc602f0..9c2dd2c6 100644
--- a/util/json.lua
+++ b/util/json.lua
@@ -1,3 +1,10 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
local type = type;
local t_insert, t_concat, t_remove, t_sort = table.insert, table.concat, table.remove, table.sort;
@@ -9,6 +16,9 @@ local error = error;
local newproxy, getmetatable = newproxy, getmetatable;
local print = print;
+local has_array, array = pcall(require, "util.array");
+local array_mt = hasarray and getmetatable(array()) or {};
+
--module("json")
local json = {};
@@ -29,6 +39,19 @@ for i=0,31 do
if not escapes[ch] then escapes[ch] = ("\\u%.4X"):format(i); end
end
+local function codepoint_to_utf8(code)
+ if code < 0x80 then return s_char(code); end
+ local bits0_6 = code % 64;
+ if code < 0x800 then
+ local bits6_5 = (code - bits0_6) / 64;
+ return s_char(0x80 + 0x40 + bits6_5, 0x80 + bits0_6);
+ end
+ local bits0_12 = code % 4096;
+ local bits6_6 = (bits0_12 - bits0_6) / 64;
+ local bits12_4 = (code - bits0_12) / 4096;
+ return s_char(0x80 + 0x40 + 0x20 + bits12_4, 0x80 + bits6_6, 0x80 + bits0_6);
+end
+
local valid_types = {
number = true,
string = true,
@@ -130,7 +153,12 @@ function simplesave(o, buffer)
elseif t == "string" then
stringsave(o, buffer);
elseif t == "table" then
- tablesave(o, buffer);
+ local mt = getmetatable(o);
+ if mt == array_mt then
+ arraysave(o, buffer);
+ else
+ tablesave(o, buffer);
+ end
elseif t == "boolean" then
t_insert(buffer, (o and "true" or "false"));
else
@@ -148,6 +176,11 @@ function json.encode_ordered(obj)
simplesave(obj, t);
return t_concat(t);
end
+function json.encode_array(obj)
+ local t = {};
+ arraysave(obj, t);
+ return t_concat(t);
+end
-----------------------------------
@@ -197,7 +230,7 @@ function json.decode(json)
local readvalue;
local function readarray()
- local t = {};
+ local t = setmetatable({}, array_mt);
next(); -- skip '['
skipstuff();
if ch == "]" then next(); return t; end
@@ -244,7 +277,7 @@ function json.decode(json)
if not ch:match("[0-9a-fA-F]") then error("invalid unicode escape sequence in string"); end
seq = seq..ch;
end
- s = s..s.char(tonumber(seq, 16)); -- FIXME do proper utf-8
+ s = s..codepoint_to_utf8(tonumber(seq, 16));
next();
else error("invalid escape sequence in string"); end
end
diff --git a/util/openssl.lua b/util/openssl.lua
index b3dc2943..ef3fba96 100644
--- a/util/openssl.lua
+++ b/util/openssl.lua
@@ -23,11 +23,12 @@ function config.new()
prompt = "no",
},
distinguished_name = {
- commonName = "example.com",
countryName = "GB",
+ -- stateOrProvinceName = "",
localityName = "The Internet",
organizationName = "Your Organisation",
organizationalUnitName = "XMPP Department",
+ commonName = "example.com",
emailAddress = "xmpp@example.com",
},
v3_extensions = {
@@ -43,6 +44,17 @@ function config.new()
}, ssl_config_mt);
end
+local DN_order = {
+ "countryName";
+ "stateOrProvinceName";
+ "localityName";
+ "streetAddress";
+ "organizationName";
+ "organizationalUnitName";
+ "commonName";
+ "emailAddress";
+}
+_M._DN_order = DN_order;
function ssl_config:serialize()
local s = "";
for k, t in pairs(self) do
@@ -53,6 +65,14 @@ function ssl_config:serialize()
s = s .. s_format("%s.%d = %s\n", san, i -1, n[i]);
end
end
+ elseif k == "distinguished_name" then
+ for i=1,#DN_order do
+ local k = DN_order[i]
+ local v = t[k];
+ if v then
+ s = s .. ("%s = %s\n"):format(k, v);
+ end
+ end
else
for k, v in pairs(t) do
s = s .. ("%s = %s\n"):format(k, v);
@@ -100,13 +120,13 @@ function ssl_config:from_prosody(hosts, config, certhosts)
if name == certhost or name:sub(-1-#certhost) == "."..certhost then
found_matching_hosts = true;
self:add_dNSName(name);
- --print(name .. "#component_module: " .. (config.get(name, "core", "component_module") or "nil"));
- if config.get(name, "core", "component_module") == nil then
+ --print(name .. "#component_module: " .. (config.get(name, "component_module") or "nil"));
+ if config.get(name, "component_module") == nil then
self:add_sRVName(name, "xmpp-client");
end
- --print(name .. "#anonymous_login: " .. tostring(config.get(name, "core", "anonymous_login")));
- if not (config.get(name, "core", "anonymous_login") or
- config.get(name, "core", "authentication") == "anonymous") then
+ --print(name .. "#anonymous_login: " .. tostring(config.get(name, "anonymous_login")));
+ if not (config.get(name, "anonymous_login") or
+ config.get(name, "authentication") == "anonymous") then
self:add_sRVName(name, "xmpp-server");
end
self:add_xmppAddr(name);
diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua
index e38f85d4..b80a69f2 100644
--- a/util/prosodyctl.lua
+++ b/util/prosodyctl.lua
@@ -140,11 +140,12 @@ function adduser(params)
if not host_session then
return false, "no-such-host";
end
+
+ storagemanager.initialize_host(host);
local provider = host_session.users;
if not(provider) or provider.name == "null" then
usermanager.initialize_host(host);
end
- storagemanager.initialize_host(host);
local ok, errmsg = usermanager.create_user(user, password, host);
if not ok then
@@ -155,11 +156,12 @@ end
function user_exists(params)
local user, host, password = nodeprep(params.user), nameprep(params.host), params.password;
+
+ storagemanager.initialize_host(host);
local provider = prosody.hosts[host].users;
if not(provider) or provider.name == "null" then
usermanager.initialize_host(host);
end
- storagemanager.initialize_host(host);
return usermanager.user_exists(user, host);
end
@@ -182,12 +184,12 @@ function deluser(params)
end
function getpid()
- local pidfile = config.get("*", "core", "pidfile");
+ local pidfile = config.get("*", "pidfile");
if not pidfile then
return false, "no-pidfile";
end
- local modules_enabled = set.new(config.get("*", "core", "modules_enabled"));
+ local modules_enabled = set.new(config.get("*", "modules_enabled"));
if not modules_enabled:contains("posix") then
return false, "no-posix";
end
diff --git a/util/rfc3484.lua b/util/rfc6724.lua
index 5ee572a0..c8aec631 100644
--- a/util/rfc3484.lua
+++ b/util/rfc6724.lua
@@ -1,13 +1,22 @@
-- Prosody IM
--- Copyright (C) 2008-2011 Florian Zeitz
+-- Copyright (C) 2011-2013 Florian Zeitz
--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
-local commonPrefixLength = require"util.ip".commonPrefixLength
+-- This is used to sort destination addresses by preference
+-- during S2S connections.
+-- We can't hand this off to getaddrinfo, since it blocks
+
+local ip_commonPrefixLength = require"util.ip".commonPrefixLength
local new_ip = require"util.ip".new_ip;
+local function commonPrefixLength(ipA, ipB)
+ local len = ip_commonPrefixLength(ipA, ipB);
+ return len < 64 and len or 64;
+end
+
local function t_sort(t, comp)
for i = 1, (#t - 1) do
for j = (i + 1), #t do
@@ -56,7 +65,7 @@ local function source(dest, candidates)
return false;
end
- -- Rule 7: Prefer public addresses (over temporary ones)
+ -- Rule 7: Prefer temporary addresses (over public ones)
-- XXX: No way to determine this
-- Rule 8: Use longest matching prefix
if commonPrefixLength(ipA, dest) > commonPrefixLength(ipB, dest) then
diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua
index d0e8987c..cf2f0ede 100644
--- a/util/sasl/scram.lua
+++ b/util/sasl/scram.lua
@@ -15,8 +15,9 @@ local s_match = string.match;
local type = type
local string = string
local base64 = require "util.encodings".base64;
-local hmac_sha1 = require "util.hmac".sha1;
+local hmac_sha1 = require "util.hashes".hmac_sha1;
local sha1 = require "util.hashes".sha1;
+local Hi = require "util.hashes".scram_Hi_sha1;
local generate_uuid = require "util.uuid".generate;
local saslprep = require "util.encodings".stringprep.saslprep;
local nodeprep = require "util.encodings".stringprep.nodeprep;
@@ -65,18 +66,6 @@ local function binaryXOR( a, b )
return t_concat(result);
end
--- hash algorithm independent Hi(PBKDF2) implementation
-function Hi(hmac, str, salt, i)
- local Ust = hmac(str, salt.."\0\0\0\1");
- local res = Ust;
- for n=1,i-1 do
- local Und = hmac(str, Ust)
- res = binaryXOR(res, Und)
- Ust = Und
- end
- return res
-end
-
local function validate_username(username, _nodeprep)
-- check for forbidden char sequences
for eq in username:gmatch("=(.?.?)") do
@@ -110,7 +99,7 @@ function getAuthenticationDatabaseSHA1(password, salt, iteration_count)
if iteration_count < 4096 then
log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.")
end
- local salted_password = Hi(hmac_sha1, password, salt, iteration_count);
+ local salted_password = Hi(password, salt, iteration_count);
local stored_key = sha1(hmac_sha1(salted_password, "Client Key"))
local server_key = hmac_sha1(salted_password, "Server Key");
return true, stored_key, server_key
diff --git a/util/sql.lua b/util/sql.lua
new file mode 100644
index 00000000..f360d6d0
--- /dev/null
+++ b/util/sql.lua
@@ -0,0 +1,340 @@
+
+local setmetatable, getmetatable = setmetatable, getmetatable;
+local ipairs, unpack, select = ipairs, unpack, select;
+local tonumber, tostring = tonumber, tostring;
+local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback;
+local t_concat = table.concat;
+local s_char = string.char;
+local log = require "util.logger".init("sql");
+
+local DBI = require "DBI";
+-- This loads all available drivers while globals are unlocked
+-- LuaDBI should be fixed to not set globals.
+DBI.Drivers();
+local build_url = require "socket.url".build;
+
+module("sql")
+
+local column_mt = {};
+local table_mt = {};
+local query_mt = {};
+--local op_mt = {};
+local index_mt = {};
+
+function is_column(x) return getmetatable(x)==column_mt; end
+function is_index(x) return getmetatable(x)==index_mt; end
+function is_table(x) return getmetatable(x)==table_mt; end
+function is_query(x) return getmetatable(x)==query_mt; end
+--function is_op(x) return getmetatable(x)==op_mt; end
+--function expr(...) return setmetatable({...}, op_mt); end
+function Integer(n) return "Integer()" end
+function String(n) return "String()" end
+
+--[[local ops = {
+ __add = function(a, b) return "("..a.."+"..b..")" end;
+ __sub = function(a, b) return "("..a.."-"..b..")" end;
+ __mul = function(a, b) return "("..a.."*"..b..")" end;
+ __div = function(a, b) return "("..a.."/"..b..")" end;
+ __mod = function(a, b) return "("..a.."%"..b..")" end;
+ __pow = function(a, b) return "POW("..a..","..b..")" end;
+ __unm = function(a) return "NOT("..a..")" end;
+ __len = function(a) return "COUNT("..a..")" end;
+ __eq = function(a, b) return "("..a.."=="..b..")" end;
+ __lt = function(a, b) return "("..a.."<"..b..")" end;
+ __le = function(a, b) return "("..a.."<="..b..")" end;
+};
+
+local functions = {
+
+};
+
+local cmap = {
+ [Integer] = Integer();
+ [String] = String();
+};]]
+
+function Column(definition)
+ return setmetatable(definition, column_mt);
+end
+function Table(definition)
+ local c = {}
+ for i,col in ipairs(definition) do
+ if is_column(col) then
+ c[i], c[col.name] = col, col;
+ elseif is_index(col) then
+ col.table = definition.name;
+ end
+ end
+ return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
+end
+function Index(definition)
+ return setmetatable(definition, index_mt);
+end
+
+function table_mt:__tostring()
+ local s = { 'name="'..self.__table__.name..'"' }
+ for i,col in ipairs(self.__table__) do
+ s[#s+1] = tostring(col);
+ end
+ return 'Table{ '..t_concat(s, ", ")..' }'
+end
+table_mt.__index = {};
+function table_mt.__index:create(engine)
+ return engine:_create_table(self);
+end
+function table_mt:__call(...)
+ -- TODO
+end
+function column_mt:__tostring()
+ return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
+end
+function index_mt:__tostring()
+ local s = 'Index{ name="'..self.name..'"';
+ for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end
+ return s..' }';
+-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
+end
+--
+
+local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end
+local function parse_url(url)
+ local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)");
+ assert(scheme, "Invalid URL format");
+ local username, password, host, port;
+ local authpart, hostpart = secondpart:match("([^@]+)@([^@+])");
+ if not authpart then hostpart = secondpart; end
+ if authpart then
+ username, password = authpart:match("([^:]*):(.*)");
+ username = username or authpart;
+ password = password and urldecode(password);
+ end
+ if hostpart then
+ host, port = hostpart:match("([^:]*):(.*)");
+ host = host or hostpart;
+ port = port and assert(tonumber(port), "Invalid URL format");
+ end
+ return {
+ scheme = scheme:lower();
+ username = username; password = password;
+ host = host; port = port;
+ database = #database > 0 and database or nil;
+ };
+end
+
+--[[local session = {};
+
+function session.query(...)
+ local rets = {...};
+ local query = setmetatable({ __rets = rets, __filters }, query_mt);
+ return query;
+end
+--
+
+local function db2uri(params)
+ return build_url{
+ scheme = params.driver,
+ user = params.username,
+ password = params.password,
+ host = params.host,
+ port = params.port,
+ path = params.database,
+ };
+end]]
+
+local engine = {};
+function engine:connect()
+ if self.conn then return true; end
+
+ local params = self.params;
+ assert(params.driver, "no driver")
+ local dbh, err = DBI.Connect(
+ params.driver, params.database,
+ params.username, params.password,
+ params.host, params.port
+ );
+ if not dbh then return nil, err; end
+ dbh:autocommit(false); -- don't commit automatically
+ self.conn = dbh;
+ self.prepared = {};
+ return true;
+end
+function engine:execute(sql, ...)
+ local success, err = self:connect();
+ if not success then return success, err; end
+ local prepared = self.prepared;
+
+ local stmt = prepared[sql];
+ if not stmt then
+ local err;
+ stmt, err = self.conn:prepare(sql);
+ if not stmt then return stmt, err; end
+ prepared[sql] = stmt;
+ end
+
+ local success, err = stmt:execute(...);
+ if not success then return success, err; end
+ return stmt;
+end
+
+local result_mt = { __index = {
+ affected = function(self) return self.__affected; end;
+ rowcount = function(self) return self.__rowcount; end;
+} };
+
+function engine:execute_query(sql, ...)
+ if self.params.driver == "PostgreSQL" then
+ sql = sql:gsub("`", "\"");
+ end
+ local stmt = assert(self.conn:prepare(sql));
+ assert(stmt:execute(...));
+ return stmt:rows();
+end
+function engine:execute_update(sql, ...)
+ if self.params.driver == "PostgreSQL" then
+ sql = sql:gsub("`", "\"");
+ end
+ local prepared = self.prepared;
+ local stmt = prepared[sql];
+ if not stmt then
+ stmt = assert(self.conn:prepare(sql));
+ prepared[sql] = stmt;
+ end
+ assert(stmt:execute(...));
+ return setmetatable({ __affected = stmt:affected(), __rowcount = stmt:rowcount() }, result_mt);
+end
+engine.insert = engine.execute_update;
+engine.select = engine.execute_query;
+engine.delete = engine.execute_update;
+engine.update = engine.execute_update;
+function engine:_transaction(func, ...)
+ if not self.conn then
+ local a,b = self:connect();
+ if not a then return a,b; end
+ end
+ --assert(not self.__transaction, "Recursive transactions not allowed");
+ local args, n_args = {...}, select("#", ...);
+ local function f() return func(unpack(args, 1, n_args)); end
+ self.__transaction = true;
+ local success, a, b, c = xpcall(f, debug_traceback);
+ self.__transaction = nil;
+ if success then
+ log("debug", "SQL transaction success [%s]", tostring(func));
+ local ok, err = self.conn:commit();
+ if not ok then return ok, err; end -- commit failed
+ return success, a, b, c;
+ else
+ log("debug", "SQL transaction failure [%s]: %s", tostring(func), a);
+ if self.conn then self.conn:rollback(); end
+ return success, a;
+ end
+end
+function engine:transaction(...)
+ local a,b = self:_transaction(...);
+ if not a then
+ local conn = self.conn;
+ if not conn or not conn:ping() then
+ self.conn = nil;
+ a,b = self:_transaction(...);
+ end
+ end
+ return a,b;
+end
+function engine:_create_index(index)
+ local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` (";
+ for i=1,#index do
+ sql = sql.."`"..index[i].."`";
+ if i ~= #index then sql = sql..", "; end
+ end
+ sql = sql..");"
+ if self.params.driver == "PostgreSQL" then
+ sql = sql:gsub("`", "\"");
+ elseif self.params.driver == "MySQL" then
+ sql = sql:gsub("`([,)])", "`(20)%1");
+ end
+ --print(sql);
+ return self:execute(sql);
+end
+function engine:_create_table(table)
+ local sql = "CREATE TABLE `"..table.name.."` (";
+ for i,col in ipairs(table.c) do
+ sql = sql.."`"..col.name.."` "..col.type;
+ if col.nullable == false then sql = sql.." NOT NULL"; end
+ if i ~= #table.c then sql = sql..", "; end
+ end
+ sql = sql.. ");"
+ if self.params.driver == "PostgreSQL" then
+ sql = sql:gsub("`", "\"");
+ end
+ local success,err = self:execute(sql);
+ if not success then return success,err; end
+ for i,v in ipairs(table.__table__) do
+ if is_index(v) then
+ self:_create_index(v);
+ end
+ end
+ return success;
+end
+local engine_mt = { __index = engine };
+
+local function db2uri(params)
+ return build_url{
+ scheme = params.driver,
+ user = params.username,
+ password = params.password,
+ host = params.host,
+ port = params.port,
+ path = params.database,
+ };
+end
+local engine_cache = {}; -- TODO make weak valued
+function create_engine(self, params)
+ local url = db2uri(params);
+ if not engine_cache[url] then
+ local engine = setmetatable({ url = url, params = params }, engine_mt);
+ engine_cache[url] = engine;
+ end
+ return engine_cache[url];
+end
+
+
+--[[Users = Table {
+ name="users";
+ Column { name="user_id", type=String(), primary_key=true };
+};
+print(Users)
+print(Users.c.user_id)]]
+
+--local engine = create_engine('postgresql://scott:tiger@localhost:5432/mydatabase');
+--[[local engine = create_engine{ driver = "SQLite3", database = "./alchemy.sqlite" };
+
+local i = 0;
+for row in assert(engine:execute("select * from sqlite_master")):rows(true) do
+ i = i+1;
+ print(i);
+ for k,v in pairs(row) do
+ print("",k,v);
+ end
+end
+print("---")
+
+Prosody = Table {
+ name="prosody";
+ Column { name="host", type="TEXT", nullable=false };
+ Column { name="user", type="TEXT", nullable=false };
+ Column { name="store", type="TEXT", nullable=false };
+ Column { name="key", type="TEXT", nullable=false };
+ Column { name="type", type="TEXT", nullable=false };
+ Column { name="value", type="TEXT", nullable=false };
+ Index { name="prosody_index", "host", "user", "store", "key" };
+};
+--print(Prosody);
+assert(engine:transaction(function()
+ assert(Prosody:create(engine));
+end));
+
+for row in assert(engine:execute("select user from prosody")):rows(true) do
+ print("username:", row['username'])
+end
+--result.close();]]
+
+return _M;
diff --git a/util/stanza.lua b/util/stanza.lua
index a0ab2a5a..7c214210 100644
--- a/util/stanza.lua
+++ b/util/stanza.lua
@@ -18,6 +18,7 @@ local pairs = pairs;
local ipairs = ipairs;
local type = type;
local s_gsub = string.gsub;
+local s_sub = string.sub;
local s_find = string.find;
local os = os;
@@ -153,7 +154,7 @@ function stanza_mt:maptags(callback)
local n_children, n_tags = #self, #tags;
local i = 1;
- while curr_tag <= n_tags do
+ while curr_tag <= n_tags and n_tags > 0 do
if self[i] == tags[curr_tag] then
local ret = callback(self[i]);
if ret == nil then
@@ -161,17 +162,44 @@ function stanza_mt:maptags(callback)
t_remove(tags, curr_tag);
n_children = n_children - 1;
n_tags = n_tags - 1;
+ i = i - 1;
+ curr_tag = curr_tag - 1;
else
self[i] = ret;
- tags[i] = ret;
+ tags[curr_tag] = ret;
end
- i = i + 1;
curr_tag = curr_tag + 1;
end
+ i = i + 1;
end
return self;
end
+function stanza_mt:find(path)
+ local pos = 1;
+ local len = #path + 1;
+
+ repeat
+ local xmlns, name, text;
+ local char = s_sub(path, pos, pos);
+ if char == "@" then
+ return self.attr[s_sub(path, pos + 1)];
+ elseif char == "{" then
+ xmlns, pos = s_match(path, "^([^}]+)}()", pos + 1);
+ end
+ name, text, pos = s_match(path, "^([^@/#]*)([/#]?)()", pos);
+ name = name ~= "" and name or nil;
+ if pos == len then
+ if text == "#" then
+ return self:get_child_text(name, xmlns);
+ end
+ return self:get_child(name, xmlns);
+ end
+ self = self:get_child(name, xmlns);
+ until not self
+end
+
+
local xml_escape
do
local escape_table = { ["'"] = "&apos;", ["\""] = "&quot;", ["<"] = "&lt;", [">"] = "&gt;", ["&"] = "&amp;" };