diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/adns.lua | 24 | ||||
-rw-r--r-- | net/connlisteners.lua | 6 | ||||
-rw-r--r-- | net/dns.lua | 1409 | ||||
-rw-r--r-- | net/httpserver.lua | 31 | ||||
-rw-r--r-- | net/server.lua | 141 | ||||
-rw-r--r-- | net/xmppclient_listener.lua | 11 | ||||
-rw-r--r-- | net/xmppcomponent_listener.lua | 6 | ||||
-rw-r--r-- | net/xmppserver_listener.lua | 24 |
8 files changed, 908 insertions, 744 deletions
diff --git a/net/adns.lua b/net/adns.lua index 34ef5d77..b0c9a625 100644 --- a/net/adns.lua +++ b/net/adns.lua @@ -11,6 +11,7 @@ local dns = require "net.dns"; local log = require "util.logger".init("adns"); +local t_insert, t_remove = table.insert, table.remove; local coroutine, tostring, pcall = coroutine, tostring, pcall; module "adns" @@ -28,7 +29,7 @@ function lookup(handler, qname, qtype, qclass) log("debug", "Reply for %s (%s)", qname, tostring(coroutine.running())); local ok, err = pcall(handler, dns.peek(qname, qtype, qclass)); if not ok then - log("debug", "Error in DNS response handler: %s", tostring(err)); + log("error", "Error in DNS response handler: %s", tostring(err)); end end)(dns.peek(qname, qtype, qclass)); end @@ -41,18 +42,31 @@ function cancel(handle, call_handler) end end -function new_async_socket(sock) - local newconn = {}; +function new_async_socket(sock, resolver) + local newconn, peername = {}, "<unknown>"; local listener = {}; function listener.incoming(conn, data) dns.feed(sock, data); end - function listener.disconnect() + function listener.disconnect(conn, err) + log("warn", "DNS socket for %s disconnected: %s", peername, err); + local servers = resolver.server; + if resolver.socketset[newconn.handler] == resolver.best_server and resolver.best_server == #servers then + log("error", "Exhausted all %d configured DNS servers, next lookup will try %s again", #servers, servers[1]); + end + + resolver:servfail(conn); -- Let the magic commence end newconn.handler, newconn._socket = server.wrapclient(sock, "dns", 53, listener); + if not newconn.handler then + log("warn", "handler is nil"); + end + if not newconn._socket then + log("warn", "socket is nil"); + end newconn.handler.settimeout = function () end newconn.handler.setsockname = function (_, ...) return sock:setsockname(...); end - newconn.handler.setpeername = function (_, ...) local ret = sock:setpeername(...); _.setsend(sock.send); return ret; end + newconn.handler.setpeername = function (_, ...) peername = (...); local ret = sock:setpeername(...); _.setsend(sock.send); return ret; end newconn.handler.connect = function (_, ...) return sock:connect(...) end newconn.handler.send = function (_, data) _.write(data); return _.sendbuffer(); end return newconn.handler; diff --git a/net/connlisteners.lua b/net/connlisteners.lua index ebb3cc18..230d92a4 100644 --- a/net/connlisteners.lua +++ b/net/connlisteners.lua @@ -11,6 +11,7 @@ local listeners_dir = (CFG_SOURCEDIR or ".").."/net/"; local server = require "net.server"; local log = require "util.logger".init("connlisteners"); +local tostring = tostring; local dofile, pcall, error = dofile, pcall, error @@ -37,7 +38,10 @@ function get(name) local h = listeners[name]; if not h then local ok, ret = pcall(dofile, listeners_dir..name:gsub("[^%w%-]", "_").."_listener.lua"); - if not ok then return nil, ret; end + if not ok then + log("error", "Error while loading listener '%s': %s", tostring(name), tostring(ret)); + return nil, ret; + end h = listeners[name]; end return h; diff --git a/net/dns.lua b/net/dns.lua index 48c08218..04b2cf22 100644 --- a/net/dns.lua +++ b/net/dns.lua @@ -14,21 +14,22 @@ -- reference: http://tools.ietf.org/html/rfc1876 (LOC) -require 'socket' -local ztact = require 'util.ztact' -local require = require +local socket = require "socket"; +local ztact = require "util.ztact"; +local _, windows = pcall(require, "util.windows"); +local is_windows = (_ and windows) or os.getenv("WINDIR"); -local coroutine, io, math, socket, string, table = - coroutine, io, math, socket, string, table +local coroutine, io, math, string, table = + coroutine, io, math, string, table; local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack = - ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack + ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack; -local get, set = ztact.get, ztact.set +local get, set = ztact.get, ztact.set; -------------------------------------------------- module dns -module ('dns') +module('dns') local dns = _M; @@ -38,826 +39,928 @@ local dns = _M; local append = table.insert -local function highbyte (i) -- - - - - - - - - - - - - - - - - - - highbyte - return (i-(i%0x100))/0x100 - end +local function highbyte(i) -- - - - - - - - - - - - - - - - - - - highbyte + return (i-(i%0x100))/0x100; +end local function augment (t) -- - - - - - - - - - - - - - - - - - - - augment - local a = {} - for i,s in pairs (t) do a[i] = s a[s] = s a[string.lower (s)] = s end - return a - end + local a = {}; + for i,s in pairs(t) do + a[i] = s; + a[s] = s; + a[string.lower(s)] = s; + end + return a; +end local function encode (t) -- - - - - - - - - - - - - - - - - - - - - encode - local code = {} - for i,s in pairs (t) do - local word = string.char (highbyte (i), i %0x100) - code[i] = word - code[s] = word - code[string.lower (s)] = word - end - return code - end + local code = {}; + for i,s in pairs(t) do + local word = string.char(highbyte(i), i%0x100); + code[i] = word; + code[s] = word; + code[string.lower(s)] = word; + end + return code; +end dns.types = { - 'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS', - 'PTR', 'HINFO', 'MINFO', 'MX', 'TXT', - [ 28] = 'AAAA', [ 29] = 'LOC', [ 33] = 'SRV', - [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' } + 'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS', + 'PTR', 'HINFO', 'MINFO', 'MX', 'TXT', + [ 28] = 'AAAA', [ 29] = 'LOC', [ 33] = 'SRV', + [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' }; -dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' } +dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' }; -dns.type = augment (dns.types) -dns.class = augment (dns.classes) -dns.typecode = encode (dns.types) -dns.classcode = encode (dns.classes) +dns.type = augment (dns.types); +dns.class = augment (dns.classes); +dns.typecode = encode (dns.types); +dns.classcode = encode (dns.classes); -local function standardize (qname, qtype, qclass) -- - - - - - - standardize - if string.byte (qname, -1) ~= 0x2E then qname = qname..'.' end - qname = string.lower (qname) - return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN'] - end - - -local function prune (rrs, time, soft) -- - - - - - - - - - - - - - - prune - - time = time or socket.gettime () - for i,rr in pairs (rrs) do +local function standardize(qname, qtype, qclass) -- - - - - - - standardize + if string.byte(qname, -1) ~= 0x2E then qname = qname..'.'; end + qname = string.lower(qname); + return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN']; +end - if rr.tod then - -- rr.tod = rr.tod - 50 -- accelerated decripitude - rr.ttl = math.floor (rr.tod - time) - if rr.ttl <= 0 then rrs[i] = nil end - elseif soft == 'soft' then -- What is this? I forget! - assert (rr.ttl == 0) - rrs[i] = nil - end end end +local function prune(rrs, time, soft) -- - - - - - - - - - - - - - - prune + time = time or socket.gettime(); + for i,rr in pairs(rrs) do + if rr.tod then + -- rr.tod = rr.tod - 50 -- accelerated decripitude + rr.ttl = math.floor(rr.tod - time); + if rr.ttl <= 0 then + table.remove(rrs, i); + return prune(rrs, time, soft); -- Re-iterate + end + elseif soft == 'soft' then -- What is this? I forget! + assert(rr.ttl == 0); + rrs[i] = nil; + end + end +end -- metatables & co. ------------------------------------------ metatables & co. -local resolver = {} -resolver.__index = resolver - - -local SRV_tostring - - -local rr_metatable = {} -- - - - - - - - - - - - - - - - - - - rr_metatable -function rr_metatable.__tostring (rr) - local s0 = string.format ( - '%2s %-5s %6i %-28s', rr.class, rr.type, rr.ttl, rr.name ) - local s1 = '' - if rr.type == 'A' then s1 = ' '..rr.a - elseif rr.type == 'MX' then - s1 = string.format (' %2i %s', rr.pref, rr.mx) - elseif rr.type == 'CNAME' then s1 = ' '..rr.cname - elseif rr.type == 'LOC' then s1 = ' '..resolver.LOC_tostring (rr) - elseif rr.type == 'NS' then s1 = ' '..rr.ns - elseif rr.type == 'SRV' then s1 = ' '..SRV_tostring (rr) - elseif rr.type == 'TXT' then s1 = ' '..rr.txt - else s1 = ' <UNKNOWN RDATA TYPE>' end - return s0..s1 - end +local resolver = {}; +resolver.__index = resolver; + + +local SRV_tostring; + + +local rr_metatable = {}; -- - - - - - - - - - - - - - - - - - - rr_metatable +function rr_metatable.__tostring(rr) + local s0 = string.format('%2s %-5s %6i %-28s', rr.class, rr.type, rr.ttl, rr.name); + local s1 = ''; + if rr.type == 'A' then + s1 = ' '..rr.a; + elseif rr.type == 'MX' then + s1 = string.format(' %2i %s', rr.pref, rr.mx); + elseif rr.type == 'CNAME' then + s1 = ' '..rr.cname; + elseif rr.type == 'LOC' then + s1 = ' '..resolver.LOC_tostring(rr); + elseif rr.type == 'NS' then + s1 = ' '..rr.ns; + elseif rr.type == 'SRV' then + s1 = ' '..SRV_tostring(rr); + elseif rr.type == 'TXT' then + s1 = ' '..rr.txt; + else + s1 = ' <UNKNOWN RDATA TYPE>'; + end + return s0..s1; +end -local rrs_metatable = {} -- - - - - - - - - - - - - - - - - - rrs_metatable -function rrs_metatable.__tostring (rrs) - local t = {} - for i,rr in pairs (rrs) do append (t, tostring (rr)..'\n') end - return table.concat (t) - end +local rrs_metatable = {}; -- - - - - - - - - - - - - - - - - - rrs_metatable +function rrs_metatable.__tostring(rrs) + local t = {}; + for i,rr in pairs(rrs) do + append(t, tostring(rr)..'\n'); + end + return table.concat(t); +end -local cache_metatable = {} -- - - - - - - - - - - - - - - - cache_metatable -function cache_metatable.__tostring (cache) - local time = socket.gettime () - local t = {} - for class,types in pairs (cache) do - for type,names in pairs (types) do - for name,rrs in pairs (names) do - prune (rrs, time) - append (t, tostring (rrs)) end end end - return table.concat (t) - end +local cache_metatable = {}; -- - - - - - - - - - - - - - - - cache_metatable +function cache_metatable.__tostring(cache) + local time = socket.gettime(); + local t = {}; + for class,types in pairs(cache) do + for type,names in pairs(types) do + for name,rrs in pairs(names) do + prune(rrs, time); + append(t, tostring(rrs)); + end + end + end + return table.concat(t); +end -function resolver:new () -- - - - - - - - - - - - - - - - - - - - - resolver - local r = { active = {}, cache = {}, unsorted = {} } - setmetatable (r, resolver) - setmetatable (r.cache, cache_metatable) - setmetatable (r.unsorted, { __mode = 'kv' }) - return r - end +function resolver:new() -- - - - - - - - - - - - - - - - - - - - - resolver + local r = { active = {}, cache = {}, unsorted = {} }; + setmetatable(r, resolver); + setmetatable(r.cache, cache_metatable); + setmetatable(r.unsorted, { __mode = 'kv' }); + return r; +end -- packet layer -------------------------------------------------- packet layer -function dns.random (...) -- - - - - - - - - - - - - - - - - - - dns.random - math.randomseed (10000*socket.gettime ()) - dns.random = math.random - return dns.random (...) - end - - -local function encodeHeader (o) -- - - - - - - - - - - - - - - encodeHeader +function dns.random(...) -- - - - - - - - - - - - - - - - - - - dns.random + math.randomseed(10000*socket.gettime()); + dns.random = math.random; + return dns.random(...); +end - o = o or {} - o.id = o.id or -- 16b (random) id - dns.random (0, 0xffff) +local function encodeHeader(o) -- - - - - - - - - - - - - - - encodeHeader + o = o or {}; + o.id = o.id or dns.random(0, 0xffff); -- 16b (random) id - o.rd = o.rd or 1 -- 1b 1 recursion desired - o.tc = o.tc or 0 -- 1b 1 truncated response - o.aa = o.aa or 0 -- 1b 1 authoritative response - o.opcode = o.opcode or 0 -- 4b 0 query - -- 1 inverse query + o.rd = o.rd or 1; -- 1b 1 recursion desired + o.tc = o.tc or 0; -- 1b 1 truncated response + o.aa = o.aa or 0; -- 1b 1 authoritative response + o.opcode = o.opcode or 0; -- 4b 0 query + -- 1 inverse query -- 2 server status request -- 3-15 reserved - o.qr = o.qr or 0 -- 1b 0 query, 1 response + o.qr = o.qr or 0; -- 1b 0 query, 1 response - o.rcode = o.rcode or 0 -- 4b 0 no error + o.rcode = o.rcode or 0; -- 4b 0 no error -- 1 format error -- 2 server failure -- 3 name error -- 4 not implemented -- 5 refused -- 6-15 reserved - o.z = o.z or 0 -- 3b 0 resvered - o.ra = o.ra or 0 -- 1b 1 recursion available - - o.qdcount = o.qdcount or 1 -- 16b number of question RRs - o.ancount = o.ancount or 0 -- 16b number of answers RRs - o.nscount = o.nscount or 0 -- 16b number of nameservers RRs - o.arcount = o.arcount or 0 -- 16b number of additional RRs - - -- string.char() rounds, so prevent roundup with -0.4999 - local header = string.char ( - highbyte (o.id), o.id %0x100, - o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr, - o.rcode + 16*o.z + 128*o.ra, - highbyte (o.qdcount), o.qdcount %0x100, - highbyte (o.ancount), o.ancount %0x100, - highbyte (o.nscount), o.nscount %0x100, - highbyte (o.arcount), o.arcount %0x100 ) - - return header, o.id - end - - -local function encodeName (name) -- - - - - - - - - - - - - - - - encodeName - local t = {} - for part in string.gmatch (name, '[^.]+') do - append (t, string.char (string.len (part))) - append (t, part) - end - append (t, string.char (0)) - return table.concat (t) - end - - -local function encodeQuestion (qname, qtype, qclass) -- - - - encodeQuestion - qname = encodeName (qname) - qtype = dns.typecode[qtype or 'a'] - qclass = dns.classcode[qclass or 'in'] - return qname..qtype..qclass; - end - - -function resolver:byte (len) -- - - - - - - - - - - - - - - - - - - - - byte - len = len or 1 - local offset = self.offset - local last = offset + len - 1 - if last > #self.packet then - error (string.format ('out of bounds: %i>%i', last, #self.packet)) end - self.offset = offset + len - return string.byte (self.packet, offset, last) - end - - -function resolver:word () -- - - - - - - - - - - - - - - - - - - - - - word - local b1, b2 = self:byte (2) - return 0x100*b1 + b2 - end + o.z = o.z or 0; -- 3b 0 resvered + o.ra = o.ra or 0; -- 1b 1 recursion available + + o.qdcount = o.qdcount or 1; -- 16b number of question RRs + o.ancount = o.ancount or 0; -- 16b number of answers RRs + o.nscount = o.nscount or 0; -- 16b number of nameservers RRs + o.arcount = o.arcount or 0; -- 16b number of additional RRs + + -- string.char() rounds, so prevent roundup with -0.4999 + local header = string.char( + highbyte(o.id), o.id %0x100, + o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr, + o.rcode + 16*o.z + 128*o.ra, + highbyte(o.qdcount), o.qdcount %0x100, + highbyte(o.ancount), o.ancount %0x100, + highbyte(o.nscount), o.nscount %0x100, + highbyte(o.arcount), o.arcount %0x100 + ); + + return header, o.id; +end + + +local function encodeName(name) -- - - - - - - - - - - - - - - - encodeName + local t = {}; + for part in string.gmatch(name, '[^.]+') do + append(t, string.char(string.len(part))); + append(t, part); + end + append(t, string.char(0)); + return table.concat(t); +end + + +local function encodeQuestion(qname, qtype, qclass) -- - - - encodeQuestion + qname = encodeName(qname); + qtype = dns.typecode[qtype or 'a']; + qclass = dns.classcode[qclass or 'in']; + return qname..qtype..qclass; +end + + +function resolver:byte(len) -- - - - - - - - - - - - - - - - - - - - - byte + len = len or 1; + local offset = self.offset; + local last = offset + len - 1; + if last > #self.packet then + error(string.format('out of bounds: %i>%i', last, #self.packet)); + end + self.offset = offset + len; + return string.byte(self.packet, offset, last); +end + + +function resolver:word() -- - - - - - - - - - - - - - - - - - - - - - word + local b1, b2 = self:byte(2); + return 0x100*b1 + b2; +end function resolver:dword () -- - - - - - - - - - - - - - - - - - - - - dword - local b1, b2, b3, b4 = self:byte (4) - --print ('dword', b1, b2, b3, b4) - return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4 - end + local b1, b2, b3, b4 = self:byte(4); + --print('dword', b1, b2, b3, b4); + return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4; +end -function resolver:sub (len) -- - - - - - - - - - - - - - - - - - - - - - sub - len = len or 1 - local s = string.sub (self.packet, self.offset, self.offset + len - 1) - self.offset = self.offset + len - return s - end +function resolver:sub(len) -- - - - - - - - - - - - - - - - - - - - - - sub + len = len or 1; + local s = string.sub(self.packet, self.offset, self.offset + len - 1); + self.offset = self.offset + len; + return s; +end -function resolver:header (force) -- - - - - - - - - - - - - - - - - - header - - local id = self:word () - --print (string.format (':header id %x', id)) - if not self.active[id] and not force then return nil end - - local h = { id = id } - - local b1, b2 = self:byte (2) - - h.rd = b1 %2 - h.tc = b1 /2%2 - h.aa = b1 /4%2 - h.opcode = b1 /8%16 - h.qr = b1 /128 +function resolver:header(force) -- - - - - - - - - - - - - - - - - - header + local id = self:word(); + --print(string.format(':header id %x', id)); + if not self.active[id] and not force then return nil; end - h.rcode = b2 %16 - h.z = b2 /16%8 - h.ra = b2 /128 - - h.qdcount = self:word () - h.ancount = self:word () - h.nscount = self:word () - h.arcount = self:word () - - for k,v in pairs (h) do h[k] = v-v%1 end - - return h - end - - -function resolver:name () -- - - - - - - - - - - - - - - - - - - - - - name - local remember, pointers = nil, 0 - local len = self:byte () - local n = {} - while len > 0 do - if len >= 0xc0 then -- name is "compressed" - pointers = pointers + 1 - if pointers >= 20 then error ('dns error: 20 pointers') end - local offset = ((len-0xc0)*0x100) + self:byte () - remember = remember or self.offset - self.offset = offset + 1 -- +1 for lua - else -- name is not compressed - append (n, self:sub (len)..'.') - end - len = self:byte () - end - self.offset = remember or self.offset - return table.concat (n) - end + local h = { id = id }; + local b1, b2 = self:byte(2); -function resolver:question () -- - - - - - - - - - - - - - - - - - question - local q = {} - q.name = self:name () - q.type = dns.type[self:word ()] - q.class = dns.class[self:word ()] - return q - end + h.rd = b1 %2; + h.tc = b1 /2%2; + h.aa = b1 /4%2; + h.opcode = b1 /8%16; + h.qr = b1 /128; + h.rcode = b2 %16; + h.z = b2 /16%8; + h.ra = b2 /128; -function resolver:A (rr) -- - - - - - - - - - - - - - - - - - - - - - - - A - local b1, b2, b3, b4 = self:byte (4) - rr.a = string.format ('%i.%i.%i.%i', b1, b2, b3, b4) - end + h.qdcount = self:word(); + h.ancount = self:word(); + h.nscount = self:word(); + h.arcount = self:word(); + for k,v in pairs(h) do h[k] = v-v%1; end -function resolver:CNAME (rr) -- - - - - - - - - - - - - - - - - - - - CNAME - rr.cname = self:name () - end + return h; +end -function resolver:MX (rr) -- - - - - - - - - - - - - - - - - - - - - - - MX - rr.pref = self:word () - rr.mx = self:name () - end +function resolver:name() -- - - - - - - - - - - - - - - - - - - - - - name + local remember, pointers = nil, 0; + local len = self:byte(); + local n = {}; + while len > 0 do + if len >= 0xc0 then -- name is "compressed" + pointers = pointers + 1; + if pointers >= 20 then error('dns error: 20 pointers'); end; + local offset = ((len-0xc0)*0x100) + self:byte(); + remember = remember or self.offset; + self.offset = offset + 1; -- +1 for lua + else -- name is not compressed + append(n, self:sub(len)..'.'); + end + len = self:byte(); + end + self.offset = remember or self.offset; + return table.concat(n); +end -function resolver:LOC_nibble_power () -- - - - - - - - - - LOC_nibble_power - local b = self:byte () - --print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10)) - return ((b-(b%0x10))/0x10) * (10^(b%0x10)) - end +function resolver:question() -- - - - - - - - - - - - - - - - - - question + local q = {}; + q.name = self:name(); + q.type = dns.type[self:word()]; + q.class = dns.class[self:word()]; + return q; +end -function resolver:LOC (rr) -- - - - - - - - - - - - - - - - - - - - - - LOC - rr.version = self:byte () - if rr.version == 0 then - rr.loc = rr.loc or {} - rr.loc.size = self:LOC_nibble_power () - rr.loc.horiz_pre = self:LOC_nibble_power () - rr.loc.vert_pre = self:LOC_nibble_power () - rr.loc.latitude = self:dword () - rr.loc.longitude = self:dword () - rr.loc.altitude = self:dword () - end end +function resolver:A(rr) -- - - - - - - - - - - - - - - - - - - - - - - - A + local b1, b2, b3, b4 = self:byte(4); + rr.a = string.format('%i.%i.%i.%i', b1, b2, b3, b4); +end -local function LOC_tostring_degrees (f, pos, neg) -- - - - - - - - - - - - - - f = f - 0x80000000 - if f < 0 then pos = neg f = -f end - local deg, min, msec - msec = f%60000 - f = (f-msec)/60000 - min = f%60 - deg = (f-min)/60 - return string.format ('%3d %2d %2.3f %s', deg, min, msec/1000, pos) - end - - -function resolver.LOC_tostring (rr) -- - - - - - - - - - - - - LOC_tostring - - local t = {} - - --[[ - for k,name in pairs { 'size', 'horiz_pre', 'vert_pre', - 'latitude', 'longitude', 'altitude' } do - append (t, string.format ('%4s%-10s: %12.0f\n', '', name, rr.loc[name])) - end - --]] - - append ( t, string.format ( - '%s %s %.2fm %.2fm %.2fm %.2fm', - LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'), - LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'), - (rr.loc.altitude - 10000000) / 100, - rr.loc.size / 100, - rr.loc.horiz_pre / 100, - rr.loc.vert_pre / 100 ) ) - - return table.concat (t) - end - - -function resolver:NS (rr) -- - - - - - - - - - - - - - - - - - - - - - - NS - rr.ns = self:name () - end - +function resolver:CNAME(rr) -- - - - - - - - - - - - - - - - - - - - CNAME + rr.cname = self:name(); +end -function resolver:SOA (rr) -- - - - - - - - - - - - - - - - - - - - - - SOA - end +function resolver:MX(rr) -- - - - - - - - - - - - - - - - - - - - - - - MX + rr.pref = self:word(); + rr.mx = self:name(); +end -function resolver:SRV (rr) -- - - - - - - - - - - - - - - - - - - - - - SRV - rr.srv = {} - rr.srv.priority = self:word () - rr.srv.weight = self:word () - rr.srv.port = self:word () - rr.srv.target = self:name () - end +function resolver:LOC_nibble_power() -- - - - - - - - - - LOC_nibble_power + local b = self:byte(); + --print('nibbles', ((b-(b%0x10))/0x10), (b%0x10)); + return ((b-(b%0x10))/0x10) * (10^(b%0x10)); +end -function SRV_tostring (rr) -- - - - - - - - - - - - - - - - - - SRV_tostring - local s = rr.srv - return string.format ( '%5d %5d %5d %s', - s.priority, s.weight, s.port, s.target ) - end +function resolver:LOC(rr) -- - - - - - - - - - - - - - - - - - - - - - LOC + rr.version = self:byte(); + if rr.version == 0 then + rr.loc = rr.loc or {}; + rr.loc.size = self:LOC_nibble_power(); + rr.loc.horiz_pre = self:LOC_nibble_power(); + rr.loc.vert_pre = self:LOC_nibble_power(); + rr.loc.latitude = self:dword(); + rr.loc.longitude = self:dword(); + rr.loc.altitude = self:dword(); + end +end -function resolver:TXT (rr) -- - - - - - - - - - - - - - - - - - - - - - TXT - rr.txt = self:sub (rr.rdlength) - end +local function LOC_tostring_degrees(f, pos, neg) -- - - - - - - - - - - - - + f = f - 0x80000000; + if f < 0 then pos = neg; f = -f; end + local deg, min, msec; + msec = f%60000; + f = (f-msec)/60000; + min = f%60; + deg = (f-min)/60; + return string.format('%3d %2d %2.3f %s', deg, min, msec/1000, pos); +end -function resolver:rr () -- - - - - - - - - - - - - - - - - - - - - - - - rr - local rr = {} - setmetatable (rr, rr_metatable) - rr.name = self:name (self) - rr.type = dns.type[self:word ()] or rr.type - rr.class = dns.class[self:word ()] or rr.class - rr.ttl = 0x10000*self:word () + self:word () - rr.rdlength = self:word () - if rr.ttl == 0 then -- pass - else rr.tod = self.time + rr.ttl end +function resolver.LOC_tostring(rr) -- - - - - - - - - - - - - LOC_tostring + local t = {}; - local remember = self.offset - local rr_parser = self[dns.type[rr.type]] - if rr_parser then rr_parser (self, rr) end - self.offset = remember - rr.rdata = self:sub (rr.rdlength) - return rr - end + --[[ + for k,name in pairs { 'size', 'horiz_pre', 'vert_pre', 'latitude', 'longitude', 'altitude' } do + append(t, string.format('%4s%-10s: %12.0f\n', '', name, rr.loc[name])); + end + --]] + + append(t, string.format( + '%s %s %.2fm %.2fm %.2fm %.2fm', + LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'), + LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'), + (rr.loc.altitude - 10000000) / 100, + rr.loc.size / 100, + rr.loc.horiz_pre / 100, + rr.loc.vert_pre / 100 + )); + + return table.concat(t); +end -function resolver:rrs (count) -- - - - - - - - - - - - - - - - - - - - - rrs - local rrs = {} - for i = 1,count do append (rrs, self:rr ()) end - return rrs - end +function resolver:NS(rr) -- - - - - - - - - - - - - - - - - - - - - - - NS + rr.ns = self:name(); +end -function resolver:decode (packet, force) -- - - - - - - - - - - - - - decode +function resolver:SOA(rr) -- - - - - - - - - - - - - - - - - - - - - - SOA +end - self.packet, self.offset = packet, 1 - local header = self:header (force) - if not header then return nil end - local response = { header = header } - response.question = {} - local offset = self.offset - for i = 1,response.header.qdcount do - append (response.question, self:question ()) end - response.question.raw = string.sub (self.packet, offset, self.offset - 1) +function resolver:SRV(rr) -- - - - - - - - - - - - - - - - - - - - - - SRV + rr.srv = {}; + rr.srv.priority = self:word(); + rr.srv.weight = self:word(); + rr.srv.port = self:word(); + rr.srv.target = self:name(); +end - if not force then - if not self.active[response.header.id] or - not self.active[response.header.id][response.question.raw] then - return nil end end - response.answer = self:rrs (response.header.ancount) - response.authority = self:rrs (response.header.nscount) - response.additional = self:rrs (response.header.arcount) +function SRV_tostring(rr) -- - - - - - - - - - - - - - - - - - SRV_tostring + local s = rr.srv; + return string.format( '%5d %5d %5d %s', s.priority, s.weight, s.port, s.target ); +end - return response - end +function resolver:TXT(rr) -- - - - - - - - - - - - - - - - - - - - - - TXT + rr.txt = self:sub (rr.rdlength); +end --- socket layer -------------------------------------------------- socket layer +function resolver:rr() -- - - - - - - - - - - - - - - - - - - - - - - - rr + local rr = {}; + setmetatable(rr, rr_metatable); + rr.name = self:name(self); + rr.type = dns.type[self:word()] or rr.type; + rr.class = dns.class[self:word()] or rr.class; + rr.ttl = 0x10000*self:word() + self:word(); + rr.rdlength = self:word(); -resolver.delays = { 1, 3, 11, 45 } + if rr.ttl <= 0 then + rr.tod = self.time + 30; + else + rr.tod = self.time + rr.ttl; + end + local remember = self.offset; + local rr_parser = self[dns.type[rr.type]]; + if rr_parser then rr_parser(self, rr); end + self.offset = remember; + rr.rdata = self:sub(rr.rdlength); + return rr; +end -function resolver:addnameserver (address) -- - - - - - - - - - addnameserver - self.server = self.server or {} - append (self.server, address) - end +function resolver:rrs (count) -- - - - - - - - - - - - - - - - - - - - - rrs + local rrs = {}; + for i = 1,count do append(rrs, self:rr()); end + return rrs; +end -function resolver:setnameserver (address) -- - - - - - - - - - setnameserver - self.server = {} - self:addnameserver (address) - end +function resolver:decode(packet, force) -- - - - - - - - - - - - - - decode + self.packet, self.offset = packet, 1; + local header = self:header(force); + if not header then return nil; end + local response = { header = header }; -function resolver:adddefaultnameservers () -- - - - - adddefaultnameservers - local resolv_conf = io.open("/etc/resolv.conf"); - if resolv_conf then - for line in resolv_conf:lines() do - local address = string.match (line, 'nameserver%s+(%d+%.%d+%.%d+%.%d+)') - if address then self:addnameserver (address) end - end - else -- FIXME correct for windows, using opendns nameservers for now - self:addnameserver ("208.67.222.222") - self:addnameserver ("208.67.220.220") - end + response.question = {}; + local offset = self.offset; + for i = 1,response.header.qdcount do + append(response.question, self:question()); + end + response.question.raw = string.sub(self.packet, offset, self.offset - 1); + + if not force then + if not self.active[response.header.id] or not self.active[response.header.id][response.question.raw] then + return nil; + end + end + + response.answer = self:rrs(response.header.ancount); + response.authority = self:rrs(response.header.nscount); + response.additional = self:rrs(response.header.arcount); + + return response; end -function resolver:getsocket (servernum) -- - - - - - - - - - - - - getsocket +-- socket layer -------------------------------------------------- socket layer - self.socket = self.socket or {} - self.socketset = self.socketset or {} - local sock = self.socket[servernum] - if sock then return sock end +resolver.delays = { 1, 3 }; - sock = socket.udp () - if self.socket_wrapper then sock = self.socket_wrapper (sock) end - sock:settimeout (0) - -- todo: attempt to use a random port, fallback to 0 - sock:setsockname ('*', 0) - sock:setpeername (self.server[servernum], 53) - self.socket[servernum] = sock - self.socketset[sock] = sock - return sock - end +function resolver:addnameserver(address) -- - - - - - - - - - addnameserver + self.server = self.server or {}; + append(self.server, address); +end -function resolver:socket_wrapper_set (func) -- - - - - - - socket_wrapper_set - self.socket_wrapper = func - end +function resolver:setnameserver(address) -- - - - - - - - - - setnameserver + self.server = {}; + self:addnameserver(address); +end -function resolver:closeall () -- - - - - - - - - - - - - - - - - - closeall - for i,sock in ipairs (self.socket) do self.socket[i]:close () end - self.socket = {} - end +function resolver:adddefaultnameservers() -- - - - - adddefaultnameservers + if is_windows then + if windows then + for _, server in ipairs(windows.get_nameservers()) do + self:addnameserver(server); + end + end + if not self.server or #self.server == 0 then + -- TODO log warning about no nameservers, adding opendns servers as fallback + self:addnameserver("208.67.222.222"); + self:addnameserver("208.67.220.220") ; + end + else -- posix + local resolv_conf = io.open("/etc/resolv.conf"); + if resolv_conf then + for line in resolv_conf:lines() do + local address = line:gsub("#.*$", ""):match('^%s*nameserver%s+(%d+%.%d+%.%d+%.%d+)%s*$'); + if address then self:addnameserver(address) end + end + end + if not self.server or #self.server == 0 then + -- TODO log warning about no nameservers, adding localhost as the default nameserver + self:addnameserver("127.0.0.1"); + end + end +end -function resolver:remember (rr, type) -- - - - - - - - - - - - - - remember - --print ('remember', type, rr.class, rr.type, rr.name) +function resolver:getsocket(servernum) -- - - - - - - - - - - - - getsocket + self.socket = self.socket or {}; + self.socketset = self.socketset or {}; - if type ~= '*' then - type = rr.type - local all = get (self.cache, rr.class, '*', rr.name) - --print ('remember all', all) - if all then append (all, rr) end - end + local sock = self.socket[servernum]; + if sock then return sock; end - self.cache = self.cache or setmetatable ({}, cache_metatable) - local rrs = get (self.cache, rr.class, type, rr.name) or - set (self.cache, rr.class, type, rr.name, setmetatable ({}, rrs_metatable)) - append (rrs, rr) + sock = socket.udp(); + if self.socket_wrapper then sock = self.socket_wrapper(sock, self); end + sock:settimeout(0); + -- todo: attempt to use a random port, fallback to 0 + sock:setsockname('*', 0); + sock:setpeername(self.server[servernum], 53); + self.socket[servernum] = sock; + self.socketset[sock] = servernum; + return sock; +end - if type == 'MX' then self.unsorted[rrs] = true end - end +function resolver:voidsocket(sock) + if self.socket[sock] then + self.socketset[self.socket[sock]] = nil; + self.socket[sock] = nil; + elseif self.socketset[sock] then + self.socket[self.socketset[sock]] = nil; + self.socketset[sock] = nil; + end +end +function resolver:socket_wrapper_set(func) -- - - - - - - socket_wrapper_set + self.socket_wrapper = func; +end -local function comp_mx (a, b) -- - - - - - - - - - - - - - - - - - - comp_mx - return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref) - end + +function resolver:closeall () -- - - - - - - - - - - - - - - - - - closeall + for i,sock in ipairs(self.socket) do + self.socket[i] = nil; + self.socketset[sock] = nil; + sock:close(); + end +end + + +function resolver:remember(rr, type) -- - - - - - - - - - - - - - remember + --print ('remember', type, rr.class, rr.type, rr.name) + + if type ~= '*' then + type = rr.type; + local all = get(self.cache, rr.class, '*', rr.name); + --print('remember all', all); + if all then append(all, rr); end + end + + self.cache = self.cache or setmetatable({}, cache_metatable); + local rrs = get(self.cache, rr.class, type, rr.name) or + set(self.cache, rr.class, type, rr.name, setmetatable({}, rrs_metatable)); + append(rrs, rr); + + if type == 'MX' then self.unsorted[rrs] = true; end +end + + +local function comp_mx(a, b) -- - - - - - - - - - - - - - - - - - - comp_mx + return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref); +end function resolver:peek (qname, qtype, qclass) -- - - - - - - - - - - - peek - qname, qtype, qclass = standardize (qname, qtype, qclass) - local rrs = get (self.cache, qclass, qtype, qname) - if not rrs then return nil end - if prune (rrs, socket.gettime ()) and qtype == '*' or not next (rrs) then - set (self.cache, qclass, qtype, qname, nil) return nil end - if self.unsorted[rrs] then table.sort (rrs, comp_mx) end - return rrs - end - - -function resolver:purge (soft) -- - - - - - - - - - - - - - - - - - - purge - if soft == 'soft' then - self.time = socket.gettime () - for class,types in pairs (self.cache or {}) do - for type,names in pairs (types) do - for name,rrs in pairs (names) do - prune (rrs, self.time, 'soft') - end end end - else self.cache = {} end - end - - -function resolver:query (qname, qtype, qclass) -- - - - - - - - - - -- query - - qname, qtype, qclass = standardize (qname, qtype, qclass) - - if not self.server then self:adddefaultnameservers () end - - local question = encodeQuestion (qname, qtype, qclass) - local peek = self:peek (qname, qtype, qclass) - if peek then return peek end - - local header, id = encodeHeader () - --print ('query id', id, qclass, qtype, qname) - local o = { packet = header..question, - server = 1, - delay = 1, - retry = socket.gettime () + self.delays[1] } - self:getsocket (o.server):send (o.packet) + qname, qtype, qclass = standardize(qname, qtype, qclass); + local rrs = get(self.cache, qclass, qtype, qname); + if not rrs then return nil; end + if prune(rrs, socket.gettime()) and qtype == '*' or not next(rrs) then + set(self.cache, qclass, qtype, qname, nil); + return nil; + end + if self.unsorted[rrs] then table.sort (rrs, comp_mx); end + return rrs; +end - -- remember the query - self.active[id] = self.active[id] or {} - self.active[id][question] = o - -- remember which coroutine wants the answer - local co = coroutine.running () - if co then - set (self.wanted, qclass, qtype, qname, co, true) - --set (self.yielded, co, qclass, qtype, qname, true) - end +function resolver:purge(soft) -- - - - - - - - - - - - - - - - - - - purge + if soft == 'soft' then + self.time = socket.gettime(); + for class,types in pairs(self.cache or {}) do + for type,names in pairs(types) do + for name,rrs in pairs(names) do + prune(rrs, self.time, 'soft') + end + end + end + else self.cache = {}; end end +function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query + qname, qtype, qclass = standardize(qname, qtype, qclass) -function resolver:receive (rset) -- - - - - - - - - - - - - - - - - receive + if not self.server then self:adddefaultnameservers(); end - --print 'receive' print (self.socket) - self.time = socket.gettime () - rset = rset or self.socket + local question = encodeQuestion(qname, qtype, qclass); + local peek = self:peek (qname, qtype, qclass); + if peek then return peek; end - local response - for i,sock in pairs (rset) do + local header, id = encodeHeader(); + --print ('query id', id, qclass, qtype, qname) + local o = { + packet = header..question, + server = self.best_server, + delay = 1, + retry = socket.gettime() + self.delays[1] + }; - if self.socketset[sock] then - local packet = sock:receive () - if packet then + -- remember the query + self.active[id] = self.active[id] or {}; + self.active[id][question] = o; - response = self:decode (packet) - if response then - --print 'received response' - --self.print (response) + -- remember which coroutine wants the answer + local co = coroutine.running(); + if co then + set(self.wanted, qclass, qtype, qname, co, true); + --set(self.yielded, co, qclass, qtype, qname, true); + end - for i,section in pairs { 'answer', 'authority', 'additional' } do - for j,rr in pairs (response[section]) do - self:remember (rr, response.question[1].type) end end + self:getsocket (o.server):send (o.packet) +end - -- retire the query - local queries = self.active[response.header.id] - if queries[response.question.raw] then - queries[response.question.raw] = nil end - if not next (queries) then self.active[response.header.id] = nil end - if not next (self.active) then self:closeall () end +function resolver:servfail(sock) + -- Resend all queries for this server + + local num = self.socketset[sock] + + -- Socket is dead now + self:voidsocket(sock); + + -- Find all requests to the down server, and retry on the next server + self.time = socket.gettime(); + for id,queries in pairs(self.active) do + for question,o in pairs(queries) do + if o.server == num then -- This request was to the broken server + o.server = o.server + 1 -- Use next server + if o.server > #self.server then + o.server = 1; + end + + o.retries = (o.retries or 0) + 1; + if o.retries >= #self.server then + --print('timeout'); + queries[question] = nil; + else + local _a = self:getsocket(o.server); + if _a then _a:send(o.packet); end + end + end + end + end + + if num == self.best_server then + self.best_server = self.best_server + 1; + if self.best_server > #self.server then + -- Exhausted all servers, try first again + self.best_server = 1; + end + end +end - -- was the query on the wanted list? - local q = response.question - local cos = get (self.wanted, q.class, q.type, q.name) - if cos then - for co in pairs (cos) do - set (self.yielded, co, q.class, q.type, q.name, nil) - if coroutine.status(co) == "suspended" then coroutine.resume (co) end - end - set (self.wanted, q.class, q.type, q.name, nil) - end end end end end +function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive + --print('receive'); print(self.socket); + self.time = socket.gettime(); + rset = rset or self.socket; + + local response; + for i,sock in pairs(rset) do + + if self.socketset[sock] then + local packet = sock:receive(); + if packet then + response = self:decode(packet); + if response then + --print('received response'); + --self.print(response); + + for i,section in pairs({ 'answer', 'authority', 'additional' }) do + for j,rr in pairs(response[section]) do + self:remember(rr, response.question[1].type) + end + end + + -- retire the query + local queries = self.active[response.header.id]; + if queries[response.question.raw] then + queries[response.question.raw] = nil; + end + if not next(queries) then self.active[response.header.id] = nil; end + if not next(self.active) then self:closeall(); end + + -- was the query on the wanted list? + local q = response.question; + local cos = get(self.wanted, q.class, q.type, q.name); + if cos then + for co in pairs(cos) do + set(self.yielded, co, q.class, q.type, q.name, nil); + if coroutine.status(co) == "suspended" then coroutine.resume(co); end + end + set(self.wanted, q.class, q.type, q.name, nil); + end + end + end + end + end - return response - end + return response; +end function resolver:feed(sock, packet) - --print 'receive' print (self.socket) - self.time = socket.gettime () - - local response = self:decode (packet) - if response then - --print 'received response' - --self.print (response) - - for i,section in pairs { 'answer', 'authority', 'additional' } do - for j,rr in pairs (response[section]) do - self:remember (rr, response.question[1].type) - end - end - - -- retire the query - local queries = self.active[response.header.id] - if queries[response.question.raw] then - queries[response.question.raw] = nil - end - if not next (queries) then self.active[response.header.id] = nil end - if not next (self.active) then self:closeall () end - - -- was the query on the wanted list? - local q = response.question[1] - if q then - local cos = get (self.wanted, q.class, q.type, q.name) - if cos then - for co in pairs (cos) do - set (self.yielded, co, q.class, q.type, q.name, nil) - if coroutine.status(co) == "suspended" then coroutine.resume (co) end - end - set (self.wanted, q.class, q.type, q.name, nil) - end - end - end - - return response + --print('receive'); print(self.socket); + self.time = socket.gettime(); + + local response = self:decode(packet); + if response then + --print('received response'); + --self.print(response); + + for i,section in pairs({ 'answer', 'authority', 'additional' }) do + for j,rr in pairs(response[section]) do + self:remember(rr, response.question[1].type); + end + end + + -- retire the query + local queries = self.active[response.header.id]; + if queries[response.question.raw] then + queries[response.question.raw] = nil; + end + if not next(queries) then self.active[response.header.id] = nil; end + if not next(self.active) then self:closeall(); end + + -- was the query on the wanted list? + local q = response.question[1]; + if q then + local cos = get(self.wanted, q.class, q.type, q.name); + if cos then + for co in pairs(cos) do + set(self.yielded, co, q.class, q.type, q.name, nil); + if coroutine.status(co) == "suspended" then coroutine.resume(co); end + end + set(self.wanted, q.class, q.type, q.name, nil); + end + end + end + + return response; end function resolver:cancel(data) - local cos = get (self.wanted, unpack(data, 1, 3)) + local cos = get(self.wanted, unpack(data, 1, 3)); if cos then cos[data[4]] = nil; end end -function resolver:pulse () -- - - - - - - - - - - - - - - - - - - - - pulse +function resolver:pulse() -- - - - - - - - - - - - - - - - - - - - - pulse + --print(':pulse'); + while self:receive() do end + if not next(self.active) then return nil; end + + self.time = socket.gettime(); + for id,queries in pairs(self.active) do + for question,o in pairs(queries) do + if self.time >= o.retry then + + o.server = o.server + 1; + if o.server > #self.server then + o.server = 1; + o.delay = o.delay + 1; + end + + if o.delay > #self.delays then + --print('timeout'); + queries[question] = nil; + if not next(queries) then self.active[id] = nil; end + if not next(self.active) then return nil; end + else + --print('retry', o.server, o.delay); + local _a = self.socket[o.server]; + if _a then _a:send(o.packet); end + o.retry = self.time + self.delays[o.delay]; + end + end + end + end - --print ':pulse' - while self:receive() do end - if not next (self.active) then return nil end + if next(self.active) then return true; end + return nil; +end - self.time = socket.gettime () - for id,queries in pairs (self.active) do - for question,o in pairs (queries) do - if self.time >= o.retry then - o.server = o.server + 1 - if o.server > #self.server then - o.server = 1 - o.delay = o.delay + 1 - end +function resolver:lookup(qname, qtype, qclass) -- - - - - - - - - - lookup + self:query (qname, qtype, qclass) + while self:pulse() do socket.select(self.socket, nil, 4); end + --print(self.cache); + return self:peek(qname, qtype, qclass); +end - if o.delay > #self.delays then - --print ('timeout') - queries[question] = nil - if not next (queries) then self.active[id] = nil end - if not next (self.active) then return nil end - else - --print ('retry', o.server, o.delay) - local _a = self.socket[o.server]; - if _a then _a:send (o.packet) end - o.retry = self.time + self.delays[o.delay] - end end end end +function resolver:lookupex(handler, qname, qtype, qclass) -- - - - - - - - - - lookup + return self:peek(qname, qtype, qclass) or self:query(qname, qtype, qclass); +end - if next (self.active) then return true end - return nil - end +--print ---------------------------------------------------------------- print -function resolver:lookup (qname, qtype, qclass) -- - - - - - - - - - lookup - self:query (qname, qtype, qclass) - while self:pulse () do socket.select (self.socket, nil, 4) end - --print (self.cache) - return self:peek (qname, qtype, qclass) - end -function resolver:lookupex (handler, qname, qtype, qclass) -- - - - - - - - - - lookup - return self:peek (qname, qtype, qclass) or self:query (qname, qtype, qclass) - end +local hints = { -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints + qr = { [0]='query', 'response' }, + opcode = { [0]='query', 'inverse query', 'server status request' }, + aa = { [0]='non-authoritative', 'authoritative' }, + tc = { [0]='complete', 'truncated' }, + rd = { [0]='recursion not desired', 'recursion desired' }, + ra = { [0]='recursion not available', 'recursion available' }, + z = { [0]='(reserved)' }, + rcode = { [0]='no error', 'format error', 'server failure', 'name error', 'not implemented' }, + + type = dns.type, + class = dns.class +}; + + +local function hint(p, s) -- - - - - - - - - - - - - - - - - - - - - - hint + return (hints[s] and hints[s][p[s]]) or ''; +end ---print ---------------------------------------------------------------- print +function resolver.print(response) -- - - - - - - - - - - - - resolver.print + for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z', + 'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do + print( string.format('%-30s', 'header.'..s), response.header[s], hint(response.header, s) ); + end + for i,question in ipairs(response.question) do + print(string.format ('question[%i].name ', i), question.name); + print(string.format ('question[%i].type ', i), question.type); + print(string.format ('question[%i].class ', i), question.class); + end -local hints = { -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints - qr = { [0]='query', 'response' }, - opcode = { [0]='query', 'inverse query', 'server status request' }, - aa = { [0]='non-authoritative', 'authoritative' }, - tc = { [0]='complete', 'truncated' }, - rd = { [0]='recursion not desired', 'recursion desired' }, - ra = { [0]='recursion not available', 'recursion available' }, - z = { [0]='(reserved)' }, - rcode = { [0]='no error', 'format error', 'server failure', 'name error', - 'not implemented' }, - - type = dns.type, - class = dns.class, } - - -local function hint (p, s) -- - - - - - - - - - - - - - - - - - - - - - hint - return (hints[s] and hints[s][p[s]]) or '' end - - -function resolver.print (response) -- - - - - - - - - - - - - resolver.print - - for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z', - 'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do - print ( string.format ('%-30s', 'header.'..s), - response.header[s], hint (response.header, s) ) - end - - for i,question in ipairs (response.question) do - print (string.format ('question[%i].name ', i), question.name) - print (string.format ('question[%i].type ', i), question.type) - print (string.format ('question[%i].class ', i), question.class) - end - - local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 } - local tmp - for s,s in pairs {'answer', 'authority', 'additional'} do - for i,rr in pairs (response[s]) do - for j,t in pairs { 'name', 'type', 'class', 'ttl', 'rdlength' } do - tmp = string.format ('%s[%i].%s', s, i, t) - print (string.format ('%-30s', tmp), rr[t], hint (rr, t)) - end - for j,t in pairs (rr) do - if not common[j] then - tmp = string.format ('%s[%i].%s', s, i, j) - print (string.format ('%-30s %s', tostring(tmp), tostring(t))) - end end end end end + local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 }; + local tmp; + for s,s in pairs({'answer', 'authority', 'additional'}) do + for i,rr in pairs(response[s]) do + for j,t in pairs({ 'name', 'type', 'class', 'ttl', 'rdlength' }) do + tmp = string.format('%s[%i].%s', s, i, t); + print(string.format('%-30s', tmp), rr[t], hint(rr, t)); + end + for j,t in pairs(rr) do + if not common[j] then + tmp = string.format('%s[%i].%s', s, i, j); + print(string.format('%-30s %s', tostring(tmp), tostring(t))); + end + end + end + end +end -- module api ------------------------------------------------------ module api -local function resolve (func, ...) -- - - - - - - - - - - - - - resolver_get - dns._resolver = dns._resolver or dns.resolver () - return func (dns._resolver, ...) - end +local function resolve(func, ...) -- - - - - - - - - - - - - - resolver_get + return func(dns._resolver, ...); +end function dns.resolver () -- - - - - - - - - - - - - - - - - - - - - resolver + -- this function seems to be redundant with resolver.new () - -- this function seems to be redundant with resolver.new () - - local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {} } - setmetatable (r, resolver) - setmetatable (r.cache, cache_metatable) - setmetatable (r.unsorted, { __mode = 'kv' }) - return r - end + local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {}, best_server = 1 }; + setmetatable (r, resolver); + setmetatable (r.cache, cache_metatable); + setmetatable (r.unsorted, { __mode = 'kv' }); + return r; +end -function dns.lookup (...) -- - - - - - - - - - - - - - - - - - - - - lookup - return resolve (resolver.lookup, ...) end +function dns.lookup(...) -- - - - - - - - - - - - - - - - - - - - - lookup + return resolve(resolver.lookup, ...); +end -function dns.purge (...) -- - - - - - - - - - - - - - - - - - - - - - purge - return resolve (resolver.purge, ...) end +function dns.purge(...) -- - - - - - - - - - - - - - - - - - - - - - purge + return resolve(resolver.purge, ...); +end -function dns.peek (...) -- - - - - - - - - - - - - - - - - - - - - - - peek - return resolve (resolver.peek, ...) end +function dns.peek(...) -- - - - - - - - - - - - - - - - - - - - - - - peek + return resolve(resolver.peek, ...); +end -function dns.query (...) -- - - - - - - - - - - - - - - - - - - - - - query - return resolve (resolver.query, ...) end +function dns.query(...) -- - - - - - - - - - - - - - - - - - - - - - query + return resolve(resolver.query, ...); +end -function dns.feed (...) -- - - - - - - - - - - - - - - - - - - - - - feed - return resolve (resolver.feed, ...) end +function dns.feed(...) -- - - - - - - - - - - - - - - - - - - - - - feed + return resolve(resolver.feed, ...); +end function dns.cancel(...) -- - - - - - - - - - - - - - - - - - - - - - cancel - return resolve(resolver.cancel, ...) end + return resolve(resolver.cancel, ...); +end -function dns:socket_wrapper_set (...) -- - - - - - - - - socket_wrapper_set - return resolve (resolver.socket_wrapper_set, ...) end +function dns:socket_wrapper_set(...) -- - - - - - - - - socket_wrapper_set + return resolve(resolver.socket_wrapper_set, ...); +end +dns._resolver = dns.resolver(); -return dns +return dns; diff --git a/net/httpserver.lua b/net/httpserver.lua index 57c8eede..ddb4475c 100644 --- a/net/httpserver.lua +++ b/net/httpserver.lua @@ -61,7 +61,7 @@ local function send_response(request, response) end else -- Response we have is just a string (the body) - log("debug", "Sending response to %s: %s", request.id or "<none>", response or "<none>"); + log("debug", "Sending 200 response to %s", request.id or "<none>"); resp = { "HTTP/1.0 200 OK\r\n" }; t_insert(resp, "Connection: close\r\n"); @@ -89,9 +89,6 @@ local function call_callback(request, err) end callback = (request.server and request.server.handlers[base]) or default_handler; - if callback == default_handler then - log("debug", "Default callback for this request (base: "..tostring(base)..")") - end end if callback then if err then @@ -233,7 +230,7 @@ function destroy_request(request) end request.handler.close() if request.conn then - listener.disconnect(request.conn, "closed"); + listener.disconnect(request.handler, "closed"); end end end @@ -251,13 +248,27 @@ function new(params) end end -function new_from_config(ports, default_base, handle_request) +function set_default_handler(handler) + default_handler = handler; +end + +function new_from_config(ports, handle_request, default_options) + if type(handle_request) == "string" then -- COMPAT with old plugins + log("warn", "Old syntax of httpserver.new_from_config being used to register %s", handle_request); + handle_request, default_options = default_options, { base = handle_request }; + end for _, options in ipairs(ports) do - local port, base, ssl, interface = 5280, default_base, false, nil; + local port = default_options.port or 5280; + local base = default_options.base; + local ssl = default_options.ssl or false; + local interface = default_options.interface; if type(options) == "number" then port = options; elseif type(options) == "table" then - port, base, ssl, interface = options.port or 5280, options.path or default_base, options.ssl or false, options.interface; + port = options.port or port; + base = options.path or base; + ssl = options.ssl or ssl; + interface = options.interface or interface; elseif type(options) == "string" then base = options; end @@ -267,7 +278,9 @@ function new_from_config(ports, default_base, handle_request) ssl.protocol = "sslv23"; end - new{ port = port, base = base, handler = handle_request, ssl = ssl, type = (ssl and "ssl") or "tcp" } + new{ port = port, interface = interface, + base = base, handler = handle_request, + ssl = ssl, type = (ssl and "ssl") or "tcp" }; end end diff --git a/net/server.lua b/net/server.lua index 966006c1..6ab8ce91 100644 --- a/net/server.lua +++ b/net/server.lua @@ -157,6 +157,7 @@ _cleanqueue = false -- clean bufferqueue after using _maxclientsperserver = 1000
+_maxsslhandshake = 30 -- max handshake round-trips
----------------------------------// PRIVATE //--
wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxconnections, startssl ) -- this function wraps a server
@@ -230,6 +231,9 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco handler.ssl = function( )
return ssl
end
+ handler.sslctx = function( )
+ return sslctx
+ end
handler.remove = function( )
connections = connections - 1
end
@@ -246,7 +250,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco _socketlist[ socket ] = nil
handler = nil
socket = nil
- mem_free( )
+ --mem_free( )
out_put "server.lua: closed server handler and removed sockets from list"
end
handler.ip = function( )
@@ -297,6 +301,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport local ssl
local dispatch = listeners.incoming or listeners.listener
+ local status = listeners.status
local disconnect = listeners.disconnect
local bufferqueue = { } -- buffer array
@@ -336,6 +341,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.ssl = function( )
return ssl
end
+ handler.sslctx = function ( )
+ return sslctx
+ end
handler.send = function( _, data, i, j )
return send( socket, data, i, j )
end
@@ -363,17 +371,20 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport send( socket, table_concat( bufferqueue, "", 1, bufferqueuelen ), 1, bufferlen ) -- forced send
end
end
- _ = shutdown and shutdown( socket )
- socket:close( )
- _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
- _socketlist[ socket ] = nil
+ if socket then
+ _ = shutdown and shutdown( socket )
+ socket:close( )
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
+ _socketlist[ socket ] = nil
+ socket = nil
+ else
+ out_put "server.lua: socket already closed"
+ end
if handler then
_writetimes[ handler ] = nil
_closelist[ handler ] = nil
handler = nil
end
- socket = nil
- mem_free( )
if server then
server.remove( )
end
@@ -396,9 +407,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.write = idfalse -- dont write anymore
return false
elseif socket and not _sendlist[ socket ] then
- _sendlistlen = _sendlistlen + 1
- _sendlist[ _sendlistlen ] = socket
- _sendlist[ socket ] = _sendlistlen
+ _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
end
bufferqueuelen = bufferqueuelen + 1
bufferqueue[ bufferqueuelen ] = data
@@ -446,9 +455,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.write = write
if noread then
noread = false
- _readlistlen = _readlistlen + 1
- _readlist[ socket ] = _readlistlen
- _readlist[ _readlistlen ] = socket
+ _readlistlen = addsocket(_readlist, socket, _readlistlen)
_readtimes[ handler ] = _currenttime
end
if nosend then
@@ -472,10 +479,10 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport readtraffic = readtraffic + count
_readtraffic = _readtraffic + count
_readtimes[ handler ] = _currenttime
- --out_put( "server.lua: read data '", buffer, "', error: ", err )
+ --out_put( "server.lua: read data '", buffer:gsub("[^%w%p ]", "."), "', error: ", err )
return dispatch( handler, buffer, err )
else -- connections was closed or fatal error
- out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " error: ", tostring(err) )
+ out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
fatalerror = true
disconnect( handler, err )
_ = handler and handler.close( )
@@ -483,13 +490,19 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end
end
local _sendbuffer = function( ) -- this function sends data
- local buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )
- local succ, err, byte = send( socket, buffer, 1, bufferlen )
- local count = ( succ or byte or 0 ) * STAT_UNIT
- sendtraffic = sendtraffic + count
- _sendtraffic = _sendtraffic + count
- _ = _cleanqueue and clean( bufferqueue )
- --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) )
+ local succ, err, byte, buffer, count;
+ local count;
+ if socket then
+ buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )
+ succ, err, byte = send( socket, buffer, 1, bufferlen )
+ count = ( succ or byte or 0 ) * STAT_UNIT
+ sendtraffic = sendtraffic + count
+ _sendtraffic = _sendtraffic + count
+ _ = _cleanqueue and clean( bufferqueue )
+ --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) )
+ else
+ succ, err, count = false, "closed", 0;
+ end
if succ then -- sending succesful
bufferqueuelen = 0
bufferlen = 0
@@ -506,7 +519,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport _writetimes[ handler ] = _currenttime
return true
else -- connection was closed during sending or fatal error
- out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " error: ", tostring(err) )
+ out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) )
fatalerror = true
disconnect( handler, err )
_ = handler and handler.close( )
@@ -514,38 +527,40 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end
end
- if sslctx then -- ssl?
+ -- Set the sslctx
+ local handshake;
+ function handler.set_sslctx(new_sslctx)
ssl = true
+ sslctx = new_sslctx;
local wrote
local read
- local handshake = coroutine_wrap( function( client ) -- create handshake coroutine
+ handshake = coroutine_wrap( function( client ) -- create handshake coroutine
local err
- for i = 1, 10 do -- 10 handshake attemps
- _sendlistlen = ( wrote and removesocket( _sendlist, socket, _sendlistlen ) ) or _sendlistlen
- _readlistlen = ( read and removesocket( _readlist, socket, _readlistlen ) ) or _readlistlen
+ for i = 1, _maxsslhandshake do
+ _sendlistlen = ( wrote and removesocket( _sendlist, client, _sendlistlen ) ) or _sendlistlen
+ _readlistlen = ( read and removesocket( _readlist, client, _readlistlen ) ) or _readlistlen
read, wrote = nil, nil
_, err = client:dohandshake( )
if not err then
out_put( "server.lua: ssl handshake done" )
handler.readbuffer = _readbuffer -- when handshake is done, replace the handshake function with regular functions
handler.sendbuffer = _sendbuffer
- -- return dispatch( handler )
+ _ = status and status( handler, "ssl-handshake-complete" )
+ _readlistlen = addsocket(_readlist, client, _readlistlen)
return true
else
- out_put( "server.lua: error during ssl handshake: ", tostring(err) )
- if err == "wantwrite" and not wrote then
- _sendlistlen = _sendlistlen + 1
- _sendlist[ _sendlistlen ] = client
- wrote = true
- elseif err == "wantread" and not read then
- _readlistlen = _readlistlen + 1
- _readlist [ _readlistlen ] = client
- read = true
- else
- break;
- end
- --coroutine_yield( handler, nil, err ) -- handshake not finished
- coroutine_yield( )
+ out_put( "server.lua: error during ssl handshake: ", tostring(err) )
+ if err == "wantwrite" and not wrote then
+ _sendlistlen = addsocket(_sendlist, client, _sendlistlen)
+ wrote = true
+ elseif err == "wantread" and not read then
+ _readlistlen = addsocket(_readlist, client, _readlistlen)
+ read = true
+ else
+ break;
+ end
+ --coroutine_yield( handler, nil, err ) -- handshake not finished
+ coroutine_yield( )
end
end
disconnect( handler, "ssl handshake failed" )
@@ -553,13 +568,16 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport return false -- handshake failed
end
)
+ end
+ if sslctx then -- ssl?
+ handler.set_sslctx(sslctx);
if startssl then -- ssl now?
--out_put("server.lua: ", "starting ssl handshake")
local err
socket, err = ssl_wrap( socket, sslctx ) -- wrap socket
if err then
out_put( "server.lua: ssl error: ", tostring(err) )
- mem_free( )
+ --mem_free( )
return nil, nil, err -- fatal error
end
socket:settimeout( 0 )
@@ -596,9 +614,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport shutdown = id
_socketlist[ socket ] = handler
- _readlistlen = _readlistlen + 1
- _readlist[ _readlistlen ] = socket
- _readlist[ socket ] = _readlistlen
+ _readlistlen = addsocket(_readlist, socket, _readlistlen)
-- remove traces of the old socket
@@ -630,9 +646,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport shutdown = ( ssl and id ) or socket.shutdown
_socketlist[ socket ] = handler
- _readlistlen = _readlistlen + 1
- _readlist[ _readlistlen ] = socket
- _readlist[ socket ] = _readlistlen
+ _readlistlen = addsocket(_readlist, socket, _readlistlen)
return handler, socket
end
@@ -644,6 +658,15 @@ idfalse = function( ) return false
end
+addsocket = function( list, socket, len )
+ if not list[ socket ] then
+ len = len + 1
+ list[ len ] = socket
+ list[ socket ] = len
+ end
+ return len;
+end
+
removesocket = function( list, socket, len ) -- this function removes sockets from a list ( copied from copas )
local pos = list[ socket ]
if pos then
@@ -664,7 +687,7 @@ closesocket = function( socket ) _readlistlen = removesocket( _readlist, socket, _readlistlen )
_socketlist[ socket ] = nil
socket:close( )
- mem_free( )
+ --mem_free( )
end
----------------------------------// PUBLIC //--
@@ -698,8 +721,7 @@ addserver = function( listeners, port, addr, pattern, sslctx, maxconnections, st return nil, err
end
server:settimeout( 0 )
- _readlistlen = _readlistlen + 1
- _readlist[ _readlistlen ] = server
+ _readlistlen = addsocket(_readlist, server, _readlistlen)
_server[ port ] = handler
_socketlist[ server ] = handler
out_put( "server.lua: new server listener on '", addr, ":", port, "'" )
@@ -713,7 +735,7 @@ end removeserver = function( port )
local handler = _server[ port ]
if not handler then
- return nil, "no server found on port '" .. tostring( port ) "'"
+ return nil, "no server found on port '" .. tostring( port ) .. "'"
end
handler.close( )
_server[ port ] = nil
@@ -733,11 +755,11 @@ closeall = function( ) _sendlist = { }
_timerlist = { }
_socketlist = { }
- mem_free( )
+ --mem_free( )
end
getsettings = function( )
- return _selecttimeout, _sleeptime, _maxsendlen, _maxreadlen, _checkinterval, _sendtimeout, _readtimeout, _cleanqueue, _maxclientsperserver
+ return _selecttimeout, _sleeptime, _maxsendlen, _maxreadlen, _checkinterval, _sendtimeout, _readtimeout, _cleanqueue, _maxclientsperserver, _maxsslhandshake
end
changesettings = function( new )
@@ -753,6 +775,7 @@ changesettings = function( new ) _readtimeout = tonumber( new.readtimeout ) or _readtimeout
_cleanqueue = new.cleanqueue
_maxclientsperserver = new._maxclientsperserver or _maxclientsperserver
+ _maxsslhandshake = new._maxsslhandshake or _maxsslhandshake
return true
end
@@ -805,7 +828,7 @@ loop = function( ) -- this is the main loop of the program _currenttime = os_time( )
if os_difftime( _currenttime - _timer ) >= 1 then
for i = 1, _timerlistlen do
- _timerlist[ i ]( ) -- fire timers
+ _timerlist[ i ]( _currenttime ) -- fire timers
end
_timer = _currenttime
end
@@ -820,9 +843,7 @@ end local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx, startssl )
local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, startssl )
_socketlist[ socket ] = handler
- _sendlistlen = _sendlistlen + 1
- _sendlist[ _sendlistlen ] = socket
- _sendlist[ socket ] = _sendlistlen
+ _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
return handler, socket
end
diff --git a/net/xmppclient_listener.lua b/net/xmppclient_listener.lua index dcc561f3..417dfd4a 100644 --- a/net/xmppclient_listener.lua +++ b/net/xmppclient_listener.lua @@ -27,7 +27,7 @@ local sm_streamopened = sessionmanager.streamopened; local sm_streamclosed = sessionmanager.streamclosed; local st = require "util.stanza"; -local stream_callbacks = { stream_tag = "http://etherx.jabber.org/streams|stream", +local stream_callbacks = { stream_tag = "http://etherx.jabber.org/streams\1stream", default_ns = "jabber:client", streamopened = sm_streamopened, streamclosed = sm_streamclosed, handlestanza = core_process_stanza }; @@ -53,7 +53,7 @@ local xmppclient = { default_port = 5222, default_mode = "*a" }; local function session_reset_stream(session) -- Reset stream - local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "|"); + local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "\1"); session.parser = parser; session.notopen = true; @@ -70,7 +70,7 @@ end local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; -local default_stream_attr = { ["xmlns:stream"] = stream_callbacks.stream_tag:gsub("%|[^|]+$", ""), xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; +local default_stream_attr = { ["xmlns:stream"] = stream_callbacks.stream_tag:match("[^\1]*"), xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; local function session_close(session, reason) local log = session.log or log; if session.conn then @@ -114,11 +114,6 @@ function xmppclient.listener(conn, data) session = sm_new_session(conn); sessions[conn] = session; - -- Logging functions -- - - local conn_name = "c2s"..tostring(conn):match("[a-f0-9]+$"); - session.log = logger.init(conn_name); - session.log("info", "Client connected"); -- Client is using legacy SSL (otherwise mod_tls sets this flag) diff --git a/net/xmppcomponent_listener.lua b/net/xmppcomponent_listener.lua index bee05967..c16f41a0 100644 --- a/net/xmppcomponent_listener.lua +++ b/net/xmppcomponent_listener.lua @@ -32,7 +32,7 @@ local xmlns_component = 'jabber:component:accept'; --- Callbacks/data for xmlhandlers to handle streams for us --- -local stream_callbacks = { stream_tag = "http://etherx.jabber.org/streams|stream", default_ns = xmlns_component }; +local stream_callbacks = { stream_tag = "http://etherx.jabber.org/streams\1stream", default_ns = xmlns_component }; function stream_callbacks.error(session, error, data, data2) log("warn", "Error processing component stream: "..tostring(error)); @@ -87,7 +87,7 @@ end --- Closing a component connection local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; -local default_stream_attr = { ["xmlns:stream"] = stream_callbacks.stream_tag:gsub("%|[^|]+$", ""), xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; +local default_stream_attr = { ["xmlns:stream"] = stream_callbacks.stream_tag:match("[^\1]*"), xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; local function session_close(session, reason) local log = session.log or log; if session.conn then @@ -138,7 +138,7 @@ function component_listener.listener(conn, data) session.log("info", "Incoming Jabber component connection"); - local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "|"); + local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "\1"); session.parser = parser; session.notopen = true; diff --git a/net/xmppserver_listener.lua b/net/xmppserver_listener.lua index 1f27d841..c7e02ec5 100644 --- a/net/xmppserver_listener.lua +++ b/net/xmppserver_listener.lua @@ -17,7 +17,7 @@ local s2s_streamopened = require "core.s2smanager".streamopened; local s2s_streamclosed = require "core.s2smanager".streamclosed; local s2s_destroy_session = require "core.s2smanager".destroy_session; local s2s_attempt_connect = require "core.s2smanager".attempt_connection; -local stream_callbacks = { stream_tag = "http://etherx.jabber.org/streams|stream", +local stream_callbacks = { stream_tag = "http://etherx.jabber.org/streams\1stream", default_ns = "jabber:server", streamopened = s2s_streamopened, streamclosed = s2s_streamclosed, handlestanza = core_process_stanza }; @@ -53,7 +53,7 @@ local xmppserver = { default_port = 5269, default_mode = "*a" }; local function session_reset_stream(session) -- Reset stream - local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "|"); + local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "\1"); session.parser = parser; session.notopen = true; @@ -61,16 +61,16 @@ local function session_reset_stream(session) function session.data(conn, data) local ok, err = parser:parse(data); if ok then return; end - log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " ")); + (session.log or log)("warn", "Received invalid XML: %s", data); + (session.log or log)("warn", "Problem was: %s", err); session:close("xml-not-well-formed"); end return true; end - local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; -local default_stream_attr = { ["xmlns:stream"] = stream_callbacks.stream_tag:gsub("%|[^|]+$", ""), xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; +local default_stream_attr = { ["xmlns:stream"] = stream_callbacks.stream_tag:match("[^\1]*"), xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; local function session_close(session, reason) local log = session.log or log; if session.conn then @@ -100,6 +100,9 @@ local function session_close(session, reason) end end session.sends2s("</stream:stream>"); + if session.notopen or not session.conn.close() then + session.conn.close(true); -- Force FIXME: timer? + end session.conn.close(); xmppserver.disconnect(session.conn, "stream error"); end @@ -134,6 +137,17 @@ function xmppserver.listener(conn, data) end end +function xmppserver.status(conn, status) + if status == "ssl-handshake-complete" then + local session = sessions[conn]; + if session and session.direction == "outgoing" then + local format, to_host, from_host = string.format, session.to_host, session.from_host; + session.log("debug", "Sending stream header..."); + session.sends2s(format([[<stream:stream xmlns='jabber:server' xmlns:db='jabber:server:dialback' xmlns:stream='http://etherx.jabber.org/streams' from='%s' to='%s' version='1.0'>]], from_host, to_host)); + end + end +end + function xmppserver.disconnect(conn, err) local session = sessions[conn]; if session then |