aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
authorMatthew Wild <mwild1@gmail.com>2009-03-04 12:58:56 +0000
committerMatthew Wild <mwild1@gmail.com>2009-03-04 12:58:56 +0000
commit775b18bd76bb0434214b0a92e1ec5d31d252cc26 (patch)
tree541105895a3dd8fb450e2ac8ce6300ce07e79697 /net
parent6a5be713088b28e3fd2cbb46ba38b9d13c26090d (diff)
downloadprosody-775b18bd76bb0434214b0a92e1ec5d31d252cc26.tar.gz
prosody-775b18bd76bb0434214b0a92e1ec5d31d252cc26.zip
net.dns: Add methods necessary for allowing non-blocking DNS lookups
Diffstat (limited to 'net')
-rw-r--r--net/dns.lua94
1 files changed, 72 insertions, 22 deletions
diff --git a/net/dns.lua b/net/dns.lua
index c86a0886..70d731b7 100644
--- a/net/dns.lua
+++ b/net/dns.lua
@@ -16,7 +16,7 @@
require 'socket'
local ztact = require 'util.ztact'
-
+local require = require
local coroutine, io, math, socket, string, table =
coroutine, io, math, socket, string, table
@@ -253,7 +253,7 @@ function resolver:word () -- - - - - - - - - - - - - - - - - - - - - - word
function resolver:dword () -- - - - - - - - - - - - - - - - - - - - - dword
local b1, b2, b3, b4 = self:byte (4)
- -- print ('dword', b1, b2, b3, b4)
+ --print ('dword', b1, b2, b3, b4)
return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4
end
@@ -269,7 +269,7 @@ function resolver:sub (len) -- - - - - - - - - - - - - - - - - - - - - - sub
function resolver:header (force) -- - - - - - - - - - - - - - - - - - header
local id = self:word ()
- -- print (string.format (':header id %x', id))
+ --print (string.format (':header id %x', id))
if not self.active[id] and not force then return nil end
local h = { id = id }
@@ -322,7 +322,7 @@ function resolver:question () -- - - - - - - - - - - - - - - - - - question
local q = {}
q.name = self:name ()
q.type = dns.type[self:word ()]
- q.class = dns.type[self:word ()]
+ q.class = dns.class[self:word ()]
return q
end
@@ -346,7 +346,7 @@ function resolver:MX (rr) -- - - - - - - - - - - - - - - - - - - - - - - MX
function resolver:LOC_nibble_power () -- - - - - - - - - - LOC_nibble_power
local b = self:byte ()
- -- print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10))
+ --print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10))
return ((b-(b%0x10))/0x10) * (10^(b%0x10))
end
@@ -549,12 +549,12 @@ function resolver:closeall () -- - - - - - - - - - - - - - - - - - closeall
function resolver:remember (rr, type) -- - - - - - - - - - - - - - remember
- -- print ('remember', type, rr.class, rr.type, rr.name)
+ --print ('remember', type, rr.class, rr.type, rr.name)
if type ~= '*' then
type = rr.type
local all = get (self.cache, rr.class, '*', rr.name)
- -- print ('remember all', all)
+ --print ('remember all', all)
if all then append (all, rr) end
end
@@ -599,14 +599,14 @@ function resolver:query (qname, qtype, qclass) -- - - - - - - - - - -- query
qname, qtype, qclass = standardize (qname, qtype, qclass)
- if not self.server then self:adddefaultnameservers () end
+ if not self.server then self:adddefaultnameservers () end
local question = encodeQuestion (qname, qtype, qclass)
local peek = self:peek (qname, qtype, qclass)
if peek then return peek end
local header, id = encodeHeader ()
- -- print ('query id', id, qclass, qtype, qname)
+ --print ('query id', id, qclass, qtype, qname)
local o = { packet = header..question,
server = 1,
delay = 1,
@@ -621,13 +621,15 @@ function resolver:query (qname, qtype, qclass) -- - - - - - - - - - -- query
local co = coroutine.running ()
if co then
set (self.wanted, qclass, qtype, qname, co, true)
- set (self.yielded, co, qclass, qtype, qname, true)
- end end
+ --set (self.yielded, co, qclass, qtype, qname, true)
+ end
+end
+
function resolver:receive (rset) -- - - - - - - - - - - - - - - - - receive
- -- print 'receive' print (self.socket)
+ --print 'receive' print (self.socket)
self.time = socket.gettime ()
rset = rset or self.socket
@@ -640,8 +642,8 @@ function resolver:receive (rset) -- - - - - - - - - - - - - - - - - receive
response = self:decode (packet)
if response then
- -- print 'received response'
- -- self.print (response)
+ --print 'received response'
+ --self.print (response)
for i,section in pairs { 'answer', 'authority', 'additional' } do
for j,rr in pairs (response[section]) do
@@ -660,7 +662,7 @@ function resolver:receive (rset) -- - - - - - - - - - - - - - - - - receive
if cos then
for co in pairs (cos) do
set (self.yielded, co, q.class, q.type, q.name, nil)
- if not self.yielded[co] then coroutine.resume (co) end
+ if coroutine.status(co) == "suspended" then coroutine.resume (co) end
end
set (self.wanted, q.class, q.type, q.name, nil)
end end end end end
@@ -669,10 +671,51 @@ function resolver:receive (rset) -- - - - - - - - - - - - - - - - - receive
end
+function resolver:feed(sock, packet)
+ --print 'receive' print (self.socket)
+ self.time = socket.gettime ()
+
+ local response = self:decode (packet)
+ if response then
+ --print 'received response'
+ --self.print (response)
+
+ for i,section in pairs { 'answer', 'authority', 'additional' } do
+ for j,rr in pairs (response[section]) do
+ self:remember (rr, response.question[1].type)
+ end
+ end
+
+ -- retire the query
+ local queries = self.active[response.header.id]
+ if queries[response.question.raw] then
+ queries[response.question.raw] = nil
+ end
+ if not next (queries) then self.active[response.header.id] = nil end
+ if not next (self.active) then self:closeall () end
+
+ -- was the query on the wanted list?
+ local q = response.question[1]
+ if q then
+ local cos = get (self.wanted, q.class, q.type, q.name)
+ if cos then
+ for co in pairs (cos) do
+ set (self.yielded, co, q.class, q.type, q.name, nil)
+ if coroutine.status(co) == "suspended" then coroutine.resume (co) end
+ end
+ set (self.wanted, q.class, q.type, q.name, nil)
+ end
+ end
+ end
+
+ return response
+end
+
+
function resolver:pulse () -- - - - - - - - - - - - - - - - - - - - - pulse
- -- print ':pulse'
- while self:receive () do end
+ --print ':pulse'
+ while self:receive() do end
if not next (self.active) then return nil end
self.time = socket.gettime ()
@@ -687,12 +730,12 @@ function resolver:pulse () -- - - - - - - - - - - - - - - - - - - - - pulse
end
if o.delay > #self.delays then
- print ('timeout')
+ --print ('timeout')
queries[question] = nil
if not next (queries) then self.active[id] = nil end
if not next (self.active) then return nil end
else
- -- print ('retry', o.server, o.delay)
+ --print ('retry', o.server, o.delay)
local _a = self.socket[o.server];
if _a then _a:send (o.packet) end
o.retry = self.time + self.delays[o.delay]
@@ -706,12 +749,16 @@ function resolver:pulse () -- - - - - - - - - - - - - - - - - - - - - pulse
function resolver:lookup (qname, qtype, qclass) -- - - - - - - - - - lookup
self:query (qname, qtype, qclass)
while self:pulse () do socket.select (self.socket, nil, 4) end
- -- print (self.cache)
+ --print (self.cache)
return self:peek (qname, qtype, qclass)
end
+function resolver:lookupex (handler, qname, qtype, qclass) -- - - - - - - - - - lookup
+ return self:peek (qname, qtype, qclass) or self:query (qname, qtype, qclass)
+ end
+
--- print ---------------------------------------------------------------- print
+--print ---------------------------------------------------------------- print
local hints = { -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints
@@ -758,7 +805,7 @@ function resolver.print (response) -- - - - - - - - - - - - - resolver.print
for j,t in pairs (rr) do
if not common[j] then
tmp = string.format ('%s[%i].%s', s, i, j)
- print (string.format ('%-30s %s', tmp, t))
+ print (string.format ('%-30s %s', tostring(tmp), tostring(t)))
end end end end end
@@ -797,6 +844,9 @@ function dns.peek (...) -- - - - - - - - - - - - - - - - - - - - - - - peek
function dns.query (...) -- - - - - - - - - - - - - - - - - - - - - - query
return resolve (resolver.query, ...) end
+function dns.feed (...) -- - - - - - - - - - - - - - - - - - - - - - feed
+ return resolve (resolver.feed, ...) end
+
function dns:socket_wrapper_set (...) -- - - - - - - - - socket_wrapper_set
return resolve (resolver.socket_wrapper_set, ...) end