diff options
Diffstat (limited to 'util')
-rw-r--r-- | util/array.lua | 6 | ||||
-rw-r--r-- | util/async.lua | 158 | ||||
-rw-r--r-- | util/caps.lua | 2 | ||||
-rw-r--r-- | util/dataforms.lua | 8 | ||||
-rw-r--r-- | util/datetime.lua | 2 | ||||
-rw-r--r-- | util/debug.lua | 38 | ||||
-rw-r--r-- | util/dependencies.lua | 20 | ||||
-rw-r--r-- | util/events.lua | 8 | ||||
-rw-r--r-- | util/filters.lua | 16 | ||||
-rw-r--r-- | util/helpers.lua | 2 | ||||
-rw-r--r-- | util/hmac.lua | 2 | ||||
-rw-r--r-- | util/import.lua | 2 | ||||
-rw-r--r-- | util/ip.lua | 54 | ||||
-rw-r--r-- | util/iterators.lua | 46 | ||||
-rw-r--r-- | util/jid.lua | 2 | ||||
-rw-r--r-- | util/json.lua | 4 | ||||
-rw-r--r-- | util/logger.lua | 2 | ||||
-rw-r--r-- | util/multitable.lua | 2 | ||||
-rw-r--r-- | util/pluginloader.lua | 2 | ||||
-rw-r--r-- | util/prosodyctl.lua | 30 | ||||
-rw-r--r-- | util/pubsub.lua | 10 | ||||
-rw-r--r-- | util/sasl.lua | 1 | ||||
-rw-r--r-- | util/sasl/external.lua | 25 | ||||
-rw-r--r-- | util/sasl/scram.lua | 28 | ||||
-rw-r--r-- | util/sasl_cyrus.lua | 4 | ||||
-rw-r--r-- | util/serialization.lua | 2 | ||||
-rw-r--r-- | util/set.lua | 38 | ||||
-rw-r--r-- | util/sql.lua | 10 | ||||
-rw-r--r-- | util/stanza.lua | 16 | ||||
-rw-r--r-- | util/termcolours.lua | 2 | ||||
-rw-r--r-- | util/timer.lua | 4 | ||||
-rw-r--r-- | util/uuid.lua | 2 | ||||
-rw-r--r-- | util/xml.lua | 4 | ||||
-rw-r--r-- | util/xmppstream.lua | 24 |
34 files changed, 417 insertions, 159 deletions
diff --git a/util/array.lua b/util/array.lua index 2d58e7fb..9bf215af 100644 --- a/util/array.lua +++ b/util/array.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -59,13 +59,13 @@ function array_base.filter(outa, ina, func) write = write + 1; end end - + if inplace and write <= start_length then for i=write,start_length do outa[i] = nil; end end - + return outa; end diff --git a/util/async.lua b/util/async.lua new file mode 100644 index 00000000..968ec804 --- /dev/null +++ b/util/async.lua @@ -0,0 +1,158 @@ +local log = require "util.logger".init("util.async"); + +local function runner_continue(thread) + -- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure) + if coroutine.status(thread) ~= "suspended" then -- This should suffice + return false; + end + local ok, state, runner = coroutine.resume(thread); + if not ok then + local level = 0; + while debug.getinfo(thread, level, "") do level = level + 1; end + ok, runner = debug.getlocal(thread, level-1, 1); + local error_handler = runner.watchers.error; + if error_handler then error_handler(runner, debug.traceback(thread, state)); end + elseif state == "ready" then + -- If state is 'ready', it is our responsibility to update runner.state from 'waiting'. + -- We also have to :run(), because the queue might have further items that will not be + -- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer). + runner.state = "ready"; + runner:run(); + end + return true; +end + +local function waiter(num) + local thread = coroutine.running(); + if not thread then + error("Not running in an async context, see http://prosody.im/doc/developers/async"); + end + num = num or 1; + local waiting; + return function () + if num == 0 then return; end -- already done + waiting = true; + coroutine.yield("wait"); + end, function () + num = num - 1; + if num == 0 and waiting then + runner_continue(thread); + elseif num < 0 then + error("done() called too many times"); + end + end; +end + +local function guarder() + local guards = {}; + return function (id, func) + local thread = coroutine.running(); + if not thread then + error("Not running in an async context, see http://prosody.im/doc/developers/async"); + end + local guard = guards[id]; + if not guard then + guard = {}; + guards[id] = guard; + log("debug", "New guard!"); + else + table.insert(guard, thread); + log("debug", "Guarded. %d threads waiting.", #guard) + coroutine.yield("wait"); + end + local function exit() + local next_waiting = table.remove(guard, 1); + if next_waiting then + log("debug", "guard: Executing next waiting thread (%d left)", #guard) + runner_continue(next_waiting); + else + log("debug", "Guard off duty.") + guards[id] = nil; + end + end + if func then + func(); + exit(); + return; + end + return exit; + end; +end + +local runner_mt = {}; +runner_mt.__index = runner_mt; + +local function runner_create_thread(func, self) + local thread = coroutine.create(function (self) + while true do + func(coroutine.yield("ready", self)); + end + end); + assert(coroutine.resume(thread, self)); -- Start it up, it will return instantly to wait for the first input + return thread; +end + +local empty_watchers = {}; +local function runner(func, watchers, data) + return setmetatable({ func = func, thread = false, state = "ready", notified_state = "ready", + queue = {}, watchers = watchers or empty_watchers, data = data } + , runner_mt); +end + +function runner_mt:run(input) + if input ~= nil then + table.insert(self.queue, input); + end + if self.state ~= "ready" then + return true, self.state, #self.queue; + end + + local q, thread = self.queue, self.thread; + if not thread or coroutine.status(thread) == "dead" then + thread = runner_create_thread(self.func, self); + self.thread = thread; + end + + local n, state, err = #q, self.state, nil; + self.state = "running"; + while n > 0 and state == "ready" do + local consumed; + for i = 1,n do + local input = q[i]; + local ok, new_state = coroutine.resume(thread, input); + if not ok then + consumed, state, err = i, "ready", debug.traceback(thread, new_state); + self.thread = nil; + break; + elseif new_state == "wait" then + consumed, state = i, "waiting"; + break; + end + end + if not consumed then consumed = n; end + if q[n+1] ~= nil then + n = #q; + end + for i = 1, n do + q[i] = q[consumed+i]; + end + n = #q; + end + self.state = state; + if err or state ~= self.notified_state then + if err then + state = "error" + else + self.notified_state = state; + end + local handler = self.watchers[state]; + if handler then handler(self, err); end + end + return true, state, n; +end + +function runner_mt:enqueue(input) + table.insert(self.queue, input); +end + +return { waiter = waiter, guarder = guarder, runner = runner }; diff --git a/util/caps.lua b/util/caps.lua index a61e7403..4723b912 100644 --- a/util/caps.lua +++ b/util/caps.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/dataforms.lua b/util/dataforms.lua index 52924841..b38d0e27 100644 --- a/util/dataforms.lua +++ b/util/dataforms.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -38,7 +38,7 @@ function form_t.form(layout, data, formtype) form:tag("field", { type = field_type, var = field.name, label = field.label }); local value = (data and data[field.name]) or field.value; - + if value then -- Add value, depending on type if field_type == "hidden" then @@ -93,11 +93,11 @@ function form_t.form(layout, data, formtype) end end end - + if field.required then form:tag("required"):up(); end - + -- Jump back up to list of fields form:up(); end diff --git a/util/datetime.lua b/util/datetime.lua index a1f62a48..dd596527 100644 --- a/util/datetime.lua +++ b/util/datetime.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/debug.lua b/util/debug.lua index bff0e347..91f691e1 100644 --- a/util/debug.lua +++ b/util/debug.lua @@ -24,11 +24,15 @@ do end module("debugx", package.seeall); -function get_locals_table(level) - level = level + 1; -- Skip this function itself +function get_locals_table(thread, level) local locals = {}; for local_num = 1, math.huge do - local name, value = debug.getlocal(level, local_num); + local name, value; + if thread then + name, value = debug.getlocal(thread, level, local_num); + else + name, value = debug.getlocal(level+1, local_num); + end if not name then break; end table.insert(locals, { name = name, value = value }); end @@ -88,19 +92,19 @@ function get_traceback_table(thread, start_level) for level = start_level, math.huge do local info; if thread then - info = debug.getinfo(thread, level+1); + info = debug.getinfo(thread, level); else info = debug.getinfo(level+1); end if not info then break; end - + levels[(level-start_level)+1] = { level = level; info = info; - locals = get_locals_table(level+1); + locals = get_locals_table(thread, level+(thread and 0 or 1)); upvalues = get_upvalues_table(info.func); }; - end + end return levels; end @@ -134,15 +138,15 @@ function _traceback(thread, message, level) return nil; -- debug.traceback() does this end - level = level or 1; + level = level or 0; message = message and (message.."\n") or ""; - - -- +3 counts for this function, and the pcall() and wrapper above us - local levels = get_traceback_table(thread, level+3); - + + -- +3 counts for this function, and the pcall() and wrapper above us, the +1... I don't know. + local levels = get_traceback_table(thread, level+(thread == nil and 4 or 0)); + local last_source_desc; - + local lines = {}; for nlevel, level in ipairs(levels) do local info = level.info; @@ -171,9 +175,11 @@ function _traceback(thread, message, level) nlevel = nlevel-1; table.insert(lines, "\t"..(nlevel==0 and ">" or " ")..getstring(styles.level_num, "("..nlevel..") ")..line); local npadding = (" "):rep(#tostring(nlevel)); - local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding); - if locals_str then - table.insert(lines, "\t "..npadding.."Locals: "..locals_str); + if level.locals then + local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding); + if locals_str then + table.insert(lines, "\t "..npadding.."Locals: "..locals_str); + end end local upvalues_str = string_from_var_table(level.upvalues, optimal_line_length, "\t "..npadding); if upvalues_str then diff --git a/util/dependencies.lua b/util/dependencies.lua index 53d2719d..109a3332 100644 --- a/util/dependencies.lua +++ b/util/dependencies.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -35,7 +35,7 @@ function missingdep(name, sources, msg) print(""); end --- COMPAT w/pre-0.8 Debian: The Debian config file used to use +-- COMPAT w/pre-0.8 Debian: The Debian config file used to use -- util.ztact, which has been removed from Prosody in 0.8. This -- is to log an error for people who still use it, so they can -- update their configs. @@ -50,9 +50,9 @@ end; function check_dependencies() local fatal; - + local lxp = softreq "lxp" - + if not lxp then missingdep("luaexpat", { ["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-expat0"; @@ -61,9 +61,9 @@ function check_dependencies() }); fatal = true; end - + local socket = softreq "socket" - + if not socket then missingdep("luasocket", { ["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-socket2"; @@ -72,7 +72,7 @@ function check_dependencies() }); fatal = true; end - + local lfs, err = softreq "lfs" if not lfs then missingdep("luafilesystem", { @@ -82,9 +82,9 @@ function check_dependencies() }); fatal = true; end - + local ssl = softreq "ssl" - + if not ssl then missingdep("LuaSec", { ["Debian/Ubuntu"] = "http://prosody.im/download/start#debian_and_ubuntu"; @@ -92,7 +92,7 @@ function check_dependencies() ["Source"] = "http://www.inf.puc-rio.br/~brunoos/luasec/"; }, "SSL/TLS support will not be available"); end - + local encodings, err = softreq "util.encodings" if not encodings then if err:match("not found") then diff --git a/util/events.lua b/util/events.lua index 412acccd..40ca3913 100644 --- a/util/events.lua +++ b/util/events.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -60,11 +60,11 @@ function new() remove_handler(event, handler); end end; - local function fire_event(event, ...) - local h = handlers[event]; + local function fire_event(event_name, event_data) + local h = handlers[event_name]; if h then for i=1,#h do - local ret = h[i](...); + local ret = h[i](event_data); if ret ~= nil then return ret; end end end diff --git a/util/filters.lua b/util/filters.lua index d143666b..8a470011 100644 --- a/util/filters.lua +++ b/util/filters.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -16,7 +16,7 @@ function initialize(session) if not session.filters then local filters = {}; session.filters = filters; - + function session.filter(type, data) local filter_list = filters[type]; if filter_list then @@ -28,11 +28,11 @@ function initialize(session) return data; end end - + for i=1,#new_filter_hooks do new_filter_hooks[i](session); end - + return session.filter; end @@ -40,20 +40,20 @@ function add_filter(session, type, callback, priority) if not session.filters then initialize(session); end - + local filter_list = session.filters[type]; if not filter_list then filter_list = {}; session.filters[type] = filter_list; end - + priority = priority or 0; - + local i = 0; repeat i = i + 1; until not filter_list[i] or filter_list[filter_list[i]] >= priority; - + t_insert(filter_list, i, callback); filter_list[callback] = priority; end diff --git a/util/helpers.lua b/util/helpers.lua index 08b86a7c..437a920c 100644 --- a/util/helpers.lua +++ b/util/helpers.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/hmac.lua b/util/hmac.lua index 51211c7a..2c4cc6ef 100644 --- a/util/hmac.lua +++ b/util/hmac.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/import.lua b/util/import.lua index 81401e8b..174da0ca 100644 --- a/util/import.lua +++ b/util/import.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/ip.lua b/util/ip.lua index 856bf034..d0ae07eb 100644 --- a/util/ip.lua +++ b/util/ip.lua @@ -12,7 +12,17 @@ local ip_mt = { __index = function (ip, key) return (ip_methods[key])(ip); end, local hex2bits = { ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111", ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011", ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111" }; local function new_ip(ipStr, proto) - if proto ~= "IPv4" and proto ~= "IPv6" then + if not proto then + local sep = ipStr:match("^%x+(.)"); + if sep == ":" or (not(sep) and ipStr:sub(1,1) == ":") then + proto = "IPv6" + elseif sep == "." then + proto = "IPv4" + end + if not proto then + return nil, "invalid address"; + end + elseif proto ~= "IPv4" and proto ~= "IPv6" then return nil, "invalid protocol"; end if proto == "IPv6" and ipStr:find('.', 1, true) then @@ -82,7 +92,7 @@ local function v6scope(ip) if ip:match("^[0:]*1$") then return 0x2; -- Link-local unicast: - elseif ip:match("^[Ff][Ee][89ABab]") then + elseif ip:match("^[Ff][Ee][89ABab]") then return 0x2; -- Site-local unicast: elseif ip:match("^[Ff][Ee][CcDdEeFf]") then @@ -192,5 +202,43 @@ function ip_methods:scope() return value; end +function ip_methods:private() + local private = self.scope ~= 0xE; + if not private and self.proto == "IPv4" then + local ip = self.addr; + local fields = {}; + ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end); + if fields[1] == 127 or fields[1] == 10 or (fields[1] == 192 and fields[2] == 168) + or (fields[1] == 172 and (fields[2] >= 16 or fields[2] <= 32)) then + private = true; + end + end + self.private = private; + return private; +end + +local function parse_cidr(cidr) + local bits; + local ip_len = cidr:find("/", 1, true); + if ip_len then + bits = tonumber(cidr:sub(ip_len+1, -1)); + cidr = cidr:sub(1, ip_len-1); + end + return new_ip(cidr), bits; +end + +local function match(ipA, ipB, bits) + local common_bits = commonPrefixLength(ipA, ipB); + if not bits then + return ipA == ipB; + end + if bits and ipB.proto == "IPv4" then + common_bits = common_bits - 96; -- v6 mapped addresses always share these bits + end + return common_bits >= bits; +end + return {new_ip = new_ip, - commonPrefixLength = commonPrefixLength}; + commonPrefixLength = commonPrefixLength, + parse_cidr = parse_cidr, + match=match}; diff --git a/util/iterators.lua b/util/iterators.lua index 1f6aacb8..aa9c3ec0 100644 --- a/util/iterators.lua +++ b/util/iterators.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -10,6 +10,10 @@ local it = {}; +local t_insert = table.insert; +local select, unpack, next = select, unpack, next; +local function pack(...) return { n = select("#", ...), ... }; end + -- Reverse an iterator function it.reverse(f, s, var) local results = {}; @@ -19,9 +23,9 @@ function it.reverse(f, s, var) local ret = { f(s, var) }; var = ret[1]; if var == nil then break; end - table.insert(results, 1, ret); + t_insert(results, 1, ret); end - + -- Then return our reverse one local i,max = 0, #results; return function (results) @@ -52,15 +56,15 @@ end -- Given an iterator, iterate only over unique items function it.unique(f, s, var) local set = {}; - + return function () while true do - local ret = { f(s, var) }; + local ret = pack(f(s, var)); var = ret[1]; if var == nil then break; end if not set[var] then set[var] = true; - return var; + return unpack(ret, 1, ret.n); end end end; @@ -69,14 +73,13 @@ end --[[ Return the number of items an iterator returns ]]-- function it.count(f, s, var) local x = 0; - + while true do - local ret = { f(s, var) }; - var = ret[1]; + var = f(s, var); if var == nil then break; end x = x + 1; end - + return x; end @@ -104,7 +107,7 @@ end function it.tail(n, f, s, var) local results, count = {}, 0; while true do - local ret = { f(s, var) }; + local ret = pack(f(s, var)); var = ret[1]; if var == nil then break; end results[(count%n)+1] = ret; @@ -117,9 +120,24 @@ function it.tail(n, f, s, var) return function () pos = pos + 1; if pos > n then return nil; end - return unpack(results[((count-1+pos)%n)+1]); + local ret = results[((count-1+pos)%n)+1]; + return unpack(ret, 1, ret.n); end - --return reverse(head(n, reverse(f, s, var))); + --return reverse(head(n, reverse(f, s, var))); -- ! +end + +function it.filter(filter, f, s, var) + if type(filter) ~= "function" then + local filter_value = filter; + function filter(x) return x ~= filter_value; end + end + return function (s, var) + local ret; + repeat ret = pack(f(s, var)); + var = ret[1]; + until var == nil or filter(unpack(ret, 1, ret.n)); + return unpack(ret, 1, ret.n); + end, s, var; end local function _ripairs_iter(t, key) if key > 1 then return key-1, t[key-1]; end end @@ -139,7 +157,7 @@ function it.to_array(f, s, var) while true do var = f(s, var); if var == nil then break; end - table.insert(t, var); + t_insert(t, var); end return t; end diff --git a/util/jid.lua b/util/jid.lua index 4c4371d8..0d9a864f 100644 --- a/util/jid.lua +++ b/util/jid.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/json.lua b/util/json.lua index 82ebcc43..a8a58afc 100644 --- a/util/json.lua +++ b/util/json.lua @@ -348,9 +348,9 @@ local first_escape = { function json.decode(json) json = json:gsub("\\.", first_escape) -- get rid of all escapes except \uXXXX, making string parsing much simpler --:gsub("[\r\n]", "\t"); -- \r\n\t are equivalent, we care about none of them, and none of them can be in strings - + -- TODO do encoding verification - + local val, index = _readvalue(json, 1); if val == nil then return val, index; end if json:find("[^ \t\r\n]", index) then return nil, "garbage at eof"; end diff --git a/util/logger.lua b/util/logger.lua index 26206d4d..cd0769f9 100644 --- a/util/logger.lua +++ b/util/logger.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/multitable.lua b/util/multitable.lua index dbf34d28..caf25118 100644 --- a/util/multitable.lua +++ b/util/multitable.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/pluginloader.lua b/util/pluginloader.lua index c10fdf65..b894f527 100644 --- a/util/pluginloader.lua +++ b/util/pluginloader.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua index b80a69f2..fe862114 100644 --- a/util/prosodyctl.lua +++ b/util/prosodyctl.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -146,7 +146,7 @@ function adduser(params) if not(provider) or provider.name == "null" then usermanager.initialize_host(host); end - + local ok, errmsg = usermanager.create_user(user, password, host); if not ok then return false, errmsg; @@ -162,7 +162,7 @@ function user_exists(params) if not(provider) or provider.name == "null" then usermanager.initialize_host(host); end - + return usermanager.user_exists(user, host); end @@ -170,7 +170,7 @@ function passwd(params) if not _M.user_exists(params) then return false, "no-such-user"; end - + return _M.adduser(params); end @@ -179,7 +179,7 @@ function deluser(params) return false, "no-such-user"; end local user, host = nodeprep(params.user), nameprep(params.host); - + return usermanager.delete_user(user, host); end @@ -188,30 +188,30 @@ function getpid() if not pidfile then return false, "no-pidfile"; end - + local modules_enabled = set.new(config.get("*", "modules_enabled")); if not modules_enabled:contains("posix") then return false, "no-posix"; end - + local file, err = io.open(pidfile, "r+"); if not file then return false, "pidfile-read-failed", err; end - + local locked, err = lfs.lock(file, "w"); if locked then file:close(); return false, "pidfile-not-locked"; end - + local pid = tonumber(file:read("*a")); file:close(); - + if not pid then return false, "invalid-pid"; end - + return true, pid; end @@ -252,10 +252,10 @@ function stop() if not ret then return false, "not-running"; end - + local ok, pid = _M.getpid() if not ok then return false, pid; end - + signal.kill(pid, signal.SIGTERM); return true; end @@ -268,10 +268,10 @@ function reload() if not ret then return false, "not-running"; end - + local ok, pid = _M.getpid() if not ok then return false, pid; end - + signal.kill(pid, signal.SIGHUP); return true; end diff --git a/util/pubsub.lua b/util/pubsub.lua index e1418c62..0dfd196b 100644 --- a/util/pubsub.lua +++ b/util/pubsub.lua @@ -29,13 +29,13 @@ end function service:may(node, actor, action) if actor == true then return true; end - + local node_obj = self.nodes[node]; local node_aff = node_obj and node_obj.affiliations[actor]; local service_aff = self.affiliations[actor] or self.config.get_affiliation(actor, node, action) or "none"; - + -- Check if node allows/forbids it local node_capabilities = node_obj and node_obj.capabilities; if node_capabilities then @@ -47,7 +47,7 @@ function service:may(node, actor, action) end end end - + -- Check service-wide capabilities instead local service_capabilities = self.config.capabilities; local caps = service_capabilities[node_aff or service_aff]; @@ -57,7 +57,7 @@ function service:may(node, actor, action) return can; end end - + return false; end @@ -211,7 +211,7 @@ function service:create(node, actor) if self.nodes[node] then return false, "conflict"; end - + self.nodes[node] = { name = node; subscribers = {}; diff --git a/util/sasl.lua b/util/sasl.lua index afb3861b..d0da9435 100644 --- a/util/sasl.lua +++ b/util/sasl.lua @@ -92,5 +92,6 @@ require "util.sasl.plain" .init(registerMechanism); require "util.sasl.digest-md5".init(registerMechanism); require "util.sasl.anonymous" .init(registerMechanism); require "util.sasl.scram" .init(registerMechanism); +require "util.sasl.external" .init(registerMechanism); return _M; diff --git a/util/sasl/external.lua b/util/sasl/external.lua new file mode 100644 index 00000000..4c5c4343 --- /dev/null +++ b/util/sasl/external.lua @@ -0,0 +1,25 @@ +local saslprep = require "util.encodings".stringprep.saslprep; + +module "sasl.external" + +local function external(self, message) + message = saslprep(message); + local state + self.username, state = self.profile.external(message); + + if state == false then + return "failure", "account-disabled"; + elseif state == nil then + return "failure", "not-authorized"; + elseif state == "expired" then + return "false", "credentials-expired"; + end + + return "success"; +end + +function init(registerMechanism) + registerMechanism("EXTERNAL", {"external"}, external); +end + +return _M; diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua index cf2f0ede..31c078a0 100644 --- a/util/sasl/scram.lua +++ b/util/sasl/scram.lua @@ -73,11 +73,11 @@ local function validate_username(username, _nodeprep) return false end end - + -- replace =2C with , and =3D with = username = username:gsub("=2C", ","); username = username:gsub("=3D", "="); - + -- apply SASLprep username = saslprep(username); @@ -108,12 +108,12 @@ end local function scram_gen(hash_name, H_f, HMAC_f) local function scram_hash(self, message) if not self.state then self["state"] = {} end - + if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end if not self.state.name then -- we are processing client_first_message local client_first_message = message; - + -- TODO: fail if authzid is provided, since we don't support them yet self.state["client_first_message"] = client_first_message; self.state["gs2_cbind_flag"], self.state["authzid"], self.state["name"], self.state["clientnonce"] @@ -127,21 +127,21 @@ local function scram_gen(hash_name, H_f, HMAC_f) if not self.state.name or not self.state.clientnonce then return "failure", "malformed-request", "Channel binding isn't support at this time."; end - + self.state.name = validate_username(self.state.name, self.profile.nodeprep); if not self.state.name then log("debug", "Username violates either SASLprep or contains forbidden character sequences.") return "failure", "malformed-request", "Invalid username."; end - + self.state["servernonce"] = generate_uuid(); - + -- retreive credentials if self.profile.plain then local password, state = self.profile.plain(self, self.state.name, self.realm) if state == nil then return "failure", "not-authorized" elseif state == false then return "failure", "account-disabled" end - + password = saslprep(password); if not password then log("debug", "Password violates SASLprep."); @@ -161,22 +161,22 @@ local function scram_gen(hash_name, H_f, HMAC_f) local stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, self.state.name, self.realm); if state == nil then return "failure", "not-authorized" elseif state == false then return "failure", "account-disabled" end - + self.state.stored_key = stored_key; self.state.server_key = server_key; self.state.iteration_count = iteration_count; self.state.salt = salt end - + local server_first_message = "r="..self.state.clientnonce..self.state.servernonce..",s="..base64.encode(self.state.salt)..",i="..self.state.iteration_count; self.state["server_first_message"] = server_first_message; return "challenge", server_first_message else -- we are processing client_final_message local client_final_message = message; - + self.state["channelbinding"], self.state["nonce"], self.state["proof"] = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)"); - + if not self.state.proof or not self.state.nonce or not self.state.channelbinding then return "failure", "malformed-request", "Missing an attribute(p, r or c) in SASL message."; end @@ -184,10 +184,10 @@ local function scram_gen(hash_name, H_f, HMAC_f) if self.state.nonce ~= self.state.clientnonce..self.state.servernonce then return "failure", "malformed-request", "Wrong nonce in client-final-message."; end - + local ServerKey = self.state.server_key; local StoredKey = self.state.stored_key; - + local AuthMessage = "n=" .. s_match(self.state.client_first_message,"n=(.+)") .. "," .. self.state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+") local ClientSignature = HMAC_f(StoredKey, AuthMessage) local ClientKey = binaryXOR(ClientSignature, base64.decode(self.state.proof)) diff --git a/util/sasl_cyrus.lua b/util/sasl_cyrus.lua index 19684587..a0e8bd69 100644 --- a/util/sasl_cyrus.lua +++ b/util/sasl_cyrus.lua @@ -78,10 +78,10 @@ local function init(service_name) end -- create a new SASL object which can be used to authenticate clients --- host_fqdn may be nil in which case gethostname() gives the value. +-- host_fqdn may be nil in which case gethostname() gives the value. -- For GSSAPI, this determines the hostname in the service ticket (after -- reverse DNS canonicalization, only if [libdefaults] rdns = true which --- is the default). +-- is the default). function new(realm, service_name, app_name, host_fqdn) init(app_name or service_name); diff --git a/util/serialization.lua b/util/serialization.lua index 8a259184..06e45054 100644 --- a/util/serialization.lua +++ b/util/serialization.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/set.lua b/util/set.lua index b9e9ef21..89cd7cf3 100644 --- a/util/set.lua +++ b/util/set.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -40,13 +40,13 @@ function set_mt.__eq(set1, set2) return false; end end - + for item in pairs(set2) do if not set1[item] then return false; end end - + return true; end function set_mt.__tostring(set) @@ -65,23 +65,23 @@ end function new(list) local items = setmetatable({}, items_mt); local set = { _items = items }; - + function set:add(item) items[item] = true; end - + function set:contains(item) return items[item]; end - + function set:items() - return items; + return next, items; end - + function set:remove(item) items[item] = nil; end - + function set:add_list(list) if list then for _, item in ipairs(list) do @@ -89,7 +89,7 @@ function new(list) end end end - + function set:include(otherset) for item in otherset do items[item] = true; @@ -101,22 +101,22 @@ function new(list) items[item] = nil; end end - + function set:empty() return not next(items); end - + if list then set:add_list(list); end - + return setmetatable(set, set_mt); end function union(set1, set2) local set = new(); local items = set._items; - + for item in pairs(set1._items) do items[item] = true; end @@ -124,14 +124,14 @@ function union(set1, set2) for item in pairs(set2._items) do items[item] = true; end - + return set; end function difference(set1, set2) local set = new(); local items = set._items; - + for item in pairs(set1._items) do items[item] = (not set2._items[item]) or nil; end @@ -142,13 +142,13 @@ end function intersection(set1, set2) local set = new(); local items = set._items; - + set1, set2 = set1._items, set2._items; - + for item in pairs(set1) do items[item] = (not not set2[item]) or nil; end - + return set; end diff --git a/util/sql.lua b/util/sql.lua index f360d6d0..b8c16e27 100644 --- a/util/sql.lua +++ b/util/sql.lua @@ -45,7 +45,7 @@ function String(n) return "String()" end }; local functions = { - + }; local cmap = { @@ -177,8 +177,8 @@ function engine:execute(sql, ...) end local result_mt = { __index = { - affected = function(self) return self.__affected; end; - rowcount = function(self) return self.__rowcount; end; + affected = function(self) return self.__stmt:affected(); end; + rowcount = function(self) return self.__stmt:rowcount(); end; } }; function engine:execute_query(sql, ...) @@ -200,7 +200,7 @@ function engine:execute_update(sql, ...) prepared[sql] = stmt; end assert(stmt:execute(...)); - return setmetatable({ __affected = stmt:affected(), __rowcount = stmt:rowcount() }, result_mt); + return setmetatable({ __stmt = stmt }, result_mt); end engine.insert = engine.execute_update; engine.select = engine.execute_query; @@ -264,6 +264,8 @@ function engine:_create_table(table) sql = sql.. ");" if self.params.driver == "PostgreSQL" then sql = sql:gsub("`", "\""); + elseif self.params.driver == "MySQL" then + sql = sql:gsub(";$", " CHARACTER SET 'utf8' COLLATE 'utf8_bin';"); end local success,err = self:execute(sql); if not success then return success,err; end diff --git a/util/stanza.lua b/util/stanza.lua index 7c214210..82601e63 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -99,7 +99,7 @@ function stanza_mt:get_child(name, xmlns) if (not name or child.name == name) and ((not xmlns and self.attr.xmlns == child.attr.xmlns) or child.attr.xmlns == xmlns) then - + return child; end end @@ -152,7 +152,7 @@ end function stanza_mt:maptags(callback) local tags, curr_tag = self.tags, 1; local n_children, n_tags = #self, #tags; - + local i = 1; while curr_tag <= n_tags and n_tags > 0 do if self[i] == tags[curr_tag] then @@ -258,13 +258,13 @@ end function stanza_mt.get_error(stanza) local type, condition, text; - + local error_tag = stanza:get_child("error"); if not error_tag then return nil, nil, nil; end type = error_tag.attr.type; - + for _, child in ipairs(error_tag.tags) do if child.attr.xmlns == xmlns_stanzas then if not text and child.name == "text" then @@ -333,7 +333,7 @@ function deserialize(stanza) stanza.tags = tags; end end - + return stanza; end @@ -390,7 +390,7 @@ if do_pretty_printing then local style_attrv = getstyle("red"); local style_tagname = getstyle("red"); local style_punc = getstyle("magenta"); - + local attr_format = " "..getstring(style_attrk, "%s")..getstring(style_punc, "=")..getstring(style_attrv, "'%s'"); local top_tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">"); --local tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">").."%s"..getstring(style_punc, "</")..getstring(style_tagname, "%s")..getstring(style_punc, ">"); @@ -411,7 +411,7 @@ if do_pretty_printing then end return s_format(tag_format, t.name, attr_string, children_text, t.name); end - + function stanza_mt.pretty_top_tag(t) local attr_string = ""; if t.attr then diff --git a/util/termcolours.lua b/util/termcolours.lua index 6ef3b689..ef978364 100644 --- a/util/termcolours.lua +++ b/util/termcolours.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/timer.lua b/util/timer.lua index af1e57b6..0e10e144 100644 --- a/util/timer.lua +++ b/util/timer.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -42,7 +42,7 @@ if not server.event then end new_data = {}; end - + local next_time = math_huge; for i, d in pairs(data) do local t, callback = d[1], d[2]; diff --git a/util/uuid.lua b/util/uuid.lua index 796c8ee4..fc487c72 100644 --- a/util/uuid.lua +++ b/util/uuid.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/xml.lua b/util/xml.lua index 076490fa..6dbed65d 100644 --- a/util/xml.lua +++ b/util/xml.lua @@ -26,8 +26,8 @@ local parse_xml = (function() attr[i] = nil; local ns, nm = k:match(ns_pattern); if nm ~= "" then - ns = ns_prefixes[ns]; - if ns then + ns = ns_prefixes[ns]; + if ns then attr[ns..":"..nm] = attr[k]; attr[k] = nil; end diff --git a/util/xmppstream.lua b/util/xmppstream.lua index 4909678c..550170c9 100644 --- a/util/xmppstream.lua +++ b/util/xmppstream.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -42,21 +42,21 @@ _M.ns_pattern = ns_pattern; function new_sax_handlers(session, stream_callbacks) local xml_handlers = {}; - + local cb_streamopened = stream_callbacks.streamopened; local cb_streamclosed = stream_callbacks.streamclosed; local cb_error = stream_callbacks.error or function(session, e, stanza) error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); end; local cb_handlestanza = stream_callbacks.handlestanza; - + local stream_ns = stream_callbacks.stream_ns or xmlns_streams; local stream_tag = stream_callbacks.stream_tag or "stream"; if stream_ns ~= "" then stream_tag = stream_ns..ns_separator..stream_tag; end local stream_error_tag = stream_ns..ns_separator..(stream_callbacks.error_tag or "error"); - + local stream_default_ns = stream_callbacks.default_ns; - + local stack = {}; local chardata, stanza = {}; local non_streamns_depth = 0; @@ -75,7 +75,7 @@ function new_sax_handlers(session, stream_callbacks) attr.xmlns = curr_ns; non_streamns_depth = non_streamns_depth + 1; end - + for i=1,#attr do local k = attr[i]; attr[i] = nil; @@ -85,7 +85,7 @@ function new_sax_handlers(session, stream_callbacks) attr[k] = nil; end end - + if not stanza then --if we are not currently inside a stanza if session.notopen then if tagname == stream_tag then @@ -102,7 +102,7 @@ function new_sax_handlers(session, stream_callbacks) if curr_ns == "jabber:client" and name ~= "iq" and name ~= "presence" and name ~= "message" then cb_error(session, "invalid-top-level-element"); end - + stanza = setmetatable({ name = name, attr = attr, tags = {} }, stanza_mt); else -- we are inside a stanza, so add a tag t_insert(stack, stanza); @@ -151,22 +151,22 @@ function new_sax_handlers(session, stream_callbacks) error("Failed to abort parsing"); end end - + if lxp_supports_doctype then xml_handlers.StartDoctypeDecl = restricted_handler; end xml_handlers.Comment = restricted_handler; xml_handlers.ProcessingInstruction = restricted_handler; - + local function reset() stanza, chardata = nil, {}; stack = {}; end - + local function set_session(stream, new_session) session = new_session; end - + return xml_handlers, { reset = reset, set_session = set_session }; end |