aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMatthew Wild <mwild1@gmail.com>2008-11-18 22:41:04 +0000
committerMatthew Wild <mwild1@gmail.com>2008-11-18 22:41:04 +0000
commit801e99fcbbfd667fb3d8779782a6d9fb214d1685 (patch)
treec5ad2998f07a3f43a980ff94786693c003095c42
parentd73e81900b49267c5dbe73c536a4c2c1793b61cb (diff)
downloadprosody-801e99fcbbfd667fb3d8779782a6d9fb214d1685.tar.gz
prosody-801e99fcbbfd667fb3d8779782a6d9fb214d1685.zip
We have SRV resolving \o/
-rw-r--r--core/s2smanager.lua34
-rw-r--r--net/dns.lua795
-rw-r--r--tests/test.lua1
-rw-r--r--util/ztact.lua364
4 files changed, 1188 insertions, 6 deletions
diff --git a/core/s2smanager.lua b/core/s2smanager.lua
index 1fc2715d..d6ad2be1 100644
--- a/core/s2smanager.lua
+++ b/core/s2smanager.lua
@@ -3,7 +3,7 @@ local hosts = hosts;
local sessions = sessions;
local socket = require "socket";
local format = string.format;
-local t_insert = table.insert;
+local t_insert, t_sort = table.insert, table.sort;
local get_traceback = debug.traceback;
local tostring, pairs, ipairs, getmetatable, print, newproxy, error, tonumber
= tostring, pairs, ipairs, getmetatable, print, newproxy, error, tonumber;
@@ -24,17 +24,19 @@ local md5_hash = require "util.hashes".md5;
local dialback_secret = "This is very secret!!! Ha!";
-local srvmap = { ["gmail.com"] = "talk.google.com", ["identi.ca"] = "hampton.controlezvous.ca", ["cdr.se"] = "jabber.cdr.se" };
+local dns = require "net.dns";
module "s2smanager"
+local function compare_srv_priorities(a,b) return a.priority < b.priority or a.weight < b.weight; end
+
function send_to_host(from_host, to_host, data)
if data.name then data = tostring(data); end
local host = hosts[from_host].s2sout[to_host];
if host then
-- We have a connection to this host already
if host.type == "s2sout_unauthed" then
- host.log("debug", "trying to send over unauthed s2sout to "..to_host..", authing it now...");
+ (host.log or log)("debug", "trying to send over unauthed s2sout to "..to_host..", authing it now...");
if not host.notopen and not host.dialback_key then
host.log("debug", "dialback had not been initiated");
initiate_dialback(host);
@@ -87,11 +89,31 @@ function new_outgoing(from_host, to_host)
local conn, handler = socket.tcp()
--FIXME: Below parameters (ports/ip) are incorrect (use SRV)
- to_host = srvmap[to_host] or to_host;
+
+ local connect_host, connect_port = to_host, 5269;
+
+ local answer = dns.lookup("_xmpp-server._tcp."..to_host..".", "SRV");
+
+ if answer then
+ log("debug", to_host.." has SRV records, handling...");
+ local srv_hosts = {};
+ host_session.srv_hosts = srv_hosts;
+ for _, record in ipairs(answer) do
+ t_insert(srv_hosts, record.srv);
+ end
+ t_sort(srv_hosts, compare_srv_priorities);
+
+ local srv_choice = srv_hosts[1];
+ if srv_choice then
+ log("debug", "Best record found");
+ connect_host, connect_port = srv_choice.target or to_host, srv_choice.port or connect_port;
+ log("debug", "Best record found, will connect to %s:%d", connect_host, connect_port);
+ end
+ end
conn:settimeout(0);
- local success, err = conn:connect(to_host, 5269);
- if not success then
+ local success, err = conn:connect(connect_host, connect_port);
+ if not success and err ~= "timeout" then
log("warn", "s2s connect() failed: %s", err);
end
diff --git a/net/dns.lua b/net/dns.lua
new file mode 100644
index 00000000..a75c1bf5
--- /dev/null
+++ b/net/dns.lua
@@ -0,0 +1,795 @@
+
+
+-- public domain 20080404 lua@ztact.com
+
+
+-- todo: quick (default) header generation
+-- todo: nxdomain, error handling
+-- todo: cache results of encodeName
+
+
+-- reference: http://tools.ietf.org/html/rfc1035
+-- reference: http://tools.ietf.org/html/rfc1876 (LOC)
+
+
+require 'socket'
+local ztact = require 'util.ztact'
+
+
+local coroutine, io, math, socket, string, table =
+ coroutine, io, math, socket, string, table
+
+local ipairs, next, pairs, print, setmetatable, tostring =
+ ipairs, next, pairs, print, setmetatable, tostring
+
+local get, set = ztact.get, ztact.set
+
+
+-------------------------------------------------- module dns
+module ('dns')
+local dns = _M;
+
+
+-- dns type & class codes ------------------------------ dns type & class codes
+
+
+local append = table.insert
+
+
+local function highbyte (i) -- - - - - - - - - - - - - - - - - - - highbyte
+ return (i-(i%0x100))/0x100
+ end
+
+
+local function augment (t) -- - - - - - - - - - - - - - - - - - - - augment
+ local a = {}
+ for i,s in pairs (t) do a[i] = s a[s] = s a[string.lower (s)] = s end
+ return a
+ end
+
+
+local function encode (t) -- - - - - - - - - - - - - - - - - - - - - encode
+ local code = {}
+ for i,s in pairs (t) do
+ local word = string.char (highbyte (i), i %0x100)
+ code[i] = word
+ code[s] = word
+ code[string.lower (s)] = word
+ end
+ return code
+ end
+
+
+dns.types = {
+ 'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS',
+ 'PTR', 'HINFO', 'MINFO', 'MX', 'TXT',
+ [ 28] = 'AAAA', [ 29] = 'LOC', [ 33] = 'SRV',
+ [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' }
+
+
+dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' }
+
+
+dns.type = augment (dns.types)
+dns.class = augment (dns.classes)
+dns.typecode = encode (dns.types)
+dns.classcode = encode (dns.classes)
+
+
+
+local function standardize (qname, qtype, qclass) -- - - - - - - standardize
+ if string.byte (qname, -1) ~= 0x2E then qname = qname..'.' end
+ qname = string.lower (qname)
+ return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN']
+ end
+
+
+local function prune (rrs, time, soft) -- - - - - - - - - - - - - - - prune
+
+ time = time or socket.gettime ()
+ for i,rr in pairs (rrs) do
+
+ if rr.tod then
+ -- rr.tod = rr.tod - 50 -- accelerated decripitude
+ rr.ttl = math.floor (rr.tod - time)
+ if rr.ttl <= 0 then rrs[i] = nil end
+
+ elseif soft == 'soft' then -- What is this? I forget!
+ assert (rr.ttl == 0)
+ rrs[i] = nil
+ end end end
+
+
+-- metatables & co. ------------------------------------------ metatables & co.
+
+
+local resolver = {}
+resolver.__index = resolver
+
+
+local SRV_tostring
+
+
+local rr_metatable = {} -- - - - - - - - - - - - - - - - - - - rr_metatable
+function rr_metatable.__tostring (rr)
+ local s0 = string.format (
+ '%2s %-5s %6i %-28s', rr.class, rr.type, rr.ttl, rr.name )
+ local s1 = ''
+ if rr.type == 'A' then s1 = ' '..rr.a
+ elseif rr.type == 'MX' then
+ s1 = string.format (' %2i %s', rr.pref, rr.mx)
+ elseif rr.type == 'CNAME' then s1 = ' '..rr.cname
+ elseif rr.type == 'LOC' then s1 = ' '..resolver.LOC_tostring (rr)
+ elseif rr.type == 'NS' then s1 = ' '..rr.ns
+ elseif rr.type == 'SRV' then s1 = ' '..SRV_tostring (rr)
+ elseif rr.type == 'TXT' then s1 = ' '..rr.txt
+ else s1 = ' <UNKNOWN RDATA TYPE>' end
+ return s0..s1
+ end
+
+
+local rrs_metatable = {} -- - - - - - - - - - - - - - - - - - rrs_metatable
+function rrs_metatable.__tostring (rrs)
+ t = {}
+ for i,rr in pairs (rrs) do append (t, tostring (rr)..'\n') end
+ return table.concat (t)
+ end
+
+
+local cache_metatable = {} -- - - - - - - - - - - - - - - - cache_metatable
+function cache_metatable.__tostring (cache)
+ local time = socket.gettime ()
+ local t = {}
+ for class,types in pairs (cache) do
+ for type,names in pairs (types) do
+ for name,rrs in pairs (names) do
+ prune (rrs, time)
+ append (t, tostring (rrs)) end end end
+ return table.concat (t)
+ end
+
+
+function resolver:new () -- - - - - - - - - - - - - - - - - - - - - resolver
+ local r = { active = {}, cache = {}, unsorted = {} }
+ setmetatable (r, resolver)
+ setmetatable (r.cache, cache_metatable)
+ setmetatable (r.unsorted, { __mode = 'kv' })
+ return r
+ end
+
+
+-- packet layer -------------------------------------------------- packet layer
+
+
+function dns.random (...) -- - - - - - - - - - - - - - - - - - - dns.random
+ math.randomseed (10000*socket.gettime ())
+ dns.random = math.random
+ return dns.random (...)
+ end
+
+
+local function encodeHeader (o) -- - - - - - - - - - - - - - - encodeHeader
+
+ o = o or {}
+
+ o.id = o.id or -- 16b (random) id
+ dns.random (0, 0xffff)
+
+ o.rd = o.rd or 1 -- 1b 1 recursion desired
+ o.tc = o.tc or 0 -- 1b 1 truncated response
+ o.aa = o.aa or 0 -- 1b 1 authoritative response
+ o.opcode = o.opcode or 0 -- 4b 0 query
+ -- 1 inverse query
+ -- 2 server status request
+ -- 3-15 reserved
+ o.qr = o.qr or 0 -- 1b 0 query, 1 response
+
+ o.rcode = o.rcode or 0 -- 4b 0 no error
+ -- 1 format error
+ -- 2 server failure
+ -- 3 name error
+ -- 4 not implemented
+ -- 5 refused
+ -- 6-15 reserved
+ o.z = o.z or 0 -- 3b 0 resvered
+ o.ra = o.ra or 0 -- 1b 1 recursion available
+
+ o.qdcount = o.qdcount or 1 -- 16b number of question RRs
+ o.ancount = o.ancount or 0 -- 16b number of answers RRs
+ o.nscount = o.nscount or 0 -- 16b number of nameservers RRs
+ o.arcount = o.arcount or 0 -- 16b number of additional RRs
+
+ -- string.char() rounds, so prevent roundup with -0.4999
+ local header = string.char (
+ highbyte (o.id), o.id %0x100,
+ o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr,
+ o.rcode + 16*o.z + 128*o.ra,
+ highbyte (o.qdcount), o.qdcount %0x100,
+ highbyte (o.ancount), o.ancount %0x100,
+ highbyte (o.nscount), o.nscount %0x100,
+ highbyte (o.arcount), o.arcount %0x100 )
+
+ return header, o.id
+ end
+
+
+local function encodeName (name) -- - - - - - - - - - - - - - - - encodeName
+ local t = {}
+ for part in string.gmatch (name, '[^.]+') do
+ append (t, string.char (string.len (part)))
+ append (t, part)
+ end
+ append (t, string.char (0))
+ return table.concat (t)
+ end
+
+
+local function encodeQuestion (qname, qtype, qclass) -- - - - encodeQuestion
+ qname = encodeName (qname)
+ qtype = dns.typecode[qtype or 'a']
+ qclass = dns.classcode[qclass or 'in']
+ return qname..qtype..qclass;
+ end
+
+
+function resolver:byte (len) -- - - - - - - - - - - - - - - - - - - - - byte
+ len = len or 1
+ local offset = self.offset
+ local last = offset + len - 1
+ if last > #self.packet then
+ error (string.format ('out of bounds: %i>%i', last, #self.packet)) end
+ self.offset = offset + len
+ return string.byte (self.packet, offset, last)
+ end
+
+
+function resolver:word () -- - - - - - - - - - - - - - - - - - - - - - word
+ local b1, b2 = self:byte (2)
+ return 0x100*b1 + b2
+ end
+
+
+function resolver:dword () -- - - - - - - - - - - - - - - - - - - - - dword
+ local b1, b2, b3, b4 = self:byte (4)
+ -- print ('dword', b1, b2, b3, b4)
+ return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4
+ end
+
+
+function resolver:sub (len) -- - - - - - - - - - - - - - - - - - - - - - sub
+ len = len or 1
+ local s = string.sub (self.packet, self.offset, self.offset + len - 1)
+ self.offset = self.offset + len
+ return s
+ end
+
+
+function resolver:header (force) -- - - - - - - - - - - - - - - - - - header
+
+ local id = self:word ()
+ -- print (string.format (':header id %x', id))
+ if not self.active[id] and not force then return nil end
+
+ local h = { id = id }
+
+ local b1, b2 = self:byte (2)
+
+ h.rd = b1 %2
+ h.tc = b1 /2%2
+ h.aa = b1 /4%2
+ h.opcode = b1 /8%16
+ h.qr = b1 /128
+
+ h.rcode = b2 %16
+ h.z = b2 /16%8
+ h.ra = b2 /128
+
+ h.qdcount = self:word ()
+ h.ancount = self:word ()
+ h.nscount = self:word ()
+ h.arcount = self:word ()
+
+ for k,v in pairs (h) do h[k] = v-v%1 end
+
+ return h
+ end
+
+
+function resolver:name () -- - - - - - - - - - - - - - - - - - - - - - name
+ local remember, pointers = nil, 0
+ local len = self:byte ()
+ local n = {}
+ while len > 0 do
+ if len >= 0xc0 then -- name is "compressed"
+ pointers = pointers + 1
+ if pointers >= 20 then error ('dns error: 20 pointers') end
+ local offset = ((len-0xc0)*0x100) + self:byte ()
+ remember = remember or self.offset
+ self.offset = offset + 1 -- +1 for lua
+ else -- name is not compressed
+ append (n, self:sub (len)..'.')
+ end
+ len = self:byte ()
+ end
+ self.offset = remember or self.offset
+ return table.concat (n)
+ end
+
+
+function resolver:question () -- - - - - - - - - - - - - - - - - - question
+ local q = {}
+ q.name = self:name ()
+ q.type = dns.type[self:word ()]
+ q.class = dns.type[self:word ()]
+ return q
+ end
+
+
+function resolver:A (rr) -- - - - - - - - - - - - - - - - - - - - - - - - A
+ local b1, b2, b3, b4 = self:byte (4)
+ rr.a = string.format ('%i.%i.%i.%i', b1, b2, b3, b4)
+ end
+
+
+function resolver:CNAME (rr) -- - - - - - - - - - - - - - - - - - - - CNAME
+ rr.cname = self:name ()
+ end
+
+
+function resolver:MX (rr) -- - - - - - - - - - - - - - - - - - - - - - - MX
+ rr.pref = self:word ()
+ rr.mx = self:name ()
+ end
+
+
+function resolver:LOC_nibble_power () -- - - - - - - - - - LOC_nibble_power
+ local b = self:byte ()
+ -- print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10))
+ return ((b-(b%0x10))/0x10) * (10^(b%0x10))
+ end
+
+
+function resolver:LOC (rr) -- - - - - - - - - - - - - - - - - - - - - - LOC
+ rr.version = self:byte ()
+ if rr.version == 0 then
+ rr.loc = rr.loc or {}
+ rr.loc.size = self:LOC_nibble_power ()
+ rr.loc.horiz_pre = self:LOC_nibble_power ()
+ rr.loc.vert_pre = self:LOC_nibble_power ()
+ rr.loc.latitude = self:dword ()
+ rr.loc.longitude = self:dword ()
+ rr.loc.altitude = self:dword ()
+ end end
+
+
+local function LOC_tostring_degrees (f, pos, neg) -- - - - - - - - - - - - -
+ f = f - 0x80000000
+ if f < 0 then pos = neg f = -f end
+ local deg, min, msec
+ msec = f%60000
+ f = (f-msec)/60000
+ min = f%60
+ deg = (f-min)/60
+ return string.format ('%3d %2d %2.3f %s', deg, min, msec/1000, pos)
+ end
+
+
+function resolver.LOC_tostring (rr) -- - - - - - - - - - - - - LOC_tostring
+
+ local t = {}
+
+ --[[
+ for k,name in pairs { 'size', 'horiz_pre', 'vert_pre',
+ 'latitude', 'longitude', 'altitude' } do
+ append (t, string.format ('%4s%-10s: %12.0f\n', '', name, rr.loc[name]))
+ end
+ --]]
+
+ append ( t, string.format (
+ '%s %s %.2fm %.2fm %.2fm %.2fm',
+ LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'),
+ LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'),
+ (rr.loc.altitude - 10000000) / 100,
+ rr.loc.size / 100,
+ rr.loc.horiz_pre / 100,
+ rr.loc.vert_pre / 100 ) )
+
+ return table.concat (t)
+ end
+
+
+function resolver:NS (rr) -- - - - - - - - - - - - - - - - - - - - - - - NS
+ rr.ns = self:name ()
+ end
+
+
+function resolver:SOA (rr) -- - - - - - - - - - - - - - - - - - - - - - SOA
+ end
+
+
+function resolver:SRV (rr) -- - - - - - - - - - - - - - - - - - - - - - SRV
+ rr.srv = {}
+ rr.srv.priority = self:word ()
+ rr.srv.weight = self:word ()
+ rr.srv.port = self:word ()
+ rr.srv.target = self:name ()
+ end
+
+
+function SRV_tostring (rr) -- - - - - - - - - - - - - - - - - - SRV_tostring
+ local s = rr.srv
+ return string.format ( '%5d %5d %5d %s',
+ s.priority, s.weight, s.port, s.target )
+ end
+
+
+function resolver:TXT (rr) -- - - - - - - - - - - - - - - - - - - - - - TXT
+ rr.txt = self:sub (rr.rdlength)
+ end
+
+
+function resolver:rr () -- - - - - - - - - - - - - - - - - - - - - - - - rr
+ local rr = {}
+ setmetatable (rr, rr_metatable)
+ rr.name = self:name (self)
+ rr.type = dns.type[self:word ()] or rr.type
+ rr.class = dns.class[self:word ()] or rr.class
+ rr.ttl = 0x10000*self:word () + self:word ()
+ rr.rdlength = self:word ()
+
+ if rr.ttl == 0 then -- pass
+ else rr.tod = self.time + rr.ttl end
+
+ local remember = self.offset
+ local rr_parser = self[dns.type[rr.type]]
+ if rr_parser then rr_parser (self, rr) end
+ self.offset = remember
+ rr.rdata = self:sub (rr.rdlength)
+ return rr
+ end
+
+
+function resolver:rrs (count) -- - - - - - - - - - - - - - - - - - - - - rrs
+ local rrs = {}
+ for i = 1,count do append (rrs, self:rr ()) end
+ return rrs
+ end
+
+
+function resolver:decode (packet, force) -- - - - - - - - - - - - - - decode
+
+ self.packet, self.offset = packet, 1
+ local header = self:header (force)
+ if not header then return nil end
+ local response = { header = header }
+
+ response.question = {}
+ local offset = self.offset
+ for i = 1,response.header.qdcount do
+ append (response.question, self:question ()) end
+ response.question.raw = string.sub (self.packet, offset, self.offset - 1)
+
+ if not force then
+ if not self.active[response.header.id] or
+ not self.active[response.header.id][response.question.raw] then
+ return nil end end
+
+ response.answer = self:rrs (response.header.ancount)
+ response.authority = self:rrs (response.header.nscount)
+ response.additional = self:rrs (response.header.arcount)
+
+ return response
+ end
+
+
+-- socket layer -------------------------------------------------- socket layer
+
+
+resolver.delays = { 1, 3, 11, 45 }
+
+
+function resolver:addnameserver (address) -- - - - - - - - - - addnameserver
+ self.server = self.server or {}
+ append (self.server, address)
+ end
+
+
+function resolver:setnameserver (address) -- - - - - - - - - - setnameserver
+ self.server = {}
+ self:addnameserver (address)
+ end
+
+
+function resolver:adddefaultnameservers () -- - - - - adddefaultnameservers
+ for line in io.lines ('/etc/resolv.conf') do
+ address = string.match (line, 'nameserver%s+(%d+%.%d+%.%d+%.%d+)')
+ if address then self:addnameserver (address) end
+ end end
+
+
+function resolver:getsocket (servernum) -- - - - - - - - - - - - - getsocket
+
+ self.socket = self.socket or {}
+ self.socketset = self.socketset or {}
+
+ local sock = self.socket[servernum]
+ if sock then return sock end
+
+ sock = socket.udp ()
+ if self.socket_wrapper then sock = self.socket_wrapper (sock) end
+ sock:settimeout (0)
+ -- todo: attempt to use a random port, fallback to 0
+ sock:setsockname ('*', 0)
+ sock:setpeername (self.server[servernum], 53)
+ self.socket[servernum] = sock
+ self.socketset[sock] = sock
+ return sock
+ end
+
+
+function resolver:socket_wrapper_set (func) -- - - - - - - socket_wrapper_set
+ self.socket_wrapper = func
+ end
+
+
+function resolver:closeall () -- - - - - - - - - - - - - - - - - - closeall
+ for i,sock in ipairs (self.socket) do self.socket[i]:close () end
+ self.socket = {}
+ end
+
+
+function resolver:remember (rr, type) -- - - - - - - - - - - - - - remember
+
+ -- print ('remember', type, rr.class, rr.type, rr.name)
+
+ if type ~= '*' then
+ type = rr.type
+ local all = get (self.cache, rr.class, '*', rr.name)
+ -- print ('remember all', all)
+ if all then append (all, rr) end
+ end
+
+ self.cache = self.cache or setmetatable ({}, cache_metatable)
+ local rrs = get (self.cache, rr.class, type, rr.name) or
+ set (self.cache, rr.class, type, rr.name, setmetatable ({}, rrs_metatable))
+ append (rrs, rr)
+
+ if type == 'MX' then self.unsorted[rrs] = true end
+ end
+
+
+local function comp_mx (a, b) -- - - - - - - - - - - - - - - - - - - comp_mx
+ return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref)
+ end
+
+
+function resolver:peek (qname, qtype, qclass) -- - - - - - - - - - - - peek
+ qname, qtype, qclass = standardize (qname, qtype, qclass)
+ local rrs = get (self.cache, qclass, qtype, qname)
+ if not rrs then return nil end
+ if prune (rrs, socket.gettime ()) and qtype == '*' or not next (rrs) then
+ set (self.cache, qclass, qtype, qname, nil) return nil end
+ if self.unsorted[rrs] then table.sort (rrs, comp_mx) end
+ return rrs
+ end
+
+
+function resolver:purge (soft) -- - - - - - - - - - - - - - - - - - - purge
+ if soft == 'soft' then
+ self.time = socket.gettime ()
+ for class,types in pairs (self.cache or {}) do
+ for type,names in pairs (types) do
+ for name,rrs in pairs (names) do
+ prune (rrs, time, 'soft')
+ end end end
+ else self.cache = {} end
+ end
+
+
+function resolver:query (qname, qtype, qclass) -- - - - - - - - - - -- query
+
+ qname, qtype, qclass = standardize (qname, qtype, qclass)
+
+ if not self.server then self:adddefaultnameservers () end
+
+ local question = question or encodeQuestion (qname, qtype, qclass)
+ local peek = self:peek (qname, qtype, qclass)
+ if peek then return peek end
+
+ local header, id = encodeHeader ()
+ -- print ('query id', id, qclass, qtype, qname)
+ local o = { packet = header..question,
+ server = 1,
+ delay = 1,
+ retry = socket.gettime () + self.delays[1] }
+ self:getsocket (o.server):send (o.packet)
+
+ -- remember the query
+ self.active[id] = self.active[id] or {}
+ self.active[id][question] = o
+
+ -- remember which coroutine wants the answer
+ local co = coroutine.running ()
+ if co then
+ set (self.wanted, qclass, qtype, qname, co, true)
+ set (self.yielded, co, qclass, qtype, qname, true)
+ end end
+
+
+function resolver:receive (rset) -- - - - - - - - - - - - - - - - - receive
+
+ -- print 'receive' print (self.socket)
+ self.time = socket.gettime ()
+ rset = rset or self.socket
+
+ local response
+ for i,sock in pairs (rset) do
+
+ if self.socketset[sock] then
+ local packet = sock:receive ()
+ if packet then
+
+ response = self:decode (packet)
+ if response then
+ -- print 'received response'
+ -- self.print (response)
+
+ for i,section in pairs { 'answer', 'authority', 'additional' } do
+ for j,rr in pairs (response[section]) do
+ self:remember (rr, response.question[1].type) end end
+
+ -- retire the query
+ local queries = self.active[response.header.id]
+ if queries[response.question.raw] then
+ queries[response.question.raw] = nil end
+ if not next (queries) then self.active[response.header.id] = nil end
+ if not next (self.active) then self:closeall () end
+
+ -- was the query on the wanted list?
+ local q = response.question
+ local cos = get (self.wanted, q.class, q.type, q.name)
+ if cos then
+ for co in pairs (cos) do
+ set (self.yielded, co, q.class, q.type, q.name, nil)
+ if not self.yielded[co] then coroutine.resume (co) end
+ end
+ set (self.wanted, q.class, q.type, q.name, nil)
+ end end end end end
+
+ return response
+ end
+
+
+function resolver:pulse () -- - - - - - - - - - - - - - - - - - - - - pulse
+
+ -- print ':pulse'
+ while self:receive () do end
+ if not next (self.active) then return nil end
+
+ self.time = socket.gettime ()
+ for id,queries in pairs (self.active) do
+ for question,o in pairs (queries) do
+ if self.time >= o.retry then
+
+ o.server = o.server + 1
+ if o.server > #self.server then
+ o.server = 1
+ o.delay = o.delay + 1
+ end
+
+ if o.delay > #self.delays then
+ print ('timeout')
+ queries[question] = nil
+ if not next (queries) then self.active[id] = nil end
+ if not next (self.active) then return nil end
+ else
+ -- print ('retry', o.server, o.delay)
+ self.socket[o.server]:send (o.packet)
+ o.retry = self.time + self.delays[o.delay]
+ end end end end
+
+ if next (self.active) then return true end
+ return nil
+ end
+
+
+function resolver:lookup (qname, qtype, qclass) -- - - - - - - - - - lookup
+ self:query (qname, qtype, qclass)
+ while self:pulse () do socket.select (self.socket, nil, 4) end
+ -- print (self.cache)
+ return self:peek (qname, qtype, qclass)
+ end
+
+
+-- print ---------------------------------------------------------------- print
+
+
+local hints = { -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints
+ qr = { [0]='query', 'response' },
+ opcode = { [0]='query', 'inverse query', 'server status request' },
+ aa = { [0]='non-authoritative', 'authoritative' },
+ tc = { [0]='complete', 'truncated' },
+ rd = { [0]='recursion not desired', 'recursion desired' },
+ ra = { [0]='recursion not available', 'recursion available' },
+ z = { [0]='(reserved)' },
+ rcode = { [0]='no error', 'format error', 'server failure', 'name error',
+ 'not implemented' },
+
+ type = dns.type,
+ class = dns.class, }
+
+
+local function hint (p, s) -- - - - - - - - - - - - - - - - - - - - - - hint
+ return (hints[s] and hints[s][p[s]]) or '' end
+
+
+function resolver.print (response) -- - - - - - - - - - - - - resolver.print
+
+ for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z',
+ 'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do
+ print ( string.format ('%-30s', 'header.'..s),
+ response.header[s], hint (response.header, s) )
+ end
+
+ for i,question in ipairs (response.question) do
+ print (string.format ('question[%i].name ', i), question.name)
+ print (string.format ('question[%i].type ', i), question.type)
+ print (string.format ('question[%i].class ', i), question.class)
+ end
+
+ local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 }
+ local tmp
+ for s,s in pairs {'answer', 'authority', 'additional'} do
+ for i,rr in pairs (response[s]) do
+ for j,t in pairs { 'name', 'type', 'class', 'ttl', 'rdlength' } do
+ tmp = string.format ('%s[%i].%s', s, i, t)
+ print (string.format ('%-30s', tmp), rr[t], hint (rr, t))
+ end
+ for j,t in pairs (rr) do
+ if not common[j] then
+ tmp = string.format ('%s[%i].%s', s, i, j)
+ print (string.format ('%-30s %s', tmp, t))
+ end end end end end
+
+
+-- module api ------------------------------------------------------ module api
+
+
+local function resolve (func, ...) -- - - - - - - - - - - - - - resolver_get
+ dns._resolver = dns._resolver or dns.resolver ()
+ return func (dns._resolver, ...)
+ end
+
+
+function dns.resolver () -- - - - - - - - - - - - - - - - - - - - - resolver
+
+ -- this function seems to be redundant with resolver.new ()
+
+ r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {} }
+ setmetatable (r, resolver)
+ setmetatable (r.cache, cache_metatable)
+ setmetatable (r.unsorted, { __mode = 'kv' })
+ return r
+ end
+
+
+function dns.lookup (...) -- - - - - - - - - - - - - - - - - - - - - lookup
+ return resolve (resolver.lookup, ...) end
+
+
+function dns.purge (...) -- - - - - - - - - - - - - - - - - - - - - - purge
+ return resolve (resolver.purge, ...) end
+
+function dns.peek (...) -- - - - - - - - - - - - - - - - - - - - - - - peek
+ return resolve (resolver.peek, ...) end
+
+
+function dns.query (...) -- - - - - - - - - - - - - - - - - - - - - - query
+ return resolve (resolver.query, ...) end
+
+
+function dns:socket_wrapper_set (...) -- - - - - - - - - socket_wrapper_set
+ return resolve (resolver.socket_wrapper_set, ...) end
+
+
+return dns
diff --git a/tests/test.lua b/tests/test.lua
index c028e859..aa0275d4 100644
--- a/tests/test.lua
+++ b/tests/test.lua
@@ -83,3 +83,4 @@ end
dotest "util.jid"
dotest "core.stanza_router"
+dotest "core.s2smanager"
diff --git a/util/ztact.lua b/util/ztact.lua
new file mode 100644
index 00000000..15bcffad
--- /dev/null
+++ b/util/ztact.lua
@@ -0,0 +1,364 @@
+
+
+-- public domain 20080410 lua@ztact.com
+
+
+pcall (require, 'lfs') -- lfs may not be installed/necessary.
+pcall (require, 'pozix') -- pozix may not be installed/necessary.
+
+
+local getfenv, ipairs, next, pairs, pcall, require, select, tostring, type =
+ getfenv, ipairs, next, pairs, pcall, require, select, tostring, type
+local unpack, xpcall =
+ unpack, xpcall
+
+local io, lfs, os, string, table, pozix = io, lfs, os, string, table, pozix
+
+local assert, print = assert, print
+
+local error = error
+
+
+module ((...) or 'ztact') ------------------------------------- module ztact
+
+
+-- dir -------------------------------------------------------------------- dir
+
+
+function dir (path) -- - - - - - - - - - - - - - - - - - - - - - - - - - dir
+ local it = lfs.dir (path)
+ return function ()
+ repeat
+ local dir = it ()
+ if dir ~= '.' and dir ~= '..' then return dir end
+ until not dir
+ end end
+
+
+function is_file (path) -- - - - - - - - - - - - - - - - - - is_file (path)
+ local mode = lfs.attributes (path, 'mode')
+ return mode == 'file' and path
+ end
+
+
+-- network byte ordering -------------------------------- network byte ordering
+
+
+function htons (word) -- - - - - - - - - - - - - - - - - - - - - - - - htons
+ return (word-word%0x100)/0x100, word%0x100
+ end
+
+
+-- pcall2 -------------------------------------------------------------- pcall2
+
+
+getfenv ().pcall = pcall -- store the original pcall as ztact.pcall
+
+
+local argc, argv, errorhandler, pcall2_f
+
+
+local function _pcall2 () -- - - - - - - - - - - - - - - - - - - - - _pcall2
+ local tmpv = argv
+ argv = nil
+ return pcall2_f (unpack (tmpv, 1, argc))
+ end
+
+
+function seterrorhandler (func) -- - - - - - - - - - - - - - seterrorhandler
+ errorhandler = func
+ end
+
+
+function pcall2 (f, ...) -- - - - - - - - - - - - - - - - - - - - - - pcall2
+
+ pcall2_f = f
+ argc = select ('#', ...)
+ argv = { ... }
+
+ if not errorhandler then
+ local debug = require ('debug')
+ errorhandler = debug.traceback
+ end
+
+ return xpcall (_pcall2, errorhandler)
+ end
+
+
+function append (t, ...) -- - - - - - - - - - - - - - - - - - - - - - append
+ local insert = table.insert
+ for i,v in ipairs {...} do
+ insert (t, v)
+ end end
+
+
+function print_r (d, indent) -- - - - - - - - - - - - - - - - - - - print_r
+ local rep = string.rep (' ', indent or 0)
+ if type (d) == 'table' then
+ for k,v in pairs (d) do
+ if type (v) == 'table' then
+ io.write (rep, k, '\n')
+ print_r (v, (indent or 0) + 1)
+ else io.write (rep, k, ' = ', tostring (v), '\n') end
+ end
+ else io.write (d, '\n') end
+ end
+
+
+function tohex (s) -- - - - - - - - - - - - - - - - - - - - - - - - - tohex
+ return string.format (string.rep ('%02x ', #s), string.byte (s, 1, #s))
+ end
+
+
+function tostring_r (d, indent, tab0) -- - - - - - - - - - - - - tostring_r
+
+ tab1 = tab0 or {}
+ local rep = string.rep (' ', indent or 0)
+ if type (d) == 'table' then
+ for k,v in pairs (d) do
+ if type (v) == 'table' then
+ append (tab1, rep, k, '\n')
+ tostring_r (v, (indent or 0) + 1, tab1)
+ else append (tab1, rep, k, ' = ', tostring (v), '\n') end
+ end
+ else append (tab1, d, '\n') end
+
+ if not tab0 then return table.concat (tab1) end
+ end
+
+
+-- queue manipulation -------------------------------------- queue manipulation
+
+
+-- Possible queue states. 1 (i.e. queue.p[1]) is head of queue.
+--
+-- 1..2
+-- 3..4 1..2
+-- 3..4 1..2 5..6
+-- 1..2 5..6
+-- 1..2
+
+
+local function print_queue (queue, ...) -- - - - - - - - - - - - print_queue
+ for i=1,10 do io.write ((queue[i] or '.')..' ') end
+ io.write ('\t')
+ for i=1,6 do io.write ((queue.p[i] or '.')..' ') end
+ print (...)
+ end
+
+
+function dequeue (queue) -- - - - - - - - - - - - - - - - - - - - - dequeue
+
+ local p = queue.p
+ if not p and queue[1] then queue.p = { 1, #queue } p = queue.p end
+
+ if not p[1] then return nil end
+
+ local element = queue[p[1]]
+ queue[p[1]] = nil
+
+ if p[1] < p[2] then p[1] = p[1] + 1
+
+ elseif p[4] then p[1], p[2], p[3], p[4] = p[3], p[4], nil, nil
+
+ elseif p[5] then p[1], p[2], p[5], p[6] = p[5], p[6], nil, nil
+
+ else p[1], p[2] = nil, nil end
+
+ print_queue (queue, ' de '..element)
+ return element
+ end
+
+
+function enqueue (queue, element) -- - - - - - - - - - - - - - - - - enqueue
+
+ local p = queue.p
+ if not p then queue.p = {} p = queue.p end
+
+ if p[5] then -- p3..p4 p1..p2 p5..p6
+ p[6] = p[6]+1
+ queue[p[6]] = element
+
+ elseif p[3] then -- p3..p4 p1..p2
+
+ if p[4]+1 < p[1] then
+ p[4] = p[4] + 1
+ queue[p[4]] = element
+
+ else
+ p[5] = p[2]+1
+ p[6], queue[p[5]] = p[5], element
+ end
+
+ elseif p[1] then -- p1..p2
+ if p[1] == 1 then
+ p[2] = p[2] + 1
+ queue[p[2]] = element
+
+ else
+ p[3], p[4], queue[1] = 1, 1, element
+ end
+
+ else -- empty queue
+ p[1], p[2], queue[1] = 1, 1, element
+ end
+
+ print_queue (queue, ' '..element)
+ end
+
+
+local function test_queue ()
+ t = {}
+ enqueue (t, 1)
+ enqueue (t, 2)
+ enqueue (t, 3)
+ enqueue (t, 4)
+ enqueue (t, 5)
+ dequeue (t)
+ dequeue (t)
+ enqueue (t, 6)
+ enqueue (t, 7)
+ enqueue (t, 8)
+ enqueue (t, 9)
+ dequeue (t)
+ dequeue (t)
+ dequeue (t)
+ dequeue (t)
+ enqueue (t, 'a')
+ dequeue (t)
+ enqueue (t, 'b')
+ enqueue (t, 'c')
+ dequeue (t)
+ dequeue (t)
+ dequeue (t)
+ dequeue (t)
+ dequeue (t)
+ enqueue (t, 'd')
+ dequeue (t)
+ dequeue (t)
+ dequeue (t)
+ end
+
+
+-- test_queue ()
+
+
+function queue_len (queue)
+ end
+
+
+function queue_peek (queue)
+ end
+
+
+-- tree manipulation ---------------------------------------- tree manipulation
+
+
+function set (parent, ...) --- - - - - - - - - - - - - - - - - - - - - - set
+
+ -- print ('set', ...)
+
+ local len = select ('#', ...)
+ local key, value = select (len-1, ...)
+ local cutpoint, cutkey
+
+ for i=1,len-2 do
+
+ local key = select (i, ...)
+ local child = parent[key]
+
+ if value == nil then
+ if child == nil then return
+ elseif next (child, next (child)) then cutpoint = nil cutkey = nil
+ elseif cutpoint == nil then cutpoint = parent cutkey = key end
+
+ elseif child == nil then child = {} parent[key] = child end
+
+ parent = child
+ end
+
+ if value == nil and cutpoint then cutpoint[cutkey] = nil
+ else parent[key] = value return value end
+ end
+
+
+function get (parent, ...) --- - - - - - - - - - - - - - - - - - - - - - get
+ local len = select ('#', ...)
+ for i=1,len do
+ parent = parent[select (i, ...)]
+ if parent == nil then break end
+ end
+ return parent
+ end
+
+
+-- misc ------------------------------------------------------------------ misc
+
+
+function find (path, ...) --------------------------------------------- find
+
+ local dirs, operators = { path }, {...}
+ for operator in ivalues (operators) do
+ if not operator (path) then break end end
+
+ while next (dirs) do
+ local parent = table.remove (dirs)
+ for child in assert (pozix.opendir (parent)) do
+ if child and child ~= '.' and child ~= '..' then
+ local path = parent..'/'..child
+ if pozix.stat (path, 'is_dir') then table.insert (dirs, path) end
+ for operator in ivalues (operators) do
+ if not operator (path) then break end end
+ end end end end
+
+
+function ivalues (t) ----------------------------------------------- ivalues
+ local i = 0
+ return function () if t[i+1] then i = i + 1 return t[i] end end
+ end
+
+
+function lson_encode (mixed, f, indent, indents) --------------- lson_encode
+
+
+ local capture
+ if not f then
+ capture = {}
+ f = function (s) append (capture, s) end
+ end
+
+ indent = indent or 0
+ indents = indents or {}
+ indents[indent] = indents[indent] or string.rep (' ', 2*indent)
+
+ local type = type (mixed)
+
+ if type == 'number' then f (mixed)
+
+ else if type == 'string' then f (string.format ('%q', mixed))
+
+ else if type == 'table' then
+ f ('{')
+ for k,v in pairs (mixed) do
+ f ('\n')
+ f (indents[indent])
+ f ('[') f (lson_encode (k)) f ('] = ')
+ lson_encode (v, f, indent+1, indents)
+ f (',')
+ end
+ f (' }')
+ end end end
+
+ if capture then return table.concat (capture) end
+ end
+
+
+function timestamp (time) ---------------------------------------- timestamp
+ return os.date ('%Y%m%d.%H%M%S', time)
+ end
+
+
+function values (t) ------------------------------------------------- values
+ local k, v
+ return function () k, v = next (t, k) return v end
+ end