diff options
Diffstat (limited to 'util')
50 files changed, 2857 insertions, 1206 deletions
diff --git a/util/adhoc.lua b/util/adhoc.lua new file mode 100644 index 00000000..671e85cf --- /dev/null +++ b/util/adhoc.lua @@ -0,0 +1,31 @@ +local function new_simple_form(form, result_handler) + return function(self, data, state) + if state then + if data.action == "cancel" then + return { status = "canceled" }; + end + local fields, err = form:data(data.form); + return result_handler(fields, err, data); + else + return { status = "executing", actions = {"next", "complete", default = "complete"}, form = form }, "executing"; + end + end +end + +local function new_initial_data_form(form, initial_data, result_handler) + return function(self, data, state) + if state then + if data.action == "cancel" then + return { status = "canceled" }; + end + local fields, err = form:data(data.form); + return result_handler(fields, err, data); + else + return { status = "executing", actions = {"next", "complete", default = "complete"}, + form = { layout = form, values = initial_data() } }, "executing"; + end + end +end + +return { new_simple_form = new_simple_form, + new_initial_data_form = new_initial_data_form }; diff --git a/util/array.lua b/util/array.lua index 6c1f0460..9bf215af 100644 --- a/util/array.lua +++ b/util/array.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -9,12 +9,20 @@ local t_insert, t_sort, t_remove, t_concat = table.insert, table.sort, table.remove, table.concat; +local setmetatable = setmetatable; +local math_random = math.random; +local pairs, ipairs = pairs, ipairs; +local tostring = tostring; + local array = {}; local array_base = {}; local array_methods = {}; -local array_mt = { __index = array_methods, __tostring = function (array) return array:concat(", "); end }; +local array_mt = { __index = array_methods, __tostring = function (array) return "{"..array:concat(", ").."}"; end }; -local function new_array(_, t) +local function new_array(self, t, _s, _var) + if type(t) == "function" then -- Assume iterator + t = self.collect(t, _s, _var); + end return setmetatable(t or {}, array_mt); end @@ -25,6 +33,15 @@ end setmetatable(array, { __call = new_array }); +-- Read-only methods +function array_methods:random() + return self[math_random(1,#self)]; +end + +-- These methods can be called two ways: +-- array.method(existing_array, [params [, ...]]) -- Create new array for result +-- existing_array:method([params, ...]) -- Transform existing array into result +-- function array_base.map(outa, ina, func) for k,v in ipairs(ina) do outa[k] = func(v); @@ -42,13 +59,13 @@ function array_base.filter(outa, ina, func) write = write + 1; end end - + if inplace and write <= start_length then for i=write,start_length do outa[i] = nil; end end - + return outa; end @@ -60,15 +77,18 @@ function array_base.sort(outa, ina, ...) return outa; end ---- These methods only mutate -function array_methods:random() - return self[math.random(1,#self)]; +function array_base.pluck(outa, ina, key) + for i=1,#ina do + outa[i] = ina[i][key]; + end + return outa; end +--- These methods only mutate the array function array_methods:shuffle(outa, ina) local len = #self; for i=1,#self do - local r = math.random(i,len); + local r = math_random(i,len); self[i], self[r] = self[r], self[i]; end return self; @@ -91,18 +111,32 @@ function array_methods:append(array) return self; end -array_methods.push = table.insert; -array_methods.pop = table.remove; -array_methods.concat = table.concat; -array_methods.length = function (t) return #t; end +function array_methods:push(x) + t_insert(self, x); + return self; +end + +function array_methods:pop(x) + local v = self[x]; + t_remove(self, x); + return v; +end + +function array_methods:concat(sep) + return t_concat(array.map(self, tostring), sep); +end + +function array_methods:length() + return #self; +end --- These methods always create a new array function array.collect(f, s, var) - local t, var = {}; + local t = {}; while true do var = f(s, var); if var == nil then break; end - table.insert(t, var); + t_insert(t, var); end return setmetatable(t, array_mt); end diff --git a/util/async.lua b/util/async.lua new file mode 100644 index 00000000..968ec804 --- /dev/null +++ b/util/async.lua @@ -0,0 +1,158 @@ +local log = require "util.logger".init("util.async"); + +local function runner_continue(thread) + -- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure) + if coroutine.status(thread) ~= "suspended" then -- This should suffice + return false; + end + local ok, state, runner = coroutine.resume(thread); + if not ok then + local level = 0; + while debug.getinfo(thread, level, "") do level = level + 1; end + ok, runner = debug.getlocal(thread, level-1, 1); + local error_handler = runner.watchers.error; + if error_handler then error_handler(runner, debug.traceback(thread, state)); end + elseif state == "ready" then + -- If state is 'ready', it is our responsibility to update runner.state from 'waiting'. + -- We also have to :run(), because the queue might have further items that will not be + -- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer). + runner.state = "ready"; + runner:run(); + end + return true; +end + +local function waiter(num) + local thread = coroutine.running(); + if not thread then + error("Not running in an async context, see http://prosody.im/doc/developers/async"); + end + num = num or 1; + local waiting; + return function () + if num == 0 then return; end -- already done + waiting = true; + coroutine.yield("wait"); + end, function () + num = num - 1; + if num == 0 and waiting then + runner_continue(thread); + elseif num < 0 then + error("done() called too many times"); + end + end; +end + +local function guarder() + local guards = {}; + return function (id, func) + local thread = coroutine.running(); + if not thread then + error("Not running in an async context, see http://prosody.im/doc/developers/async"); + end + local guard = guards[id]; + if not guard then + guard = {}; + guards[id] = guard; + log("debug", "New guard!"); + else + table.insert(guard, thread); + log("debug", "Guarded. %d threads waiting.", #guard) + coroutine.yield("wait"); + end + local function exit() + local next_waiting = table.remove(guard, 1); + if next_waiting then + log("debug", "guard: Executing next waiting thread (%d left)", #guard) + runner_continue(next_waiting); + else + log("debug", "Guard off duty.") + guards[id] = nil; + end + end + if func then + func(); + exit(); + return; + end + return exit; + end; +end + +local runner_mt = {}; +runner_mt.__index = runner_mt; + +local function runner_create_thread(func, self) + local thread = coroutine.create(function (self) + while true do + func(coroutine.yield("ready", self)); + end + end); + assert(coroutine.resume(thread, self)); -- Start it up, it will return instantly to wait for the first input + return thread; +end + +local empty_watchers = {}; +local function runner(func, watchers, data) + return setmetatable({ func = func, thread = false, state = "ready", notified_state = "ready", + queue = {}, watchers = watchers or empty_watchers, data = data } + , runner_mt); +end + +function runner_mt:run(input) + if input ~= nil then + table.insert(self.queue, input); + end + if self.state ~= "ready" then + return true, self.state, #self.queue; + end + + local q, thread = self.queue, self.thread; + if not thread or coroutine.status(thread) == "dead" then + thread = runner_create_thread(self.func, self); + self.thread = thread; + end + + local n, state, err = #q, self.state, nil; + self.state = "running"; + while n > 0 and state == "ready" do + local consumed; + for i = 1,n do + local input = q[i]; + local ok, new_state = coroutine.resume(thread, input); + if not ok then + consumed, state, err = i, "ready", debug.traceback(thread, new_state); + self.thread = nil; + break; + elseif new_state == "wait" then + consumed, state = i, "waiting"; + break; + end + end + if not consumed then consumed = n; end + if q[n+1] ~= nil then + n = #q; + end + for i = 1, n do + q[i] = q[consumed+i]; + end + n = #q; + end + self.state = state; + if err or state ~= self.notified_state then + if err then + state = "error" + else + self.notified_state = state; + end + local handler = self.watchers[state]; + if handler then handler(self, err); end + end + return true, state, n; +end + +function runner_mt:enqueue(input) + table.insert(self.queue, input); +end + +return { waiter = waiter, guarder = guarder, runner = runner }; diff --git a/util/broadcast.lua b/util/broadcast.lua deleted file mode 100644 index be17461d..00000000 --- a/util/broadcast.lua +++ /dev/null @@ -1,68 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - -local ipairs, pairs, setmetatable, type = - ipairs, pairs, setmetatable, type; - -module "pubsub" - -local pubsub_node_mt = { __index = _M }; - -function new_node(name) - return setmetatable({ name = name, subscribers = {} }, pubsub_node_mt); -end - -function set_subscribers(node, subscribers_list, list_type) - local subscribers = node.subscribers; - - if list_type == "array" then - for _, jid in ipairs(subscribers_list) do - if not subscribers[jid] then - node:add_subscriber(jid); - end - end - elseif (not list_type) or list_type == "set" then - for jid in pairs(subscribers_list) do - if type(jid) == "string" then - node:add_subscriber(jid); - end - end - end -end - -function get_subscribers(node) - return node.subscribers; -end - -function publish(node, item, dispatcher, data) - local subscribers = node.subscribers; - for i = 1,#subscribers do - item.attr.to = subscribers[i]; - dispatcher(data, item); - end -end - -function add_subscriber(node, jid) - local subscribers = node.subscribers; - if not subscribers[jid] then - local space = #subscribers; - subscribers[space] = jid; - subscribers[jid] = space; - end -end - -function remove_subscriber(node, jid) - local subscribers = node.subscribers; - if subscribers[jid] then - subscribers[subscribers[jid]] = nil; - subscribers[jid] = nil; - end -end - -return _M; diff --git a/util/caps.lua b/util/caps.lua index a61e7403..4723b912 100644 --- a/util/caps.lua +++ b/util/caps.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/dataforms.lua b/util/dataforms.lua index ae745e03..b38d0e27 100644 --- a/util/dataforms.lua +++ b/util/dataforms.lua @@ -1,16 +1,17 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- local setmetatable = setmetatable; local pairs, ipairs = pairs, ipairs; -local tostring, type = tostring, type; +local tostring, type, next = tostring, type, next; local t_concat = table.concat; local st = require "util.stanza"; +local jid_prep = require "util.jid".prep; module "dataforms" @@ -37,7 +38,7 @@ function form_t.form(layout, data, formtype) form:tag("field", { type = field_type, var = field.name, label = field.label }); local value = (data and data[field.name]) or field.value; - + if value then -- Add value, depending on type if field_type == "hidden" then @@ -52,7 +53,7 @@ function form_t.form(layout, data, formtype) elseif field_type == "boolean" then form:tag("value"):text((value and "1") or "0"):up(); elseif field_type == "fixed" then - + form:tag("value"):text(value):up(); elseif field_type == "jid-multi" then for _, jid in ipairs(value) do form:tag("value"):text(jid):up(); @@ -92,11 +93,11 @@ function form_t.form(layout, data, formtype) end end end - + if field.required then form:tag("required"):up(); end - + -- Jump back up to list of fields form:up(); end @@ -107,30 +108,41 @@ local field_readers = {}; function form_t.data(layout, stanza) local data = {}; - - for field_tag in stanza:childtags() do - local field_type; - for n, field in ipairs(layout) do + local errors = {}; + + for _, field in ipairs(layout) do + local tag; + for field_tag in stanza:childtags() do if field.name == field_tag.attr.var then - field_type = field.type; + tag = field_tag; break; end end - - local reader = field_readers[field_type]; - if reader then - data[field_tag.attr.var] = reader(field_tag); + + if not tag then + if field.required then + errors[field.name] = "Required value missing"; + end + else + local reader = field_readers[field.type]; + if reader then + data[field.name], errors[field.name] = reader(tag, field.required); + end end - + end + if next(errors) then + return data, errors; end return data; end field_readers["text-single"] = - function (field_tag) - local value = field_tag:child_with_name("value"); - if value then - return value[1]; + function (field_tag, required) + local data = field_tag:get_child_text("value"); + if data and #data > 0 then + return data + elseif required then + return nil, "Required value missing"; end end @@ -138,64 +150,85 @@ field_readers["text-private"] = field_readers["text-single"]; field_readers["jid-single"] = - field_readers["text-single"]; + function (field_tag, required) + local raw_data = field_tag:get_child_text("value") + local data = jid_prep(raw_data); + if data and #data > 0 then + return data + elseif raw_data then + return nil, "Invalid JID: " .. raw_data; + elseif required then + return nil, "Required value missing"; + end + end field_readers["jid-multi"] = - function (field_tag) + function (field_tag, required) local result = {}; - for value_tag in field_tag:childtags() do - if value_tag.name == "value" then - result[#result+1] = value_tag[1]; + local err = {}; + for value_tag in field_tag:childtags("value") do + local raw_value = value_tag:get_text(); + local value = jid_prep(raw_value); + result[#result+1] = value; + if raw_value and not value then + err[#err+1] = ("Invalid JID: " .. raw_value); end end - return result; + if #result > 0 then + return result, (#err > 0 and t_concat(err, "\n") or nil); + elseif required then + return nil, "Required value missing"; + end end -field_readers["text-multi"] = - function (field_tag) +field_readers["list-multi"] = + function (field_tag, required) local result = {}; - for value_tag in field_tag:childtags() do - if value_tag.name == "value" then - result[#result+1] = value_tag[1]; - end + for value in field_tag:childtags("value") do + result[#result+1] = value:get_text(); + end + if #result > 0 then + return result; + elseif required then + return nil, "Required value missing"; + end + end + +field_readers["text-multi"] = + function (field_tag, required) + local data, err = field_readers["list-multi"](field_tag, required); + if data then + data = t_concat(data, "\n"); end - return t_concat(result, "\n"); + return data, err; end field_readers["list-single"] = field_readers["text-single"]; -field_readers["list-multi"] = - function (field_tag) - local result = {}; - for value_tag in field_tag:childtags() do - if value_tag.name == "value" then - result[#result+1] = value_tag[1]; - end - end - return result; - end +local boolean_values = { + ["1"] = true, ["true"] = true, + ["0"] = false, ["false"] = false, +}; field_readers["boolean"] = - function (field_tag) - local value = field_tag:child_with_name("value"); - if value then - if value[1] == "1" or value[1] == "true" then - return true; - else - return false; - end + function (field_tag, required) + local raw_value = field_tag:get_child_text("value"); + local value = boolean_values[raw_value ~= nil and raw_value]; + if value ~= nil then + return value; + elseif raw_value then + return nil, "Invalid boolean representation"; + elseif required then + return nil, "Required value missing"; end end field_readers["hidden"] = function (field_tag) - local value = field_tag:child_with_name("value"); - if value then - return value[1]; - end + return field_tag:get_child_text("value"); end - + return _M; diff --git a/util/datamanager.lua b/util/datamanager.lua index fbdfb581..4a4d62b3 100644 --- a/util/datamanager.lua +++ b/util/datamanager.lua @@ -1,34 +1,47 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- local format = string.format; -local setmetatable, type = setmetatable, type; -local pairs, ipairs = pairs, ipairs; +local setmetatable = setmetatable; +local ipairs = ipairs; local char = string.char; -local loadfile, setfenv, pcall = loadfile, setfenv, pcall; +local pcall = pcall; local log = require "util.logger".init("datamanager"); local io_open = io.open; local os_remove = os.remove; -local tostring, tonumber = tostring, tonumber; -local error = error; +local os_rename = os.rename; +local tonumber = tonumber; local next = next; local t_insert = table.insert; -local append = require "util.serialization".append; -local path_separator = "/"; if os.getenv("WINDIR") then path_separator = "\\" end +local t_concat = table.concat; +local envloadfile = require"util.envload".envloadfile; +local serialize = require "util.serialization".serialize; +local path_separator = assert ( package.config:match ( "^([^\n]+)" ) , "package.config not in standard form" ) -- Extract directory seperator from package.config (an undocumented string that comes with lua) local lfs = require "lfs"; -local raw_mkdir; +local prosody = prosody; -if prosody.platform == "posix" then - raw_mkdir = require "util.pposix".mkdir; -- Doesn't trample on umask -else - raw_mkdir = lfs.mkdir; -end +local raw_mkdir = lfs.mkdir; +local function fallocate(f, offset, len) + -- This assumes that current position == offset + local fake_data = (" "):rep(len); + local ok, msg = f:write(fake_data); + if not ok then + return ok, msg; + end + f:seek("set", offset); + return true; +end; +pcall(function() + local pposix = require "util.pposix"; + raw_mkdir = pposix.mkdir or raw_mkdir; -- Doesn't trample on umask + fallocate = pposix.fallocate or fallocate; +end); module "datamanager" @@ -56,7 +69,7 @@ local function mkdir(path) return path; end -local data_path = "data"; +local data_path = (prosody and prosody.paths and prosody.paths.data) or "."; local callbacks = {}; ------- API ------------- @@ -71,7 +84,7 @@ local function callback(username, host, datastore, data) username, host, datastore, data = f(username, host, datastore, data); if username == false then break; end end - + return username, host, datastore, data; end function add_callback(func) @@ -100,37 +113,68 @@ function getpath(username, host, datastore, ext, create) if username then if create then mkdir(mkdir(mkdir(data_path).."/"..host).."/"..datastore); end return format("%s/%s/%s/%s.%s", data_path, host, datastore, username, ext); - elseif host then + else if create then mkdir(mkdir(data_path).."/"..host); end return format("%s/%s/%s.%s", data_path, host, datastore, ext); - else - if create then mkdir(data_path); end - return format("%s/%s.%s", data_path, datastore, ext); end end function load(username, host, datastore) - local data, ret = loadfile(getpath(username, host, datastore)); + local data, ret = envloadfile(getpath(username, host, datastore), {}); if not data then local mode = lfs.attributes(getpath(username, host, datastore), "mode"); if not mode then - log("debug", "Failed to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil")); + log("debug", "Assuming empty %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil"); return nil; else -- file exists, but can't be read -- TODO more detailed error checking and logging? - log("error", "Failed to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil")); + log("error", "Failed to load %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil"); return nil, "Error reading storage"; end end - setfenv(data, {}); + local success, ret = pcall(data); if not success then - log("error", "Unable to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil")); + log("error", "Unable to load %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil"); return nil, "Error reading storage"; end return ret; end +local function atomic_store(filename, data) + local scratch = filename.."~"; + local f, ok, msg; + repeat + f, msg = io_open(scratch, "w"); + if not f then break end + + ok, msg = f:write(data); + if not ok then break end + + ok, msg = f:close(); + if not ok then break end + + return os_rename(scratch, filename); + until false; + + -- Cleanup + if f then f:close(); end + os_remove(scratch); + return nil, msg; +end + +if prosody.platform ~= "posix" then + -- os.rename does not overwrite existing files on Windows + -- TODO We could use Transactional NTFS on Vista and above + function atomic_store(filename, data) + local f, err = io_open(filename, "w"); + if not f then return f, err; end + local ok, msg = f:write(data); + if not ok then f:close(); return ok, msg; end + return f:close(); + end +end + function store(username, host, datastore, data) if not data then data = {}; @@ -142,20 +186,26 @@ function store(username, host, datastore, data) end -- save the datastore - local f, msg = io_open(getpath(username, host, datastore, nil, true), "w+"); - if not f then - log("error", "Unable to write to "..datastore.." storage ('"..msg.."') for user: "..(username or "nil").."@"..(host or "nil")); - return nil, "Error saving to storage"; - end - f:write("return "); - append(f, data); - f:close(); - if next(data) == nil then -- try to delete empty datastore - log("debug", "Removing empty %s datastore for user %s@%s", datastore, username or "nil", host or "nil"); - os_remove(getpath(username, host, datastore)); - end - -- we write data even when we are deleting because lua doesn't have a - -- platform independent way of checking for non-exisitng files + local d = "return " .. serialize(data) .. ";\n"; + local mkdir_cache_cleared; + repeat + local ok, msg = atomic_store(getpath(username, host, datastore, nil, true), d); + if not ok then + if not mkdir_cache_cleared then -- We may need to recreate a removed directory + _mkdir = {}; + mkdir_cache_cleared = true; + else + log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil"); + return nil, "Error saving to storage"; + end + end + if next(data) == nil then -- try to delete empty datastore + log("debug", "Removing empty %s datastore for user %s@%s", datastore, username or "nil", host or "nil"); + os_remove(getpath(username, host, datastore)); + end + -- we write data even when we are deleting because lua doesn't have a + -- platform independent way of checking for non-exisitng files + until ok; return true; end @@ -163,14 +213,24 @@ function list_append(username, host, datastore, data) if not data then return; end if callback(username, host, datastore) == false then return true; end -- save the datastore - local f, msg = io_open(getpath(username, host, datastore, "list", true), "a+"); + local f, msg = io_open(getpath(username, host, datastore, "list", true), "r+"); + if not f then + f, msg = io_open(getpath(username, host, datastore, "list", true), "w"); + end if not f then - log("error", "Unable to write to "..datastore.." storage ('"..msg.."') for user: "..(username or "nil").."@"..(host or "nil")); + log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil"); return; end - f:write("item("); - append(f, data); - f:write(");\n"); + local data = "item(" .. serialize(data) .. ");\n"; + local pos = f:seek("end"); + local ok, msg = fallocate(f, pos, #data); + f:seek("set", pos); + if ok then + f:write(data); + else + log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil"); + return ok, msg; + end f:close(); return true; end @@ -181,17 +241,15 @@ function list_store(username, host, datastore, data) end if callback(username, host, datastore) == false then return true; end -- save the datastore - local f, msg = io_open(getpath(username, host, datastore, "list", true), "w+"); - if not f then - log("error", "Unable to write to "..datastore.." storage ('"..msg.."') for user: "..(username or "nil").."@"..(host or "nil")); - return; + local d = {}; + for _, item in ipairs(data) do + d[#d+1] = "item(" .. serialize(item) .. ");\n"; end - for _, d in ipairs(data) do - f:write("item("); - append(f, d); - f:write(");\n"); + local ok, msg = atomic_store(getpath(username, host, datastore, "list", true), t_concat(d)); + if not ok then + log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil"); + return; end - f:close(); if next(data) == nil then -- try to delete empty datastore log("debug", "Removing empty %s datastore for user %s@%s", datastore, username or "nil", host or "nil"); os_remove(getpath(username, host, datastore, "list")); @@ -202,26 +260,108 @@ function list_store(username, host, datastore, data) end function list_load(username, host, datastore) - local data, ret = loadfile(getpath(username, host, datastore, "list")); + local items = {}; + local data, ret = envloadfile(getpath(username, host, datastore, "list"), {item = function(i) t_insert(items, i); end}); if not data then local mode = lfs.attributes(getpath(username, host, datastore, "list"), "mode"); if not mode then - log("debug", "Failed to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil")); + log("debug", "Assuming empty %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil"); return nil; else -- file exists, but can't be read -- TODO more detailed error checking and logging? - log("error", "Failed to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil")); + log("error", "Failed to load %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil"); return nil, "Error reading storage"; end end - local items = {}; - setfenv(data, {item = function(i) t_insert(items, i); end}); + local success, ret = pcall(data); if not success then - log("error", "Unable to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil")); + log("error", "Unable to load %s storage ('%s') for user: %s@%s", datastore, ret, username or "nil", host or "nil"); return nil, "Error reading storage"; end return items; end +local type_map = { + keyval = "dat"; + list = "list"; +} + +function users(host, store, typ) + typ = type_map[typ or "keyval"]; + local store_dir = format("%s/%s/%s", data_path, encode(host), store); + + local mode, err = lfs.attributes(store_dir, "mode"); + if not mode then + return function() log("debug", err or (store_dir .. " does not exist")) end + end + local next, state = lfs.dir(store_dir); + return function(state) + for node in next, state do + local file, ext = node:match("^(.*)%.([dalist]+)$"); + if file and ext == typ then + return decode(file); + end + end + end, state; +end + +function stores(username, host, typ) + typ = type_map[typ or "keyval"]; + local store_dir = format("%s/%s/", data_path, encode(host)); + + local mode, err = lfs.attributes(store_dir, "mode"); + if not mode then + return function() log("debug", err or (store_dir .. " does not exist")) end + end + local next, state = lfs.dir(store_dir); + return function(state) + for node in next, state do + if not node:match"^%." then + if username == true then + if lfs.attributes(store_dir..node, "mode") == "directory" then + return decode(node); + end + elseif username then + local store = decode(node) + if lfs.attributes(getpath(username, host, store, typ), "mode") then + return store; + end + elseif lfs.attributes(node, "mode") == "file" then + local file, ext = node:match("^(.*)%.([dalist]+)$"); + if ext == typ then + return decode(file) + end + end + end + end + end, state; +end + +local function do_remove(path) + local ok, err = os_remove(path); + if not ok and lfs.attributes(path, "mode") then + return ok, err; + end + return true +end + +function purge(username, host) + local host_dir = format("%s/%s/", data_path, encode(host)); + local errs = {}; + for file in lfs.dir(host_dir) do + if lfs.attributes(host_dir..file, "mode") == "directory" then + local store = decode(file); + local ok, err = do_remove(getpath(username, host, store)); + if not ok then errs[#errs+1] = err; end + + local ok, err = do_remove(getpath(username, host, store, "list")); + if not ok then errs[#errs+1] = err; end + end + end + return #errs == 0, t_concat(errs, ", "); +end + +_M.path_decode = decode; +_M.path_encode = encode; return _M; diff --git a/util/datetime.lua b/util/datetime.lua index 301a49a5..dd596527 100644 --- a/util/datetime.lua +++ b/util/datetime.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -36,7 +36,7 @@ end function parse(s) if s then local year, month, day, hour, min, sec, tzd; - year, month, day, hour, min, sec, tzd = s:match("^(%d%d%d%d)-?(%d%d)-?(%d%d)T(%d%d):(%d%d):(%d%d)%.?%d*([Z+%-].*)$"); + year, month, day, hour, min, sec, tzd = s:match("^(%d%d%d%d)%-?(%d%d)%-?(%d%d)T(%d%d):(%d%d):(%d%d)%.?%d*([Z+%-]?.*)$"); if year then local time_offset = os_difftime(os_time(os_date("*t")), os_time(os_date("!*t"))); -- to deal with local timezone local tzd_offset = 0; @@ -49,7 +49,7 @@ function parse(s) if sign == "-" then tzd_offset = -tzd_offset; end end sec = (sec + time_offset) - tzd_offset; - return os_time({year=year, month=month, day=day, hour=hour, min=min, sec=sec}); + return os_time({year=year, month=month, day=day, hour=hour, min=min, sec=sec, isdst=false}); end end end diff --git a/util/debug.lua b/util/debug.lua new file mode 100644 index 00000000..91f691e1 --- /dev/null +++ b/util/debug.lua @@ -0,0 +1,199 @@ +-- Variables ending with these names will not +-- have their values printed ('password' includes +-- 'new_password', etc.) +local censored_names = { + password = true; + passwd = true; + pass = true; + pwd = true; +}; +local optimal_line_length = 65; + +local termcolours = require "util.termcolours"; +local getstring = termcolours.getstring; +local styles; +do + _ = termcolours.getstyle; + styles = { + boundary_padding = _("bright"); + filename = _("bright", "blue"); + level_num = _("green"); + funcname = _("yellow"); + location = _("yellow"); + }; +end +module("debugx", package.seeall); + +function get_locals_table(thread, level) + local locals = {}; + for local_num = 1, math.huge do + local name, value; + if thread then + name, value = debug.getlocal(thread, level, local_num); + else + name, value = debug.getlocal(level+1, local_num); + end + if not name then break; end + table.insert(locals, { name = name, value = value }); + end + return locals; +end + +function get_upvalues_table(func) + local upvalues = {}; + if func then + for upvalue_num = 1, math.huge do + local name, value = debug.getupvalue(func, upvalue_num); + if not name then break; end + table.insert(upvalues, { name = name, value = value }); + end + end + return upvalues; +end + +function string_from_var_table(var_table, max_line_len, indent_str) + local var_string = {}; + local col_pos = 0; + max_line_len = max_line_len or math.huge; + indent_str = "\n"..(indent_str or ""); + for _, var in ipairs(var_table) do + local name, value = var.name, var.value; + if name:sub(1,1) ~= "(" then + if type(value) == "string" then + if censored_names[name:match("%a+$")] then + value = "<hidden>"; + else + value = ("%q"):format(value); + end + else + value = tostring(value); + end + if #value > max_line_len then + value = value:sub(1, max_line_len-3).."…"; + end + local str = ("%s = %s"):format(name, tostring(value)); + col_pos = col_pos + #str; + if col_pos > max_line_len then + table.insert(var_string, indent_str); + col_pos = 0; + end + table.insert(var_string, str); + end + end + if #var_string == 0 then + return nil; + else + return "{ "..table.concat(var_string, ", "):gsub(indent_str..", ", indent_str).." }"; + end +end + +function get_traceback_table(thread, start_level) + local levels = {}; + for level = start_level, math.huge do + local info; + if thread then + info = debug.getinfo(thread, level); + else + info = debug.getinfo(level+1); + end + if not info then break; end + + levels[(level-start_level)+1] = { + level = level; + info = info; + locals = get_locals_table(thread, level+(thread and 0 or 1)); + upvalues = get_upvalues_table(info.func); + }; + end + return levels; +end + +function traceback(...) + local ok, ret = pcall(_traceback, ...); + if not ok then + return "Error in error handling: "..ret; + end + return ret; +end + +local function build_source_boundary_marker(last_source_desc) + local padding = string.rep("-", math.floor(((optimal_line_length - 6) - #last_source_desc)/2)); + return getstring(styles.boundary_padding, "v"..padding).." "..getstring(styles.filename, last_source_desc).." "..getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v ")); +end + +function _traceback(thread, message, level) + + -- Lua manual says: debug.traceback ([thread,] [message [, level]]) + -- I fathom this to mean one of: + -- () + -- (thread) + -- (message, level) + -- (thread, message, level) + + if thread == nil then -- Defaults + thread, message, level = coroutine.running(), message, level; + elseif type(thread) == "string" then + thread, message, level = coroutine.running(), thread, message; + elseif type(thread) ~= "thread" then + return nil; -- debug.traceback() does this + end + + level = level or 0; + + message = message and (message.."\n") or ""; + + -- +3 counts for this function, and the pcall() and wrapper above us, the +1... I don't know. + local levels = get_traceback_table(thread, level+(thread == nil and 4 or 0)); + + local last_source_desc; + + local lines = {}; + for nlevel, level in ipairs(levels) do + local info = level.info; + local line = "..."; + local func_type = info.namewhat.." "; + local source_desc = (info.short_src == "[C]" and "C code") or info.short_src or "Unknown"; + if func_type == " " then func_type = ""; end; + if info.short_src == "[C]" then + line = "[ C ] "..func_type.."C function "..getstring(styles.location, (info.name and ("%q"):format(info.name) or "(unknown name)")); + elseif info.what == "main" then + line = "[Lua] "..getstring(styles.location, info.short_src.." line "..info.currentline); + else + local name = info.name or " "; + if name ~= " " then + name = ("%q"):format(name); + end + if func_type == "global " or func_type == "local " then + func_type = func_type.."function "; + end + line = "[Lua] "..getstring(styles.location, info.short_src.." line "..info.currentline).." in "..func_type..getstring(styles.funcname, name).." (defined on line "..info.linedefined..")"; + end + if source_desc ~= last_source_desc then -- Venturing into a new source, add marker for previous + last_source_desc = source_desc; + table.insert(lines, "\t "..build_source_boundary_marker(last_source_desc)); + end + nlevel = nlevel-1; + table.insert(lines, "\t"..(nlevel==0 and ">" or " ")..getstring(styles.level_num, "("..nlevel..") ")..line); + local npadding = (" "):rep(#tostring(nlevel)); + if level.locals then + local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding); + if locals_str then + table.insert(lines, "\t "..npadding.."Locals: "..locals_str); + end + end + local upvalues_str = string_from_var_table(level.upvalues, optimal_line_length, "\t "..npadding); + if upvalues_str then + table.insert(lines, "\t "..npadding.."Upvals: "..upvalues_str); + end + end + +-- table.insert(lines, "\t "..build_source_boundary_marker(last_source_desc)); + + return message.."stack traceback:\n"..table.concat(lines, "\n"); +end + +function use() + debug.traceback = traceback; +end + +return _M; diff --git a/util/dependencies.lua b/util/dependencies.lua index 9371521c..109a3332 100644 --- a/util/dependencies.lua +++ b/util/dependencies.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -35,11 +35,24 @@ function missingdep(name, sources, msg) print(""); end +-- COMPAT w/pre-0.8 Debian: The Debian config file used to use +-- util.ztact, which has been removed from Prosody in 0.8. This +-- is to log an error for people who still use it, so they can +-- update their configs. +package.preload["util.ztact"] = function () + if not package.loaded["core.loggingmanager"] then + error("util.ztact has been removed from Prosody and you need to fix your config " + .."file. More information can be found at http://prosody.im/doc/packagers#ztact", 0); + else + error("module 'util.ztact' has been deprecated in Prosody 0.8."); + end +end; + function check_dependencies() local fatal; - + local lxp = softreq "lxp" - + if not lxp then missingdep("luaexpat", { ["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-expat0"; @@ -48,9 +61,9 @@ function check_dependencies() }); fatal = true; end - + local socket = softreq "socket" - + if not socket then missingdep("luasocket", { ["Debian/Ubuntu"] = "sudo apt-get install liblua5.1-socket2"; @@ -59,7 +72,7 @@ function check_dependencies() }); fatal = true; end - + local lfs, err = softreq "lfs" if not lfs then missingdep("luafilesystem", { @@ -69,9 +82,9 @@ function check_dependencies() }); fatal = true; end - + local ssl = softreq "ssl" - + if not ssl then missingdep("LuaSec", { ["Debian/Ubuntu"] = "http://prosody.im/download/start#debian_and_ubuntu"; @@ -79,7 +92,7 @@ function check_dependencies() ["Source"] = "http://www.inf.puc-rio.br/~brunoos/luasec/"; }, "SSL/TLS support will not be available"); end - + local encodings, err = softreq "util.encodings" if not encodings then if err:match("not found") then @@ -123,6 +136,14 @@ function log_warnings() log("error", "This version of LuaSec contains a known bug that causes disconnects, see http://prosody.im/doc/depends"); end end + if lxp then + if not pcall(lxp.new, { StartDoctypeDecl = false }) then + log("error", "The version of LuaExpat on your system leaves Prosody " + .."vulnerable to denial-of-service attacks. You should upgrade to " + .."LuaExpat 1.1.1 or higher as soon as possible. See " + .."http://prosody.im/doc/depends#luaexpat for more information."); + end + end end return _M; diff --git a/util/envload.lua b/util/envload.lua new file mode 100644 index 00000000..53e28348 --- /dev/null +++ b/util/envload.lua @@ -0,0 +1,34 @@ +-- Prosody IM +-- Copyright (C) 2008-2011 Florian Zeitz +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local load, loadstring, loadfile, setfenv = load, loadstring, loadfile, setfenv; +local envload; +local envloadfile; + +if setfenv then + function envload(code, source, env) + local f, err = loadstring(code, source); + if f and env then setfenv(f, env); end + return f, err; + end + + function envloadfile(file, env) + local f, err = loadfile(file); + if f and env then setfenv(f, env); end + return f, err; + end +else + function envload(code, source, env) + return load(code, source, nil, env); + end + + function envloadfile(file, env) + return loadfile(file, nil, env); + end +end + +return { envload = envload, envloadfile = envloadfile }; diff --git a/util/events.lua b/util/events.lua index 412acccd..40ca3913 100644 --- a/util/events.lua +++ b/util/events.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -60,11 +60,11 @@ function new() remove_handler(event, handler); end end; - local function fire_event(event, ...) - local h = handlers[event]; + local function fire_event(event_name, event_data) + local h = handlers[event_name]; if h then for i=1,#h do - local ret = h[i](...); + local ret = h[i](event_data); if ret ~= nil then return ret; end end end diff --git a/util/filters.lua b/util/filters.lua index d143666b..8a470011 100644 --- a/util/filters.lua +++ b/util/filters.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -16,7 +16,7 @@ function initialize(session) if not session.filters then local filters = {}; session.filters = filters; - + function session.filter(type, data) local filter_list = filters[type]; if filter_list then @@ -28,11 +28,11 @@ function initialize(session) return data; end end - + for i=1,#new_filter_hooks do new_filter_hooks[i](session); end - + return session.filter; end @@ -40,20 +40,20 @@ function add_filter(session, type, callback, priority) if not session.filters then initialize(session); end - + local filter_list = session.filters[type]; if not filter_list then filter_list = {}; session.filters[type] = filter_list; end - + priority = priority or 0; - + local i = 0; repeat i = i + 1; until not filter_list[i] or filter_list[filter_list[i]] >= priority; - + t_insert(filter_list, i, callback); filter_list[callback] = priority; end diff --git a/util/helpers.lua b/util/helpers.lua index 11356176..437a920c 100644 --- a/util/helpers.lua +++ b/util/helpers.lua @@ -1,17 +1,27 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- +local debug = require "util.debug"; + module("helpers", package.seeall); -- Helper functions for debugging local log = require "util.logger".init("util.debug"); +function log_host_events(host) + return log_events(prosody.hosts[host].events, host); +end + +function revert_log_host_events(host) + return revert_log_events(prosody.hosts[host].events); +end + function log_events(events, name, logger) local f = events.fire_event; if not f then @@ -28,7 +38,36 @@ function log_events(events, name, logger) end function revert_log_events(events) - events.fire_event, events[events.fire_event] = events[events.fire_event], nil; -- :) + events.fire_event, events[events.fire_event] = events[events.fire_event], nil; -- :)) +end + +function show_events(events, specific_event) + local event_handlers = events._handlers; + local events_array = {}; + local event_handler_arrays = {}; + for event in pairs(events._event_map) do + local handlers = event_handlers[event]; + if handlers and (event == specific_event or not specific_event) then + table.insert(events_array, event); + local handler_strings = {}; + for i, handler in ipairs(handlers) do + local upvals = debug.string_from_var_table(debug.get_upvalues_table(handler)); + handler_strings[i] = " "..i..": "..tostring(handler)..(upvals and ("\n "..upvals) or ""); + end + event_handler_arrays[event] = handler_strings; + end + end + table.sort(events_array); + local i = 1; + while i <= #events_array do + local handlers = event_handler_arrays[events_array[i]]; + for j=#handlers, 1, -1 do + table.insert(events_array, i+1, handlers[j]); + end + if i > 1 then events_array[i] = "\n"..events_array[i]; end + i = i + #handlers + 1 + end + return table.concat(events_array, "\n"); end function get_upvalue(f, get_name) diff --git a/util/hmac.lua b/util/hmac.lua index 6df6986e..2c4cc6ef 100644 --- a/util/hmac.lua +++ b/util/hmac.lua @@ -1,69 +1,15 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- -local hashes = require "util.hashes" - -local s_char = string.char; -local s_gsub = string.gsub; -local s_rep = string.rep; - -module "hmac" - -local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;}; -local function xor(x, y) - local lowx, lowy = x % 16, y % 16; - local hix, hiy = (x - lowx) / 16, (y - lowy) / 16; - local lowr, hir = xor_map[lowx * 16 + lowy + 1], xor_map[hix * 16 + hiy + 1]; - local r = hir * 16 + lowr; - return r; -end -local opadc, ipadc = s_char(0x5c), s_char(0x36); -local ipad_map = {}; -local opad_map = {}; -for i=0,255 do - ipad_map[s_char(i)] = s_char(xor(0x36, i)); - opad_map[s_char(i)] = s_char(xor(0x5c, i)); -end - ---[[ -key - the key to use in the hash -message - the message to hash -hash - the hash function -blocksize - the blocksize for the hash function in bytes -hex - return raw hash or hexadecimal string ---]] -function hmac(key, message, hash, blocksize, hex) - if #key > blocksize then - key = hash(key) - end +-- COMPAT: Only for external pre-0.9 modules - local padding = blocksize - #key; - local ipad = s_gsub(key, ".", ipad_map)..s_rep(ipadc, padding); - local opad = s_gsub(key, ".", opad_map)..s_rep(opadc, padding); - - return hash(opad..hash(ipad..message), hex) -end - -function md5(key, message, hex) - return hmac(key, message, hashes.md5, 64, hex) -end - -function sha1(key, message, hex) - return hmac(key, message, hashes.sha1, 64, hex) -end - -function sha256(key, message, hex) - return hmac(key, message, hashes.sha256, 64, hex) -end +local hashes = require "util.hashes" -return _M +return { md5 = hashes.hmac_md5, + sha1 = hashes.hmac_sha1, + sha256 = hashes.hmac_sha256 }; diff --git a/util/http.lua b/util/http.lua new file mode 100644 index 00000000..f7259920 --- /dev/null +++ b/util/http.lua @@ -0,0 +1,64 @@ +-- Prosody IM +-- Copyright (C) 2013 Florian Zeitz +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local format, char = string.format, string.char; +local pairs, ipairs, tonumber = pairs, ipairs, tonumber; +local t_insert, t_concat = table.insert, table.concat; + +local function urlencode(s) + return s and (s:gsub("[^a-zA-Z0-9.~_-]", function (c) return format("%%%02x", c:byte()); end)); +end +local function urldecode(s) + return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end)); +end + +local function _formencodepart(s) + return s and (s:gsub("%W", function (c) + if c ~= " " then + return format("%%%02x", c:byte()); + else + return "+"; + end + end)); +end + +local function formencode(form) + local result = {}; + if form[1] then -- Array of ordered { name, value } + for _, field in ipairs(form) do + t_insert(result, _formencodepart(field.name).."=".._formencodepart(field.value)); + end + else -- Unordered map of name -> value + for name, value in pairs(form) do + t_insert(result, _formencodepart(name).."=".._formencodepart(value)); + end + end + return t_concat(result, "&"); +end + +local function formdecode(s) + if not s:match("=") then return urldecode(s); end + local r = {}; + for k, v in s:gmatch("([^=&]*)=([^&]*)") do + k, v = k:gsub("%+", "%%20"), v:gsub("%+", "%%20"); + k, v = urldecode(k), urldecode(v); + t_insert(r, { name = k, value = v }); + r[k] = v; + end + return r; +end + +local function contains_token(field, token) + field = ","..field:gsub("[ \t]", ""):lower()..","; + return field:find(","..token:lower()..",", 1, true) ~= nil; +end + +return { + urlencode = urlencode, urldecode = urldecode; + formencode = formencode, formdecode = formdecode; + contains_token = contains_token; +}; diff --git a/util/httpstream.lua b/util/httpstream.lua deleted file mode 100644 index bdc3fce7..00000000 --- a/util/httpstream.lua +++ /dev/null @@ -1,137 +0,0 @@ - -local coroutine = coroutine; -local tonumber = tonumber; - -local deadroutine = coroutine.create(function() end); -coroutine.resume(deadroutine); - -module("httpstream") - -local function parser(success_cb, parser_type, options_cb) - local data = coroutine.yield(); - local function readline() - local pos = data:find("\r\n", nil, true); - while not pos do - data = data..coroutine.yield(); - pos = data:find("\r\n", nil, true); - end - local r = data:sub(1, pos-1); - data = data:sub(pos+2); - return r; - end - local function readlength(n) - while #data < n do - data = data..coroutine.yield(); - end - local r = data:sub(1, n); - data = data:sub(n + 1); - return r; - end - local function readheaders() - local headers = {}; -- read headers - while true do - local line = readline(); - if line == "" then break; end -- headers done - local key, val = line:match("^([^%s:]+): *(.*)$"); - if not key then coroutine.yield("invalid-header-line"); end -- TODO handle multi-line and invalid headers - key = key:lower(); - headers[key] = headers[key] and headers[key]..","..val or val; - end - return headers; - end - - if not parser_type or parser_type == "server" then - while true do - -- read status line - local status_line = readline(); - local method, path, httpversion = status_line:match("^(%S+)%s+(%S+)%s+HTTP/(%S+)$"); - if not method then coroutine.yield("invalid-status-line"); end - path = path:gsub("^//+", "/"); -- TODO parse url more - local headers = readheaders(); - - -- read body - local len = tonumber(headers["content-length"]); - len = len or 0; -- TODO check for invalid len - local body = readlength(len); - - success_cb({ - method = method; - path = path; - httpversion = httpversion; - headers = headers; - body = body; - }); - end - elseif parser_type == "client" then - while true do - -- read status line - local status_line = readline(); - local httpversion, status_code, reason_phrase = status_line:match("^HTTP/(%S+)%s+(%d%d%d)%s+(.*)$"); - status_code = tonumber(status_code); - if not status_code then coroutine.yield("invalid-status-line"); end - local headers = readheaders(); - - -- read body - local have_body = not - ( (options_cb and options_cb().method == "HEAD") - or (status_code == 204 or status_code == 304 or status_code == 301) - or (status_code >= 100 and status_code < 200) ); - - local body; - if have_body then - local len = tonumber(headers["content-length"]); - if headers["transfer-encoding"] == "chunked" then - body = ""; - while true do - local chunk_size = readline():match("^%x+"); - if not chunk_size then coroutine.yield("invalid-chunk-size"); end - chunk_size = tonumber(chunk_size, 16) - if chunk_size == 0 then break; end - body = body..readlength(chunk_size); - if readline() ~= "" then coroutine.yield("invalid-chunk-ending"); end - end - local trailers = readheaders(); - elseif len then -- TODO check for invalid len - body = readlength(len); - else -- read to end - repeat - local newdata = coroutine.yield(); - data = data..newdata; - until newdata == ""; - body, data = data, ""; - end - end - - success_cb({ - code = status_code; - httpversion = httpversion; - headers = headers; - body = body; - -- COMPAT the properties below are deprecated - responseversion = httpversion; - responseheaders = headers; - }); - end - else coroutine.yield("unknown-parser-type"); end -end - -function new(success_cb, error_cb, parser_type, options_cb) - local co = coroutine.create(parser); - coroutine.resume(co, success_cb, parser_type, options_cb) - return { - feed = function(self, data) - if not data then - if parser_type == "client" then coroutine.resume(co, ""); end - co = deadroutine; - return error_cb(); - end - local success, result = coroutine.resume(co, data); - if result then - co = deadroutine; - return error_cb(result); - end - end; - }; -end - -return _M; diff --git a/util/import.lua b/util/import.lua index 81401e8b..174da0ca 100644 --- a/util/import.lua +++ b/util/import.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/ip.lua b/util/ip.lua new file mode 100644 index 00000000..d0ae07eb --- /dev/null +++ b/util/ip.lua @@ -0,0 +1,244 @@ +-- Prosody IM +-- Copyright (C) 2008-2011 Florian Zeitz +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local ip_methods = {}; +local ip_mt = { __index = function (ip, key) return (ip_methods[key])(ip); end, + __tostring = function (ip) return ip.addr; end, + __eq = function (ipA, ipB) return ipA.addr == ipB.addr; end}; +local hex2bits = { ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111", ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011", ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111" }; + +local function new_ip(ipStr, proto) + if not proto then + local sep = ipStr:match("^%x+(.)"); + if sep == ":" or (not(sep) and ipStr:sub(1,1) == ":") then + proto = "IPv6" + elseif sep == "." then + proto = "IPv4" + end + if not proto then + return nil, "invalid address"; + end + elseif proto ~= "IPv4" and proto ~= "IPv6" then + return nil, "invalid protocol"; + end + if proto == "IPv6" and ipStr:find('.', 1, true) then + local changed; + ipStr, changed = ipStr:gsub(":(%d+)%.(%d+)%.(%d+)%.(%d+)$", function(a,b,c,d) + return (":%04X:%04X"):format(a*256+b,c*256+d); + end); + if changed ~= 1 then return nil, "invalid-address"; end + end + + return setmetatable({ addr = ipStr, proto = proto }, ip_mt); +end + +local function toBits(ip) + local result = ""; + local fields = {}; + if ip.proto == "IPv4" then + ip = ip.toV4mapped; + end + ip = (ip.addr):upper(); + ip:gsub("([^:]*):?", function (c) fields[#fields + 1] = c end); + if not ip:match(":$") then fields[#fields] = nil; end + for i, field in ipairs(fields) do + if field:len() == 0 and i ~= 1 and i ~= #fields then + for i = 1, 16 * (9 - #fields) do + result = result .. "0"; + end + else + for i = 1, 4 - field:len() do + result = result .. "0000"; + end + for i = 1, field:len() do + result = result .. hex2bits[field:sub(i,i)]; + end + end + end + return result; +end + +local function commonPrefixLength(ipA, ipB) + ipA, ipB = toBits(ipA), toBits(ipB); + for i = 1, 128 do + if ipA:sub(i,i) ~= ipB:sub(i,i) then + return i-1; + end + end + return 128; +end + +local function v4scope(ip) + local fields = {}; + ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end); + -- Loopback: + if fields[1] == 127 then + return 0x2; + -- Link-local unicast: + elseif fields[1] == 169 and fields[2] == 254 then + return 0x2; + -- Global unicast: + else + return 0xE; + end +end + +local function v6scope(ip) + -- Loopback: + if ip:match("^[0:]*1$") then + return 0x2; + -- Link-local unicast: + elseif ip:match("^[Ff][Ee][89ABab]") then + return 0x2; + -- Site-local unicast: + elseif ip:match("^[Ff][Ee][CcDdEeFf]") then + return 0x5; + -- Multicast: + elseif ip:match("^[Ff][Ff]") then + return tonumber("0x"..ip:sub(4,4)); + -- Global unicast: + else + return 0xE; + end +end + +local function label(ip) + if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then + return 0; + elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then + return 2; + elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then + return 5; + elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then + return 13; + elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then + return 11; + elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then + return 12; + elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then + return 3; + elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then + return 4; + else + return 1; + end +end + +local function precedence(ip) + if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then + return 50; + elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then + return 30; + elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then + return 5; + elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then + return 3; + elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then + return 1; + elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then + return 1; + elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then + return 1; + elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then + return 35; + else + return 40; + end +end + +local function toV4mapped(ip) + local fields = {}; + local ret = "::ffff:"; + ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end); + ret = ret .. ("%02x"):format(fields[1]); + ret = ret .. ("%02x"):format(fields[2]); + ret = ret .. ":" + ret = ret .. ("%02x"):format(fields[3]); + ret = ret .. ("%02x"):format(fields[4]); + return new_ip(ret, "IPv6"); +end + +function ip_methods:toV4mapped() + if self.proto ~= "IPv4" then return nil, "No IPv4 address" end + local value = toV4mapped(self.addr); + self.toV4mapped = value; + return value; +end + +function ip_methods:label() + local value; + if self.proto == "IPv4" then + value = label(self.toV4mapped); + else + value = label(self); + end + self.label = value; + return value; +end + +function ip_methods:precedence() + local value; + if self.proto == "IPv4" then + value = precedence(self.toV4mapped); + else + value = precedence(self); + end + self.precedence = value; + return value; +end + +function ip_methods:scope() + local value; + if self.proto == "IPv4" then + value = v4scope(self.addr); + else + value = v6scope(self.addr); + end + self.scope = value; + return value; +end + +function ip_methods:private() + local private = self.scope ~= 0xE; + if not private and self.proto == "IPv4" then + local ip = self.addr; + local fields = {}; + ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end); + if fields[1] == 127 or fields[1] == 10 or (fields[1] == 192 and fields[2] == 168) + or (fields[1] == 172 and (fields[2] >= 16 or fields[2] <= 32)) then + private = true; + end + end + self.private = private; + return private; +end + +local function parse_cidr(cidr) + local bits; + local ip_len = cidr:find("/", 1, true); + if ip_len then + bits = tonumber(cidr:sub(ip_len+1, -1)); + cidr = cidr:sub(1, ip_len-1); + end + return new_ip(cidr), bits; +end + +local function match(ipA, ipB, bits) + local common_bits = commonPrefixLength(ipA, ipB); + if not bits then + return ipA == ipB; + end + if bits and ipB.proto == "IPv4" then + common_bits = common_bits - 96; -- v6 mapped addresses always share these bits + end + return common_bits >= bits; +end + +return {new_ip = new_ip, + commonPrefixLength = commonPrefixLength, + parse_cidr = parse_cidr, + match=match}; diff --git a/util/iterators.lua b/util/iterators.lua index dc692d64..aa9c3ec0 100644 --- a/util/iterators.lua +++ b/util/iterators.lua @@ -1,15 +1,21 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- --[[ Iterators ]]-- +local it = {}; + +local t_insert = table.insert; +local select, unpack, next = select, unpack, next; +local function pack(...) return { n = select("#", ...), ... }; end + -- Reverse an iterator -function reverse(f, s, var) +function it.reverse(f, s, var) local results = {}; -- First call the normal iterator @@ -17,9 +23,9 @@ function reverse(f, s, var) local ret = { f(s, var) }; var = ret[1]; if var == nil then break; end - table.insert(results, 1, ret); + t_insert(results, 1, ret); end - + -- Then return our reverse one local i,max = 0, #results; return function (results) @@ -34,12 +40,12 @@ end local function _keys_it(t, key) return (next(t, key)); end -function keys(t) +function it.keys(t) return _keys_it, t; end -- Iterate only over values in a table -function values(t) +function it.values(t) local key, val; return function (t) key, val = next(t, key); @@ -48,38 +54,37 @@ function values(t) end -- Given an iterator, iterate only over unique items -function unique(f, s, var) +function it.unique(f, s, var) local set = {}; - + return function () while true do - local ret = { f(s, var) }; + local ret = pack(f(s, var)); var = ret[1]; if var == nil then break; end if not set[var] then set[var] = true; - return var; + return unpack(ret, 1, ret.n); end end end; end --[[ Return the number of items an iterator returns ]]-- -function count(f, s, var) +function it.count(f, s, var) local x = 0; - + while true do - local ret = { f(s, var) }; - var = ret[1]; + var = f(s, var); if var == nil then break; end x = x + 1; end - + return x; end -- Return the first n items an iterator returns -function head(n, f, s, var) +function it.head(n, f, s, var) local c = 0; return function (s, var) if c >= n then @@ -91,7 +96,7 @@ function head(n, f, s, var) end -- Skip the first n items an iterator returns -function skip(n, f, s, var) +function it.skip(n, f, s, var) for i=1,n do var = f(s, var); end @@ -99,10 +104,10 @@ function skip(n, f, s, var) end -- Return the last n items an iterator returns -function tail(n, f, s, var) +function it.tail(n, f, s, var) local results, count = {}, 0; while true do - local ret = { f(s, var) }; + local ret = pack(f(s, var)); var = ret[1]; if var == nil then break; end results[(count%n)+1] = ret; @@ -115,26 +120,52 @@ function tail(n, f, s, var) return function () pos = pos + 1; if pos > n then return nil; end - return unpack(results[((count-1+pos)%n)+1]); + local ret = results[((count-1+pos)%n)+1]; + return unpack(ret, 1, ret.n); end - --return reverse(head(n, reverse(f, s, var))); + --return reverse(head(n, reverse(f, s, var))); -- ! +end + +function it.filter(filter, f, s, var) + if type(filter) ~= "function" then + local filter_value = filter; + function filter(x) return x ~= filter_value; end + end + return function (s, var) + local ret; + repeat ret = pack(f(s, var)); + var = ret[1]; + until var == nil or filter(unpack(ret, 1, ret.n)); + return unpack(ret, 1, ret.n); + end, s, var; +end + +local function _ripairs_iter(t, key) if key > 1 then return key-1, t[key-1]; end end +function it.ripairs(t) + return _ripairs_iter, t, #t+1; +end + +local function _range_iter(max, curr) if curr < max then return curr + 1; end end +function it.range(x, y) + if not y then x, y = 1, x; end -- Default to 1..x if y not given + return _range_iter, y, x-1; end -- Convert the values returned by an iterator to an array -function it2array(f, s, var) +function it.to_array(f, s, var) local t, var = {}; while true do var = f(s, var); if var == nil then break; end - table.insert(t, var); + t_insert(t, var); end return t; end -- Treat the return of an iterator as key,value pairs, -- and build a table -function it2table(f, s, var) - local t, var = {}; +function it.to_table(f, s, var) + local t, var2 = {}; while true do var, var2 = f(s, var); if var == nil then break; end @@ -142,3 +173,5 @@ function it2table(f, s, var) end return t; end + +return it; diff --git a/util/jid.lua b/util/jid.lua index 069817c6..0d9a864f 100644 --- a/util/jid.lua +++ b/util/jid.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -13,6 +13,16 @@ local nodeprep = require "util.encodings".stringprep.nodeprep; local nameprep = require "util.encodings".stringprep.nameprep; local resourceprep = require "util.encodings".stringprep.resourceprep; +local escapes = { + [" "] = "\\20"; ['"'] = "\\22"; + ["&"] = "\\26"; ["'"] = "\\27"; + ["/"] = "\\2f"; [":"] = "\\3a"; + ["<"] = "\\3c"; [">"] = "\\3e"; + ["@"] = "\\40"; ["\\"] = "\\5c"; +}; +local unescapes = {}; +for k,v in pairs(escapes) do unescapes[v] = k; end + module "jid" local function _split(jid) @@ -91,4 +101,7 @@ function compare(jid, acl) return false end +function escape(s) return s and (s:gsub(".", escapes)); end +function unescape(s) return s and (s:gsub("\\%x%x", unescapes)); end + return _M; diff --git a/util/json.lua b/util/json.lua index 40939bb4..a8a58afc 100644 --- a/util/json.lua +++ b/util/json.lua @@ -1,14 +1,24 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- local type = type; -local t_insert, t_concat, t_remove = table.insert, table.concat, table.remove; +local t_insert, t_concat, t_remove, t_sort = table.insert, table.concat, table.remove, table.sort; local s_char = string.char; local tostring, tonumber = tostring, tonumber; local pairs, ipairs = pairs, ipairs; local next = next; local error = error; -local newproxy, getmetatable = newproxy, getmetatable; +local newproxy, getmetatable, setmetatable = newproxy, getmetatable, setmetatable; local print = print; +local has_array, array = pcall(require, "util.array"); +local array_mt = has_array and getmetatable(array()) or {}; + --module("json") local json = {}; @@ -29,6 +39,19 @@ for i=0,31 do if not escapes[ch] then escapes[ch] = ("\\u%.4X"):format(i); end end +local function codepoint_to_utf8(code) + if code < 0x80 then return s_char(code); end + local bits0_6 = code % 64; + if code < 0x800 then + local bits6_5 = (code - bits0_6) / 64; + return s_char(0x80 + 0x40 + bits6_5, 0x80 + bits0_6); + end + local bits0_12 = code % 4096; + local bits6_6 = (bits0_12 - bits0_6) / 64; + local bits12_4 = (code - bits0_12) / 4096; + return s_char(0x80 + 0x40 + 0x20 + bits12_4, 0x80 + bits6_6, 0x80 + bits0_6); +end + local valid_types = { number = true, string = true, @@ -79,11 +102,25 @@ function tablesave(o, buffer) if next(__hash) ~= nil or next(hash) ~= nil or next(__array) == nil then t_insert(buffer, "{"); local mark = #buffer; - for k,v in pairs(hash) do - stringsave(k, buffer); - t_insert(buffer, ":"); - simplesave(v, buffer); - t_insert(buffer, ","); + if buffer.ordered then + local keys = {}; + for k in pairs(hash) do + t_insert(keys, k); + end + t_sort(keys); + for _,k in ipairs(keys) do + stringsave(k, buffer); + t_insert(buffer, ":"); + simplesave(hash[k], buffer); + t_insert(buffer, ","); + end + else + for k,v in pairs(hash) do + stringsave(k, buffer); + t_insert(buffer, ":"); + simplesave(v, buffer); + t_insert(buffer, ","); + end end if next(__hash) ~= nil then t_insert(buffer, "\"__hash\":["); @@ -116,7 +153,12 @@ function simplesave(o, buffer) elseif t == "string" then stringsave(o, buffer); elseif t == "table" then - tablesave(o, buffer); + local mt = getmetatable(o); + if mt == array_mt then + arraysave(o, buffer); + else + tablesave(o, buffer); + end elseif t == "boolean" then t_insert(buffer, (o and "true" or "false")); else @@ -129,214 +171,191 @@ function json.encode(obj) simplesave(obj, t); return t_concat(t); end +function json.encode_ordered(obj) + local t = { ordered = true }; + simplesave(obj, t); + return t_concat(t); +end +function json.encode_array(obj) + local t = {}; + arraysave(obj, t); + return t_concat(t); +end ----------------------------------- -function json.decode(json) - local pos = 1; - local current = {}; - local stack = {}; - local ch, peek; - local function next() - ch = json:sub(pos, pos); - pos = pos+1; - peek = json:sub(pos, pos); - return ch; - end - - local function skipwhitespace() - while ch and (ch == "\r" or ch == "\n" or ch == "\t" or ch == " ") do - next(); +local function _skip_whitespace(json, index) + return json:find("[^ \t\r\n]", index) or index; -- no need to check \r\n, we converted those to \t +end +local function _fixobject(obj) + local __array = obj.__array; + if __array then + obj.__array = nil; + for i,v in ipairs(__array) do + t_insert(obj, v); end end - local function skiplinecomment() - repeat next(); until not(ch) or ch == "\r" or ch == "\n"; - skipwhitespace(); - end - local function skipstarcomment() - next(); next(); -- skip '/', '*' - while peek and ch ~= "*" and peek ~= "/" do next(); end - if not peek then error("eof in star comment") end - next(); next(); -- skip '*', '/' - skipwhitespace(); - end - local function skipstuff() - while true do - skipwhitespace(); - if ch == "/" and peek == "*" then - skipstarcomment(); - elseif ch == "/" and peek == "*" then - skiplinecomment(); + local __hash = obj.__hash; + if __hash then + obj.__hash = nil; + local k; + for i,v in ipairs(__hash) do + if k ~= nil then + obj[k] = v; k = nil; else - return; + k = v; end end end - - local readvalue; - local function readarray() - local t = {}; - next(); -- skip '[' - skipstuff(); - if ch == "]" then next(); return t; end - t_insert(t, readvalue()); - while true do - skipstuff(); - if ch == "]" then next(); return t; end - if not ch then error("eof while reading array"); - elseif ch == "," then next(); - elseif ch then error("unexpected character in array, comma expected"); end - if not ch then error("eof while reading array"); end - t_insert(t, readvalue()); + return obj; +end +local _readvalue, _readstring; +local function _readobject(json, index) + local o = {}; + while true do + local key, val; + index = _skip_whitespace(json, index + 1); + if json:byte(index) ~= 0x22 then -- "\"" + if json:byte(index) == 0x7d then return o, index + 1; end -- "}" + return nil, "key expected"; end + key, index = _readstring(json, index); + if key == nil then return nil, index; end + index = _skip_whitespace(json, index); + if json:byte(index) ~= 0x3a then return nil, "colon expected"; end -- ":" + val, index = _readvalue(json, index + 1); + if val == nil then return nil, index; end + o[key] = val; + index = _skip_whitespace(json, index); + local b = json:byte(index); + if b == 0x7d then return _fixobject(o), index + 1; end -- "}" + if b ~= 0x2c then return nil, "object eof"; end -- "," end - - local function checkandskip(c) - local x = ch or "eof"; - if x ~= c then error("unexpected "..x..", '"..c.."' expected"); end - next(); - end - local function readliteral(lit, val) - for c in lit:gmatch(".") do - checkandskip(c); +end +local function _readarray(json, index) + local a = {}; + local oindex = index; + while true do + local val; + val, index = _readvalue(json, index + 1); + if val == nil then + if json:byte(oindex + 1) == 0x5d then return setmetatable(a, array_mt), oindex + 2; end -- "]" + return val, index; end - return val; + t_insert(a, val); + index = _skip_whitespace(json, index); + local b = json:byte(index); + if b == 0x5d then return setmetatable(a, array_mt), index + 1; end -- "]" + if b ~= 0x2c then return nil, "array eof"; end -- "," end - local function readstring() - local s = ""; - checkandskip("\""); - while ch do - while ch and ch ~= "\\" and ch ~= "\"" do - s = s..ch; next(); - end - if ch == "\\" then - next(); - if unescapes[ch] then - s = s..unescapes[ch]; - next(); - elseif ch == "u" then - local seq = ""; - for i=1,4 do - next(); - if not ch then error("unexpected eof in string"); end - if not ch:match("[0-9a-fA-F]") then error("invalid unicode escape sequence in string"); end - seq = seq..ch; - end - s = s..s.char(tonumber(seq, 16)); -- FIXME do proper utf-8 - next(); - else error("invalid escape sequence in string"); end - end - if ch == "\"" then - next(); - return s; - end - end - error("eof while reading string"); +end +local _unescape_error; +local function _unescape_surrogate_func(x) + local lead, trail = tonumber(x:sub(3, 6), 16), tonumber(x:sub(9, 12), 16); + local codepoint = lead * 0x400 + trail - 0x35FDC00; + local a = codepoint % 64; + codepoint = (codepoint - a) / 64; + local b = codepoint % 64; + codepoint = (codepoint - b) / 64; + local c = codepoint % 64; + codepoint = (codepoint - c) / 64; + return s_char(0xF0 + codepoint, 0x80 + c, 0x80 + b, 0x80 + a); +end +local function _unescape_func(x) + x = x:match("%x%x%x%x", 3); + if x then + --if x >= 0xD800 and x <= 0xDFFF then _unescape_error = true; end -- bad surrogate pair + return codepoint_to_utf8(tonumber(x, 16)); end - local function readnumber() - local s = ""; - if ch == "-" then - s = s..ch; next(); - if not ch:match("[0-9]") then error("number format error"); end - end - if ch == "0" then - s = s..ch; next(); - if ch:match("[0-9]") then error("number format error"); end - else - while ch and ch:match("[0-9]") do - s = s..ch; next(); - end - end - if ch == "." then - s = s..ch; next(); - if not ch:match("[0-9]") then error("number format error"); end - while ch and ch:match("[0-9]") do - s = s..ch; next(); - end - if ch == "e" or ch == "E" then - s = s..ch; next(); - if ch == "+" or ch == "-" then - s = s..ch; next(); - if not ch:match("[0-9]") then error("number format error"); end - while ch and ch:match("[0-9]") do - s = s..ch; next(); - end - end - end - end - return tonumber(s); + _unescape_error = true; +end +function _readstring(json, index) + index = index + 1; + local endindex = json:find("\"", index, true); + if endindex then + local s = json:sub(index, endindex - 1); + --if s:find("[%z-\31]") then return nil, "control char in string"; end + -- FIXME handle control characters + _unescape_error = nil; + --s = s:gsub("\\u[dD][89abAB]%x%x\\u[dD][cdefCDEF]%x%x", _unescape_surrogate_func); + -- FIXME handle escapes beyond BMP + s = s:gsub("\\u.?.?.?.?", _unescape_func); + if _unescape_error then return nil, "invalid escape"; end + return s, endindex + 1; end - local function readmember(t) - local k = readstring(); - checkandskip(":"); - t[k] = readvalue(); + return nil, "string eof"; +end +local function _readnumber(json, index) + local m = json:match("[0-9%.%-eE%+]+", index); -- FIXME do strict checking + return tonumber(m), index + #m; +end +local function _readnull(json, index) + local a, b, c = json:byte(index + 1, index + 3); + if a == 0x75 and b == 0x6c and c == 0x6c then + return null, index + 4; end - local function fixobject(obj) - local __array = obj.__array; - if __array then - obj.__array = nil; - for i,v in ipairs(__array) do - t_insert(obj, v); - end - end - local __hash = obj.__hash; - if __hash then - obj.__hash = nil; - local k; - for i,v in ipairs(__hash) do - if k ~= nil then - obj[k] = v; k = nil; - else - k = v; - end - end - end - return obj; + return nil, "null parse failed"; +end +local function _readtrue(json, index) + local a, b, c = json:byte(index + 1, index + 3); + if a == 0x72 and b == 0x75 and c == 0x65 then + return true, index + 4; end - local function readobject() - local t = {}; - next(); -- skip '{' - skipstuff(); - if ch == "}" then next(); return t; end - if not ch then error("eof while reading object"); end - readmember(t); - while true do - skipstuff(); - if ch == "}" then next(); return fixobject(t); end - if not ch then error("eof while reading object"); - elseif ch == "," then next(); - elseif ch then error("unexpected character in object, comma expected"); end - if not ch then error("eof while reading object"); end - readmember(t); - end + return nil, "true parse failed"; +end +local function _readfalse(json, index) + local a, b, c, d = json:byte(index + 1, index + 4); + if a == 0x61 and b == 0x6c and c == 0x73 and d == 0x65 then + return false, index + 5; end - - function readvalue() - skipstuff(); - while ch do - if ch == "{" then - return readobject(); - elseif ch == "[" then - return readarray(); - elseif ch == "\"" then - return readstring(); - elseif ch:match("[%-0-9%.]") then - return readnumber(); - elseif ch == "n" then - return readliteral("null", null); - elseif ch == "t" then - return readliteral("true", true); - elseif ch == "f" then - return readliteral("false", false); - else - error("invalid character at value start: "..ch); - end - end - error("eof while reading value"); + return nil, "false parse failed"; +end +function _readvalue(json, index) + index = _skip_whitespace(json, index); + local b = json:byte(index); + -- TODO try table lookup instead of if-else? + if b == 0x7B then -- "{" + return _readobject(json, index); + elseif b == 0x5B then -- "[" + return _readarray(json, index); + elseif b == 0x22 then -- "\"" + return _readstring(json, index); + elseif b ~= nil and b >= 0x30 and b <= 0x39 or b == 0x2d then -- "0"-"9" or "-" + return _readnumber(json, index); + elseif b == 0x6e then -- "n" + return _readnull(json, index); + elseif b == 0x74 then -- "t" + return _readtrue(json, index); + elseif b == 0x66 then -- "f" + return _readfalse(json, index); + else + return nil, "value expected"; end - next(); - return readvalue(); +end +local first_escape = { + ["\\\""] = "\\u0022"; + ["\\\\"] = "\\u005c"; + ["\\/" ] = "\\u002f"; + ["\\b" ] = "\\u0008"; + ["\\f" ] = "\\u000C"; + ["\\n" ] = "\\u000A"; + ["\\r" ] = "\\u000D"; + ["\\t" ] = "\\u0009"; + ["\\u" ] = "\\u"; +}; + +function json.decode(json) + json = json:gsub("\\.", first_escape) -- get rid of all escapes except \uXXXX, making string parsing much simpler + --:gsub("[\r\n]", "\t"); -- \r\n\t are equivalent, we care about none of them, and none of them can be in strings + + -- TODO do encoding verification + + local val, index = _readvalue(json, 1); + if val == nil then return val, index; end + if json:find("[^ \t\r\n]", index) then return nil, "garbage at eof"; end + + return val; end function json.test(object) diff --git a/util/logger.lua b/util/logger.lua index c3bf3992..cd0769f9 100644 --- a/util/logger.lua +++ b/util/logger.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -13,8 +13,7 @@ local ipairs, pairs, setmetatable = ipairs, pairs, setmetatable; module "logger" -local name_sinks, level_sinks = {}, {}; -local name_patterns = {}; +local level_sinks = {}; local make_logger; @@ -24,8 +23,6 @@ function init(name) local log_warn = make_logger(name, "warn"); local log_error = make_logger(name, "error"); - --name = nil; -- While this line is not commented, will automatically fill in file/line number info - local namelen = #name; return function (level, message, ...) if level == "debug" then return log_debug(message, ...); @@ -46,17 +43,7 @@ function make_logger(source_name, level) level_sinks[level] = level_handlers; end - local source_handlers = name_sinks[source_name]; - local logger = function (message, ...) - if source_handlers then - for i = 1,#source_handlers do - if source_handlers[i](source_name, level, message, ...) == false then - return; - end - end - end - for i = 1,#level_handlers do level_handlers[i](source_name, level, message, ...); end @@ -66,14 +53,12 @@ function make_logger(source_name, level) end function reset() - for k in pairs(name_sinks) do name_sinks[k] = nil; end for level, handler_list in pairs(level_sinks) do -- Clear all handlers for this level for i = 1, #handler_list do handler_list[i] = nil; end end - for k in pairs(name_patterns) do name_patterns[k] = nil; end end function add_level_sink(level, sink_function) @@ -84,22 +69,6 @@ function add_level_sink(level, sink_function) end end -function add_name_sink(name, sink_function, exclusive) - if not name_sinks[name] then - name_sinks[name] = { sink_function }; - else - name_sinks[name][#name_sinks[name] + 1] = sink_function; - end -end - -function add_name_pattern_sink(name_pattern, sink_function, exclusive) - if not name_patterns[name_pattern] then - name_patterns[name_pattern] = { sink_function }; - else - name_patterns[name_pattern][#name_patterns[name_pattern] + 1] = sink_function; - end -end - _M.new = make_logger; return _M; diff --git a/util/multitable.lua b/util/multitable.lua index 66b9bd8a..caf25118 100644 --- a/util/multitable.lua +++ b/util/multitable.lua @@ -1,17 +1,14 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- - - local select = select; local t_insert = table.insert; -local pairs = pairs; -local next = next; +local unpack, pairs, next, type = unpack, pairs, next, type; module "multitable" @@ -129,6 +126,41 @@ local function search_add(self, results, ...) return results; end +function iter(self, ...) + local query = { ... }; + local maxdepth = select("#", ...); + local stack = { self.data }; + local keys = { }; + local function it(self) + local depth = #stack; + local key = next(stack[depth], keys[depth]); + if key == nil then -- Go up the stack + stack[depth], keys[depth] = nil, nil; + if depth > 1 then + return it(self); + end + return; -- The end + else + keys[depth] = key; + end + local value = stack[depth][key]; + if query[depth] == nil or key == query[depth] then + if depth == maxdepth then -- Result + local result = {}; -- Collect keys forming path to result + for i = 1, depth do + result[i] = keys[i]; + end + result[depth+1] = value; + return unpack(result, 1, depth+1); + elseif type(value) == "table" then + t_insert(stack, value); -- Descend + end + end + return it(self); + end; + return it, self; +end + function new() return { data = {}; @@ -138,6 +170,7 @@ function new() remove = remove; search = search; search_add = search_add; + iter = iter; }; end diff --git a/util/openssl.lua b/util/openssl.lua new file mode 100644 index 00000000..ef3fba96 --- /dev/null +++ b/util/openssl.lua @@ -0,0 +1,172 @@ +local type, tostring, pairs, ipairs = type, tostring, pairs, ipairs; +local t_insert, t_concat = table.insert, table.concat; +local s_format = string.format; + +local oid_xmppaddr = "1.3.6.1.5.5.7.8.5"; -- [XMPP-CORE] +local oid_dnssrv = "1.3.6.1.5.5.7.8.7"; -- [SRV-ID] + +local idna_to_ascii = require "util.encodings".idna.to_ascii; + +local _M = {}; +local config = {}; +_M.config = config; + +local ssl_config = {}; +local ssl_config_mt = {__index=ssl_config}; + +function config.new() + return setmetatable({ + req = { + distinguished_name = "distinguished_name", + req_extensions = "v3_extensions", + x509_extensions = "v3_extensions", + prompt = "no", + }, + distinguished_name = { + countryName = "GB", + -- stateOrProvinceName = "", + localityName = "The Internet", + organizationName = "Your Organisation", + organizationalUnitName = "XMPP Department", + commonName = "example.com", + emailAddress = "xmpp@example.com", + }, + v3_extensions = { + basicConstraints = "CA:FALSE", + keyUsage = "digitalSignature,keyEncipherment", + extendedKeyUsage = "serverAuth,clientAuth", + subjectAltName = "@subject_alternative_name", + }, + subject_alternative_name = { + DNS = {}, + otherName = {}, + }, + }, ssl_config_mt); +end + +local DN_order = { + "countryName"; + "stateOrProvinceName"; + "localityName"; + "streetAddress"; + "organizationName"; + "organizationalUnitName"; + "commonName"; + "emailAddress"; +} +_M._DN_order = DN_order; +function ssl_config:serialize() + local s = ""; + for k, t in pairs(self) do + s = s .. ("[%s]\n"):format(k); + if k == "subject_alternative_name" then + for san, n in pairs(t) do + for i = 1,#n do + s = s .. s_format("%s.%d = %s\n", san, i -1, n[i]); + end + end + elseif k == "distinguished_name" then + for i=1,#DN_order do + local k = DN_order[i] + local v = t[k]; + if v then + s = s .. ("%s = %s\n"):format(k, v); + end + end + else + for k, v in pairs(t) do + s = s .. ("%s = %s\n"):format(k, v); + end + end + s = s .. "\n"; + end + return s; +end + +local function utf8string(s) + -- This is how we tell openssl not to encode UTF-8 strings as fake Latin1 + return s_format("FORMAT:UTF8,UTF8:%s", s); +end + +local function ia5string(s) + return s_format("IA5STRING:%s", s); +end + +_M.util = { + utf8string = utf8string, + ia5string = ia5string, +}; + +function ssl_config:add_dNSName(host) + t_insert(self.subject_alternative_name.DNS, idna_to_ascii(host)); +end + +function ssl_config:add_sRVName(host, service) + t_insert(self.subject_alternative_name.otherName, + s_format("%s;%s", oid_dnssrv, ia5string("_" .. service .."." .. idna_to_ascii(host)))); +end + +function ssl_config:add_xmppAddr(host) + t_insert(self.subject_alternative_name.otherName, + s_format("%s;%s", oid_xmppaddr, utf8string(host))); +end + +function ssl_config:from_prosody(hosts, config, certhosts) + -- TODO Decide if this should go elsewhere + local found_matching_hosts = false; + for i = 1,#certhosts do + local certhost = certhosts[i]; + for name in pairs(hosts) do + if name == certhost or name:sub(-1-#certhost) == "."..certhost then + found_matching_hosts = true; + self:add_dNSName(name); + --print(name .. "#component_module: " .. (config.get(name, "component_module") or "nil")); + if config.get(name, "component_module") == nil then + self:add_sRVName(name, "xmpp-client"); + end + --print(name .. "#anonymous_login: " .. tostring(config.get(name, "anonymous_login"))); + if not (config.get(name, "anonymous_login") or + config.get(name, "authentication") == "anonymous") then + self:add_sRVName(name, "xmpp-server"); + end + self:add_xmppAddr(name); + end + end + end + if not found_matching_hosts then + return nil, "no-matching-hosts"; + end +end + +do -- Lua to shell calls. + local function shell_escape(s) + return s:gsub("'",[['\'']]); + end + + local function serialize(f,o) + local r = {"openssl", f}; + for k,v in pairs(o) do + if type(k) == "string" then + t_insert(r, ("-%s"):format(k)); + if v ~= true then + t_insert(r, ("'%s'"):format(shell_escape(tostring(v)))); + end + end + end + for _,v in ipairs(o) do + t_insert(r, ("'%s'"):format(shell_escape(tostring(v)))); + end + return t_concat(r, " "); + end + + local os_execute = os.execute; + setmetatable(_M, { + __index=function(_,f) + return function(opts) + return 0 == os_execute(serialize(f, type(opts) == "table" and opts or {})); + end; + end; + }); +end + +return _M; diff --git a/util/pluginloader.lua b/util/pluginloader.lua index 31ab1e88..b894f527 100644 --- a/util/pluginloader.lua +++ b/util/pluginloader.lua @@ -1,58 +1,60 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- +local dir_sep, path_sep = package.config:match("^(%S+)%s(%S+)"); +local plugin_dir = {}; +for path in (CFG_PLUGINDIR or "./plugins/"):gsub("[/\\]", dir_sep):gmatch("[^"..path_sep.."]+") do + path = path..dir_sep; -- add path separator to path end + path = path:gsub(dir_sep..dir_sep.."+", dir_sep); -- coalesce multiple separaters + plugin_dir[#plugin_dir + 1] = path; +end -local plugin_dir = CFG_PLUGINDIR or "./plugins/"; - -local io_open, os_time = io.open, os.time; -local loadstring, pairs = loadstring, pairs; +local io_open = io.open; +local envload = require "util.envload".envload; module "pluginloader" -local function load_file(name) - local file, err = io_open(plugin_dir..name); - if not file then return file, err; end - local content = file:read("*a"); - file:close(); - return content, name; +function load_file(names) + local file, err, path; + for i=1,#plugin_dir do + for j=1,#names do + path = plugin_dir[i]..names[j]; + file, err = io_open(path); + if file then + local content = file:read("*a"); + file:close(); + return content, path; + end + end + end + return file, err; end -function load_resource(plugin, resource, loader) - local path, name = plugin:match("([^/]*)/?(.*)"); - if name == "" then - if not resource then - resource = "mod_"..plugin..".lua"; - end - loader = loader or load_file; - - local content, err = loader(plugin.."/"..resource); - if not content then content, err = loader(resource); end - -- TODO add support for packed plugins - - return content, err; - else - if not resource then - resource = "mod_"..name..".lua"; - end - loader = loader or load_file; +function load_resource(plugin, resource) + resource = resource or "mod_"..plugin..".lua"; - local content, err = loader(plugin.."/"..resource); - if not content then content, err = loader(path.."/"..resource); end - -- TODO add support for packed plugins - - return content, err; - end + local names = { + "mod_"..plugin.."/"..plugin.."/"..resource; -- mod_hello/hello/mod_hello.lua + "mod_"..plugin.."/"..resource; -- mod_hello/mod_hello.lua + plugin.."/"..resource; -- hello/mod_hello.lua + resource; -- mod_hello.lua + }; + + return load_file(names); end -function load_code(plugin, resource) +function load_code(plugin, resource, env) local content, err = load_resource(plugin, resource); if not content then return content, err; end - return loadstring(content, "@"..err); + local path = err; + local f, err = envload(content, "@"..path, env); + if not f then return f, err; end + return f, path; end return _M; diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua index 40d21be8..fe862114 100644 --- a/util/prosodyctl.lua +++ b/util/prosodyctl.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -15,18 +15,119 @@ local usermanager = require "core.usermanager"; local signal = require "util.signal"; local set = require "util.set"; local lfs = require "lfs"; +local pcall = pcall; +local type = type; local nodeprep, nameprep = stringprep.nodeprep, stringprep.nameprep; local io, os = io, os; +local print = print; local tostring, tonumber = tostring, tonumber; local CFG_SOURCEDIR = _G.CFG_SOURCEDIR; +local _G = _G; local prosody = prosody; module "prosodyctl" +-- UI helpers +function show_message(msg, ...) + print(msg:format(...)); +end + +function show_warning(msg, ...) + print(msg:format(...)); +end + +function show_usage(usage, desc) + print("Usage: ".._G.arg[0].." "..usage); + if desc then + print(" "..desc); + end +end + +function getchar(n) + local stty_ret = os.execute("stty raw -echo 2>/dev/null"); + local ok, char; + if stty_ret == 0 then + ok, char = pcall(io.read, n or 1); + os.execute("stty sane"); + else + ok, char = pcall(io.read, "*l"); + if ok then + char = char:sub(1, n or 1); + end + end + if ok then + return char; + end +end + +function getline() + local ok, line = pcall(io.read, "*l"); + if ok then + return line; + end +end + +function getpass() + local stty_ret = os.execute("stty -echo 2>/dev/null"); + if stty_ret ~= 0 then + io.write("\027[08m"); -- ANSI 'hidden' text attribute + end + local ok, pass = pcall(io.read, "*l"); + if stty_ret == 0 then + os.execute("stty sane"); + else + io.write("\027[00m"); + end + io.write("\n"); + if ok then + return pass; + end +end + +function show_yesno(prompt) + io.write(prompt, " "); + local choice = getchar():lower(); + io.write("\n"); + if not choice:match("%a") then + choice = prompt:match("%[.-(%U).-%]$"); + if not choice then return nil; end + end + return (choice == "y"); +end + +function read_password() + local password; + while true do + io.write("Enter new password: "); + password = getpass(); + if not password then + show_message("No password - cancelled"); + return; + end + io.write("Retype new password: "); + if getpass() ~= password then + if not show_yesno [=[Passwords did not match, try again? [Y/n]]=] then + return; + end + else + break; + end + end + return password; +end + +function show_prompt(prompt) + io.write(prompt, " "); + local line = getline(); + line = line and line:gsub("\n$",""); + return (line and #line > 0) and line or nil; +end + +-- Server control function adduser(params) local user, host, password = nodeprep(params.user), nameprep(params.host), params.password; if not user then @@ -35,12 +136,17 @@ function adduser(params) return false, "invalid-hostname"; end - local provider = prosody.hosts[host].users; + local host_session = prosody.hosts[host]; + if not host_session then + return false, "no-such-host"; + end + + storagemanager.initialize_host(host); + local provider = host_session.users; if not(provider) or provider.name == "null" then usermanager.initialize_host(host); end - storagemanager.initialize_host(host); - + local ok, errmsg = usermanager.create_user(user, password, host); if not ok then return false, errmsg; @@ -50,12 +156,13 @@ end function user_exists(params) local user, host, password = nodeprep(params.user), nameprep(params.host), params.password; + + storagemanager.initialize_host(host); local provider = prosody.hosts[host].users; if not(provider) or provider.name == "null" then usermanager.initialize_host(host); end - storagemanager.initialize_host(host); - + return usermanager.user_exists(user, host); end @@ -63,7 +170,7 @@ function passwd(params) if not _M.user_exists(params) then return false, "no-such-user"; end - + return _M.adduser(params); end @@ -71,40 +178,40 @@ function deluser(params) if not _M.user_exists(params) then return false, "no-such-user"; end - params.password = nil; - - return _M.adduser(params); + local user, host = nodeprep(params.user), nameprep(params.host); + + return usermanager.delete_user(user, host); end function getpid() - local pidfile = config.get("*", "core", "pidfile"); + local pidfile = config.get("*", "pidfile"); if not pidfile then return false, "no-pidfile"; end - - local modules_enabled = set.new(config.get("*", "core", "modules_enabled")); + + local modules_enabled = set.new(config.get("*", "modules_enabled")); if not modules_enabled:contains("posix") then return false, "no-posix"; end - + local file, err = io.open(pidfile, "r+"); if not file then return false, "pidfile-read-failed", err; end - + local locked, err = lfs.lock(file, "w"); if locked then file:close(); return false, "pidfile-not-locked"; end - + local pid = tonumber(file:read("*a")); file:close(); - + if not pid then return false, "invalid-pid"; end - + return true, pid; end @@ -145,10 +252,28 @@ function stop() if not ret then return false, "not-running"; end - + local ok, pid = _M.getpid() if not ok then return false, pid; end - + signal.kill(pid, signal.SIGTERM); return true; end + +function reload() + local ok, ret = _M.isrunning(); + if not ok then + return ok, ret; + end + if not ret then + return false, "not-running"; + end + + local ok, pid = _M.getpid() + if not ok then return false, pid; end + + signal.kill(pid, signal.SIGHUP); + return true; +end + +return _M; diff --git a/util/pubsub.lua b/util/pubsub.lua index 3beafab5..0dfd196b 100644 --- a/util/pubsub.lua +++ b/util/pubsub.lua @@ -1,3 +1,5 @@ +local events = require "util.events"; + module("pubsub", package.seeall); local service = {}; @@ -16,6 +18,7 @@ function new(config) affiliations = {}; subscriptions = {}; nodes = {}; + events = events.new(); }, service_mt); end @@ -26,18 +29,15 @@ end function service:may(node, actor, action) if actor == true then return true; end - - + local node_obj = self.nodes[node]; local node_aff = node_obj and node_obj.affiliations[actor]; local service_aff = self.affiliations[actor] or self.config.get_affiliation(actor, node, action) or "none"; - + + -- Check if node allows/forbids it local node_capabilities = node_obj and node_obj.capabilities; - local service_capabilities = self.config.capabilities; - - -- Check if node allows/forbids it if node_capabilities then local caps = node_capabilities[node_aff or service_aff]; if caps then @@ -47,7 +47,9 @@ function service:may(node, actor, action) end end end + -- Check service-wide capabilities instead + local service_capabilities = self.config.capabilities; local caps = service_capabilities[node_aff or service_aff]; if caps then local can = caps[action]; @@ -55,7 +57,7 @@ function service:may(node, actor, action) return can; end end - + return false; end @@ -70,14 +72,14 @@ function service:set_affiliation(node, actor, jid, affiliation) return false, "item-not-found"; end node_obj.affiliations[jid] = affiliation; - local _, jid_sub = self:get_subscription(node, nil, jid); + local _, jid_sub = self:get_subscription(node, true, jid); if not jid_sub and not self:may(node, jid, "be_unsubscribed") then - local ok, err = self:add_subscription(node, nil, jid); + local ok, err = self:add_subscription(node, true, jid); if not ok then return ok, err; end elseif jid_sub and not self:may(node, jid, "be_subscribed") then - local ok, err = self:add_subscription(node, nil, jid); + local ok, err = self:add_subscription(node, true, jid); if not ok then return ok, err; end @@ -88,7 +90,7 @@ end function service:add_subscription(node, actor, jid, options) -- Access checking local cap; - if jid == actor or self:jids_equal(actor, jid) then + if actor == true or jid == actor or self:jids_equal(actor, jid) then cap = "subscribe"; else cap = "subscribe_other"; @@ -105,7 +107,7 @@ function service:add_subscription(node, actor, jid, options) if not self.config.autocreate_on_subscribe then return false, "item-not-found"; else - local ok, err = self:create(node, actor); + local ok, err = self:create(node, true); if not ok then return ok, err; end @@ -124,13 +126,14 @@ function service:add_subscription(node, actor, jid, options) else self.subscriptions[normal_jid] = { [jid] = { [node] = true } }; end + self.events.fire_event("subscription-added", { node = node, jid = jid, normalized_jid = normal_jid, options = options }); return true; end function service:remove_subscription(node, actor, jid) -- Access checking local cap; - if jid == actor or self:jids_equal(actor, jid) then + if actor == true or jid == actor or self:jids_equal(actor, jid) then cap = "unsubscribe"; else cap = "unsubscribe_other"; @@ -164,13 +167,26 @@ function service:remove_subscription(node, actor, jid) self.subscriptions[normal_jid] = nil; end end + self.events.fire_event("subscription-removed", { node = node, jid = jid, normalized_jid = normal_jid }); + return true; +end + +function service:remove_all_subscriptions(actor, jid) + local normal_jid = self.config.normalize_jid(jid); + local subs = self.subscriptions[normal_jid] + subs = subs and subs[jid]; + if subs then + for node in pairs(subs) do + self:remove_subscription(node, true, jid); + end + end return true; end function service:get_subscription(node, actor, jid) -- Access checking local cap; - if jid == actor or self:jids_equal(actor, jid) then + if actor == true or jid == actor or self:jids_equal(actor, jid) then cap = "get_subscription"; else cap = "get_subscription_other"; @@ -195,7 +211,7 @@ function service:create(node, actor) if self.nodes[node] then return false, "conflict"; end - + self.nodes[node] = { name = node; subscribers = {}; @@ -210,6 +226,21 @@ function service:create(node, actor) return ok, err; end +function service:delete(node, actor) + -- Access checking + if not self:may(node, actor, "delete") then + return false, "forbidden"; + end + -- + local node_obj = self.nodes[node]; + if not node_obj then + return false, "item-not-found"; + end + self.nodes[node] = nil; + self.config.broadcaster("delete", node, node_obj.subscribers); + return true; +end + function service:publish(node, actor, id, item) -- Access checking if not self:may(node, actor, "publish") then @@ -221,14 +252,15 @@ function service:publish(node, actor, id, item) if not self.config.autocreate_on_publish then return false, "item-not-found"; end - local ok, err = self:create(node, actor); + local ok, err = self:create(node, true); if not ok then return ok, err; end node_obj = self.nodes[node]; end node_obj.data[id] = item; - self.config.broadcaster(node, node_obj.subscribers, item); + self.events.fire_event("item-published", { node = node, actor = actor, id = id, item = item }); + self.config.broadcaster("items", node, node_obj.subscribers, item); return true; end @@ -244,7 +276,24 @@ function service:retract(node, actor, id, retract) end node_obj.data[id] = nil; if retract then - self.config.broadcaster(node, node_obj.subscribers, retract); + self.config.broadcaster("items", node, node_obj.subscribers, retract); + end + return true +end + +function service:purge(node, actor, notify) + -- Access checking + if not self:may(node, actor, "retract") then + return false, "forbidden"; + end + -- + local node_obj = self.nodes[node]; + if not node_obj then + return false, "item-not-found"; + end + node_obj.data = {}; -- Purge + if notify then + self.config.broadcaster("purge", node, node_obj.subscribers); end return true end @@ -278,7 +327,7 @@ end function service:get_subscriptions(node, actor, jid) -- Access checking local cap; - if jid == actor or self:jids_equal(actor, jid) then + if actor == true or jid == actor or self:jids_equal(actor, jid) then cap = "get_subscriptions"; else cap = "get_subscriptions_other"; @@ -304,7 +353,7 @@ function service:get_subscriptions(node, actor, jid) if node then -- Return only subscriptions to this node if subscribed_nodes[node] then ret[#ret+1] = { - node = subscribed_node; + node = node; jid = jid; subscription = node_obj.subscribers[jid]; }; diff --git a/util/rfc6724.lua b/util/rfc6724.lua new file mode 100644 index 00000000..c8aec631 --- /dev/null +++ b/util/rfc6724.lua @@ -0,0 +1,142 @@ +-- Prosody IM +-- Copyright (C) 2011-2013 Florian Zeitz +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +-- This is used to sort destination addresses by preference +-- during S2S connections. +-- We can't hand this off to getaddrinfo, since it blocks + +local ip_commonPrefixLength = require"util.ip".commonPrefixLength +local new_ip = require"util.ip".new_ip; + +local function commonPrefixLength(ipA, ipB) + local len = ip_commonPrefixLength(ipA, ipB); + return len < 64 and len or 64; +end + +local function t_sort(t, comp) + for i = 1, (#t - 1) do + for j = (i + 1), #t do + local a, b = t[i], t[j]; + if not comp(a,b) then + t[i], t[j] = b, a; + end + end + end +end + +local function source(dest, candidates) + local function comp(ipA, ipB) + -- Rule 1: Prefer same address + if dest == ipA then + return true; + elseif dest == ipB then + return false; + end + + -- Rule 2: Prefer appropriate scope + if ipA.scope < ipB.scope then + if ipA.scope < dest.scope then + return false; + else + return true; + end + elseif ipA.scope > ipB.scope then + if ipB.scope < dest.scope then + return true; + else + return false; + end + end + + -- Rule 3: Avoid deprecated addresses + -- XXX: No way to determine this + -- Rule 4: Prefer home addresses + -- XXX: Mobility Address related, no way to determine this + -- Rule 5: Prefer outgoing interface + -- XXX: Interface to address relation. No way to determine this + -- Rule 6: Prefer matching label + if ipA.label == dest.label and ipB.label ~= dest.label then + return true; + elseif ipB.label == dest.label and ipA.label ~= dest.label then + return false; + end + + -- Rule 7: Prefer temporary addresses (over public ones) + -- XXX: No way to determine this + -- Rule 8: Use longest matching prefix + if commonPrefixLength(ipA, dest) > commonPrefixLength(ipB, dest) then + return true; + else + return false; + end + end + + t_sort(candidates, comp); + return candidates[1]; +end + +local function destination(candidates, sources) + local sourceAddrs = {}; + local function comp(ipA, ipB) + local ipAsource = sourceAddrs[ipA]; + local ipBsource = sourceAddrs[ipB]; + -- Rule 1: Avoid unusable destinations + -- XXX: No such information + -- Rule 2: Prefer matching scope + if ipA.scope == ipAsource.scope and ipB.scope ~= ipBsource.scope then + return true; + elseif ipA.scope ~= ipAsource.scope and ipB.scope == ipBsource.scope then + return false; + end + + -- Rule 3: Avoid deprecated addresses + -- XXX: No way to determine this + -- Rule 4: Prefer home addresses + -- XXX: Mobility Address related, no way to determine this + -- Rule 5: Prefer matching label + if ipAsource.label == ipA.label and ipBsource.label ~= ipB.label then + return true; + elseif ipBsource.label == ipB.label and ipAsource.label ~= ipA.label then + return false; + end + + -- Rule 6: Prefer higher precedence + if ipA.precedence > ipB.precedence then + return true; + elseif ipA.precedence < ipB.precedence then + return false; + end + + -- Rule 7: Prefer native transport + -- XXX: No way to determine this + -- Rule 8: Prefer smaller scope + if ipA.scope < ipB.scope then + return true; + elseif ipA.scope > ipB.scope then + return false; + end + + -- Rule 9: Use longest matching prefix + if commonPrefixLength(ipA, ipAsource) > commonPrefixLength(ipB, ipBsource) then + return true; + elseif commonPrefixLength(ipA, ipAsource) < commonPrefixLength(ipB, ipBsource) then + return false; + end + + -- Rule 10: Otherwise, leave order unchanged + return true; + end + for _, ip in ipairs(candidates) do + sourceAddrs[ip] = source(ip, sources); + end + + t_sort(candidates, comp); + return candidates; +end + +return {source = source, + destination = destination}; diff --git a/util/sasl.lua b/util/sasl.lua index 393a0919..0d90880d 100644 --- a/util/sasl.lua +++ b/util/sasl.lua @@ -68,13 +68,17 @@ end -- create a new SASL object which can be used to authenticate clients function new(realm, profile) - local mechanisms = {}; - for backend, f in pairs(profile) do - if backend_mechanism[backend] then - for _, mechanism in ipairs(backend_mechanism[backend]) do - mechanisms[mechanism] = true; + local mechanisms = profile.mechanisms; + if not mechanisms then + mechanisms = {}; + for backend, f in pairs(profile) do + if backend_mechanism[backend] then + for _, mechanism in ipairs(backend_mechanism[backend]) do + mechanisms[mechanism] = true; + end end end + profile.mechanisms = mechanisms; end return setmetatable({ profile = profile, realm = realm, mechs = mechanisms }, method); end @@ -131,5 +135,6 @@ require "util.sasl.plain" .init(registerMechanism); require "util.sasl.digest-md5".init(registerMechanism); require "util.sasl.anonymous" .init(registerMechanism); require "util.sasl.scram" .init(registerMechanism); +require "util.sasl.external" .init(registerMechanism); return _M; diff --git a/util/sasl/anonymous.lua b/util/sasl/anonymous.lua index b9af17fe..ca5fe404 100644 --- a/util/sasl/anonymous.lua +++ b/util/sasl/anonymous.lua @@ -16,7 +16,7 @@ local s_match = string.match; local log = require "util.logger".init("sasl"); local generate_uuid = require "util.uuid".generate; -module "anonymous" +module "sasl.anonymous" --========================= --SASL ANONYMOUS according to RFC 4505 @@ -43,4 +43,4 @@ function init(registerMechanism) registerMechanism("ANONYMOUS", {"anonymous"}, anonymous); end -return _M;
\ No newline at end of file +return _M; diff --git a/util/sasl/digest-md5.lua b/util/sasl/digest-md5.lua index 6f2c765e..591d8537 100644 --- a/util/sasl/digest-md5.lua +++ b/util/sasl/digest-md5.lua @@ -23,8 +23,9 @@ local to_byte, to_char = string.byte, string.char; local md5 = require "util.hashes".md5; local log = require "util.logger".init("sasl"); local generate_uuid = require "util.uuid".generate; +local nodeprep = require "util.encodings".stringprep.nodeprep; -module "digest-md5" +module "sasl.digest-md5" --========================= --SASL DIGEST-MD5 according to RFC 2831 @@ -139,10 +140,15 @@ local function digest(self, message) end -- check for username, it's REQUIRED by RFC 2831 - if not response["username"] then + local username = response["username"]; + local _nodeprep = self.profile.nodeprep; + if username and _nodeprep ~= false then + username = (_nodeprep or nodeprep)(username); -- FIXME charset + end + if not username or username == "" then return "failure", "malformed-request"; end - self["username"] = response["username"]; + self.username = username; -- check for nonce, ... if not response["nonce"] then @@ -178,7 +184,6 @@ local function digest(self, message) end --TODO maybe realm support - self.username = response["username"]; local Y, state; if self.profile.plain then local password, state = self.profile.plain(self, response["username"], self.realm) @@ -240,4 +245,4 @@ function init(registerMechanism) registerMechanism("DIGEST-MD5", {"plain"}, digest); end -return _M;
\ No newline at end of file +return _M; diff --git a/util/sasl/external.lua b/util/sasl/external.lua new file mode 100644 index 00000000..4c5c4343 --- /dev/null +++ b/util/sasl/external.lua @@ -0,0 +1,25 @@ +local saslprep = require "util.encodings".stringprep.saslprep; + +module "sasl.external" + +local function external(self, message) + message = saslprep(message); + local state + self.username, state = self.profile.external(message); + + if state == false then + return "failure", "account-disabled"; + elseif state == nil then + return "failure", "not-authorized"; + elseif state == "expired" then + return "false", "credentials-expired"; + end + + return "success"; +end + +function init(registerMechanism) + registerMechanism("EXTERNAL", {"external"}, external); +end + +return _M; diff --git a/util/sasl/plain.lua b/util/sasl/plain.lua index d6ebe304..c9ec2911 100644 --- a/util/sasl/plain.lua +++ b/util/sasl/plain.lua @@ -13,9 +13,10 @@ local s_match = string.match; local saslprep = require "util.encodings".stringprep.saslprep; +local nodeprep = require "util.encodings".stringprep.nodeprep; local log = require "util.logger".init("sasl"); -module "plain" +module "sasl.plain" -- ================================ -- SASL PLAIN according to RFC 4616 @@ -54,6 +55,14 @@ local function plain(self, message) return "failure", "malformed-request", "Invalid username or password."; end + local _nodeprep = self.profile.nodeprep; + if _nodeprep ~= false then + authentication = (_nodeprep or nodeprep)(authentication); + if not authentication or authentication == "" then + return "failure", "malformed-request", "Invalid username or password." + end + end + local correct, state = false, false; if self.profile.plain then local correct_password; @@ -64,15 +73,13 @@ local function plain(self, message) end self.username = authentication - if not state then + if state == false then return "failure", "account-disabled"; - end - - if correct then - return "success"; - else + elseif state == nil or not correct then return "failure", "not-authorized", "Unable to authorize you with the authentication credentials you've sent."; end + + return "success"; end function init(registerMechanism) diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua index 071de505..cf938dba 100644 --- a/util/sasl/scram.lua +++ b/util/sasl/scram.lua @@ -16,16 +16,18 @@ local type = type local string = string local tostring = tostring; local base64 = require "util.encodings".base64; -local hmac_sha1 = require "util.hmac".sha1; +local hmac_sha1 = require "util.hashes".hmac_sha1; local sha1 = require "util.hashes".sha1; +local Hi = require "util.hashes".scram_Hi_sha1; local generate_uuid = require "util.uuid".generate; local saslprep = require "util.encodings".stringprep.saslprep; +local nodeprep = require "util.encodings".stringprep.nodeprep; local log = require "util.logger".init("sasl"); local t_concat = table.concat; local char = string.char; local byte = string.byte; -module "scram" +module "sasl.scram" --========================= --SASL SCRAM-SHA-1 according to RFC 5802 @@ -69,33 +71,26 @@ local function binaryXOR( a, b ) return t_concat(result); end --- hash algorithm independent Hi(PBKDF2) implementation -function Hi(hmac, str, salt, i) - local Ust = hmac(str, salt.."\0\0\0\1"); - local res = Ust; - for n=1,i-1 do - local Und = hmac(str, Ust) - res = binaryXOR(res, Und) - Ust = Und - end - return res -end - -local function validate_username(username) +local function validate_username(username, _nodeprep) -- check for forbidden char sequences for eq in username:gmatch("=(.?.?)") do - if eq ~= "2D" and eq ~= "3D" then + if eq ~= "2C" and eq ~= "3D" then return false end end - - -- replace =2D with , and =3D with = - username = username:gsub("=2D", ","); + + -- replace =2C with , and =3D with = + username = username:gsub("=2C", ","); username = username:gsub("=3D", "="); - + -- apply SASLprep username = saslprep(username); - return username; + + if username and _nodeprep ~= false then + username = (_nodeprep or nodeprep)(username); + end + + return username and #username>0 and username; end local function hashprep(hashname) @@ -109,7 +104,7 @@ function getAuthenticationDatabaseSHA1(password, salt, iteration_count) if iteration_count < 4096 then log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.") end - local salted_password = Hi(hmac_sha1, password, salt, iteration_count); + local salted_password = Hi(password, salt, iteration_count); local stored_key = sha1(hmac_sha1(salted_password, "Client Key")) local server_key = hmac_sha1(salted_password, "Server Key"); return true, stored_key, server_key @@ -120,12 +115,12 @@ local function scram_gen(hash_name, H_f, HMAC_f) if not self.state then self["state"] = {} end local support_channel_binding = false; if self.profile.cb then support_channel_binding = true; end - + if type(message) ~= "string" or #message == 0 then return "failure", "malformed-request" end if not self.state.name then -- we are processing client_first_message local client_first_message = message; - log("debug", client_first_message); + -- TODO: fail if authzid is provided, since we don't support them yet self.state["client_first_message"] = client_first_message; self.state["gs2_cbind_flag"], self.state["gs2_cbind_name"], self.state["authzid"], self.state["name"], self.state["clientnonce"] @@ -156,21 +151,21 @@ local function scram_gen(hash_name, H_f, HMAC_f) if not self.state.name or not self.state.clientnonce then return "failure", "malformed-request", "Channel binding isn't support at this time."; end - - self.state.name = validate_username(self.state.name); + + self.state.name = validate_username(self.state.name, self.profile.nodeprep); if not self.state.name then log("debug", "Username violates either SASLprep or contains forbidden character sequences.") return "failure", "malformed-request", "Invalid username."; end - + self.state["servernonce"] = generate_uuid(); - + -- retreive credentials if self.profile.plain then local password, state = self.profile.plain(self, self.state.name, self.realm) if state == nil then return "failure", "not-authorized" elseif state == false then return "failure", "account-disabled" end - + password = saslprep(password); if not password then log("debug", "Password violates SASLprep."); @@ -190,20 +185,20 @@ local function scram_gen(hash_name, H_f, HMAC_f) local stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, self.state.name, self.realm); if state == nil then return "failure", "not-authorized" elseif state == false then return "failure", "account-disabled" end - + self.state.stored_key = stored_key; self.state.server_key = server_key; self.state.iteration_count = iteration_count; self.state.salt = salt end - + local server_first_message = "r="..self.state.clientnonce..self.state.servernonce..",s="..base64.encode(self.state.salt)..",i="..self.state.iteration_count; self.state["server_first_message"] = server_first_message; return "challenge", server_first_message else -- we are processing client_final_message local client_final_message = message; - log("debug", "client_final_message: %s", client_final_message); + self.state["channelbinding"], self.state["nonce"], self.state["proof"] = client_final_message:match("^c=(.*),r=(.*),.*p=(.*)"); if not self.state.proof or not self.state.nonce or not self.state.channelbinding then @@ -223,10 +218,10 @@ local function scram_gen(hash_name, H_f, HMAC_f) if self.state.nonce ~= self.state.clientnonce..self.state.servernonce then return "failure", "malformed-request", "Wrong nonce in client-final-message."; end - + local ServerKey = self.state.server_key; local StoredKey = self.state.stored_key; - + local AuthMessage = "n=" .. s_match(self.state.client_first_message,"n=(.+)") .. "," .. self.state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+") local ClientSignature = HMAC_f(StoredKey, AuthMessage) local ClientKey = binaryXOR(ClientSignature, base64.decode(self.state.proof)) diff --git a/util/sasl_cyrus.lua b/util/sasl_cyrus.lua index 002118fd..a0e8bd69 100644 --- a/util/sasl_cyrus.lua +++ b/util/sasl_cyrus.lua @@ -78,11 +78,15 @@ local function init(service_name) end -- create a new SASL object which can be used to authenticate clients -function new(realm, service_name, app_name) +-- host_fqdn may be nil in which case gethostname() gives the value. +-- For GSSAPI, this determines the hostname in the service ticket (after +-- reverse DNS canonicalization, only if [libdefaults] rdns = true which +-- is the default). +function new(realm, service_name, app_name, host_fqdn) init(app_name or service_name); - local st, ret = pcall(cyrussasl.server_new, service_name, nil, realm, nil, nil) + local st, ret = pcall(cyrussasl.server_new, service_name, host_fqdn, realm, nil, nil) if not st then log("error", "Creating SASL server connection failed: %s", ret); return nil; diff --git a/util/serialization.lua b/util/serialization.lua index e193b64f..06e45054 100644 --- a/util/serialization.lua +++ b/util/serialization.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -16,11 +16,12 @@ local pairs = pairs; local next = next; local loadstring = loadstring; -local setfenv = setfenv; local pcall = pcall; local debug_traceback = debug.traceback; local log = require "util.logger".init("serialization"); +local envload = require"util.envload".envload; + module "serialization" local indent = function(i) @@ -84,9 +85,8 @@ end function deserialize(str) if type(str) ~= "string" then return nil; end str = "return "..str; - local f, err = loadstring(str, "@data"); + local f, err = envload(str, "@data", {}); if not f then return nil, err; end - setfenv(f, {}); local success, ret = pcall(f); if not success then return nil, ret; end return ret; diff --git a/util/set.lua b/util/set.lua index e4cc2dff..89cd7cf3 100644 --- a/util/set.lua +++ b/util/set.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -26,8 +26,9 @@ function set_mt.__div(set, func) local new_set, new_items = _M.new(); local items, new_items = set._items, new_set._items; for item in pairs(items) do - if func(item) then - new_items[item] = true; + local new_item = func(item); + if new_item ~= nil then + new_items[new_item] = true; end end return new_set; @@ -39,13 +40,13 @@ function set_mt.__eq(set1, set2) return false; end end - + for item in pairs(set2) do if not set1[item] then return false; end end - + return true; end function set_mt.__tostring(set) @@ -64,56 +65,58 @@ end function new(list) local items = setmetatable({}, items_mt); local set = { _items = items }; - + function set:add(item) items[item] = true; end - + function set:contains(item) return items[item]; end - + function set:items() - return items; + return next, items; end - + function set:remove(item) items[item] = nil; end - + function set:add_list(list) - for _, item in ipairs(list) do - items[item] = true; + if list then + for _, item in ipairs(list) do + items[item] = true; + end end end - + function set:include(otherset) - for item in pairs(otherset) do + for item in otherset do items[item] = true; end end function set:exclude(otherset) - for item in pairs(otherset) do + for item in otherset do items[item] = nil; end end - + function set:empty() return not next(items); end - + if list then set:add_list(list); end - + return setmetatable(set, set_mt); end function union(set1, set2) local set = new(); local items = set._items; - + for item in pairs(set1._items) do items[item] = true; end @@ -121,14 +124,14 @@ function union(set1, set2) for item in pairs(set2._items) do items[item] = true; end - + return set; end function difference(set1, set2) local set = new(); local items = set._items; - + for item in pairs(set1._items) do items[item] = (not set2._items[item]) or nil; end @@ -139,13 +142,13 @@ end function intersection(set1, set2) local set = new(); local items = set._items; - + set1, set2 = set1._items, set2._items; - + for item in pairs(set1) do items[item] = (not not set2[item]) or nil; end - + return set; end diff --git a/util/sql.lua b/util/sql.lua new file mode 100644 index 00000000..b8c16e27 --- /dev/null +++ b/util/sql.lua @@ -0,0 +1,342 @@ + +local setmetatable, getmetatable = setmetatable, getmetatable; +local ipairs, unpack, select = ipairs, unpack, select; +local tonumber, tostring = tonumber, tostring; +local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback; +local t_concat = table.concat; +local s_char = string.char; +local log = require "util.logger".init("sql"); + +local DBI = require "DBI"; +-- This loads all available drivers while globals are unlocked +-- LuaDBI should be fixed to not set globals. +DBI.Drivers(); +local build_url = require "socket.url".build; + +module("sql") + +local column_mt = {}; +local table_mt = {}; +local query_mt = {}; +--local op_mt = {}; +local index_mt = {}; + +function is_column(x) return getmetatable(x)==column_mt; end +function is_index(x) return getmetatable(x)==index_mt; end +function is_table(x) return getmetatable(x)==table_mt; end +function is_query(x) return getmetatable(x)==query_mt; end +--function is_op(x) return getmetatable(x)==op_mt; end +--function expr(...) return setmetatable({...}, op_mt); end +function Integer(n) return "Integer()" end +function String(n) return "String()" end + +--[[local ops = { + __add = function(a, b) return "("..a.."+"..b..")" end; + __sub = function(a, b) return "("..a.."-"..b..")" end; + __mul = function(a, b) return "("..a.."*"..b..")" end; + __div = function(a, b) return "("..a.."/"..b..")" end; + __mod = function(a, b) return "("..a.."%"..b..")" end; + __pow = function(a, b) return "POW("..a..","..b..")" end; + __unm = function(a) return "NOT("..a..")" end; + __len = function(a) return "COUNT("..a..")" end; + __eq = function(a, b) return "("..a.."=="..b..")" end; + __lt = function(a, b) return "("..a.."<"..b..")" end; + __le = function(a, b) return "("..a.."<="..b..")" end; +}; + +local functions = { + +}; + +local cmap = { + [Integer] = Integer(); + [String] = String(); +};]] + +function Column(definition) + return setmetatable(definition, column_mt); +end +function Table(definition) + local c = {} + for i,col in ipairs(definition) do + if is_column(col) then + c[i], c[col.name] = col, col; + elseif is_index(col) then + col.table = definition.name; + end + end + return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt); +end +function Index(definition) + return setmetatable(definition, index_mt); +end + +function table_mt:__tostring() + local s = { 'name="'..self.__table__.name..'"' } + for i,col in ipairs(self.__table__) do + s[#s+1] = tostring(col); + end + return 'Table{ '..t_concat(s, ", ")..' }' +end +table_mt.__index = {}; +function table_mt.__index:create(engine) + return engine:_create_table(self); +end +function table_mt:__call(...) + -- TODO +end +function column_mt:__tostring() + return 'Column{ name="'..self.name..'", type="'..self.type..'" }' +end +function index_mt:__tostring() + local s = 'Index{ name="'..self.name..'"'; + for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end + return s..' }'; +-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }' +end +-- + +local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end +local function parse_url(url) + local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)"); + assert(scheme, "Invalid URL format"); + local username, password, host, port; + local authpart, hostpart = secondpart:match("([^@]+)@([^@+])"); + if not authpart then hostpart = secondpart; end + if authpart then + username, password = authpart:match("([^:]*):(.*)"); + username = username or authpart; + password = password and urldecode(password); + end + if hostpart then + host, port = hostpart:match("([^:]*):(.*)"); + host = host or hostpart; + port = port and assert(tonumber(port), "Invalid URL format"); + end + return { + scheme = scheme:lower(); + username = username; password = password; + host = host; port = port; + database = #database > 0 and database or nil; + }; +end + +--[[local session = {}; + +function session.query(...) + local rets = {...}; + local query = setmetatable({ __rets = rets, __filters }, query_mt); + return query; +end +-- + +local function db2uri(params) + return build_url{ + scheme = params.driver, + user = params.username, + password = params.password, + host = params.host, + port = params.port, + path = params.database, + }; +end]] + +local engine = {}; +function engine:connect() + if self.conn then return true; end + + local params = self.params; + assert(params.driver, "no driver") + local dbh, err = DBI.Connect( + params.driver, params.database, + params.username, params.password, + params.host, params.port + ); + if not dbh then return nil, err; end + dbh:autocommit(false); -- don't commit automatically + self.conn = dbh; + self.prepared = {}; + return true; +end +function engine:execute(sql, ...) + local success, err = self:connect(); + if not success then return success, err; end + local prepared = self.prepared; + + local stmt = prepared[sql]; + if not stmt then + local err; + stmt, err = self.conn:prepare(sql); + if not stmt then return stmt, err; end + prepared[sql] = stmt; + end + + local success, err = stmt:execute(...); + if not success then return success, err; end + return stmt; +end + +local result_mt = { __index = { + affected = function(self) return self.__stmt:affected(); end; + rowcount = function(self) return self.__stmt:rowcount(); end; +} }; + +function engine:execute_query(sql, ...) + if self.params.driver == "PostgreSQL" then + sql = sql:gsub("`", "\""); + end + local stmt = assert(self.conn:prepare(sql)); + assert(stmt:execute(...)); + return stmt:rows(); +end +function engine:execute_update(sql, ...) + if self.params.driver == "PostgreSQL" then + sql = sql:gsub("`", "\""); + end + local prepared = self.prepared; + local stmt = prepared[sql]; + if not stmt then + stmt = assert(self.conn:prepare(sql)); + prepared[sql] = stmt; + end + assert(stmt:execute(...)); + return setmetatable({ __stmt = stmt }, result_mt); +end +engine.insert = engine.execute_update; +engine.select = engine.execute_query; +engine.delete = engine.execute_update; +engine.update = engine.execute_update; +function engine:_transaction(func, ...) + if not self.conn then + local a,b = self:connect(); + if not a then return a,b; end + end + --assert(not self.__transaction, "Recursive transactions not allowed"); + local args, n_args = {...}, select("#", ...); + local function f() return func(unpack(args, 1, n_args)); end + self.__transaction = true; + local success, a, b, c = xpcall(f, debug_traceback); + self.__transaction = nil; + if success then + log("debug", "SQL transaction success [%s]", tostring(func)); + local ok, err = self.conn:commit(); + if not ok then return ok, err; end -- commit failed + return success, a, b, c; + else + log("debug", "SQL transaction failure [%s]: %s", tostring(func), a); + if self.conn then self.conn:rollback(); end + return success, a; + end +end +function engine:transaction(...) + local a,b = self:_transaction(...); + if not a then + local conn = self.conn; + if not conn or not conn:ping() then + self.conn = nil; + a,b = self:_transaction(...); + end + end + return a,b; +end +function engine:_create_index(index) + local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` ("; + for i=1,#index do + sql = sql.."`"..index[i].."`"; + if i ~= #index then sql = sql..", "; end + end + sql = sql..");" + if self.params.driver == "PostgreSQL" then + sql = sql:gsub("`", "\""); + elseif self.params.driver == "MySQL" then + sql = sql:gsub("`([,)])", "`(20)%1"); + end + --print(sql); + return self:execute(sql); +end +function engine:_create_table(table) + local sql = "CREATE TABLE `"..table.name.."` ("; + for i,col in ipairs(table.c) do + sql = sql.."`"..col.name.."` "..col.type; + if col.nullable == false then sql = sql.." NOT NULL"; end + if i ~= #table.c then sql = sql..", "; end + end + sql = sql.. ");" + if self.params.driver == "PostgreSQL" then + sql = sql:gsub("`", "\""); + elseif self.params.driver == "MySQL" then + sql = sql:gsub(";$", " CHARACTER SET 'utf8' COLLATE 'utf8_bin';"); + end + local success,err = self:execute(sql); + if not success then return success,err; end + for i,v in ipairs(table.__table__) do + if is_index(v) then + self:_create_index(v); + end + end + return success; +end +local engine_mt = { __index = engine }; + +local function db2uri(params) + return build_url{ + scheme = params.driver, + user = params.username, + password = params.password, + host = params.host, + port = params.port, + path = params.database, + }; +end +local engine_cache = {}; -- TODO make weak valued +function create_engine(self, params) + local url = db2uri(params); + if not engine_cache[url] then + local engine = setmetatable({ url = url, params = params }, engine_mt); + engine_cache[url] = engine; + end + return engine_cache[url]; +end + + +--[[Users = Table { + name="users"; + Column { name="user_id", type=String(), primary_key=true }; +}; +print(Users) +print(Users.c.user_id)]] + +--local engine = create_engine('postgresql://scott:tiger@localhost:5432/mydatabase'); +--[[local engine = create_engine{ driver = "SQLite3", database = "./alchemy.sqlite" }; + +local i = 0; +for row in assert(engine:execute("select * from sqlite_master")):rows(true) do + i = i+1; + print(i); + for k,v in pairs(row) do + print("",k,v); + end +end +print("---") + +Prosody = Table { + name="prosody"; + Column { name="host", type="TEXT", nullable=false }; + Column { name="user", type="TEXT", nullable=false }; + Column { name="store", type="TEXT", nullable=false }; + Column { name="key", type="TEXT", nullable=false }; + Column { name="type", type="TEXT", nullable=false }; + Column { name="value", type="TEXT", nullable=false }; + Index { name="prosody_index", "host", "user", "store", "key" }; +}; +--print(Prosody); +assert(engine:transaction(function() + assert(Prosody:create(engine)); +end)); + +for row in assert(engine:execute("select user from prosody")):rows(true) do + print("username:", row['username']) +end +--result.close();]] + +return _M; diff --git a/util/stanza.lua b/util/stanza.lua index afaf9ce9..82601e63 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -1,29 +1,24 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- local t_insert = table.insert; -local t_concat = table.concat; local t_remove = table.remove; local t_concat = table.concat; local s_format = string.format; local s_match = string.match; local tostring = tostring; local setmetatable = setmetatable; -local getmetatable = getmetatable; local pairs = pairs; local ipairs = ipairs; local type = type; -local next = next; -local print = print; -local unpack = unpack; local s_gsub = string.gsub; -local s_char = string.char; +local s_sub = string.sub; local s_find = string.find; local os = os; @@ -44,11 +39,13 @@ module "stanza" stanza_mt = { __type = "stanza" }; stanza_mt.__index = stanza_mt; +local stanza_mt = stanza_mt; function stanza(name, attr) local stanza = { name = name, attr = attr or {}, tags = {} }; return setmetatable(stanza, stanza_mt); end +local stanza = stanza; function stanza_mt:query(xmlns) return self:tag("query", { xmlns = xmlns }); @@ -102,12 +99,20 @@ function stanza_mt:get_child(name, xmlns) if (not name or child.name == name) and ((not xmlns and self.attr.xmlns == child.attr.xmlns) or child.attr.xmlns == xmlns) then - + return child; end end end +function stanza_mt:get_child_text(name, xmlns) + local tag = self:get_child(name, xmlns); + if tag then + return tag:get_text(); + end + return nil; +end + function stanza_mt:child_with_name(name) for _, child in ipairs(self.tags) do if child.name == name then return child; end @@ -128,37 +133,28 @@ function stanza_mt:children() end, self, i; end -function stanza_mt:matching_tags(name, xmlns) - xmlns = xmlns or self.attr.xmlns; +function stanza_mt:childtags(name, xmlns) local tags = self.tags; local start_i, max_i = 1, #tags; return function () - for i=start_i,max_i do - v = tags[i]; - if (not name or v.name == name) - and (not xmlns or xmlns == v.attr.xmlns) then - start_i = i+1; - return v; - end + for i = start_i, max_i do + local v = tags[i]; + if (not name or v.name == name) + and ((not xmlns and self.attr.xmlns == v.attr.xmlns) + or v.attr.xmlns == xmlns) then + start_i = i+1; + return v; end - end, tags, i; -end - -function stanza_mt:childtags() - local i = 0; - return function (a) - i = i + 1 - local v = self.tags[i] - if v then return v; end - end, self.tags[1], i; + end + end; end function stanza_mt:maptags(callback) local tags, curr_tag = self.tags, 1; local n_children, n_tags = #self, #tags; - + local i = 1; - while curr_tag <= n_tags do + while curr_tag <= n_tags and n_tags > 0 do if self[i] == tags[curr_tag] then local ret = callback(self[i]); if ret == nil then @@ -166,17 +162,44 @@ function stanza_mt:maptags(callback) t_remove(tags, curr_tag); n_children = n_children - 1; n_tags = n_tags - 1; + i = i - 1; + curr_tag = curr_tag - 1; else self[i] = ret; - tags[i] = ret; + tags[curr_tag] = ret; end - i = i + 1; curr_tag = curr_tag + 1; end + i = i + 1; end return self; end +function stanza_mt:find(path) + local pos = 1; + local len = #path + 1; + + repeat + local xmlns, name, text; + local char = s_sub(path, pos, pos); + if char == "@" then + return self.attr[s_sub(path, pos + 1)]; + elseif char == "{" then + xmlns, pos = s_match(path, "^([^}]+)}()", pos + 1); + end + name, text, pos = s_match(path, "^([^@/#]*)([/#]?)()", pos); + name = name ~= "" and name or nil; + if pos == len then + if text == "#" then + return self:get_child_text(name, xmlns); + end + return self:get_child(name, xmlns); + end + self = self:get_child(name, xmlns); + until not self +end + + local xml_escape do local escape_table = { ["'"] = "'", ["\""] = """, ["<"] = "<", [">"] = ">", ["&"] = "&" }; @@ -235,14 +258,14 @@ end function stanza_mt.get_error(stanza) local type, condition, text; - + local error_tag = stanza:get_child("error"); if not error_tag then return nil, nil, nil; end type = error_tag.attr.type; - - for child in error_tag:childtags() do + + for _, child in ipairs(error_tag.tags) do if child.attr.xmlns == xmlns_stanzas then if not text and child.name == "text" then text = child:get_text(); @@ -257,11 +280,6 @@ function stanza_mt.get_error(stanza) return type, condition or "undefined-condition", text; end -function stanza_mt.__add(s1, s2) - return s1:add_direct_child(s2); -end - - do local id = 0; function new_id() @@ -315,28 +333,25 @@ function deserialize(stanza) stanza.tags = tags; end end - + return stanza; end -function clone(stanza) - local lookup_table = {}; - local function _copy(object) - if type(object) ~= "table" then - return object; - elseif lookup_table[object] then - return lookup_table[object]; +local function _clone(stanza) + local attr, tags = {}, {}; + for k,v in pairs(stanza.attr) do attr[k] = v; end + local new = { name = stanza.name, attr = attr, tags = tags }; + for i=1,#stanza do + local child = stanza[i]; + if child.name then + child = _clone(child); + t_insert(tags, child); end - local new_table = {}; - lookup_table[object] = new_table; - for index, value in pairs(object) do - new_table[_copy(index)] = _copy(value); - end - return setmetatable(new_table, getmetatable(object)); + t_insert(new, child); end - - return _copy(stanza) + return setmetatable(new, stanza_mt); end +clone = _clone; function message(attr, body) if not body then @@ -375,7 +390,7 @@ if do_pretty_printing then local style_attrv = getstyle("red"); local style_tagname = getstyle("red"); local style_punc = getstyle("magenta"); - + local attr_format = " "..getstring(style_attrk, "%s")..getstring(style_punc, "=")..getstring(style_attrv, "'%s'"); local top_tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">"); --local tag_format = getstring(style_punc, "<")..getstring(style_tagname, "%s").."%s"..getstring(style_punc, ">").."%s"..getstring(style_punc, "</")..getstring(style_tagname, "%s")..getstring(style_punc, ">"); @@ -396,7 +411,7 @@ if do_pretty_printing then end return s_format(tag_format, t.name, attr_string, children_text, t.name); end - + function stanza_mt.pretty_top_tag(t) local attr_string = ""; if t.attr then diff --git a/util/template.lua b/util/template.lua index ebd8be14..66d4fca7 100644 --- a/util/template.lua +++ b/util/template.lua @@ -1,64 +1,28 @@ -local st = require "util.stanza"; -local lxp = require "lxp"; +local stanza_mt = require "util.stanza".stanza_mt; local setmetatable = setmetatable; local pairs = pairs; local ipairs = ipairs; local error = error; local loadstring = loadstring; local debug = debug; +local t_remove = table.remove; +local parse_xml = require "util.xml".parse; module("template") -local parse_xml = (function() - local ns_prefixes = { - ["http://www.w3.org/XML/1998/namespace"] = "xml"; - }; - local ns_separator = "\1"; - local ns_pattern = "^([^"..ns_separator.."]*)"..ns_separator.."?(.*)$"; - return function(xml) - local handler = {}; - local stanza = st.stanza("root"); - function handler:StartElement(tagname, attr) - local curr_ns,name = tagname:match(ns_pattern); - if name == "" then - curr_ns, name = "", curr_ns; - end - if curr_ns ~= "" then - attr.xmlns = curr_ns; - end - for i=1,#attr do - local k = attr[i]; - attr[i] = nil; - local ns, nm = k:match(ns_pattern); - if nm ~= "" then - ns = ns_prefixes[ns]; - if ns then - attr[ns..":"..nm] = attr[k]; - attr[k] = nil; - end - end - end - stanza:tag(name, attr); - end - function handler:CharacterData(data) - data = data:gsub("^%s*", ""):gsub("%s*$", ""); - stanza:text(data); - end - function handler:EndElement(tagname) - stanza:up(); - end - local parser = lxp.new(handler, "\1"); - local ok, err, line, col = parser:parse(xml); - if ok then ok, err, line, col = parser:parse(); end - --parser:close(); - if ok then - return stanza.tags[1]; +local function trim_xml(stanza) + for i=#stanza,1,-1 do + local child = stanza[i]; + if child.name then + trim_xml(child); else - return ok, err.." (line "..line..", col "..col..")"; + child = child:gsub("^%s*", ""):gsub("%s*$", ""); + stanza[i] = child; + if child == "" then t_remove(stanza, i); end end - end; -end)(); + end +end local function create_string_string(str) str = ("%q"):format(str); @@ -100,7 +64,6 @@ local function create_clone_string(stanza, lookup, xmlns) end return lookup[stanza]; end -local stanza_mt = st.stanza_mt; local function create_cloner(stanza, chunkname) local lookup = {}; local name = create_clone_string(stanza, lookup, ""); @@ -118,6 +81,7 @@ local template_mt = { __tostring = function(t) return t.name end }; local function create_template(templates, text) local stanza, err = parse_xml(text); if not stanza then error(err); end + trim_xml(stanza); local info = debug.getinfo(3, "Sl"); info = info and ("template(%s:%d)"):format(info.short_src:match("[^\\/]*$"), info.currentline) or "template(unknown)"; diff --git a/util/termcolours.lua b/util/termcolours.lua index df204688..ef978364 100644 --- a/util/termcolours.lua +++ b/util/termcolours.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -9,6 +9,7 @@ local t_concat, t_insert = table.concat, table.insert; local char, format = string.char, string.format; +local tonumber = tonumber; local ipairs = ipairs; local io_write = io.write; @@ -34,6 +35,15 @@ local winstylemap = { ["1;31"] = 4+8 -- bold red } +local cssmap = { + [1] = "font-weight: bold", [2] = "opacity: 0.5", [4] = "text-decoration: underline", [8] = "visibility: hidden", + [30] = "color:black", [31] = "color:red", [32]="color:green", [33]="color:#FFD700", + [34] = "color:blue", [35] = "color: magenta", [36] = "color:cyan", [37] = "color: white", + [40] = "background-color:black", [41] = "background-color:red", [42]="background-color:green", + [43]="background-color:yellow", [44] = "background-color:blue", [45] = "background-color: magenta", + [46] = "background-color:cyan", [47] = "background-color: white"; +}; + local fmt_string = char(0x1B).."[%sm%s"..char(0x1B).."[0m"; function getstring(style, text) if style then @@ -76,4 +86,17 @@ if windows then end end +local function ansi2css(ansi_codes) + if ansi_codes == "0" then return "</span>"; end + local css = {}; + for code in ansi_codes:gmatch("[^;]+") do + t_insert(css, cssmap[tonumber(code)]); + end + return "</span><span style='"..t_concat(css, ";").."'>"; +end + +function tohtml(input) + return input:gsub("\027%[(.-)m", ansi2css); +end + return _M; diff --git a/util/throttle.lua b/util/throttle.lua new file mode 100644 index 00000000..55e1d07b --- /dev/null +++ b/util/throttle.lua @@ -0,0 +1,46 @@ + +local gettime = require "socket".gettime; +local setmetatable = setmetatable; +local floor = math.floor; + +module "throttle" + +local throttle = {}; +local throttle_mt = { __index = throttle }; + +function throttle:update() + local newt = gettime(); + local elapsed = newt - self.t; + self.t = newt; + local balance = floor(self.rate * elapsed) + self.balance; + if balance > self.max then + self.balance = self.max; + else + self.balance = balance; + end + return self.balance; +end + +function throttle:peek(cost) + cost = cost or 1; + return self.balance >= cost or self:update() >= cost; +end + +function throttle:poll(cost, split) + if self:peek(cost) then + self.balance = self.balance - cost; + return true; + else + local balance = self.balance; + if split then + self.balance = 0; + end + return false, balance, (cost-balance); + end +end + +function create(max, period) + return setmetatable({ rate = max / period, max = max, t = 0, balance = max }, throttle_mt); +end + +return _M; diff --git a/util/timer.lua b/util/timer.lua index 3061da72..0e10e144 100644 --- a/util/timer.lua +++ b/util/timer.lua @@ -1,22 +1,17 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- - -local ns_addtimer = require "net.server".addtimer; -local event = require "net.server".event; -local event_base = require "net.server".event_base; - +local server = require "net.server"; local math_min = math.min local math_huge = math.huge local get_time = require "socket".gettime; local t_insert = table.insert; -local t_remove = table.remove; -local ipairs, pairs = ipairs, pairs; +local pairs = pairs; local type = type; local data = {}; @@ -25,18 +20,21 @@ local new_data = {}; module "timer" local _add_task; -if not event then - function _add_task(delay, func) +if not server.event then + function _add_task(delay, callback) local current_time = get_time(); delay = delay + current_time; if delay >= current_time then - t_insert(new_data, {delay, func}); + t_insert(new_data, {delay, callback}); else - func(); + local r = callback(current_time); + if r and type(r) == "number" then + return _add_task(r, callback); + end end end - ns_addtimer(function() + server._addtimer(function() local current_time = get_time(); if #new_data > 0 then for _, d in pairs(new_data) do @@ -44,15 +42,15 @@ if not event then end new_data = {}; end - + local next_time = math_huge; for i, d in pairs(data) do - local t, func = d[1], d[2]; + local t, callback = d[1], d[2]; if t <= current_time then data[i] = nil; - local r = func(current_time); + local r = callback(current_time); if type(r) == "number" then - _add_task(r, func); + _add_task(r, callback); next_time = math_min(next_time, r); end else @@ -62,11 +60,14 @@ if not event then return next_time; end); else + local event = server.event; + local event_base = server.event_base; local EVENT_LEAVE = (event.core and event.core.LEAVE) or -1; - function _add_task(delay, func) + + function _add_task(delay, callback) local event_handle; event_handle = event_base:addevent(nil, 0, function () - local ret = func(); + local ret = callback(get_time()); if ret then return 0, ret; elseif event_handle then diff --git a/util/uuid.lua b/util/uuid.lua index 796c8ee4..fc487c72 100644 --- a/util/uuid.lua +++ b/util/uuid.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- diff --git a/util/watchdog.lua b/util/watchdog.lua new file mode 100644 index 00000000..bcb2e274 --- /dev/null +++ b/util/watchdog.lua @@ -0,0 +1,34 @@ +local timer = require "util.timer"; +local setmetatable = setmetatable; +local os_time = os.time; + +module "watchdog" + +local watchdog_methods = {}; +local watchdog_mt = { __index = watchdog_methods }; + +function new(timeout, callback) + local watchdog = setmetatable({ timeout = timeout, last_reset = os_time(), callback = callback }, watchdog_mt); + timer.add_task(timeout+1, function (current_time) + local last_reset = watchdog.last_reset; + if not last_reset then + return; + end + local time_left = (last_reset + timeout) - current_time; + if time_left < 0 then + return watchdog:callback(); + end + return time_left + 1; + end); + return watchdog; +end + +function watchdog_methods:reset() + self.last_reset = os_time(); +end + +function watchdog_methods:cancel() + self.last_reset = nil; +end + +return _M; diff --git a/util/x509.lua b/util/x509.lua index 11f231a0..19d4ec6d 100644 --- a/util/x509.lua +++ b/util/x509.lua @@ -11,8 +11,8 @@ -- IDN libraries complicate that. --- [TLS-CERTS] - http://tools.ietf.org/html/draft-saintandre-tls-server-id-check-10 --- [XMPP-CORE] - http://tools.ietf.org/html/draft-ietf-xmpp-3920bis-18 +-- [TLS-CERTS] - http://tools.ietf.org/html/rfc6125 +-- [XMPP-CORE] - http://tools.ietf.org/html/rfc6120 -- [SRV-ID] - http://tools.ietf.org/html/rfc4985 -- [IDNA] - http://tools.ietf.org/html/rfc5890 -- [LDAP] - http://tools.ietf.org/html/rfc4519 @@ -21,6 +21,10 @@ local nameprep = require "util.encodings".stringprep.nameprep; local idna_to_ascii = require "util.encodings".idna.to_ascii; local log = require "util.logger".init("x509"); +local pairs, ipairs = pairs, ipairs; +local s_format = string.format; +local t_insert = table.insert; +local t_concat = table.concat; module "x509" @@ -32,7 +36,7 @@ local oid_dnssrv = "1.3.6.1.5.5.7.8.7"; -- [SRV-ID] -- Compare a hostname (possibly international) with asserted names -- extracted from a certificate. -- This function follows the rules laid out in --- sections 4.4.1 and 4.4.2 of [TLS-CERTS] +-- sections 6.4.1 and 6.4.2 of [TLS-CERTS] -- -- A wildcard ("*") all by itself is allowed only as the left-most label local function compare_dnsname(host, asserted_names) @@ -150,7 +154,7 @@ function verify_identity(host, service, cert) if ext[oid_subjectaltname] then local sans = ext[oid_subjectaltname]; - -- Per [TLS-CERTS] 4.3, 4.4.4, "a client MUST NOT seek a match for a + -- Per [TLS-CERTS] 6.3, 6.4.4, "a client MUST NOT seek a match for a -- reference identifier if the presented identifiers include a DNS-ID -- SRV-ID, URI-ID, or any application-specific identifier types" local had_supported_altnames = false @@ -183,7 +187,7 @@ function verify_identity(host, service, cert) -- a dNSName subjectAltName (wildcards may apply for, and receive, -- cat treats) -- - -- Per [TLS-CERTS] 1.5, a CN-ID is the Common Name from a cert subject + -- Per [TLS-CERTS] 1.8, a CN-ID is the Common Name from a cert subject -- which has one and only one Common Name local subject = cert:subject() local cn = nil @@ -200,7 +204,7 @@ function verify_identity(host, service, cert) end if cn then - -- Per [TLS-CERTS] 4.4.4, follow the comparison rules for dNSName SANs. + -- Per [TLS-CERTS] 6.4.4, follow the comparison rules for dNSName SANs. return compare_dnsname(host, { cn }) end diff --git a/util/xml.lua b/util/xml.lua new file mode 100644 index 00000000..6dbed65d --- /dev/null +++ b/util/xml.lua @@ -0,0 +1,57 @@ + +local st = require "util.stanza"; +local lxp = require "lxp"; + +module("xml") + +local parse_xml = (function() + local ns_prefixes = { + ["http://www.w3.org/XML/1998/namespace"] = "xml"; + }; + local ns_separator = "\1"; + local ns_pattern = "^([^"..ns_separator.."]*)"..ns_separator.."?(.*)$"; + return function(xml) + local handler = {}; + local stanza = st.stanza("root"); + function handler:StartElement(tagname, attr) + local curr_ns,name = tagname:match(ns_pattern); + if name == "" then + curr_ns, name = "", curr_ns; + end + if curr_ns ~= "" then + attr.xmlns = curr_ns; + end + for i=1,#attr do + local k = attr[i]; + attr[i] = nil; + local ns, nm = k:match(ns_pattern); + if nm ~= "" then + ns = ns_prefixes[ns]; + if ns then + attr[ns..":"..nm] = attr[k]; + attr[k] = nil; + end + end + end + stanza:tag(name, attr); + end + function handler:CharacterData(data) + stanza:text(data); + end + function handler:EndElement(tagname) + stanza:up(); + end + local parser = lxp.new(handler, "\1"); + local ok, err, line, col = parser:parse(xml); + if ok then ok, err, line, col = parser:parse(); end + --parser:close(); + if ok then + return stanza.tags[1]; + else + return ok, err.." (line "..line..", col "..col..")"; + end + end; +end)(); + +parse = parse_xml; +return _M; diff --git a/util/xmlrpc.lua b/util/xmlrpc.lua deleted file mode 100644 index 29815b0d..00000000 --- a/util/xmlrpc.lua +++ /dev/null @@ -1,182 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - -local pairs = pairs; -local type = type; -local error = error; -local t_concat = table.concat; -local t_insert = table.insert; -local tostring = tostring; -local tonumber = tonumber; -local select = select; -local st = require "util.stanza"; - -module "xmlrpc" - -local _lua_to_xmlrpc; -local map = { - table=function(stanza, object) - stanza:tag("struct"); - for name, value in pairs(object) do - stanza:tag("member"); - stanza:tag("name"):text(tostring(name)):up(); - stanza:tag("value"); - _lua_to_xmlrpc(stanza, value); - stanza:up(); - stanza:up(); - end - stanza:up(); - end; - boolean=function(stanza, object) - stanza:tag("boolean"):text(object and "1" or "0"):up(); - end; - string=function(stanza, object) - stanza:tag("string"):text(object):up(); - end; - number=function(stanza, object) - stanza:tag("int"):text(tostring(object)):up(); - end; - ["nil"]=function(stanza, object) -- nil extension - stanza:tag("nil"):up(); - end; -}; -_lua_to_xmlrpc = function(stanza, object) - local h = map[type(object)]; - if h then - h(stanza, object); - else - error("Type not supported by XML-RPC: " .. type(object)); - end -end -function create_response(object) - local stanza = st.stanza("methodResponse"):tag("params"):tag("param"):tag("value"); - _lua_to_xmlrpc(stanza, object); - stanza:up():up():up(); - return stanza; -end -function create_error_response(faultCode, faultString) - local stanza = st.stanza("methodResponse"):tag("fault"):tag("value"); - _lua_to_xmlrpc(stanza, {faultCode=faultCode, faultString=faultString}); - stanza:up():up(); - return stanza; -end - -function create_request(method_name, ...) - local stanza = st.stanza("methodCall") - :tag("methodName"):text(method_name):up() - :tag("params"); - for i=1,select('#', ...) do - stanza:tag("param"):tag("value"); - _lua_to_xmlrpc(stanza, select(i, ...)); - stanza:up():up(); - end - stanza:up():up():up(); - return stanza; -end - -local _xmlrpc_to_lua; -local int_parse = function(stanza) - if #stanza.tags ~= 0 or #stanza == 0 then error("<"..stanza.name.."> must have a single text child"); end - local n = tonumber(t_concat(stanza)); - if n then return n; end - error("Failed to parse content of <"..stanza.name..">"); -end -local rmap = { - methodCall=function(stanza) - if #stanza.tags ~= 2 then error("<methodCall> must have exactly two subtags"); end -- FIXME <params> is optional - if stanza.tags[1].name ~= "methodName" then error("First <methodCall> child tag must be <methodName>") end - if stanza.tags[2].name ~= "params" then error("Second <methodCall> child tag must be <params>") end - return _xmlrpc_to_lua(stanza.tags[1]), _xmlrpc_to_lua(stanza.tags[2]); - end; - methodName=function(stanza) - if #stanza.tags ~= 0 then error("<methodName> must not have any subtags"); end - if #stanza == 0 then error("<methodName> must have text content"); end - return t_concat(stanza); - end; - params=function(stanza) - local t = {}; - for _, child in pairs(stanza.tags) do - if child.name ~= "param" then error("<params> can only have <param> children"); end; - t_insert(t, _xmlrpc_to_lua(child)); - end - return t; - end; - param=function(stanza) - if not(#stanza.tags == 1 and stanza.tags[1].name == "value") then error("<param> must have exactly one <value> child"); end - return _xmlrpc_to_lua(stanza.tags[1]); - end; - value=function(stanza) - if #stanza.tags == 0 then return t_concat(stanza); end - if #stanza.tags ~= 1 then error("<value> must have a single child"); end - return _xmlrpc_to_lua(stanza.tags[1]); - end; - int=int_parse; - i4=int_parse; - double=int_parse; - boolean=function(stanza) - if #stanza.tags ~= 0 or #stanza == 0 then error("<boolean> must have a single text child"); end - local b = t_concat(stanza); - if b ~= "1" and b ~= "0" then error("Failed to parse content of <boolean>"); end - return b == "1" and true or false; - end; - string=function(stanza) - if #stanza.tags ~= 0 then error("<string> must have a single text child"); end - return t_concat(stanza); - end; - array=function(stanza) - if #stanza.tags ~= 1 then error("<array> must have a single <data> child"); end - return _xmlrpc_to_lua(stanza.tags[1]); - end; - data=function(stanza) - local t = {}; - for _,child in pairs(stanza.tags) do - if child.name ~= "value" then error("<data> can only have <value> children"); end - t_insert(t, _xmlrpc_to_lua(child)); - end - return t; - end; - struct=function(stanza) - local t = {}; - for _,child in pairs(stanza.tags) do - if child.name ~= "member" then error("<struct> can only have <member> children"); end - local name, value = _xmlrpc_to_lua(child); - t[name] = value; - end - return t; - end; - member=function(stanza) - if #stanza.tags ~= 2 then error("<member> must have exactly two subtags"); end -- FIXME <params> is optional - if stanza.tags[1].name ~= "name" then error("First <member> child tag must be <name>") end - if stanza.tags[2].name ~= "value" then error("Second <member> child tag must be <value>") end - return _xmlrpc_to_lua(stanza.tags[1]), _xmlrpc_to_lua(stanza.tags[2]); - end; - name=function(stanza) - if #stanza.tags ~= 0 then error("<name> must have a single text child"); end - local n = t_concat(stanza) - if tostring(tonumber(n)) == n then n = tonumber(n); end - return n; - end; - ["nil"]=function(stanza) -- nil extension - return nil; - end; -} -_xmlrpc_to_lua = function(stanza) - local h = rmap[stanza.name]; - if h then - return h(stanza); - else - error("Unknown element: "..stanza.name); - end -end -function translate_request(stanza) - if stanza.name ~= "methodCall" then error("XML-RPC requests must have <methodCall> as root element"); end - return _xmlrpc_to_lua(stanza); -end - -return _M; diff --git a/util/xmppstream.lua b/util/xmppstream.lua index cbdadd9b..550170c9 100644 --- a/util/xmppstream.lua +++ b/util/xmppstream.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -9,21 +9,27 @@ local lxp = require "lxp"; local st = require "util.stanza"; +local stanza_mt = st.stanza_mt; +local error = error; local tostring = tostring; local t_insert = table.insert; local t_concat = table.concat; +local t_remove = table.remove; +local setmetatable = setmetatable; -local default_log = require "util.logger".init("xmppstream"); - -local error = error; +-- COMPAT: w/LuaExpat 1.1.0 +local lxp_supports_doctype = pcall(lxp.new, { StartDoctypeDecl = false }); module "xmppstream" local new_parser = lxp.new; -local ns_prefixes = { - ["http://www.w3.org/XML/1998/namespace"] = "xml"; +local xml_namespace = { + ["http://www.w3.org/XML/1998/namespace\1lang"] = "xml:lang"; + ["http://www.w3.org/XML/1998/namespace\1space"] = "xml:space"; + ["http://www.w3.org/XML/1998/namespace\1base"] = "xml:base"; + ["http://www.w3.org/XML/1998/namespace\1id"] = "xml:id"; }; local xmlns_streams = "http://etherx.jabber.org/streams"; @@ -36,29 +42,28 @@ _M.ns_pattern = ns_pattern; function new_sax_handlers(session, stream_callbacks) local xml_handlers = {}; - - local log = session.log or default_log; - + local cb_streamopened = stream_callbacks.streamopened; local cb_streamclosed = stream_callbacks.streamclosed; - local cb_error = stream_callbacks.error or function(session, e) error("XML stream error: "..tostring(e)); end; + local cb_error = stream_callbacks.error or function(session, e, stanza) error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); end; local cb_handlestanza = stream_callbacks.handlestanza; - + local stream_ns = stream_callbacks.stream_ns or xmlns_streams; local stream_tag = stream_callbacks.stream_tag or "stream"; if stream_ns ~= "" then stream_tag = stream_ns..ns_separator..stream_tag; end local stream_error_tag = stream_ns..ns_separator..(stream_callbacks.error_tag or "error"); - + local stream_default_ns = stream_callbacks.default_ns; - + + local stack = {}; local chardata, stanza = {}; local non_streamns_depth = 0; function xml_handlers:StartElement(tagname, attr) if stanza and #chardata > 0 then -- We have some character data in the buffer - stanza:text(t_concat(chardata)); + t_insert(stanza, t_concat(chardata)); chardata = {}; end local curr_ns,name = tagname:match(ns_pattern); @@ -70,21 +75,17 @@ function new_sax_handlers(session, stream_callbacks) attr.xmlns = curr_ns; non_streamns_depth = non_streamns_depth + 1; end - - -- FIXME !!!!! + for i=1,#attr do local k = attr[i]; attr[i] = nil; - local ns, nm = k:match(ns_pattern); - if nm ~= "" then - ns = ns_prefixes[ns]; - if ns then - attr[ns..":"..nm] = attr[k]; - attr[k] = nil; - end + local xmlk = xml_namespace[k]; + if xmlk then + attr[xmlk] = attr[k]; + attr[k] = nil; end end - + if not stanza then --if we are not currently inside a stanza if session.notopen then if tagname == stream_tag then @@ -101,10 +102,14 @@ function new_sax_handlers(session, stream_callbacks) if curr_ns == "jabber:client" and name ~= "iq" and name ~= "presence" and name ~= "message" then cb_error(session, "invalid-top-level-element"); end - - stanza = st.stanza(name, attr); + + stanza = setmetatable({ name = name, attr = attr, tags = {} }, stanza_mt); else -- we are inside a stanza, so add a tag - stanza:tag(name, attr); + t_insert(stack, stanza); + local oldstanza = stanza; + stanza = setmetatable({ name = name, attr = attr, tags = {} }, stanza_mt); + t_insert(oldstanza, stanza); + t_insert(oldstanza.tags, stanza); end end function xml_handlers:CharacterData(data) @@ -119,12 +124,11 @@ function new_sax_handlers(session, stream_callbacks) if stanza then if #chardata > 0 then -- We have some character data in the buffer - stanza:text(t_concat(chardata)); + t_insert(stanza, t_concat(chardata)); chardata = {}; end -- Complete stanza - local last_add = stanza.last_add; - if not last_add or #last_add == 0 then + if #stack == 0 then if tagname ~= stream_error_tag then cb_handlestanza(session, stanza); else @@ -132,33 +136,37 @@ function new_sax_handlers(session, stream_callbacks) end stanza = nil; else - stanza:up(); + stanza = t_remove(stack); end else - if tagname == stream_tag then - if cb_streamclosed then - cb_streamclosed(session); - end - else - local curr_ns,name = tagname:match(ns_pattern); - if name == "" then - curr_ns, name = "", curr_ns; - end - cb_error(session, "parse-error", "unexpected-element-close", name); + if cb_streamclosed then + cb_streamclosed(session); end - stanza, chardata = nil, {}; end end - + + local function restricted_handler(parser) + cb_error(session, "parse-error", "restricted-xml", "Restricted XML, see RFC 6120 section 11.1."); + if not parser.stop or not parser:stop() then + error("Failed to abort parsing"); + end + end + + if lxp_supports_doctype then + xml_handlers.StartDoctypeDecl = restricted_handler; + end + xml_handlers.Comment = restricted_handler; + xml_handlers.ProcessingInstruction = restricted_handler; + local function reset() stanza, chardata = nil, {}; + stack = {}; end - + local function set_session(stream, new_session) session = new_session; - log = new_session.log or default_log; end - + return xml_handlers, { reset = reset, set_session = set_session }; end |