aboutsummaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
Diffstat (limited to 'util')
-rw-r--r--util/events.lua6
-rw-r--r--util/ip.lua52
-rw-r--r--util/iterators.lua36
-rw-r--r--util/sasl.lua1
-rw-r--r--util/sasl/external.lua25
-rw-r--r--util/sql.lua8
6 files changed, 111 insertions, 17 deletions
diff --git a/util/events.lua b/util/events.lua
index 412acccd..4ace026f 100644
--- a/util/events.lua
+++ b/util/events.lua
@@ -60,11 +60,11 @@ function new()
remove_handler(event, handler);
end
end;
- local function fire_event(event, ...)
- local h = handlers[event];
+ local function fire_event(event_name, event_data)
+ local h = handlers[event_name];
if h then
for i=1,#h do
- local ret = h[i](...);
+ local ret = h[i](event_data);
if ret ~= nil then return ret; end
end
end
diff --git a/util/ip.lua b/util/ip.lua
index 856bf034..62649c9b 100644
--- a/util/ip.lua
+++ b/util/ip.lua
@@ -12,7 +12,17 @@ local ip_mt = { __index = function (ip, key) return (ip_methods[key])(ip); end,
local hex2bits = { ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111", ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011", ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111" };
local function new_ip(ipStr, proto)
- if proto ~= "IPv4" and proto ~= "IPv6" then
+ if not proto then
+ local sep = ipStr:match("^%x+(.)");
+ if sep == ":" or (not(sep) and ipStr:sub(1,1) == ":") then
+ proto = "IPv6"
+ elseif sep == "." then
+ proto = "IPv4"
+ end
+ if not proto then
+ return nil, "invalid address";
+ end
+ elseif proto ~= "IPv4" and proto ~= "IPv6" then
return nil, "invalid protocol";
end
if proto == "IPv6" and ipStr:find('.', 1, true) then
@@ -192,5 +202,43 @@ function ip_methods:scope()
return value;
end
+function ip_methods:private()
+ local private = self.scope ~= 0xE;
+ if not private and self.proto == "IPv4" then
+ local ip = self.addr;
+ local fields = {};
+ ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end);
+ if fields[1] == 127 or fields[1] == 10 or (fields[1] == 192 and fields[2] == 168)
+ or (fields[1] == 172 and (fields[2] >= 16 or fields[2] <= 32)) then
+ private = true;
+ end
+ end
+ self.private = private;
+ return private;
+end
+
+local function parse_cidr(cidr)
+ local bits;
+ local ip_len = cidr:find("/", 1, true);
+ if ip_len then
+ bits = tonumber(cidr:sub(ip_len+1, -1));
+ cidr = cidr:sub(1, ip_len-1);
+ end
+ return new_ip(cidr), bits;
+end
+
+local function match(ipA, ipB, bits)
+ local common_bits = commonPrefixLength(ipA, ipB);
+ if not bits then
+ return ipA == ipB;
+ end
+ if bits and ipB.proto == "IPv4" then
+ common_bits = common_bits - 96; -- v6 mapped addresses always share these bits
+ end
+ return common_bits >= bits;
+end
+
return {new_ip = new_ip,
- commonPrefixLength = commonPrefixLength};
+ commonPrefixLength = commonPrefixLength,
+ parse_cidr = parse_cidr,
+ match=match};
diff --git a/util/iterators.lua b/util/iterators.lua
index 1f6aacb8..4b429163 100644
--- a/util/iterators.lua
+++ b/util/iterators.lua
@@ -10,6 +10,10 @@
local it = {};
+local t_insert = table.insert;
+local select, unpack, next = select, unpack, next;
+local function pack(...) return { n = select("#", ...), ... }; end
+
-- Reverse an iterator
function it.reverse(f, s, var)
local results = {};
@@ -19,7 +23,7 @@ function it.reverse(f, s, var)
local ret = { f(s, var) };
var = ret[1];
if var == nil then break; end
- table.insert(results, 1, ret);
+ t_insert(results, 1, ret);
end
-- Then return our reverse one
@@ -55,12 +59,12 @@ function it.unique(f, s, var)
return function ()
while true do
- local ret = { f(s, var) };
+ local ret = pack(f(s, var));
var = ret[1];
if var == nil then break; end
if not set[var] then
set[var] = true;
- return var;
+ return unpack(ret, 1, ret.n);
end
end
end;
@@ -71,8 +75,7 @@ function it.count(f, s, var)
local x = 0;
while true do
- local ret = { f(s, var) };
- var = ret[1];
+ var = f(s, var);
if var == nil then break; end
x = x + 1;
end
@@ -104,7 +107,7 @@ end
function it.tail(n, f, s, var)
local results, count = {}, 0;
while true do
- local ret = { f(s, var) };
+ local ret = pack(f(s, var));
var = ret[1];
if var == nil then break; end
results[(count%n)+1] = ret;
@@ -117,9 +120,24 @@ function it.tail(n, f, s, var)
return function ()
pos = pos + 1;
if pos > n then return nil; end
- return unpack(results[((count-1+pos)%n)+1]);
+ local ret = results[((count-1+pos)%n)+1];
+ return unpack(ret, 1, ret.n);
end
- --return reverse(head(n, reverse(f, s, var)));
+ --return reverse(head(n, reverse(f, s, var))); -- !
+end
+
+function it.filter(filter, f, s, var)
+ if type(filter) ~= "function" then
+ local filter_value = filter;
+ function filter(x) return x ~= filter_value; end
+ end
+ return function (s, var)
+ local ret;
+ repeat ret = pack(f(s, var));
+ var = ret[1];
+ until var == nil or filter(unpack(ret, 1, ret.n));
+ return unpack(ret, 1, ret.n);
+ end, s, var;
end
local function _ripairs_iter(t, key) if key > 1 then return key-1, t[key-1]; end end
@@ -139,7 +157,7 @@ function it.to_array(f, s, var)
while true do
var = f(s, var);
if var == nil then break; end
- table.insert(t, var);
+ t_insert(t, var);
end
return t;
end
diff --git a/util/sasl.lua b/util/sasl.lua
index afb3861b..d0da9435 100644
--- a/util/sasl.lua
+++ b/util/sasl.lua
@@ -92,5 +92,6 @@ require "util.sasl.plain" .init(registerMechanism);
require "util.sasl.digest-md5".init(registerMechanism);
require "util.sasl.anonymous" .init(registerMechanism);
require "util.sasl.scram" .init(registerMechanism);
+require "util.sasl.external" .init(registerMechanism);
return _M;
diff --git a/util/sasl/external.lua b/util/sasl/external.lua
new file mode 100644
index 00000000..4c5c4343
--- /dev/null
+++ b/util/sasl/external.lua
@@ -0,0 +1,25 @@
+local saslprep = require "util.encodings".stringprep.saslprep;
+
+module "sasl.external"
+
+local function external(self, message)
+ message = saslprep(message);
+ local state
+ self.username, state = self.profile.external(message);
+
+ if state == false then
+ return "failure", "account-disabled";
+ elseif state == nil then
+ return "failure", "not-authorized";
+ elseif state == "expired" then
+ return "false", "credentials-expired";
+ end
+
+ return "success";
+end
+
+function init(registerMechanism)
+ registerMechanism("EXTERNAL", {"external"}, external);
+end
+
+return _M;
diff --git a/util/sql.lua b/util/sql.lua
index f360d6d0..63c399ff 100644
--- a/util/sql.lua
+++ b/util/sql.lua
@@ -177,8 +177,8 @@ function engine:execute(sql, ...)
end
local result_mt = { __index = {
- affected = function(self) return self.__affected; end;
- rowcount = function(self) return self.__rowcount; end;
+ affected = function(self) return self.__stmt:affected(); end;
+ rowcount = function(self) return self.__stmt:rowcount(); end;
} };
function engine:execute_query(sql, ...)
@@ -200,7 +200,7 @@ function engine:execute_update(sql, ...)
prepared[sql] = stmt;
end
assert(stmt:execute(...));
- return setmetatable({ __affected = stmt:affected(), __rowcount = stmt:rowcount() }, result_mt);
+ return setmetatable({ __stmt = stmt }, result_mt);
end
engine.insert = engine.execute_update;
engine.select = engine.execute_query;
@@ -264,6 +264,8 @@ function engine:_create_table(table)
sql = sql.. ");"
if self.params.driver == "PostgreSQL" then
sql = sql:gsub("`", "\"");
+ elseif self.params.driver == "MySQL" then
+ sql = sql:gsub(";$", " CHARACTER SET 'utf8' COLLATE 'utf8_bin';");
end
local success,err = self:execute(sql);
if not success then return success,err; end