diff options
author | Matthew Wild <mwild1@gmail.com> | 2009-03-04 12:58:56 +0000 |
---|---|---|
committer | Matthew Wild <mwild1@gmail.com> | 2009-03-04 12:58:56 +0000 |
commit | 775b18bd76bb0434214b0a92e1ec5d31d252cc26 (patch) | |
tree | 541105895a3dd8fb450e2ac8ce6300ce07e79697 /net/dns.lua | |
parent | 6a5be713088b28e3fd2cbb46ba38b9d13c26090d (diff) | |
download | prosody-775b18bd76bb0434214b0a92e1ec5d31d252cc26.tar.gz prosody-775b18bd76bb0434214b0a92e1ec5d31d252cc26.zip |
net.dns: Add methods necessary for allowing non-blocking DNS lookups
Diffstat (limited to 'net/dns.lua')
-rw-r--r-- | net/dns.lua | 94 |
1 files changed, 72 insertions, 22 deletions
diff --git a/net/dns.lua b/net/dns.lua index c86a0886..70d731b7 100644 --- a/net/dns.lua +++ b/net/dns.lua @@ -16,7 +16,7 @@ require 'socket' local ztact = require 'util.ztact' - +local require = require local coroutine, io, math, socket, string, table = coroutine, io, math, socket, string, table @@ -253,7 +253,7 @@ function resolver:word () -- - - - - - - - - - - - - - - - - - - - - - word function resolver:dword () -- - - - - - - - - - - - - - - - - - - - - dword local b1, b2, b3, b4 = self:byte (4) - -- print ('dword', b1, b2, b3, b4) + --print ('dword', b1, b2, b3, b4) return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4 end @@ -269,7 +269,7 @@ function resolver:sub (len) -- - - - - - - - - - - - - - - - - - - - - - sub function resolver:header (force) -- - - - - - - - - - - - - - - - - - header local id = self:word () - -- print (string.format (':header id %x', id)) + --print (string.format (':header id %x', id)) if not self.active[id] and not force then return nil end local h = { id = id } @@ -322,7 +322,7 @@ function resolver:question () -- - - - - - - - - - - - - - - - - - question local q = {} q.name = self:name () q.type = dns.type[self:word ()] - q.class = dns.type[self:word ()] + q.class = dns.class[self:word ()] return q end @@ -346,7 +346,7 @@ function resolver:MX (rr) -- - - - - - - - - - - - - - - - - - - - - - - MX function resolver:LOC_nibble_power () -- - - - - - - - - - LOC_nibble_power local b = self:byte () - -- print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10)) + --print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10)) return ((b-(b%0x10))/0x10) * (10^(b%0x10)) end @@ -549,12 +549,12 @@ function resolver:closeall () -- - - - - - - - - - - - - - - - - - closeall function resolver:remember (rr, type) -- - - - - - - - - - - - - - remember - -- print ('remember', type, rr.class, rr.type, rr.name) + --print ('remember', type, rr.class, rr.type, rr.name) if type ~= '*' then type = rr.type local all = get (self.cache, rr.class, '*', rr.name) - -- print ('remember all', all) + --print ('remember all', all) if all then append (all, rr) end end @@ -599,14 +599,14 @@ function resolver:query (qname, qtype, qclass) -- - - - - - - - - - -- query qname, qtype, qclass = standardize (qname, qtype, qclass) - if not self.server then self:adddefaultnameservers () end + if not self.server then self:adddefaultnameservers () end local question = encodeQuestion (qname, qtype, qclass) local peek = self:peek (qname, qtype, qclass) if peek then return peek end local header, id = encodeHeader () - -- print ('query id', id, qclass, qtype, qname) + --print ('query id', id, qclass, qtype, qname) local o = { packet = header..question, server = 1, delay = 1, @@ -621,13 +621,15 @@ function resolver:query (qname, qtype, qclass) -- - - - - - - - - - -- query local co = coroutine.running () if co then set (self.wanted, qclass, qtype, qname, co, true) - set (self.yielded, co, qclass, qtype, qname, true) - end end + --set (self.yielded, co, qclass, qtype, qname, true) + end +end + function resolver:receive (rset) -- - - - - - - - - - - - - - - - - receive - -- print 'receive' print (self.socket) + --print 'receive' print (self.socket) self.time = socket.gettime () rset = rset or self.socket @@ -640,8 +642,8 @@ function resolver:receive (rset) -- - - - - - - - - - - - - - - - - receive response = self:decode (packet) if response then - -- print 'received response' - -- self.print (response) + --print 'received response' + --self.print (response) for i,section in pairs { 'answer', 'authority', 'additional' } do for j,rr in pairs (response[section]) do @@ -660,7 +662,7 @@ function resolver:receive (rset) -- - - - - - - - - - - - - - - - - receive if cos then for co in pairs (cos) do set (self.yielded, co, q.class, q.type, q.name, nil) - if not self.yielded[co] then coroutine.resume (co) end + if coroutine.status(co) == "suspended" then coroutine.resume (co) end end set (self.wanted, q.class, q.type, q.name, nil) end end end end end @@ -669,10 +671,51 @@ function resolver:receive (rset) -- - - - - - - - - - - - - - - - - receive end +function resolver:feed(sock, packet) + --print 'receive' print (self.socket) + self.time = socket.gettime () + + local response = self:decode (packet) + if response then + --print 'received response' + --self.print (response) + + for i,section in pairs { 'answer', 'authority', 'additional' } do + for j,rr in pairs (response[section]) do + self:remember (rr, response.question[1].type) + end + end + + -- retire the query + local queries = self.active[response.header.id] + if queries[response.question.raw] then + queries[response.question.raw] = nil + end + if not next (queries) then self.active[response.header.id] = nil end + if not next (self.active) then self:closeall () end + + -- was the query on the wanted list? + local q = response.question[1] + if q then + local cos = get (self.wanted, q.class, q.type, q.name) + if cos then + for co in pairs (cos) do + set (self.yielded, co, q.class, q.type, q.name, nil) + if coroutine.status(co) == "suspended" then coroutine.resume (co) end + end + set (self.wanted, q.class, q.type, q.name, nil) + end + end + end + + return response +end + + function resolver:pulse () -- - - - - - - - - - - - - - - - - - - - - pulse - -- print ':pulse' - while self:receive () do end + --print ':pulse' + while self:receive() do end if not next (self.active) then return nil end self.time = socket.gettime () @@ -687,12 +730,12 @@ function resolver:pulse () -- - - - - - - - - - - - - - - - - - - - - pulse end if o.delay > #self.delays then - print ('timeout') + --print ('timeout') queries[question] = nil if not next (queries) then self.active[id] = nil end if not next (self.active) then return nil end else - -- print ('retry', o.server, o.delay) + --print ('retry', o.server, o.delay) local _a = self.socket[o.server]; if _a then _a:send (o.packet) end o.retry = self.time + self.delays[o.delay] @@ -706,12 +749,16 @@ function resolver:pulse () -- - - - - - - - - - - - - - - - - - - - - pulse function resolver:lookup (qname, qtype, qclass) -- - - - - - - - - - lookup self:query (qname, qtype, qclass) while self:pulse () do socket.select (self.socket, nil, 4) end - -- print (self.cache) + --print (self.cache) return self:peek (qname, qtype, qclass) end +function resolver:lookupex (handler, qname, qtype, qclass) -- - - - - - - - - - lookup + return self:peek (qname, qtype, qclass) or self:query (qname, qtype, qclass) + end + --- print ---------------------------------------------------------------- print +--print ---------------------------------------------------------------- print local hints = { -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints @@ -758,7 +805,7 @@ function resolver.print (response) -- - - - - - - - - - - - - resolver.print for j,t in pairs (rr) do if not common[j] then tmp = string.format ('%s[%i].%s', s, i, j) - print (string.format ('%-30s %s', tmp, t)) + print (string.format ('%-30s %s', tostring(tmp), tostring(t))) end end end end end @@ -797,6 +844,9 @@ function dns.peek (...) -- - - - - - - - - - - - - - - - - - - - - - - peek function dns.query (...) -- - - - - - - - - - - - - - - - - - - - - - query return resolve (resolver.query, ...) end +function dns.feed (...) -- - - - - - - - - - - - - - - - - - - - - - feed + return resolve (resolver.feed, ...) end + function dns:socket_wrapper_set (...) -- - - - - - - - - socket_wrapper_set return resolve (resolver.socket_wrapper_set, ...) end |