diff options
Diffstat (limited to 'util')
-rw-r--r-- | util/array.lua | 4 | ||||
-rw-r--r-- | util/broadcast.lua | 4 | ||||
-rw-r--r-- | util/caps.lua | 61 | ||||
-rw-r--r-- | util/dataforms.lua | 41 | ||||
-rw-r--r-- | util/datamanager.lua | 18 | ||||
-rw-r--r-- | util/datetime.lua | 22 | ||||
-rw-r--r-- | util/dependencies.lua | 26 | ||||
-rw-r--r-- | util/events.lua | 56 | ||||
-rw-r--r-- | util/filters.lua | 87 | ||||
-rw-r--r-- | util/hmac.lua | 2 | ||||
-rw-r--r-- | util/httpstream.lua | 137 | ||||
-rw-r--r-- | util/iterators.lua | 13 | ||||
-rw-r--r-- | util/jid.lua | 15 | ||||
-rw-r--r-- | util/json.lua | 358 | ||||
-rw-r--r-- | util/logger.lua | 55 | ||||
-rw-r--r-- | util/pluginloader.lua | 56 | ||||
-rw-r--r-- | util/prosodyctl.lua | 114 | ||||
-rw-r--r-- | util/sasl.lua | 88 | ||||
-rw-r--r-- | util/sasl/anonymous.lua | 18 | ||||
-rw-r--r-- | util/sasl/digest-md5.lua | 8 | ||||
-rw-r--r-- | util/sasl/plain.lua | 10 | ||||
-rw-r--r-- | util/sasl/scram.lua | 57 | ||||
-rw-r--r-- | util/sasl_cyrus.lua | 72 | ||||
-rw-r--r-- | util/serialization.lua | 23 | ||||
-rw-r--r-- | util/set.lua | 2 | ||||
-rw-r--r-- | util/stanza.lua | 117 | ||||
-rw-r--r-- | util/template.lua | 133 | ||||
-rw-r--r-- | util/termcolours.lua | 37 | ||||
-rw-r--r-- | util/timer.lua | 13 | ||||
-rw-r--r-- | util/xmppstream.lua | 204 | ||||
-rw-r--r-- | util/ztact.lua | 366 |
31 files changed, 1523 insertions, 694 deletions
diff --git a/util/array.lua b/util/array.lua index 98c0ebe8..6c1f0460 100644 --- a/util/array.lua +++ b/util/array.lua @@ -6,8 +6,8 @@ -- COPYING file in the source package for more information. -- -local t_insert, t_sort, t_remove, t_concat - = table.insert, table.sort, table.remove, table.concat; +local t_insert, t_sort, t_remove, t_concat + = table.insert, table.sort, table.remove, table.concat; local array = {}; local array_base = {}; diff --git a/util/broadcast.lua b/util/broadcast.lua index c74bf4e1..be17461d 100644 --- a/util/broadcast.lua +++ b/util/broadcast.lua @@ -7,8 +7,8 @@ -- -local ipairs, pairs, setmetatable, type = - ipairs, pairs, setmetatable, type; +local ipairs, pairs, setmetatable, type = + ipairs, pairs, setmetatable, type; module "pubsub" diff --git a/util/caps.lua b/util/caps.lua new file mode 100644 index 00000000..a61e7403 --- /dev/null +++ b/util/caps.lua @@ -0,0 +1,61 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local base64 = require "util.encodings".base64.encode; +local sha1 = require "util.hashes".sha1; + +local t_insert, t_sort, t_concat = table.insert, table.sort, table.concat; +local ipairs = ipairs; + +module "caps" + +function calculate_hash(disco_info) + local identities, features, extensions = {}, {}, {}; + for _, tag in ipairs(disco_info) do + if tag.name == "identity" then + t_insert(identities, (tag.attr.category or "").."\0"..(tag.attr.type or "").."\0"..(tag.attr["xml:lang"] or "").."\0"..(tag.attr.name or "")); + elseif tag.name == "feature" then + t_insert(features, tag.attr.var or ""); + elseif tag.name == "x" and tag.attr.xmlns == "jabber:x:data" then + local form = {}; + local FORM_TYPE; + for _, field in ipairs(tag.tags) do + if field.name == "field" and field.attr.var then + local values = {}; + for _, val in ipairs(field.tags) do + val = #val.tags == 0 and val:get_text(); + if val then t_insert(values, val); end + end + t_sort(values); + if field.attr.var == "FORM_TYPE" then + FORM_TYPE = values[1]; + elseif #values > 0 then + t_insert(form, field.attr.var.."\0"..t_concat(values, "<")); + else + t_insert(form, field.attr.var); + end + end + end + t_sort(form); + form = t_concat(form, "<"); + if FORM_TYPE then form = FORM_TYPE.."\0"..form; end + t_insert(extensions, form); + end + end + t_sort(identities); + t_sort(features); + t_sort(extensions); + if #identities > 0 then identities = t_concat(identities, "<"):gsub("%z", "/").."<"; else identities = ""; end + if #features > 0 then features = t_concat(features, "<").."<"; else features = ""; end + if #extensions > 0 then extensions = t_concat(extensions, "<"):gsub("%z", "<").."<"; else extensions = ""; end + local S = identities..features..extensions; + local ver = base64(sha1(S)); + return ver, S; +end + +return _M; diff --git a/util/dataforms.lua b/util/dataforms.lua index 5a3b1fb5..ae745e03 100644 --- a/util/dataforms.lua +++ b/util/dataforms.lua @@ -67,9 +67,25 @@ function form_t.form(layout, data, formtype) form:tag("value"):text(line):up(); end elseif field_type == "list-single" then + local has_default = false; for _, val in ipairs(value) do if type(val) == "table" then form:tag("option", { label = val.label }):tag("value"):text(val.value):up():up(); + if val.default and (not has_default) then + form:tag("value"):text(val.value):up(); + has_default = true; + end + else + form:tag("option", { label= val }):tag("value"):text(tostring(val)):up():up(); + end + end + elseif field_type == "list-multi" then + for _, val in ipairs(value) do + if type(val) == "table" then + form:tag("option", { label = val.label }):tag("value"):text(val.value):up():up(); + if val.default then + form:tag("value"):text(val.value):up(); + end else form:tag("option", { label= val }):tag("value"):text(tostring(val)):up():up(); end @@ -110,7 +126,7 @@ function form_t.data(layout, stanza) return data; end -field_readers["text-single"] = +field_readers["text-single"] = function (field_tag) local value = field_tag:child_with_name("value"); if value then @@ -118,13 +134,13 @@ field_readers["text-single"] = end end -field_readers["text-private"] = +field_readers["text-private"] = field_readers["text-single"]; field_readers["jid-single"] = field_readers["text-single"]; -field_readers["jid-multi"] = +field_readers["jid-multi"] = function (field_tag) local result = {}; for value_tag in field_tag:childtags() do @@ -135,7 +151,7 @@ field_readers["jid-multi"] = return result; end -field_readers["text-multi"] = +field_readers["text-multi"] = function (field_tag) local result = {}; for value_tag in field_tag:childtags() do @@ -149,7 +165,18 @@ field_readers["text-multi"] = field_readers["list-single"] = field_readers["text-single"]; -field_readers["boolean"] = +field_readers["list-multi"] = + function (field_tag) + local result = {}; + for value_tag in field_tag:childtags() do + if value_tag.name == "value" then + result[#result+1] = value_tag[1]; + end + end + return result; + end + +field_readers["boolean"] = function (field_tag) local value = field_tag:child_with_name("value"); if value then @@ -158,10 +185,10 @@ field_readers["boolean"] = else return false; end - end + end end -field_readers["hidden"] = +field_readers["hidden"] = function (field_tag) local value = field_tag:child_with_name("value"); if value then diff --git a/util/datamanager.lua b/util/datamanager.lua index 57cd2594..d5e9c88c 100644 --- a/util/datamanager.lua +++ b/util/datamanager.lua @@ -22,6 +22,7 @@ local t_insert = table.insert; local append = require "util.serialization".append; local path_separator = "/"; if os.getenv("WINDIR") then path_separator = "\\" end local lfs = require "lfs"; +local prosody = prosody; local raw_mkdir; if prosody.platform == "posix" then @@ -56,7 +57,7 @@ local function mkdir(path) return path; end -local data_path = "data"; +local data_path = (prosody and prosody.paths and prosody.paths.data) or "."; local callbacks = {}; ------- API ------------- @@ -114,7 +115,7 @@ function load(username, host, datastore) if not data then local mode = lfs.attributes(getpath(username, host, datastore), "mode"); if not mode then - log("debug", "Failed to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil")); + log("debug", "Assuming empty "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil")); return nil; else -- file exists, but can't be read -- TODO more detailed error checking and logging? @@ -204,15 +205,22 @@ end function list_load(username, host, datastore) local data, ret = loadfile(getpath(username, host, datastore, "list")); if not data then - log("debug", "Failed to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil")); - return nil; + local mode = lfs.attributes(getpath(username, host, datastore, "list"), "mode"); + if not mode then + log("debug", "Assuming empty "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil")); + return nil; + else -- file exists, but can't be read + -- TODO more detailed error checking and logging? + log("error", "Failed to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil")); + return nil, "Error reading storage"; + end end local items = {}; setfenv(data, {item = function(i) t_insert(items, i); end}); local success, ret = pcall(data); if not success then log("error", "Unable to load "..datastore.." storage ('"..ret.."') for user: "..(username or "nil").."@"..(host or "nil")); - return nil; + return nil, "Error reading storage"; end return items; end diff --git a/util/datetime.lua b/util/datetime.lua index cf00e4c3..c73d8e76 100644 --- a/util/datetime.lua +++ b/util/datetime.lua @@ -10,7 +10,10 @@ -- XEP-0082: XMPP Date and Time Profiles local os_date = os.date; +local os_time = os.time; +local os_difftime = os.difftime; local error = error; +local tonumber = tonumber; module "datetime" @@ -31,7 +34,24 @@ function legacy(t) end function parse(s) - error("datetime.parse: Not implemented"); -- TODO + if s then + local year, month, day, hour, min, sec, tzd; + year, month, day, hour, min, sec, tzd = s:match("^(%d%d%d%d)-?(%d%d)-?(%d%d)T(%d%d):(%d%d):(%d%d)%.?%d*([Z+%-].*)$"); + if year then + local time_offset = os_difftime(os_time(os_date("*t")), os_time(os_date("!*t"))); -- to deal with local timezone + local tzd_offset = 0; + if tzd ~= "" and tzd ~= "Z" then + local sign, h, m = tzd:match("([+%-])(%d%d):?(%d*)"); + if not sign then return; end + if #m ~= 2 then m = "0"; end + h, m = tonumber(h), tonumber(m); + tzd_offset = h * 60 * 60 + m * 60; + if sign == "-" then tzd_offset = -tzd_offset; end + end + sec = (sec + time_offset) - tzd_offset; + return os_time({year=year, month=month, day=day, hour=hour, min=min, sec=sec, isdst=false}); + end + end end return _M; diff --git a/util/dependencies.lua b/util/dependencies.lua index 6024dd63..5baea942 100644 --- a/util/dependencies.lua +++ b/util/dependencies.lua @@ -35,6 +35,19 @@ function missingdep(name, sources, msg) print(""); end +-- COMPAT w/pre-0.8 Debian: The Debian config file used to use +-- util.ztact, which has been removed from Prosody in 0.8. This +-- is to log an error for people who still use it, so they can +-- update their configs. +package.preload["util.ztact"] = function () + if not package.loaded["core.loggingmanager"] then + error("util.ztact has been removed from Prosody and you need to fix your config " + .."file. More information can be found at http://prosody.im/doc/packagers#ztact", 0); + else + error("module 'util.ztact' has been deprecated in Prosody 0.8."); + end +end; + function check_dependencies() local fatal; @@ -78,11 +91,6 @@ function check_dependencies() ["luarocks"] = "luarocks install luasec"; ["Source"] = "http://www.inf.puc-rio.br/~brunoos/luasec/"; }, "SSL/TLS support will not be available"); - else - local major, minor, veryminor, patched = ssl._VERSION:match("(%d+)%.(%d+)%.?(%d*)(M?)"); - if not major or ((tonumber(major) == 0 and (tonumber(minor) or 0) <= 3 and (tonumber(veryminor) or 0) <= 2) and patched ~= "M") then - log("error", "This version of LuaSec contains a known bug that causes disconnects, see http://prosody.im/doc/depends"); - end end local encodings, err = softreq "util.encodings" @@ -121,5 +129,13 @@ function check_dependencies() return not fatal; end +function log_warnings() + if ssl then + local major, minor, veryminor, patched = ssl._VERSION:match("(%d+)%.(%d+)%.?(%d*)(M?)"); + if not major or ((tonumber(major) == 0 and (tonumber(minor) or 0) <= 3 and (tonumber(veryminor) or 0) <= 2) and patched ~= "M") then + log("error", "This version of LuaSec contains a known bug that causes disconnects, see http://prosody.im/doc/depends"); + end + end +end return _M; diff --git a/util/events.lua b/util/events.lua index 363d2ac6..412acccd 100644 --- a/util/events.lua +++ b/util/events.lua @@ -7,29 +7,29 @@ -- -local ipairs = ipairs; local pairs = pairs; local t_insert = table.insert; local t_sort = table.sort; -local select = select; +local setmetatable = setmetatable; +local next = next; module "events" function new() - local dispatchers = {}; local handlers = {}; local event_map = {}; - local function _rebuild_index(event) -- TODO optimize index rebuilding + local function _rebuild_index(handlers, event) local _handlers = event_map[event]; - local index = handlers[event]; - if index then - for i=#index,1,-1 do index[i] = nil; end - else index = {}; handlers[event] = index; end + if not _handlers or next(_handlers) == nil then return; end + local index = {}; for handler in pairs(_handlers) do t_insert(index, handler); end t_sort(index, function(a, b) return _handlers[a] > _handlers[b]; end); + handlers[event] = index; + return index; end; + setmetatable(handlers, { __index = _rebuild_index }); local function add_handler(event, handler, priority) local map = event_map[event]; if map then @@ -38,13 +38,16 @@ function new() map = {[handler] = priority or 0}; event_map[event] = map; end - _rebuild_index(event); + handlers[event] = nil; end; local function remove_handler(event, handler) local map = event_map[event]; if map then map[handler] = nil; - _rebuild_index(event); + handlers[event] = nil; + if next(map) == nil then + event_map[event] = nil; + end end end; local function add_handlers(handlers) @@ -57,22 +60,7 @@ function new() remove_handler(event, handler); end end; - local function _create_dispatcher(event) -- FIXME duplicate code in fire_event - local h = handlers[event]; - if not h then h = {}; handlers[event] = h; end - local dispatcher = function(...) - for i=1,#h do - local ret = h[i](...); - if ret ~= nil then return ret; end - end - end; - dispatchers[event] = dispatcher; - return dispatcher; - end; - local function get_dispatcher(event) - return dispatchers[event] or _create_dispatcher(event); - end; - local function fire_event(event, ...) -- FIXME duplicates dispatcher code + local function fire_event(event, ...) local h = handlers[event]; if h then for i=1,#h do @@ -81,24 +69,12 @@ function new() end end end; - local function get_named_arg_dispatcher(event, ...) - local dispatcher = get_dispatcher(event); - local keys = {...}; - local data = {}; - return function(...) - for i, key in ipairs(keys) do data[key] = select(i, ...); end - dispatcher(data); - end; - end; return { add_handler = add_handler; remove_handler = remove_handler; - add_plugin = add_plugin; - remove_plugin = remove_plugin; - get_dispatcher = get_dispatcher; + add_handlers = add_handlers; + remove_handlers = remove_handlers; fire_event = fire_event; - get_named_arg_dispatcher = get_named_arg_dispatcher; - _dispatchers = dispatchers; _handlers = handlers; _event_map = event_map; }; diff --git a/util/filters.lua b/util/filters.lua new file mode 100644 index 00000000..d143666b --- /dev/null +++ b/util/filters.lua @@ -0,0 +1,87 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local t_insert, t_remove = table.insert, table.remove; + +module "filters" + +local new_filter_hooks = {}; + +function initialize(session) + if not session.filters then + local filters = {}; + session.filters = filters; + + function session.filter(type, data) + local filter_list = filters[type]; + if filter_list then + for i = 1, #filter_list do + data = filter_list[i](data, session); + if data == nil then break; end + end + end + return data; + end + end + + for i=1,#new_filter_hooks do + new_filter_hooks[i](session); + end + + return session.filter; +end + +function add_filter(session, type, callback, priority) + if not session.filters then + initialize(session); + end + + local filter_list = session.filters[type]; + if not filter_list then + filter_list = {}; + session.filters[type] = filter_list; + end + + priority = priority or 0; + + local i = 0; + repeat + i = i + 1; + until not filter_list[i] or filter_list[filter_list[i]] >= priority; + + t_insert(filter_list, i, callback); + filter_list[callback] = priority; +end + +function remove_filter(session, type, callback) + if not session.filters then return; end + local filter_list = session.filters[type]; + if filter_list and filter_list[callback] then + for i=1, #filter_list do + if filter_list[i] == callback then + t_remove(filter_list, i); + filter_list[callback] = nil; + return true; + end + end + end +end + +function add_filter_hook(callback) + t_insert(new_filter_hooks, callback); +end + +function remove_filter_hook(callback) + for i=1,#new_filter_hooks do + if new_filter_hooks[i] == callback then + t_remove(new_filter_hooks, i); + end + end +end + +return _M; diff --git a/util/hmac.lua b/util/hmac.lua index 66dd41d8..6df6986e 100644 --- a/util/hmac.lua +++ b/util/hmac.lua @@ -40,7 +40,7 @@ hash blocksize the blocksize for the hash function in bytes hex - return raw hash or hexadecimal string + return raw hash or hexadecimal string --]] function hmac(key, message, hash, blocksize, hex) if #key > blocksize then diff --git a/util/httpstream.lua b/util/httpstream.lua new file mode 100644 index 00000000..bdc3fce7 --- /dev/null +++ b/util/httpstream.lua @@ -0,0 +1,137 @@ + +local coroutine = coroutine; +local tonumber = tonumber; + +local deadroutine = coroutine.create(function() end); +coroutine.resume(deadroutine); + +module("httpstream") + +local function parser(success_cb, parser_type, options_cb) + local data = coroutine.yield(); + local function readline() + local pos = data:find("\r\n", nil, true); + while not pos do + data = data..coroutine.yield(); + pos = data:find("\r\n", nil, true); + end + local r = data:sub(1, pos-1); + data = data:sub(pos+2); + return r; + end + local function readlength(n) + while #data < n do + data = data..coroutine.yield(); + end + local r = data:sub(1, n); + data = data:sub(n + 1); + return r; + end + local function readheaders() + local headers = {}; -- read headers + while true do + local line = readline(); + if line == "" then break; end -- headers done + local key, val = line:match("^([^%s:]+): *(.*)$"); + if not key then coroutine.yield("invalid-header-line"); end -- TODO handle multi-line and invalid headers + key = key:lower(); + headers[key] = headers[key] and headers[key]..","..val or val; + end + return headers; + end + + if not parser_type or parser_type == "server" then + while true do + -- read status line + local status_line = readline(); + local method, path, httpversion = status_line:match("^(%S+)%s+(%S+)%s+HTTP/(%S+)$"); + if not method then coroutine.yield("invalid-status-line"); end + path = path:gsub("^//+", "/"); -- TODO parse url more + local headers = readheaders(); + + -- read body + local len = tonumber(headers["content-length"]); + len = len or 0; -- TODO check for invalid len + local body = readlength(len); + + success_cb({ + method = method; + path = path; + httpversion = httpversion; + headers = headers; + body = body; + }); + end + elseif parser_type == "client" then + while true do + -- read status line + local status_line = readline(); + local httpversion, status_code, reason_phrase = status_line:match("^HTTP/(%S+)%s+(%d%d%d)%s+(.*)$"); + status_code = tonumber(status_code); + if not status_code then coroutine.yield("invalid-status-line"); end + local headers = readheaders(); + + -- read body + local have_body = not + ( (options_cb and options_cb().method == "HEAD") + or (status_code == 204 or status_code == 304 or status_code == 301) + or (status_code >= 100 and status_code < 200) ); + + local body; + if have_body then + local len = tonumber(headers["content-length"]); + if headers["transfer-encoding"] == "chunked" then + body = ""; + while true do + local chunk_size = readline():match("^%x+"); + if not chunk_size then coroutine.yield("invalid-chunk-size"); end + chunk_size = tonumber(chunk_size, 16) + if chunk_size == 0 then break; end + body = body..readlength(chunk_size); + if readline() ~= "" then coroutine.yield("invalid-chunk-ending"); end + end + local trailers = readheaders(); + elseif len then -- TODO check for invalid len + body = readlength(len); + else -- read to end + repeat + local newdata = coroutine.yield(); + data = data..newdata; + until newdata == ""; + body, data = data, ""; + end + end + + success_cb({ + code = status_code; + httpversion = httpversion; + headers = headers; + body = body; + -- COMPAT the properties below are deprecated + responseversion = httpversion; + responseheaders = headers; + }); + end + else coroutine.yield("unknown-parser-type"); end +end + +function new(success_cb, error_cb, parser_type, options_cb) + local co = coroutine.create(parser); + coroutine.resume(co, success_cb, parser_type, options_cb) + return { + feed = function(self, data) + if not data then + if parser_type == "client" then coroutine.resume(co, ""); end + co = deadroutine; + return error_cb(); + end + local success, result = coroutine.resume(co, data); + if result then + co = deadroutine; + return error_cb(result); + end + end; + }; +end + +return _M; diff --git a/util/iterators.lua b/util/iterators.lua index 318c1a96..dc692d64 100644 --- a/util/iterators.lua +++ b/util/iterators.lua @@ -73,7 +73,7 @@ function count(f, s, var) var = ret[1]; if var == nil then break; end x = x + 1; - end + end return x; end @@ -90,6 +90,15 @@ function head(n, f, s, var) end, s; end +-- Skip the first n items an iterator returns +function skip(n, f, s, var) + for i=1,n do + var = f(s, var); + end + return f, s, var; +end + +-- Return the last n items an iterator returns function tail(n, f, s, var) local results, count = {}, 0; while true do @@ -122,7 +131,7 @@ function it2array(f, s, var) return t; end --- Treat the return of an iterator as key,value pairs, +-- Treat the return of an iterator as key,value pairs, -- and build a table function it2table(f, s, var) local t, var = {}; diff --git a/util/jid.lua b/util/jid.lua index ba9730fa..069817c6 100644 --- a/util/jid.lua +++ b/util/jid.lua @@ -17,7 +17,7 @@ module "jid" local function _split(jid) if not jid then return; end - local node, nodepos = match(jid, "^([^@]+)@()"); + local node, nodepos = match(jid, "^([^@/]+)@()"); local host, hostpos = match(jid, "^([^@/]+)()", nodepos) if node and not host then return nil, nil, nil; end local resource = match(jid, "^/(.+)$", hostpos); @@ -78,4 +78,17 @@ function join(node, host, resource) return nil; -- Invalid JID end +function compare(jid, acl) + -- compare jid to single acl rule + -- TODO compare to table of rules? + local jid_node, jid_host, jid_resource = _split(jid); + local acl_node, acl_host, acl_resource = _split(acl); + if ((acl_node ~= nil and acl_node == jid_node) or acl_node == nil) and + ((acl_host ~= nil and acl_host == jid_host) or acl_host == nil) and + ((acl_resource ~= nil and acl_resource == jid_resource) or acl_resource == nil) then + return true + end + return false +end + return _M; diff --git a/util/json.lua b/util/json.lua new file mode 100644 index 00000000..05453703 --- /dev/null +++ b/util/json.lua @@ -0,0 +1,358 @@ + +local type = type; +local t_insert, t_concat, t_remove = table.insert, table.concat, table.remove; +local s_char = string.char; +local tostring, tonumber = tostring, tonumber; +local pairs, ipairs = pairs, ipairs; +local next = next; +local error = error; +local newproxy, getmetatable = newproxy, getmetatable; +local print = print; + +--module("json") +local json = {}; + +local null = newproxy and newproxy(true) or {}; +if getmetatable and getmetatable(null) then + getmetatable(null).__tostring = function() return "null"; end; +end +json.null = null; + +local escapes = { + ["\""] = "\\\"", ["\\"] = "\\\\", ["\b"] = "\\b", + ["\f"] = "\\f", ["\n"] = "\\n", ["\r"] = "\\r", ["\t"] = "\\t"}; +local unescapes = { + ["\""] = "\"", ["\\"] = "\\", ["/"] = "/", + b = "\b", f = "\f", n = "\n", r = "\r", t = "\t"}; +for i=0,31 do + local ch = s_char(i); + if not escapes[ch] then escapes[ch] = ("\\u%.4X"):format(i); end +end + +local valid_types = { + number = true, + string = true, + table = true, + boolean = true +}; +local special_keys = { + __array = true; + __hash = true; +}; + +local simplesave, tablesave, arraysave, stringsave; + +function stringsave(o, buffer) + -- FIXME do proper utf-8 and binary data detection + t_insert(buffer, "\""..(o:gsub(".", escapes)).."\""); +end + +function arraysave(o, buffer) + t_insert(buffer, "["); + if next(o) then + for i,v in ipairs(o) do + simplesave(v, buffer); + t_insert(buffer, ","); + end + t_remove(buffer); + end + t_insert(buffer, "]"); +end + +function tablesave(o, buffer) + local __array = {}; + local __hash = {}; + local hash = {}; + for i,v in ipairs(o) do + __array[i] = v; + end + for k,v in pairs(o) do + local ktype, vtype = type(k), type(v); + if valid_types[vtype] or v == null then + if ktype == "string" and not special_keys[k] then + hash[k] = v; + elseif (valid_types[ktype] or k == null) and __array[k] == nil then + __hash[k] = v; + end + end + end + if next(__hash) ~= nil or next(hash) ~= nil or next(__array) == nil then + t_insert(buffer, "{"); + local mark = #buffer; + for k,v in pairs(hash) do + stringsave(k, buffer); + t_insert(buffer, ":"); + simplesave(v, buffer); + t_insert(buffer, ","); + end + if next(__hash) ~= nil then + t_insert(buffer, "\"__hash\":["); + for k,v in pairs(__hash) do + simplesave(k, buffer); + t_insert(buffer, ","); + simplesave(v, buffer); + t_insert(buffer, ","); + end + t_remove(buffer); + t_insert(buffer, "]"); + t_insert(buffer, ","); + end + if next(__array) then + t_insert(buffer, "\"__array\":"); + arraysave(__array, buffer); + t_insert(buffer, ","); + end + if mark ~= #buffer then t_remove(buffer); end + t_insert(buffer, "}"); + else + arraysave(__array, buffer); + end +end + +function simplesave(o, buffer) + local t = type(o); + if t == "number" then + t_insert(buffer, tostring(o)); + elseif t == "string" then + stringsave(o, buffer); + elseif t == "table" then + tablesave(o, buffer); + elseif t == "boolean" then + t_insert(buffer, (o and "true" or "false")); + else + t_insert(buffer, "null"); + end +end + +function json.encode(obj) + local t = {}; + simplesave(obj, t); + return t_concat(t); +end + +----------------------------------- + + +function json.decode(json) + local pos = 1; + local current = {}; + local stack = {}; + local ch, peek; + local function next() + ch = json:sub(pos, pos); + pos = pos+1; + peek = json:sub(pos, pos); + return ch; + end + + local function skipwhitespace() + while ch and (ch == "\r" or ch == "\n" or ch == "\t" or ch == " ") do + next(); + end + end + local function skiplinecomment() + repeat next(); until not(ch) or ch == "\r" or ch == "\n"; + skipwhitespace(); + end + local function skipstarcomment() + next(); next(); -- skip '/', '*' + while peek and ch ~= "*" and peek ~= "/" do next(); end + if not peek then error("eof in star comment") end + next(); next(); -- skip '*', '/' + skipwhitespace(); + end + local function skipstuff() + while true do + skipwhitespace(); + if ch == "/" and peek == "*" then + skipstarcomment(); + elseif ch == "/" and peek == "*" then + skiplinecomment(); + else + return; + end + end + end + + local readvalue; + local function readarray() + local t = {}; + next(); -- skip '[' + skipstuff(); + if ch == "]" then next(); return t; end + t_insert(t, readvalue()); + while true do + skipstuff(); + if ch == "]" then next(); return t; end + if not ch then error("eof while reading array"); + elseif ch == "," then next(); + elseif ch then error("unexpected character in array, comma expected"); end + if not ch then error("eof while reading array"); end + t_insert(t, readvalue()); + end + end + + local function checkandskip(c) + local x = ch or "eof"; + if x ~= c then error("unexpected "..x..", '"..c.."' expected"); end + next(); + end + local function readliteral(lit, val) + for c in lit:gmatch(".") do + checkandskip(c); + end + return val; + end + local function readstring() + local s = ""; + checkandskip("\""); + while ch do + while ch and ch ~= "\\" and ch ~= "\"" do + s = s..ch; next(); + end + if ch == "\\" then + next(); + if unescapes[ch] then + s = s..unescapes[ch]; + next(); + elseif ch == "u" then + local seq = ""; + for i=1,4 do + next(); + if not ch then error("unexpected eof in string"); end + if not ch:match("[0-9a-fA-F]") then error("invalid unicode escape sequence in string"); end + seq = seq..ch; + end + s = s..s.char(tonumber(seq, 16)); -- FIXME do proper utf-8 + next(); + else error("invalid escape sequence in string"); end + end + if ch == "\"" then + next(); + return s; + end + end + error("eof while reading string"); + end + local function readnumber() + local s = ""; + if ch == "-" then + s = s..ch; next(); + if not ch:match("[0-9]") then error("number format error"); end + end + if ch == "0" then + s = s..ch; next(); + if ch:match("[0-9]") then error("number format error"); end + else + while ch and ch:match("[0-9]") do + s = s..ch; next(); + end + end + if ch == "." then + s = s..ch; next(); + if not ch:match("[0-9]") then error("number format error"); end + while ch and ch:match("[0-9]") do + s = s..ch; next(); + end + if ch == "e" or ch == "E" then + s = s..ch; next(); + if ch == "+" or ch == "-" then + s = s..ch; next(); + if not ch:match("[0-9]") then error("number format error"); end + while ch and ch:match("[0-9]") do + s = s..ch; next(); + end + end + end + end + return tonumber(s); + end + local function readmember(t) + skipstuff(); + local k = readstring(); + skipstuff(); + checkandskip(":"); + t[k] = readvalue(); + end + local function fixobject(obj) + local __array = obj.__array; + if __array then + obj.__array = nil; + for i,v in ipairs(__array) do + t_insert(obj, v); + end + end + local __hash = obj.__hash; + if __hash then + obj.__hash = nil; + local k; + for i,v in ipairs(__hash) do + if k ~= nil then + obj[k] = v; k = nil; + else + k = v; + end + end + end + return obj; + end + local function readobject() + local t = {}; + next(); -- skip '{' + skipstuff(); + if ch == "}" then next(); return t; end + if not ch then error("eof while reading object"); end + readmember(t); + while true do + skipstuff(); + if ch == "}" then next(); return fixobject(t); end + if not ch then error("eof while reading object"); + elseif ch == "," then next(); + elseif ch then error("unexpected character in object, comma expected"); end + if not ch then error("eof while reading object"); end + readmember(t); + end + end + + function readvalue() + skipstuff(); + while ch do + if ch == "{" then + return readobject(); + elseif ch == "[" then + return readarray(); + elseif ch == "\"" then + return readstring(); + elseif ch:match("[%-0-9%.]") then + return readnumber(); + elseif ch == "n" then + return readliteral("null", null); + elseif ch == "t" then + return readliteral("true", true); + elseif ch == "f" then + return readliteral("false", false); + else + error("invalid character at value start: "..ch); + end + end + error("eof while reading value"); + end + next(); + return readvalue(); +end + +function json.test(object) + local encoded = json.encode(object); + local decoded = json.decode(encoded); + local recoded = json.encode(decoded); + if encoded ~= recoded then + print("FAILED"); + print("encoded:", encoded); + print("recoded:", recoded); + else + print(encoded); + end + return encoded == recoded; +end + +return json; diff --git a/util/logger.lua b/util/logger.lua index fb0bc37b..c3bf3992 100644 --- a/util/logger.lua +++ b/util/logger.lua @@ -8,9 +8,6 @@ local pcall = pcall; -local config = require "core.configmanager"; -local log_sources = config.get("*", "core", "log_sources"); - local find = string.find; local ipairs, pairs, setmetatable = ipairs, pairs, setmetatable; @@ -19,25 +16,9 @@ module "logger" local name_sinks, level_sinks = {}, {}; local name_patterns = {}; --- Weak-keyed so that loggers are collected -local modify_hooks = setmetatable({}, { __mode = "k" }); - local make_logger; -local outfunction = nil; function init(name) - if log_sources then - local log_this = false; - for _, source in ipairs(log_sources) do - if find(name, source) then - log_this = true; - break; - end - end - - if not log_this then return function () end end - end - local log_debug = make_logger(name, "debug"); local log_info = make_logger(name, "info"); local log_warn = make_logger(name, "warn"); @@ -46,8 +27,6 @@ function init(name) --name = nil; -- While this line is not commented, will automatically fill in file/line number info local namelen = #name; return function (level, message, ...) - if outfunction then return outfunction(name, level, message, ...); end - if level == "debug" then return log_debug(message, ...); elseif level == "info" then @@ -69,38 +48,32 @@ function make_logger(source_name, level) local source_handlers = name_sinks[source_name]; - -- All your premature optimisation is belong to me! - local num_level_handlers, num_source_handlers = #level_handlers, source_handlers and #source_handlers; - local logger = function (message, ...) if source_handlers then - for i = 1,num_source_handlers do + for i = 1,#source_handlers do if source_handlers[i](source_name, level, message, ...) == false then return; end end end - for i = 1,num_level_handlers do + for i = 1,#level_handlers do level_handlers[i](source_name, level, message, ...); end end - -- To make sure our cached lengths stay in sync with reality - modify_hooks[logger] = function () num_level_handlers, num_source_handlers = #level_handlers, source_handlers and #source_handlers; end; - return logger; end -function setwriter(f) - local old_func = outfunction; - if not f then outfunction = nil; return true, old_func; end - local ok, ret = pcall(f, "logger", "info", "Switched logging output successfully"); - if ok then - outfunction = f; - ret = old_func; +function reset() + for k in pairs(name_sinks) do name_sinks[k] = nil; end + for level, handler_list in pairs(level_sinks) do + -- Clear all handlers for this level + for i = 1, #handler_list do + handler_list[i] = nil; + end end - return ok, ret; + for k in pairs(name_patterns) do name_patterns[k] = nil; end end function add_level_sink(level, sink_function) @@ -109,10 +82,6 @@ function add_level_sink(level, sink_function) else level_sinks[level][#level_sinks[level] + 1 ] = sink_function; end - - for _, modify_hook in pairs(modify_hooks) do - modify_hook(); - end end function add_name_sink(name, sink_function, exclusive) @@ -121,10 +90,6 @@ function add_name_sink(name, sink_function, exclusive) else name_sinks[name][#name_sinks[name] + 1] = sink_function; end - - for _, modify_hook in pairs(modify_hooks) do - modify_hook(); - end end function add_name_pattern_sink(name_pattern, sink_function, exclusive) diff --git a/util/pluginloader.lua b/util/pluginloader.lua index 956b92bd..555e41bf 100644 --- a/util/pluginloader.lua +++ b/util/pluginloader.lua @@ -6,41 +6,55 @@ -- COPYING file in the source package for more information. -- - -local plugin_dir = CFG_PLUGINDIR or "./plugins/"; +local dir_sep, path_sep = package.config:match("^(%S+)%s(%S+)"); +local plugin_dir = {}; +for path in (CFG_PLUGINDIR or "./plugins/"):gsub("[/\\]", dir_sep):gmatch("[^"..path_sep.."]+") do + path = path..dir_sep; -- add path separator to path end + path = path:gsub(dir_sep..dir_sep.."+", dir_sep); -- coalesce multiple separaters + plugin_dir[#plugin_dir + 1] = path; +end local io_open, os_time = io.open, os.time; local loadstring, pairs = loadstring, pairs; -local datamanager = require "util.datamanager"; - module "pluginloader" -local function load_file(name) - local file, err = io_open(plugin_dir..name); - if not file then return file, err; end - local content = file:read("*a"); - file:close(); - return content, name; +local function load_file(names) + local file, err, path; + for i=1,#plugin_dir do + for j=1,#names do + path = plugin_dir[i]..names[j]; + file, err = io_open(path); + if file then + local content = file:read("*a"); + file:close(); + return content, path; + end + end + end + return file, err; end -function load_resource(plugin, resource, loader) - if not resource then - resource = "mod_"..plugin..".lua"; - end - loader = loader or load_file; +function load_resource(plugin, resource) + resource = resource or "mod_"..plugin..".lua"; + + local names = { + "mod_"..plugin.."/"..plugin.."/"..resource; -- mod_hello/hello/mod_hello.lua + "mod_"..plugin.."/"..resource; -- mod_hello/mod_hello.lua + plugin.."/"..resource; -- hello/mod_hello.lua + resource; -- mod_hello.lua + }; - local content, err = loader(plugin.."/"..resource); - if not content then content, err = loader(resource); end - -- TODO add support for packed plugins - - return content, err; + return load_file(names); end function load_code(plugin, resource) local content, err = load_resource(plugin, resource); if not content then return content, err; end - return loadstring(content, "@"..err); + local path = err; + local f, err = loadstring(content, "@"..path); + if not f then return f, err; end + return f, path; end return _M; diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua index 04d58d1d..aa1850b2 100644 --- a/util/prosodyctl.lua +++ b/util/prosodyctl.lua @@ -10,19 +10,109 @@ local config = require "core.configmanager"; local encodings = require "util.encodings"; local stringprep = encodings.stringprep; +local storagemanager = require "core.storagemanager"; local usermanager = require "core.usermanager"; local signal = require "util.signal"; +local set = require "util.set"; local lfs = require "lfs"; +local pcall = pcall; local nodeprep, nameprep = stringprep.nodeprep, stringprep.nameprep; local io, os = io, os; +local print = print; local tostring, tonumber = tostring, tonumber; local CFG_SOURCEDIR = _G.CFG_SOURCEDIR; +local _G = _G; +local prosody = prosody; + module "prosodyctl" +-- UI helpers +function show_message(msg, ...) + print(msg:format(...)); +end + +function show_warning(msg, ...) + print(msg:format(...)); +end + +function show_usage(usage, desc) + print("Usage: ".._G.arg[0].." "..usage); + if desc then + print(" "..desc); + end +end + +function getchar(n) + local stty_ret = os.execute("stty raw -echo 2>/dev/null"); + local ok, char; + if stty_ret == 0 then + ok, char = pcall(io.read, n or 1); + os.execute("stty sane"); + else + ok, char = pcall(io.read, "*l"); + if ok then + char = char:sub(1, n or 1); + end + end + if ok then + return char; + end +end + +function getpass() + local stty_ret = os.execute("stty -echo 2>/dev/null"); + if stty_ret ~= 0 then + io.write("\027[08m"); -- ANSI 'hidden' text attribute + end + local ok, pass = pcall(io.read, "*l"); + if stty_ret == 0 then + os.execute("stty sane"); + else + io.write("\027[00m"); + end + io.write("\n"); + if ok then + return pass; + end +end + +function show_yesno(prompt) + io.write(prompt, " "); + local choice = getchar():lower(); + io.write("\n"); + if not choice:match("%a") then + choice = prompt:match("%[.-(%U).-%]$"); + if not choice then return nil; end + end + return (choice == "y"); +end + +function read_password() + local password; + while true do + io.write("Enter new password: "); + password = getpass(); + if not password then + show_message("No password - cancelled"); + return; + end + io.write("Retype new password: "); + if getpass() ~= password then + if not show_yesno [=[Passwords did not match, try again? [Y/n]]=] then + return; + end + else + break; + end + end + return password; +end + +-- Server control function adduser(params) local user, host, password = nodeprep(params.user), nameprep(params.host), params.password; if not user then @@ -30,16 +120,29 @@ function adduser(params) elseif not host then return false, "invalid-hostname"; end + + local provider = prosody.hosts[host].users; + if not(provider) or provider.name == "null" then + usermanager.initialize_host(host); + end + storagemanager.initialize_host(host); - local ok = usermanager.create_user(user, password, host); + local ok, errmsg = usermanager.create_user(user, password, host); if not ok then - return false, "unable-to-save-data"; + return false, errmsg; end return true; end function user_exists(params) - return usermanager.user_exists(params.user, params.host); + local user, host, password = nodeprep(params.user), nameprep(params.host), params.password; + local provider = prosody.hosts[host].users; + if not(provider) or provider.name == "null" then + usermanager.initialize_host(host); + end + storagemanager.initialize_host(host); + + return usermanager.user_exists(user, host); end function passwd(params) @@ -65,6 +168,11 @@ function getpid() return false, "no-pidfile"; end + local modules_enabled = set.new(config.get("*", "core", "modules_enabled")); + if not modules_enabled:contains("posix") then + return false, "no-posix"; + end + local file, err = io.open(pidfile, "r+"); if not file then return false, "pidfile-read-failed", err; diff --git a/util/sasl.lua b/util/sasl.lua index 306acc0c..17d10b80 100644 --- a/util/sasl.lua +++ b/util/sasl.lua @@ -12,27 +12,13 @@ -- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -local md5 = require "util.hashes".md5; -local log = require "util.logger".init("sasl"); -local st = require "util.stanza"; -local set = require "util.set"; -local array = require "util.array"; -local to_unicode = require "util.encodings".idna.to_unicode; - -local tostring = tostring; local pairs, ipairs = pairs, ipairs; -local t_insert, t_concat = table.insert, table.concat; -local s_match = string.match; +local t_insert = table.insert; local type = type -local error = error local setmetatable = setmetatable; local assert = assert; local require = require; -require "util.iterators" -local keys = keys - -local array = require "util.array" module "sasl" --[[ @@ -61,72 +47,50 @@ local function registerMechanism(name, backends, f) end -- create a new SASL object which can be used to authenticate clients -function new(realm, profile, forbidden) - local sasl_i = {profile = profile}; - sasl_i.realm = realm; - local s = setmetatable(sasl_i, method); - if forbidden == nil then forbidden = {} end - s:forbidden(forbidden) - return s; +function new(realm, profile) + local mechanisms = profile.mechanisms; + if not mechanisms then + mechanisms = {}; + for backend, f in pairs(profile) do + if backend_mechanism[backend] then + for _, mechanism in ipairs(backend_mechanism[backend]) do + mechanisms[mechanism] = true; + end + end + end + profile.mechanisms = mechanisms; + end + return setmetatable({ profile = profile, realm = realm, mechs = mechanisms }, method); end --- get a fresh clone with the same realm, profiles and forbidden mechanisms +-- get a fresh clone with the same realm and profile function method:clean_clone() - return new(self.realm, self.profile, self:forbidden()) -end - --- set the forbidden mechanisms -function method:forbidden( restrict ) - if restrict then - -- set forbidden - self.restrict = set.new(restrict); - else - -- get forbidden - return array.collect(self.restrict:items()); - end + return new(self.realm, self.profile) end -- get a list of possible SASL mechanims to use function method:mechanisms() - local mechanisms = {} - for backend, f in pairs(self.profile) do - if backend_mechanism[backend] then - for _, mechanism in ipairs(backend_mechanism[backend]) do - if not self.restrict:contains(mechanism) then - mechanisms[mechanism] = true; - end - end - end - end - self["possible_mechanisms"] = mechanisms; - return array.collect(keys(mechanisms)); + return self.mechs; end -- select a mechanism to use function method:select(mechanism) - if self.mech_i then - return false; + if not self.selected and self.mechs[mechanism] then + self.selected = mechanism; + return true; end - - self.mech_i = mechanisms[mechanism] - if self.mech_i == nil then - return false; - end - return true; end -- feed new messages to process into the library function method:process(message) --if message == "" or message == nil then return "failure", "malformed-request" end - return self.mech_i(self, message); + return mechanisms[self.selected](self, message); end -- load the mechanisms -local load_mechs = {"plain", "digest-md5", "anonymous", "scram"} -for _, mech in ipairs(load_mechs) do - local name = "util.sasl."..mech; - local m = require(name); - m.init(registerMechanism) -end +require "util.sasl.plain" .init(registerMechanism); +require "util.sasl.digest-md5".init(registerMechanism); +require "util.sasl.anonymous" .init(registerMechanism); +require "util.sasl.scram" .init(registerMechanism); return _M; diff --git a/util/sasl/anonymous.lua b/util/sasl/anonymous.lua index 7b5a5081..ca5fe404 100644 --- a/util/sasl/anonymous.lua +++ b/util/sasl/anonymous.lua @@ -16,16 +16,26 @@ local s_match = string.match; local log = require "util.logger".init("sasl"); local generate_uuid = require "util.uuid".generate; -module "anonymous" +module "sasl.anonymous" --========================= --SASL ANONYMOUS according to RFC 4505 + +--[[ +Supported Authentication Backends + +anonymous: + function(username, realm) + return true; --for normal usage just return true; if you don't like the supplied username you can return false. + end +]] + local function anonymous(self, message) local username; repeat username = generate_uuid(); - until self.profile.anonymous(username, self.realm); - self["username"] = username; + until self.profile.anonymous(self, username, self.realm); + self.username = username; return "success" end @@ -33,4 +43,4 @@ function init(registerMechanism) registerMechanism("ANONYMOUS", {"anonymous"}, anonymous); end -return _M;
\ No newline at end of file +return _M; diff --git a/util/sasl/digest-md5.lua b/util/sasl/digest-md5.lua index 2837148e..de2538fc 100644 --- a/util/sasl/digest-md5.lua +++ b/util/sasl/digest-md5.lua @@ -24,7 +24,7 @@ local md5 = require "util.hashes".md5; local log = require "util.logger".init("sasl"); local generate_uuid = require "util.uuid".generate; -module "digest-md5" +module "sasl.digest-md5" --========================= --SASL DIGEST-MD5 according to RFC 2831 @@ -181,12 +181,12 @@ local function digest(self, message) self.username = response["username"]; local Y, state; if self.profile.plain then - local password, state = self.profile.plain(response["username"], self.realm) + local password, state = self.profile.plain(self, response["username"], self.realm) if state == nil then return "failure", "not-authorized" elseif state == false then return "failure", "account-disabled" end Y = md5(response["username"]..":"..response["realm"]..":"..password); elseif self.profile["digest-md5"] then - Y, state = self.profile["digest-md5"](response["username"], self.realm, response["realm"], response["charset"]) + Y, state = self.profile["digest-md5"](self, response["username"], self.realm, response["realm"], response["charset"]) if state == nil then return "failure", "not-authorized" elseif state == false then return "failure", "account-disabled" end elseif self.profile["digest-md5-test"] then @@ -240,4 +240,4 @@ function init(registerMechanism) registerMechanism("DIGEST-MD5", {"plain"}, digest); end -return _M;
\ No newline at end of file +return _M; diff --git a/util/sasl/plain.lua b/util/sasl/plain.lua index 39821182..fb20cf97 100644 --- a/util/sasl/plain.lua +++ b/util/sasl/plain.lua @@ -15,7 +15,7 @@ local s_match = string.match; local saslprep = require "util.encodings".stringprep.saslprep; local log = require "util.logger".init("sasl"); -module "plain" +module "sasl.plain" -- ================================ -- SASL PLAIN according to RFC 4616 @@ -29,7 +29,7 @@ plain: end plain_test: - function(username, realm, password) + function(username, password, realm) return true or false, state; end ]] @@ -57,10 +57,10 @@ local function plain(self, message) local correct, state = false, false; if self.profile.plain then local correct_password; - correct_password, state = self.profile.plain(authentication, self.realm); - if correct_password == password then correct = true; else correct = false; end + correct_password, state = self.profile.plain(self, authentication, self.realm); + correct = (correct_password == password); elseif self.profile.plain_test then - correct, state = self.profile.plain_test(authentication, self.realm, password); + correct, state = self.profile.plain_test(self, authentication, password, self.realm); end self.username = authentication diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua index 1340423c..aad33ebc 100644 --- a/util/sasl/scram.lua +++ b/util/sasl/scram.lua @@ -24,10 +24,10 @@ local t_concat = table.concat; local char = string.char; local byte = string.byte; -module "scram" +module "sasl.scram" --========================= ---SASL SCRAM-SHA-1 according to draft-ietf-sasl-scram-10 +--SASL SCRAM-SHA-1 according to RFC 5802 --[[ Supported Authentication Backends @@ -35,7 +35,7 @@ Supported Authentication Backends scram_{MECH}: -- MECH being a standard hash name (like those at IANA's hash registry) with '-' replaced with '_' function(username, realm) - return salted_password, iteration_count, salt, state; + return stored_key, server_key, iteration_count, salt, state; end ]] @@ -65,9 +65,9 @@ local function binaryXOR( a, b ) end -- hash algorithm independent Hi(PBKDF2) implementation -local function Hi(hmac, str, salt, i) +function Hi(hmac, str, salt, i) local Ust = hmac(str, salt.."\0\0\0\1"); - local res = Ust; + local res = Ust; for n=1,i-1 do local Und = hmac(str, Ust) res = binaryXOR(res, Und) @@ -79,13 +79,13 @@ end local function validate_username(username) -- check for forbidden char sequences for eq in username:gmatch("=(.?.?)") do - if eq ~= "2D" and eq ~= "3D" then - return false - end + if eq ~= "2C" and eq ~= "3D" then + return false + end end - -- replace =2D with , and =3D with = - username = username:gsub("=2D", ","); + -- replace =2C with , and =3D with = + username = username:gsub("=2C", ","); username = username:gsub("=3D", "="); -- apply SASLprep @@ -93,22 +93,21 @@ local function validate_username(username) return username; end -local function hashprep( hashname ) - local hash = hashname:lower() - hash = hash:gsub("-", "_") - return hash +local function hashprep(hashname) + return hashname:lower():gsub("-", "_"); end -function saltedPasswordSHA1(password, salt, iteration_count) - local salted_password +function getAuthenticationDatabaseSHA1(password, salt, iteration_count) if type(password) ~= "string" or type(salt) ~= "string" or type(iteration_count) ~= "number" then return false, "inappropriate argument types" end if iteration_count < 4096 then log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.") end - - return true, Hi(hmac_sha1, password, salt, iteration_count); + local salted_password = Hi(hmac_sha1, password, salt, iteration_count); + local stored_key = sha1(hmac_sha1(salted_password, "Client Key")) + local server_key = hmac_sha1(salted_password, "Server Key"); + return true, stored_key, server_key end local function scram_gen(hash_name, H_f, HMAC_f) @@ -144,7 +143,7 @@ local function scram_gen(hash_name, H_f, HMAC_f) -- retreive credentials if self.profile.plain then - local password, state = self.profile.plain(self.state.name, self.realm) + local password, state = self.profile.plain(self, self.state.name, self.realm) if state == nil then return "failure", "not-authorized" elseif state == false then return "failure", "account-disabled" end @@ -158,17 +157,18 @@ local function scram_gen(hash_name, H_f, HMAC_f) self.state.iteration_count = default_i; local succ = false; - succ, self.state.salted_password = saltedPasswordSHA1(password, self.state.salt, default_i, self.state.iteration_count); + succ, self.state.stored_key, self.state.server_key = getAuthenticationDatabaseSHA1(password, self.state.salt, default_i, self.state.iteration_count); if not succ then - log("error", "Generating salted password failed. Reason: %s", self.state.salted_password); + log("error", "Generating authentication database failed. Reason: %s", self.state.stored_key); return "failure", "temporary-auth-failure"; end elseif self.profile["scram_"..hashprep(hash_name)] then - local salted_password, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self.state.name, self.realm); + local stored_key, server_key, iteration_count, salt, state = self.profile["scram_"..hashprep(hash_name)](self, self.state.name, self.realm); if state == nil then return "failure", "not-authorized" elseif state == false then return "failure", "account-disabled" end - self.state.salted_password = salted_password; + self.state.stored_key = stored_key; + self.state.server_key = server_key; self.state.iteration_count = iteration_count; self.state.salt = salt end @@ -190,16 +190,15 @@ local function scram_gen(hash_name, H_f, HMAC_f) return "failure", "malformed-request", "Wrong nonce in client-final-message."; end - local SaltedPassword = self.state.salted_password; - local ClientKey = HMAC_f(SaltedPassword, "Client Key") - local ServerKey = HMAC_f(SaltedPassword, "Server Key") - local StoredKey = H_f(ClientKey) + local ServerKey = self.state.server_key; + local StoredKey = self.state.stored_key; + local AuthMessage = "n=" .. s_match(self.state.client_first_message,"n=(.+)") .. "," .. self.state.server_first_message .. "," .. s_match(client_final_message, "(.+),p=.+") local ClientSignature = HMAC_f(StoredKey, AuthMessage) - local ClientProof = binaryXOR(ClientKey, ClientSignature) + local ClientKey = binaryXOR(ClientSignature, base64.decode(self.state.proof)) local ServerSignature = HMAC_f(ServerKey, AuthMessage) - if base64.encode(ClientProof) == self.state.proof then + if StoredKey == H_f(ClientKey) then local server_final_message = "v="..base64.encode(ServerSignature); self["username"] = self.state.name; return "success", server_final_message; diff --git a/util/sasl_cyrus.lua b/util/sasl_cyrus.lua index 7d35b5e4..002118fd 100644 --- a/util/sasl_cyrus.lua +++ b/util/sasl_cyrus.lua @@ -13,17 +13,9 @@ local cyrussasl = require "cyrussasl"; local log = require "util.logger".init("sasl_cyrus"); -local array = require "util.array"; -local tostring = tostring; -local pairs, ipairs = pairs, ipairs; -local t_insert, t_concat = table.insert, table.concat; -local s_match = string.match; local setmetatable = setmetatable -local keys = keys; - -local print = print local pcall = pcall local s_match, s_gmatch = string.match, string.gmatch @@ -87,21 +79,17 @@ end -- create a new SASL object which can be used to authenticate clients function new(realm, service_name, app_name) - local sasl_i = {}; init(app_name or service_name); - sasl_i.realm = realm; - sasl_i.service_name = service_name; - local st, ret = pcall(cyrussasl.server_new, service_name, nil, realm, nil, nil) - if st then - sasl_i.cyrus = ret; - else + if not st then log("error", "Creating SASL server connection failed: %s", ret); return nil; end + local sasl_i = { realm = realm, service_name = service_name, cyrus = ret }; + if cyrussasl.set_canon_cb then local c14n_cb = function (user) local node = s_match(user, "^([^@]+)"); @@ -112,37 +100,31 @@ function new(realm, service_name, app_name) end cyrussasl.setssf(sasl_i.cyrus, 0, 0xffffffff) - local s = setmetatable(sasl_i, method); - return s; + local mechanisms = {}; + local cyrus_mechs = cyrussasl.listmech(sasl_i.cyrus, nil, "", " ", ""); + for w in s_gmatch(cyrus_mechs, "[^ ]+") do + mechanisms[w] = true; + end + sasl_i.mechs = mechanisms; + return setmetatable(sasl_i, method); end --- get a fresh clone with the same realm, profiles and forbidden mechanisms +-- get a fresh clone with the same realm and service name function method:clean_clone() return new(self.realm, self.service_name) end --- set the forbidden mechanisms -function method:forbidden( restrict ) - log("warn", "Called method:forbidden. NOT IMPLEMENTED.") - return {} -end - -- get a list of possible SASL mechanims to use function method:mechanisms() - local mechanisms = {} - local cyrus_mechs = cyrussasl.listmech(self.cyrus, nil, "", " ", "") - for w in s_gmatch(cyrus_mechs, "[^ ]+") do - mechanisms[w] = true; - end - self.mechs = mechanisms - return array.collect(keys(mechanisms)); + return self.mechs; end -- select a mechanism to use function method:select(mechanism) - self.mechanism = mechanism; - if not self.mechs then self:mechanisms(); end - return self.mechs[mechanism]; + if not self.selected and self.mechs[mechanism] then + self.selected = mechanism; + return true; + end end -- feed new messages to process into the library @@ -150,8 +132,9 @@ function method:process(message) local err; local data; - if self.mechanism then - err, data = cyrussasl.server_start(self.cyrus, self.mechanism, message or "") + if not self.first_step_done then + err, data = cyrussasl.server_start(self.cyrus, self.selected, message or "") + self.first_step_done = true; else err, data = cyrussasl.server_step(self.cyrus, message or "") end @@ -159,17 +142,20 @@ function method:process(message) self.username = cyrussasl.get_username(self.cyrus) if (err == 0) then -- SASL_OK - return "success", data + if self.require_provisioning and not self.require_provisioning(self.username) then + return "failure", "not-authorized", "User authenticated successfully, but not provisioned for XMPP"; + end + return "success", data elseif (err == 1) then -- SASL_CONTINUE - return "challenge", data + return "challenge", data elseif (err == -4) then -- SASL_NOMECH - log("debug", "SASL mechanism not available from remote end") - return "failure", "invalid-mechanism", "SASL mechanism not available" + log("debug", "SASL mechanism not available from remote end") + return "failure", "invalid-mechanism", "SASL mechanism not available" elseif (err == -13) then -- SASL_BADAUTH - return "failure", "not-authorized", sasl_errstring[err]; + return "failure", "not-authorized", sasl_errstring[err]; else - log("debug", "Got SASL error condition %d: %s", err, sasl_errstring[err]); - return "failure", "undefined-condition", sasl_errstring[err]; + log("debug", "Got SASL error condition %d: %s", err, sasl_errstring[err]); + return "failure", "undefined-condition", sasl_errstring[err]; end end diff --git a/util/serialization.lua b/util/serialization.lua index bad2fe43..e193b64f 100644 --- a/util/serialization.lua +++ b/util/serialization.lua @@ -15,6 +15,10 @@ local error = error; local pairs = pairs; local next = next; +local loadstring = loadstring; +local setfenv = setfenv; +local pcall = pcall; + local debug_traceback = debug.traceback; local log = require "util.logger".init("serialization"); module "serialization" @@ -24,14 +28,20 @@ local indent = function(i) end local function basicSerialize (o) if type(o) == "number" or type(o) == "boolean" then - return tostring(o); + -- no need to check for NaN, as that's not a valid table index + if o == 1/0 then return "(1/0)"; + elseif o == -1/0 then return "(-1/0)"; + else return tostring(o); end else -- assume it is a string -- FIXME make sure it's a string. throw an error otherwise. return (("%q"):format(tostring(o)):gsub("\\\n", "\\n")); end end local function _simplesave(o, ind, t, func) if type(o) == "number" then - func(t, tostring(o)); + if o ~= o then func(t, "(0/0)"); + elseif o == 1/0 then func(t, "(1/0)"); + elseif o == -1/0 then func(t, "(-1/0)"); + else func(t, tostring(o)); end elseif type(o) == "string" then func(t, (("%q"):format(o):gsub("\\\n", "\\n"))); elseif type(o) == "table" then @@ -72,7 +82,14 @@ function serialize(o) end function deserialize(str) - error("Not implemented"); + if type(str) ~= "string" then return nil; end + str = "return "..str; + local f, err = loadstring(str, "@data"); + if not f then return nil, err; end + setfenv(f, {}); + local success, ret = pcall(f); + if not success then return nil, ret; end + return ret; end return _M; diff --git a/util/set.lua b/util/set.lua index ee154ece..e4cc2dff 100644 --- a/util/set.lua +++ b/util/set.lua @@ -6,7 +6,7 @@ -- COPYING file in the source package for more information. -- -local ipairs, pairs, setmetatable, next, tostring = +local ipairs, pairs, setmetatable, next, tostring = ipairs, pairs, setmetatable, next, tostring; local t_concat = table.concat; diff --git a/util/stanza.lua b/util/stanza.lua index 08ef2c9a..de83977f 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -44,11 +44,13 @@ module "stanza" stanza_mt = { __type = "stanza" }; stanza_mt.__index = stanza_mt; +local stanza_mt = stanza_mt; function stanza(name, attr) - local stanza = { name = name, attr = attr or {}, tags = {}, last_add = {}}; + local stanza = { name = name, attr = attr or {}, tags = {} }; return setmetatable(stanza, stanza_mt); end +local stanza = stanza; function stanza_mt:query(xmlns) return self:tag("query", { xmlns = xmlns }); @@ -60,26 +62,27 @@ end function stanza_mt:tag(name, attrs) local s = stanza(name, attrs); - (self.last_add[#self.last_add] or self):add_direct_child(s); - t_insert(self.last_add, s); + local last_add = self.last_add; + if not last_add then last_add = {}; self.last_add = last_add; end + (last_add[#last_add] or self):add_direct_child(s); + t_insert(last_add, s); return self; end function stanza_mt:text(text) - (self.last_add[#self.last_add] or self):add_direct_child(text); + local last_add = self.last_add; + (last_add and last_add[#last_add] or self):add_direct_child(text); return self; end function stanza_mt:up() - t_remove(self.last_add); + local last_add = self.last_add; + if last_add then t_remove(last_add); end return self; end function stanza_mt:reset() - local last_add = self.last_add; - for i = 1,#last_add do - last_add[i] = nil; - end + self.last_add = nil; return self; end @@ -91,7 +94,8 @@ function stanza_mt:add_direct_child(child) end function stanza_mt:add_child(child) - (self.last_add[#self.last_add] or self):add_direct_child(child); + local last_add = self.last_add; + (last_add and last_add[#last_add] or self):add_direct_child(child); return self; end @@ -106,6 +110,14 @@ function stanza_mt:get_child(name, xmlns) end end +function stanza_mt:get_child_text(name, xmlns) + local tag = self:get_child(name, xmlns); + if tag then + return tag:get_text(); + end + return nil; +end + function stanza_mt:child_with_name(name) for _, child in ipairs(self.tags) do if child.name == name then return child; end @@ -122,17 +134,48 @@ function stanza_mt:children() local i = 0; return function (a) i = i + 1 - local v = a[i] - if v then return v; end + return a[i]; end, self, i; end -function stanza_mt:childtags() - local i = 0; - return function (a) - i = i + 1 - local v = self.tags[i] - if v then return v; end - end, self.tags[1], i; + +function stanza_mt:childtags(name, xmlns) + xmlns = xmlns or self.attr.xmlns; + local tags = self.tags; + local start_i, max_i = 1, #tags; + return function () + for i = start_i, max_i do + local v = tags[i]; + if (not name or v.name == name) + and (not xmlns or xmlns == v.attr.xmlns) then + start_i = i+1; + return v; + end + end + end; +end + +function stanza_mt:maptags(callback) + local tags, curr_tag = self.tags, 1; + local n_children, n_tags = #self, #tags; + + local i = 1; + while curr_tag <= n_tags do + if self[i] == tags[curr_tag] then + local ret = callback(self[i]); + if ret == nil then + t_remove(self, i); + t_remove(tags, curr_tag); + n_children = n_children - 1; + n_tags = n_tags - 1; + else + self[i] = ret; + tags[i] = ret; + end + i = i + 1; + curr_tag = curr_tag + 1; + end + end + return self; end local xml_escape @@ -200,7 +243,7 @@ function stanza_mt.get_error(stanza) end type = error_tag.attr.type; - for child in error_tag:children() do + for child in error_tag:childtags() do if child.attr.xmlns == xmlns_stanzas then if not text and child.name == "text" then text = child:get_text(); @@ -212,7 +255,7 @@ function stanza_mt.get_error(stanza) end end end - return type, condition or "undefined-condition", text or ""; + return type, condition or "undefined-condition", text; end function stanza_mt.__add(s1, s2) @@ -271,39 +314,33 @@ function deserialize(stanza) end end stanza.tags = tags; - if not stanza.last_add then - stanza.last_add = {}; - end end end return stanza; end -function clone(stanza) - local lookup_table = {}; - local function _copy(object) - if type(object) ~= "table" then - return object; - elseif lookup_table[object] then - return lookup_table[object]; - end - local new_table = {}; - lookup_table[object] = new_table; - for index, value in pairs(object) do - new_table[_copy(index)] = _copy(value); +local function _clone(stanza) + local attr, tags = {}, {}; + for k,v in pairs(stanza.attr) do attr[k] = v; end + local new = { name = stanza.name, attr = attr, tags = tags }; + for i=1,#stanza do + local child = stanza[i]; + if child.name then + child = _clone(child); + t_insert(tags, child); end - return setmetatable(new_table, getmetatable(object)); + t_insert(new, child); end - - return _copy(stanza) + return setmetatable(new, stanza_mt); end +clone = _clone; function message(attr, body) if not body then return stanza("message", attr); else - return stanza("message", attr):tag("body"):text(body); + return stanza("message", attr):tag("body"):text(body):up(); end end function iq(attr) diff --git a/util/template.lua b/util/template.lua new file mode 100644 index 00000000..ebd8be14 --- /dev/null +++ b/util/template.lua @@ -0,0 +1,133 @@ + +local st = require "util.stanza"; +local lxp = require "lxp"; +local setmetatable = setmetatable; +local pairs = pairs; +local ipairs = ipairs; +local error = error; +local loadstring = loadstring; +local debug = debug; + +module("template") + +local parse_xml = (function() + local ns_prefixes = { + ["http://www.w3.org/XML/1998/namespace"] = "xml"; + }; + local ns_separator = "\1"; + local ns_pattern = "^([^"..ns_separator.."]*)"..ns_separator.."?(.*)$"; + return function(xml) + local handler = {}; + local stanza = st.stanza("root"); + function handler:StartElement(tagname, attr) + local curr_ns,name = tagname:match(ns_pattern); + if name == "" then + curr_ns, name = "", curr_ns; + end + if curr_ns ~= "" then + attr.xmlns = curr_ns; + end + for i=1,#attr do + local k = attr[i]; + attr[i] = nil; + local ns, nm = k:match(ns_pattern); + if nm ~= "" then + ns = ns_prefixes[ns]; + if ns then + attr[ns..":"..nm] = attr[k]; + attr[k] = nil; + end + end + end + stanza:tag(name, attr); + end + function handler:CharacterData(data) + data = data:gsub("^%s*", ""):gsub("%s*$", ""); + stanza:text(data); + end + function handler:EndElement(tagname) + stanza:up(); + end + local parser = lxp.new(handler, "\1"); + local ok, err, line, col = parser:parse(xml); + if ok then ok, err, line, col = parser:parse(); end + --parser:close(); + if ok then + return stanza.tags[1]; + else + return ok, err.." (line "..line..", col "..col..")"; + end + end; +end)(); + +local function create_string_string(str) + str = ("%q"):format(str); + str = str:gsub("{([^}]*)}", function(s) + return '"..(data["'..s..'"]or"").."'; + end); + return str; +end +local function create_attr_string(attr, xmlns) + local str = '{'; + for name,value in pairs(attr) do + if name ~= "xmlns" or value ~= xmlns then + str = str..("[%q]=%s;"):format(name, create_string_string(value)); + end + end + return str..'}'; +end +local function create_clone_string(stanza, lookup, xmlns) + if not lookup[stanza] then + local s = ('setmetatable({name=%q,attr=%s,tags={'):format(stanza.name, create_attr_string(stanza.attr, xmlns)); + -- add tags + for i,tag in ipairs(stanza.tags) do + s = s..create_clone_string(tag, lookup, stanza.attr.xmlns)..";"; + end + s = s..'};'; + -- add children + for i,child in ipairs(stanza) do + if child.name then + s = s..create_clone_string(child, lookup, stanza.attr.xmlns)..";"; + else + s = s..create_string_string(child)..";" + end + end + s = s..'}, stanza_mt)'; + s = s:gsub('%.%.""', ""):gsub('([=;])""%.%.', "%1"):gsub(';"";', ";"); -- strip empty strings + local n = #lookup + 1; + lookup[n] = s; + lookup[stanza] = "_"..n; + end + return lookup[stanza]; +end +local stanza_mt = st.stanza_mt; +local function create_cloner(stanza, chunkname) + local lookup = {}; + local name = create_clone_string(stanza, lookup, ""); + local f = "local setmetatable,stanza_mt=...;return function(data)"; + for i=1,#lookup do + f = f.."local _"..i.."="..lookup[i]..";"; + end + f = f.."return "..name..";end"; + local f,err = loadstring(f, chunkname); + if not f then error(err); end + return f(setmetatable, stanza_mt); +end + +local template_mt = { __tostring = function(t) return t.name end }; +local function create_template(templates, text) + local stanza, err = parse_xml(text); + if not stanza then error(err); end + + local info = debug.getinfo(3, "Sl"); + info = info and ("template(%s:%d)"):format(info.short_src:match("[^\\/]*$"), info.currentline) or "template(unknown)"; + + local template = setmetatable({ apply = create_cloner(stanza, info), name = info, text = text }, template_mt); + templates[text] = template; + return template; +end + +local templates = setmetatable({}, { __mode = 'k', __index = create_template }); +return function(text) + return templates[text]; +end; diff --git a/util/termcolours.lua b/util/termcolours.lua index 4e267bee..df204688 100644 --- a/util/termcolours.lua +++ b/util/termcolours.lua @@ -10,6 +10,14 @@ local t_concat, t_insert = table.concat, table.insert; local char, format = string.char, string.format; local ipairs = ipairs; +local io_write = io.write; + +local windows; +if os.getenv("WINDIR") then + windows = require "util.windows"; +end +local orig_color = windows and windows.get_consolecolor and windows.get_consolecolor(); + module "termcolours" local stylemap = { @@ -19,6 +27,13 @@ local stylemap = { bold = 1, dark = 2, underline = 4, underlined = 4, normal = 0; } +local winstylemap = { + ["0"] = orig_color, -- reset + ["1"] = 7+8, -- bold + ["1;33"] = 2+4+8, -- bold yellow + ["1;31"] = 4+8 -- bold red +} + local fmt_string = char(0x1B).."[%sm%s"..char(0x1B).."[0m"; function getstring(style, text) if style then @@ -39,4 +54,26 @@ function getstyle(...) return t_concat(result, ";"); end +local last = "0"; +function setstyle(style) + style = style or "0"; + if style ~= last then + io_write("\27["..style.."m"); + last = style; + end +end + +if windows then + function setstyle(style) + style = style or "0"; + if style ~= last then + windows.set_consolecolor(winstylemap[style] or orig_color); + last = style; + end + end + if not orig_color then + function setstyle(style) end + end +end + return _M; diff --git a/util/timer.lua b/util/timer.lua index fa1dd7c5..3061da72 100644 --- a/util/timer.lua +++ b/util/timer.lua @@ -11,7 +11,9 @@ local ns_addtimer = require "net.server".addtimer; local event = require "net.server".event; local event_base = require "net.server".event_base; -local get_time = os.time; +local math_min = math.min +local math_huge = math.huge +local get_time = require "socket".gettime; local t_insert = table.insert; local t_remove = table.remove; local ipairs, pairs = ipairs, pairs; @@ -43,14 +45,21 @@ if not event then new_data = {}; end + local next_time = math_huge; for i, d in pairs(data) do local t, func = d[1], d[2]; if t <= current_time then data[i] = nil; local r = func(current_time); - if type(r) == "number" then _add_task(r, func); end + if type(r) == "number" then + _add_task(r, func); + next_time = math_min(next_time, r); + end + else + next_time = math_min(next_time, t - current_time); end end + return next_time; end); else local EVENT_LEAVE = (event.core and event.core.LEAVE) or -1; diff --git a/util/xmppstream.lua b/util/xmppstream.lua new file mode 100644 index 00000000..d1cb652d --- /dev/null +++ b/util/xmppstream.lua @@ -0,0 +1,204 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + + +local lxp = require "lxp"; +local st = require "util.stanza"; + +local tostring = tostring; +local t_insert = table.insert; +local t_concat = table.concat; + +local default_log = require "util.logger".init("xmppstream"); + +-- COMPAT: w/LuaExpat 1.1.0 +local lxp_supports_doctype = pcall(lxp.new, { StartDoctypeDecl = false }); + +if not lxp_supports_doctype then + default_log("warn", "The version of LuaExpat on your system leaves Prosody " + .."vulnerable to denial-of-service attacks. You should upgrade to " + .."LuaExpat 1.1.1 or higher as soon as possible. See " + .."http://prosody.im/doc/depends#luaexpat for more information."); +end + +local error = error; + +module "xmppstream" + +local new_parser = lxp.new; + +local ns_prefixes = { + ["http://www.w3.org/XML/1998/namespace"] = "xml"; +}; + +local xmlns_streams = "http://etherx.jabber.org/streams"; + +local ns_separator = "\1"; +local ns_pattern = "^([^"..ns_separator.."]*)"..ns_separator.."?(.*)$"; + +_M.ns_separator = ns_separator; +_M.ns_pattern = ns_pattern; + +function new_sax_handlers(session, stream_callbacks) + local xml_handlers = {}; + + local log = session.log or default_log; + + local cb_streamopened = stream_callbacks.streamopened; + local cb_streamclosed = stream_callbacks.streamclosed; + local cb_error = stream_callbacks.error or function(session, e) error("XML stream error: "..tostring(e)); end; + local cb_handlestanza = stream_callbacks.handlestanza; + + local stream_ns = stream_callbacks.stream_ns or xmlns_streams; + local stream_tag = stream_callbacks.stream_tag or "stream"; + if stream_ns ~= "" then + stream_tag = stream_ns..ns_separator..stream_tag; + end + local stream_error_tag = stream_ns..ns_separator..(stream_callbacks.error_tag or "error"); + + local stream_default_ns = stream_callbacks.default_ns; + + local chardata, stanza = {}; + local non_streamns_depth = 0; + function xml_handlers:StartElement(tagname, attr) + if stanza and #chardata > 0 then + -- We have some character data in the buffer + stanza:text(t_concat(chardata)); + chardata = {}; + end + local curr_ns,name = tagname:match(ns_pattern); + if name == "" then + curr_ns, name = "", curr_ns; + end + + if curr_ns ~= stream_default_ns or non_streamns_depth > 0 then + attr.xmlns = curr_ns; + non_streamns_depth = non_streamns_depth + 1; + end + + -- FIXME !!!!! + for i=1,#attr do + local k = attr[i]; + attr[i] = nil; + local ns, nm = k:match(ns_pattern); + if nm ~= "" then + ns = ns_prefixes[ns]; + if ns then + attr[ns..":"..nm] = attr[k]; + attr[k] = nil; + end + end + end + + if not stanza then --if we are not currently inside a stanza + if session.notopen then + if tagname == stream_tag then + non_streamns_depth = 0; + if cb_streamopened then + cb_streamopened(session, attr); + end + else + -- Garbage before stream? + cb_error(session, "no-stream"); + end + return; + end + if curr_ns == "jabber:client" and name ~= "iq" and name ~= "presence" and name ~= "message" then + cb_error(session, "invalid-top-level-element"); + end + + stanza = st.stanza(name, attr); + else -- we are inside a stanza, so add a tag + stanza:tag(name, attr); + end + end + function xml_handlers:CharacterData(data) + if stanza then + t_insert(chardata, data); + end + end + function xml_handlers:EndElement(tagname) + if non_streamns_depth > 0 then + non_streamns_depth = non_streamns_depth - 1; + end + if stanza then + if #chardata > 0 then + -- We have some character data in the buffer + stanza:text(t_concat(chardata)); + chardata = {}; + end + -- Complete stanza + local last_add = stanza.last_add; + if not last_add or #last_add == 0 then + if tagname ~= stream_error_tag then + cb_handlestanza(session, stanza); + else + cb_error(session, "stream-error", stanza); + end + stanza = nil; + else + stanza:up(); + end + else + if tagname == stream_tag then + if cb_streamclosed then + cb_streamclosed(session); + end + else + local curr_ns,name = tagname:match(ns_pattern); + if name == "" then + curr_ns, name = "", curr_ns; + end + cb_error(session, "parse-error", "unexpected-element-close", name); + end + stanza, chardata = nil, {}; + end + end + + local function restricted_handler() + cb_error(session, "parse-error", "restricted-xml", "Restricted XML, see RFC 6120 section 11.1."); + end + + if lxp_supports_doctype then + xml_handlers.StartDoctypeDecl = restricted_handler; + end + xml_handlers.Comment = restricted_handler; + xml_handlers.StartCdataSection = restricted_handler; + xml_handlers.ProcessingInstruction = restricted_handler; + + local function reset() + stanza, chardata = nil, {}; + end + + local function set_session(stream, new_session) + session = new_session; + log = new_session.log or default_log; + end + + return xml_handlers, { reset = reset, set_session = set_session }; +end + +function new(session, stream_callbacks) + local handlers, meta = new_sax_handlers(session, stream_callbacks); + local parser = new_parser(handlers, ns_separator); + local parse = parser.parse; + + return { + reset = function () + parser = new_parser(handlers, ns_separator); + parse = parser.parse; + meta.reset(); + end, + feed = function (self, data) + return parse(parser, data); + end, + set_session = meta.set_session; + }; +end + +return _M; diff --git a/util/ztact.lua b/util/ztact.lua deleted file mode 100644 index 2507bf8e..00000000 --- a/util/ztact.lua +++ /dev/null @@ -1,366 +0,0 @@ --- Prosody IM --- This file is included with Prosody IM. It has modifications, --- which are hereby placed in the public domain. - --- public domain 20080410 lua@ztact.com - - -pcall (require, 'lfs') -- lfs may not be installed/necessary. -pcall (require, 'pozix') -- pozix may not be installed/necessary. - - -local getfenv, ipairs, next, pairs, pcall, require, select, tostring, type = - getfenv, ipairs, next, pairs, pcall, require, select, tostring, type -local unpack, xpcall = - unpack, xpcall - -local io, lfs, os, string, table, pozix = io, lfs, os, string, table, pozix - -local assert, print = assert, print - -local error = error - - -module ((...) or 'ztact') ------------------------------------- module ztact - - --- dir -------------------------------------------------------------------- dir - - -function dir (path) -- - - - - - - - - - - - - - - - - - - - - - - - - - dir - local it = lfs.dir (path) - return function () - repeat - local dir = it () - if dir ~= '.' and dir ~= '..' then return dir end - until not dir - end end - - -function is_file (path) -- - - - - - - - - - - - - - - - - - is_file (path) - local mode = lfs.attributes (path, 'mode') - return mode == 'file' and path - end - - --- network byte ordering -------------------------------- network byte ordering - - -function htons (word) -- - - - - - - - - - - - - - - - - - - - - - - - htons - return (word-word%0x100)/0x100, word%0x100 - end - - --- pcall2 -------------------------------------------------------------- pcall2 - - -getfenv ().pcall = pcall -- store the original pcall as ztact.pcall - - -local argc, argv, errorhandler, pcall2_f - - -local function _pcall2 () -- - - - - - - - - - - - - - - - - - - - - _pcall2 - local tmpv = argv - argv = nil - return pcall2_f (unpack (tmpv, 1, argc)) - end - - -function seterrorhandler (func) -- - - - - - - - - - - - - - seterrorhandler - errorhandler = func - end - - -function pcall2 (f, ...) -- - - - - - - - - - - - - - - - - - - - - - pcall2 - - pcall2_f = f - argc = select ('#', ...) - argv = { ... } - - if not errorhandler then - local debug = require ('debug') - errorhandler = debug.traceback - end - - return xpcall (_pcall2, errorhandler) - end - - -function append (t, ...) -- - - - - - - - - - - - - - - - - - - - - - append - local insert = table.insert - for i,v in ipairs {...} do - insert (t, v) - end end - - -function print_r (d, indent) -- - - - - - - - - - - - - - - - - - - print_r - local rep = string.rep (' ', indent or 0) - if type (d) == 'table' then - for k,v in pairs (d) do - if type (v) == 'table' then - io.write (rep, k, '\n') - print_r (v, (indent or 0) + 1) - else io.write (rep, k, ' = ', tostring (v), '\n') end - end - else io.write (d, '\n') end - end - - -function tohex (s) -- - - - - - - - - - - - - - - - - - - - - - - - - tohex - return string.format (string.rep ('%02x ', #s), string.byte (s, 1, #s)) - end - - -function tostring_r (d, indent, tab0) -- - - - - - - - - - - - - tostring_r - - local tab1 = tab0 or {} - local rep = string.rep (' ', indent or 0) - if type (d) == 'table' then - for k,v in pairs (d) do - if type (v) == 'table' then - append (tab1, rep, k, '\n') - tostring_r (v, (indent or 0) + 1, tab1) - else append (tab1, rep, k, ' = ', tostring (v), '\n') end - end - else append (tab1, d, '\n') end - - if not tab0 then return table.concat (tab1) end - end - - --- queue manipulation -------------------------------------- queue manipulation - - --- Possible queue states. 1 (i.e. queue.p[1]) is head of queue. --- --- 1..2 --- 3..4 1..2 --- 3..4 1..2 5..6 --- 1..2 5..6 --- 1..2 - - -local function print_queue (queue, ...) -- - - - - - - - - - - - print_queue - for i=1,10 do io.write ((queue[i] or '.')..' ') end - io.write ('\t') - for i=1,6 do io.write ((queue.p[i] or '.')..' ') end - print (...) - end - - -function dequeue (queue) -- - - - - - - - - - - - - - - - - - - - - dequeue - - local p = queue.p - if not p and queue[1] then queue.p = { 1, #queue } p = queue.p end - - if not p[1] then return nil end - - local element = queue[p[1]] - queue[p[1]] = nil - - if p[1] < p[2] then p[1] = p[1] + 1 - - elseif p[4] then p[1], p[2], p[3], p[4] = p[3], p[4], nil, nil - - elseif p[5] then p[1], p[2], p[5], p[6] = p[5], p[6], nil, nil - - else p[1], p[2] = nil, nil end - - print_queue (queue, ' de '..element) - return element - end - - -function enqueue (queue, element) -- - - - - - - - - - - - - - - - - enqueue - - local p = queue.p - if not p then queue.p = {} p = queue.p end - - if p[5] then -- p3..p4 p1..p2 p5..p6 - p[6] = p[6]+1 - queue[p[6]] = element - - elseif p[3] then -- p3..p4 p1..p2 - - if p[4]+1 < p[1] then - p[4] = p[4] + 1 - queue[p[4]] = element - - else - p[5] = p[2]+1 - p[6], queue[p[5]] = p[5], element - end - - elseif p[1] then -- p1..p2 - if p[1] == 1 then - p[2] = p[2] + 1 - queue[p[2]] = element - - else - p[3], p[4], queue[1] = 1, 1, element - end - - else -- empty queue - p[1], p[2], queue[1] = 1, 1, element - end - - print_queue (queue, ' '..element) - end - - -local function test_queue () - local t = {} - enqueue (t, 1) - enqueue (t, 2) - enqueue (t, 3) - enqueue (t, 4) - enqueue (t, 5) - dequeue (t) - dequeue (t) - enqueue (t, 6) - enqueue (t, 7) - enqueue (t, 8) - enqueue (t, 9) - dequeue (t) - dequeue (t) - dequeue (t) - dequeue (t) - enqueue (t, 'a') - dequeue (t) - enqueue (t, 'b') - enqueue (t, 'c') - dequeue (t) - dequeue (t) - dequeue (t) - dequeue (t) - dequeue (t) - enqueue (t, 'd') - dequeue (t) - dequeue (t) - dequeue (t) - end - - --- test_queue () - - -function queue_len (queue) - end - - -function queue_peek (queue) - end - - --- tree manipulation ---------------------------------------- tree manipulation - - -function set (parent, ...) --- - - - - - - - - - - - - - - - - - - - - - set - - -- print ('set', ...) - - local len = select ('#', ...) - local key, value = select (len-1, ...) - local cutpoint, cutkey - - for i=1,len-2 do - - local key = select (i, ...) - local child = parent[key] - - if value == nil then - if child == nil then return - elseif next (child, next (child)) then cutpoint = nil cutkey = nil - elseif cutpoint == nil then cutpoint = parent cutkey = key end - - elseif child == nil then child = {} parent[key] = child end - - parent = child - end - - if value == nil and cutpoint then cutpoint[cutkey] = nil - else parent[key] = value return value end - end - - -function get (parent, ...) --- - - - - - - - - - - - - - - - - - - - - - get - local len = select ('#', ...) - for i=1,len do - parent = parent[select (i, ...)] - if parent == nil then break end - end - return parent - end - - --- misc ------------------------------------------------------------------ misc - - -function find (path, ...) --------------------------------------------- find - - local dirs, operators = { path }, {...} - for operator in ivalues (operators) do - if not operator (path) then break end end - - while next (dirs) do - local parent = table.remove (dirs) - for child in assert (pozix.opendir (parent)) do - if child and child ~= '.' and child ~= '..' then - local path = parent..'/'..child - if pozix.stat (path, 'is_dir') then table.insert (dirs, path) end - for operator in ivalues (operators) do - if not operator (path) then break end end - end end end end - - -function ivalues (t) ----------------------------------------------- ivalues - local i = 0 - return function () if t[i+1] then i = i + 1 return t[i] end end - end - - -function lson_encode (mixed, f, indent, indents) --------------- lson_encode - - - local capture - if not f then - capture = {} - f = function (s) append (capture, s) end - end - - indent = indent or 0 - indents = indents or {} - indents[indent] = indents[indent] or string.rep (' ', 2*indent) - - local type = type (mixed) - - if type == 'number' then f (mixed) - - else if type == 'string' then f (string.format ('%q', mixed)) - - else if type == 'table' then - f ('{') - for k,v in pairs (mixed) do - f ('\n') - f (indents[indent]) - f ('[') f (lson_encode (k)) f ('] = ') - lson_encode (v, f, indent+1, indents) - f (',') - end - f (' }') - end end end - - if capture then return table.concat (capture) end - end - - -function timestamp (time) ---------------------------------------- timestamp - return os.date ('%Y%m%d.%H%M%S', time) - end - - -function values (t) ------------------------------------------------- values - local k, v - return function () k, v = next (t, k) return v end - end |