diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/adns.lua | 78 | ||||
-rw-r--r-- | net/connlisteners.lua | 68 | ||||
-rw-r--r-- | net/dns.lua | 1574 | ||||
-rw-r--r-- | net/http.lua | 246 | ||||
-rw-r--r-- | net/http/codes.lua | 67 | ||||
-rw-r--r-- | net/http/parser.lua | 160 | ||||
-rw-r--r-- | net/http/server.lua | 303 | ||||
-rw-r--r-- | net/httpclient_listener.lua | 44 | ||||
-rw-r--r-- | net/httpserver.lua | 279 | ||||
-rw-r--r-- | net/httpserver_listener.lua | 46 | ||||
-rw-r--r-- | net/server.lua | 977 | ||||
-rw-r--r-- | net/server_event.lua | 872 | ||||
-rw-r--r-- | net/server_select.lua | 984 | ||||
-rw-r--r-- | net/xmppclient_listener.lua | 152 | ||||
-rw-r--r-- | net/xmppcomponent_listener.lua | 176 | ||||
-rw-r--r-- | net/xmppserver_listener.lua | 174 |
16 files changed, 3549 insertions, 2651 deletions
diff --git a/net/adns.lua b/net/adns.lua index 34ef5d77..cd69a627 100644 --- a/net/adns.lua +++ b/net/adns.lua @@ -1,6 +1,6 @@ -- Prosody IM --- Copyright (C) 2008-2009 Matthew Wild --- Copyright (C) 2008-2009 Waqas Hussain +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain -- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. @@ -11,8 +11,11 @@ local dns = require "net.dns"; local log = require "util.logger".init("adns"); +local t_insert, t_remove = table.insert, table.remove; local coroutine, tostring, pcall = coroutine, tostring, pcall; +local function dummy_send(sock, data, i, j) return (j-i)+1; end + module "adns" function lookup(handler, qname, qtype, qclass) @@ -23,41 +26,66 @@ function lookup(handler, qname, qtype, qclass) return; end log("debug", "Records for %s not in cache, sending query (%s)...", qname, tostring(coroutine.running())); - dns.query(qname, qtype, qclass); - coroutine.yield({ qclass or "IN", qtype or "A", qname, coroutine.running()}); -- Wait for reply - log("debug", "Reply for %s (%s)", qname, tostring(coroutine.running())); - local ok, err = pcall(handler, dns.peek(qname, qtype, qclass)); + local ok, err = dns.query(qname, qtype, qclass); + if ok then + coroutine.yield({ qclass or "IN", qtype or "A", qname, coroutine.running()}); -- Wait for reply + log("debug", "Reply for %s (%s)", qname, tostring(coroutine.running())); + end + if ok then + ok, err = pcall(handler, dns.peek(qname, qtype, qclass)); + else + log("error", "Error sending DNS query: %s", err); + ok, err = pcall(handler, nil, err); + end if not ok then - log("debug", "Error in DNS response handler: %s", tostring(err)); + log("error", "Error in DNS response handler: %s", tostring(err)); end end)(dns.peek(qname, qtype, qclass)); end -function cancel(handle, call_handler) +function cancel(handle, call_handler, reason) log("warn", "Cancelling DNS lookup for %s", tostring(handle[3])); - dns.cancel(handle); - if call_handler then - coroutine.resume(handle[4]); - end + dns.cancel(handle[1], handle[2], handle[3], handle[4], call_handler); end -function new_async_socket(sock) - local newconn = {}; +function new_async_socket(sock, resolver) + local peername = "<unknown>"; local listener = {}; - function listener.incoming(conn, data) - dns.feed(sock, data); + local handler = {}; + function listener.onincoming(conn, data) + if data then + dns.feed(handler, data); + end + end + function listener.ondisconnect(conn, err) + if err then + log("warn", "DNS socket for %s disconnected: %s", peername, err); + local servers = resolver.server; + if resolver.socketset[conn] == resolver.best_server and resolver.best_server == #servers then + log("error", "Exhausted all %d configured DNS servers, next lookup will try %s again", #servers, servers[1]); + end + + resolver:servfail(conn); -- Let the magic commence + end + end + handler = server.wrapclient(sock, "dns", 53, listener); + if not handler then + log("warn", "handler is nil"); end - function listener.disconnect() + + handler.settimeout = function () end + handler.setsockname = function (_, ...) return sock:setsockname(...); end + handler.setpeername = function (_, ...) peername = (...); local ret = sock:setpeername(...); _:set_send(dummy_send); return ret; end + handler.connect = function (_, ...) return sock:connect(...) end + --handler.send = function (_, data) _:write(data); return _.sendbuffer and _.sendbuffer(); end + handler.send = function (_, data) + local getpeername = sock.getpeername; + log("debug", "Sending DNS query to %s", (getpeername and getpeername(sock)) or "<unconnected>"); + return sock:send(data); end - newconn.handler, newconn._socket = server.wrapclient(sock, "dns", 53, listener); - newconn.handler.settimeout = function () end - newconn.handler.setsockname = function (_, ...) return sock:setsockname(...); end - newconn.handler.setpeername = function (_, ...) local ret = sock:setpeername(...); _.setsend(sock.send); return ret; end - newconn.handler.connect = function (_, ...) return sock:connect(...) end - newconn.handler.send = function (_, data) _.write(data); return _.sendbuffer(); end - return newconn.handler; + return handler; end -dns:socket_wrapper_set(new_async_socket); +dns.socket_wrapper_set(new_async_socket); return _M; diff --git a/net/connlisteners.lua b/net/connlisteners.lua index ebb3cc18..99ddc720 100644 --- a/net/connlisteners.lua +++ b/net/connlisteners.lua @@ -1,65 +1,15 @@ --- Prosody IM --- Copyright (C) 2008-2009 Matthew Wild --- Copyright (C) 2008-2009 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- +-- COMPAT w/pre-0.9 +local log = require "util.logger".init("net.connlisteners"); +local traceback = debug.traceback; +module "httpserver" - -local listeners_dir = (CFG_SOURCEDIR or ".").."/net/"; -local server = require "net.server"; -local log = require "util.logger".init("connlisteners"); - -local dofile, pcall, error = - dofile, pcall, error - -module "connlisteners" - -local listeners = {}; - -function register(name, listener) - if listeners[name] and listeners[name] ~= listener then - log("debug", "Listener %s is already registered, not registering any more", name); - return false; - end - listeners[name] = listener; - log("debug", "Registered connection listener %s", name); - return true; +function fail() + log("error", "Attempt to use legacy connlisteners API. For more info see http://prosody.im/doc/developers/network"); + log("error", "Legacy connlisteners API usage, %s", traceback("", 2)); end -function deregister(name) - listeners[name] = nil; -end - -function get(name) - local h = listeners[name]; - if not h then - local ok, ret = pcall(dofile, listeners_dir..name:gsub("[^%w%-]", "_").."_listener.lua"); - if not ok then return nil, ret; end - h = listeners[name]; - end - return h; -end - -function start(name, udata) - local h, err = get(name); - if not h then - error("No such connection module: "..name.. (err and (" ("..err..")") or ""), 0); - end - - if udata then - if (udata.type == "ssl" or udata.type == "tls") and not udata.ssl then - error("No SSL context supplied for a "..tostring(udata.type):upper().." connection!", 0); - elseif udata.ssl and udata.type == "tcp" then - error("SSL context supplied for a TCP connection!", 0); - end - end - - return server.addserver(h, - (udata and udata.port) or h.default_port or error("Can't start listener "..name.." because no port was specified, and it has no default port", 0), - (udata and udata.interface) or h.default_interface or "*", (udata and udata.mode) or h.default_mode or 1, (udata and udata.ssl) or nil, 99999999, udata and udata.type == "ssl"); -end +register, deregister = fail, fail; +get, start = fail, fail, epic_fail; return _M; diff --git a/net/dns.lua b/net/dns.lua index 48c08218..c9c51fe8 100644 --- a/net/dns.lua +++ b/net/dns.lua @@ -2,8 +2,6 @@ -- This file is included with Prosody IM. It has modifications, -- which are hereby placed in the public domain. --- public domain 20080404 lua@ztact.com - -- todo: quick (default) header generation -- todo: nxdomain, error handling @@ -14,21 +12,65 @@ -- reference: http://tools.ietf.org/html/rfc1876 (LOC) -require 'socket' -local ztact = require 'util.ztact' -local require = require - -local coroutine, io, math, socket, string, table = - coroutine, io, math, socket, string, table - -local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack = - ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack - -local get, set = ztact.get, ztact.set - +local socket = require "socket"; +local timer = require "util.timer"; + +local _, windows = pcall(require, "util.windows"); +local is_windows = (_ and windows) or os.getenv("WINDIR"); + +local coroutine, io, math, string, table = + coroutine, io, math, string, table; + +local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type= + ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type; + +local ztact = { -- public domain 20080404 lua@ztact.com + get = function(parent, ...) + local len = select('#', ...); + for i=1,len do + parent = parent[select(i, ...)]; + if parent == nil then break; end + end + return parent; + end; + set = function(parent, ...) + local len = select('#', ...); + local key, value = select(len-1, ...); + local cutpoint, cutkey; + + for i=1,len-2 do + local key = select (i, ...) + local child = parent[key] + + if value == nil then + if child == nil then + return; + elseif next(child, next(child)) then + cutpoint = nil; cutkey = nil; + elseif cutpoint == nil then + cutpoint = parent; cutkey = key; + end + elseif child == nil then + child = {}; + parent[key] = child; + end + parent = child + end + + if value == nil and cutpoint then + cutpoint[cutkey] = nil; + else + parent[key] = value; + return value; + end + end; +}; +local get, set = ztact.get, ztact.set; + +local default_timeout = 15; -------------------------------------------------- module dns -module ('dns') +module('dns') local dns = _M; @@ -38,826 +80,1000 @@ local dns = _M; local append = table.insert -local function highbyte (i) -- - - - - - - - - - - - - - - - - - - highbyte - return (i-(i%0x100))/0x100 - end +local function highbyte(i) -- - - - - - - - - - - - - - - - - - - highbyte + return (i-(i%0x100))/0x100; +end local function augment (t) -- - - - - - - - - - - - - - - - - - - - augment - local a = {} - for i,s in pairs (t) do a[i] = s a[s] = s a[string.lower (s)] = s end - return a - end + local a = {}; + for i,s in pairs(t) do + a[i] = s; + a[s] = s; + a[string.lower(s)] = s; + end + return a; +end local function encode (t) -- - - - - - - - - - - - - - - - - - - - - encode - local code = {} - for i,s in pairs (t) do - local word = string.char (highbyte (i), i %0x100) - code[i] = word - code[s] = word - code[string.lower (s)] = word - end - return code - end + local code = {}; + for i,s in pairs(t) do + local word = string.char(highbyte(i), i%0x100); + code[i] = word; + code[s] = word; + code[string.lower(s)] = word; + end + return code; +end dns.types = { - 'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS', - 'PTR', 'HINFO', 'MINFO', 'MX', 'TXT', - [ 28] = 'AAAA', [ 29] = 'LOC', [ 33] = 'SRV', - [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' } + 'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS', + 'PTR', 'HINFO', 'MINFO', 'MX', 'TXT', + [ 28] = 'AAAA', [ 29] = 'LOC', [ 33] = 'SRV', + [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' }; -dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' } +dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' }; -dns.type = augment (dns.types) -dns.class = augment (dns.classes) -dns.typecode = encode (dns.types) -dns.classcode = encode (dns.classes) +dns.type = augment (dns.types); +dns.class = augment (dns.classes); +dns.typecode = encode (dns.types); +dns.classcode = encode (dns.classes); -local function standardize (qname, qtype, qclass) -- - - - - - - standardize - if string.byte (qname, -1) ~= 0x2E then qname = qname..'.' end - qname = string.lower (qname) - return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN'] - end - - -local function prune (rrs, time, soft) -- - - - - - - - - - - - - - - prune - - time = time or socket.gettime () - for i,rr in pairs (rrs) do +local function standardize(qname, qtype, qclass) -- - - - - - - standardize + if string.byte(qname, -1) ~= 0x2E then qname = qname..'.'; end + qname = string.lower(qname); + return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN']; +end - if rr.tod then - -- rr.tod = rr.tod - 50 -- accelerated decripitude - rr.ttl = math.floor (rr.tod - time) - if rr.ttl <= 0 then rrs[i] = nil end - elseif soft == 'soft' then -- What is this? I forget! - assert (rr.ttl == 0) - rrs[i] = nil - end end end +local function prune(rrs, time, soft) -- - - - - - - - - - - - - - - prune + time = time or socket.gettime(); + for i,rr in pairs(rrs) do + if rr.tod then + -- rr.tod = rr.tod - 50 -- accelerated decripitude + rr.ttl = math.floor(rr.tod - time); + if rr.ttl <= 0 then + table.remove(rrs, i); + return prune(rrs, time, soft); -- Re-iterate + end + elseif soft == 'soft' then -- What is this? I forget! + assert(rr.ttl == 0); + rrs[i] = nil; + end + end +end -- metatables & co. ------------------------------------------ metatables & co. -local resolver = {} -resolver.__index = resolver - +local resolver = {}; +resolver.__index = resolver; -local SRV_tostring +resolver.timeout = default_timeout; +local function default_rr_tostring(rr) + local rr_val = rr.type and rr[rr.type:lower()]; + if type(rr_val) ~= "string" then + return "<UNKNOWN RDATA TYPE>"; + end + return rr_val; +end -local rr_metatable = {} -- - - - - - - - - - - - - - - - - - - rr_metatable -function rr_metatable.__tostring (rr) - local s0 = string.format ( - '%2s %-5s %6i %-28s', rr.class, rr.type, rr.ttl, rr.name ) - local s1 = '' - if rr.type == 'A' then s1 = ' '..rr.a - elseif rr.type == 'MX' then - s1 = string.format (' %2i %s', rr.pref, rr.mx) - elseif rr.type == 'CNAME' then s1 = ' '..rr.cname - elseif rr.type == 'LOC' then s1 = ' '..resolver.LOC_tostring (rr) - elseif rr.type == 'NS' then s1 = ' '..rr.ns - elseif rr.type == 'SRV' then s1 = ' '..SRV_tostring (rr) - elseif rr.type == 'TXT' then s1 = ' '..rr.txt - else s1 = ' <UNKNOWN RDATA TYPE>' end - return s0..s1 - end +local special_tostrings = { + LOC = resolver.LOC_tostring; + MX = function (rr) + return string.format('%2i %s', rr.pref, rr.mx); + end; + SRV = function (rr) + local s = rr.srv; + return string.format('%5d %5d %5d %s', s.priority, s.weight, s.port, s.target); + end; +}; + +local rr_metatable = {}; -- - - - - - - - - - - - - - - - - - - rr_metatable +function rr_metatable.__tostring(rr) + local rr_string = (special_tostrings[rr.type] or default_rr_tostring)(rr); + return string.format('%2s %-5s %6i %-28s %s', rr.class, rr.type, rr.ttl, rr.name, rr_string); +end -local rrs_metatable = {} -- - - - - - - - - - - - - - - - - - rrs_metatable -function rrs_metatable.__tostring (rrs) - local t = {} - for i,rr in pairs (rrs) do append (t, tostring (rr)..'\n') end - return table.concat (t) - end +local rrs_metatable = {}; -- - - - - - - - - - - - - - - - - - rrs_metatable +function rrs_metatable.__tostring(rrs) + local t = {}; + for i,rr in pairs(rrs) do + append(t, tostring(rr)..'\n'); + end + return table.concat(t); +end -local cache_metatable = {} -- - - - - - - - - - - - - - - - cache_metatable -function cache_metatable.__tostring (cache) - local time = socket.gettime () - local t = {} - for class,types in pairs (cache) do - for type,names in pairs (types) do - for name,rrs in pairs (names) do - prune (rrs, time) - append (t, tostring (rrs)) end end end - return table.concat (t) - end +local cache_metatable = {}; -- - - - - - - - - - - - - - - - cache_metatable +function cache_metatable.__tostring(cache) + local time = socket.gettime(); + local t = {}; + for class,types in pairs(cache) do + for type,names in pairs(types) do + for name,rrs in pairs(names) do + prune(rrs, time); + append(t, tostring(rrs)); + end + end + end + return table.concat(t); +end -function resolver:new () -- - - - - - - - - - - - - - - - - - - - - resolver - local r = { active = {}, cache = {}, unsorted = {} } - setmetatable (r, resolver) - setmetatable (r.cache, cache_metatable) - setmetatable (r.unsorted, { __mode = 'kv' }) - return r - end +function resolver:new() -- - - - - - - - - - - - - - - - - - - - - resolver + local r = { active = {}, cache = {}, unsorted = {} }; + setmetatable(r, resolver); + setmetatable(r.cache, cache_metatable); + setmetatable(r.unsorted, { __mode = 'kv' }); + return r; +end -- packet layer -------------------------------------------------- packet layer -function dns.random (...) -- - - - - - - - - - - - - - - - - - - dns.random - math.randomseed (10000*socket.gettime ()) - dns.random = math.random - return dns.random (...) - end - - -local function encodeHeader (o) -- - - - - - - - - - - - - - - encodeHeader +function dns.random(...) -- - - - - - - - - - - - - - - - - - - dns.random + math.randomseed(math.floor(10000*socket.gettime()) % 0x100000000); + dns.random = math.random; + return dns.random(...); +end - o = o or {} - o.id = o.id or -- 16b (random) id - dns.random (0, 0xffff) +local function encodeHeader(o) -- - - - - - - - - - - - - - - encodeHeader + o = o or {}; + o.id = o.id or dns.random(0, 0xffff); -- 16b (random) id - o.rd = o.rd or 1 -- 1b 1 recursion desired - o.tc = o.tc or 0 -- 1b 1 truncated response - o.aa = o.aa or 0 -- 1b 1 authoritative response - o.opcode = o.opcode or 0 -- 4b 0 query - -- 1 inverse query + o.rd = o.rd or 1; -- 1b 1 recursion desired + o.tc = o.tc or 0; -- 1b 1 truncated response + o.aa = o.aa or 0; -- 1b 1 authoritative response + o.opcode = o.opcode or 0; -- 4b 0 query + -- 1 inverse query -- 2 server status request -- 3-15 reserved - o.qr = o.qr or 0 -- 1b 0 query, 1 response + o.qr = o.qr or 0; -- 1b 0 query, 1 response - o.rcode = o.rcode or 0 -- 4b 0 no error + o.rcode = o.rcode or 0; -- 4b 0 no error -- 1 format error -- 2 server failure -- 3 name error -- 4 not implemented -- 5 refused -- 6-15 reserved - o.z = o.z or 0 -- 3b 0 resvered - o.ra = o.ra or 0 -- 1b 1 recursion available - - o.qdcount = o.qdcount or 1 -- 16b number of question RRs - o.ancount = o.ancount or 0 -- 16b number of answers RRs - o.nscount = o.nscount or 0 -- 16b number of nameservers RRs - o.arcount = o.arcount or 0 -- 16b number of additional RRs - - -- string.char() rounds, so prevent roundup with -0.4999 - local header = string.char ( - highbyte (o.id), o.id %0x100, - o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr, - o.rcode + 16*o.z + 128*o.ra, - highbyte (o.qdcount), o.qdcount %0x100, - highbyte (o.ancount), o.ancount %0x100, - highbyte (o.nscount), o.nscount %0x100, - highbyte (o.arcount), o.arcount %0x100 ) - - return header, o.id - end - - -local function encodeName (name) -- - - - - - - - - - - - - - - - encodeName - local t = {} - for part in string.gmatch (name, '[^.]+') do - append (t, string.char (string.len (part))) - append (t, part) - end - append (t, string.char (0)) - return table.concat (t) - end - - -local function encodeQuestion (qname, qtype, qclass) -- - - - encodeQuestion - qname = encodeName (qname) - qtype = dns.typecode[qtype or 'a'] - qclass = dns.classcode[qclass or 'in'] - return qname..qtype..qclass; - end - - -function resolver:byte (len) -- - - - - - - - - - - - - - - - - - - - - byte - len = len or 1 - local offset = self.offset - local last = offset + len - 1 - if last > #self.packet then - error (string.format ('out of bounds: %i>%i', last, #self.packet)) end - self.offset = offset + len - return string.byte (self.packet, offset, last) - end - - -function resolver:word () -- - - - - - - - - - - - - - - - - - - - - - word - local b1, b2 = self:byte (2) - return 0x100*b1 + b2 - end + o.z = o.z or 0; -- 3b 0 resvered + o.ra = o.ra or 0; -- 1b 1 recursion available + + o.qdcount = o.qdcount or 1; -- 16b number of question RRs + o.ancount = o.ancount or 0; -- 16b number of answers RRs + o.nscount = o.nscount or 0; -- 16b number of nameservers RRs + o.arcount = o.arcount or 0; -- 16b number of additional RRs + + -- string.char() rounds, so prevent roundup with -0.4999 + local header = string.char( + highbyte(o.id), o.id %0x100, + o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr, + o.rcode + 16*o.z + 128*o.ra, + highbyte(o.qdcount), o.qdcount %0x100, + highbyte(o.ancount), o.ancount %0x100, + highbyte(o.nscount), o.nscount %0x100, + highbyte(o.arcount), o.arcount %0x100 + ); + + return header, o.id; +end + + +local function encodeName(name) -- - - - - - - - - - - - - - - - encodeName + local t = {}; + for part in string.gmatch(name, '[^.]+') do + append(t, string.char(string.len(part))); + append(t, part); + end + append(t, string.char(0)); + return table.concat(t); +end + + +local function encodeQuestion(qname, qtype, qclass) -- - - - encodeQuestion + qname = encodeName(qname); + qtype = dns.typecode[qtype or 'a']; + qclass = dns.classcode[qclass or 'in']; + return qname..qtype..qclass; +end + + +function resolver:byte(len) -- - - - - - - - - - - - - - - - - - - - - byte + len = len or 1; + local offset = self.offset; + local last = offset + len - 1; + if last > #self.packet then + error(string.format('out of bounds: %i>%i', last, #self.packet)); + end + self.offset = offset + len; + return string.byte(self.packet, offset, last); +end + + +function resolver:word() -- - - - - - - - - - - - - - - - - - - - - - word + local b1, b2 = self:byte(2); + return 0x100*b1 + b2; +end function resolver:dword () -- - - - - - - - - - - - - - - - - - - - - dword - local b1, b2, b3, b4 = self:byte (4) - --print ('dword', b1, b2, b3, b4) - return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4 - end + local b1, b2, b3, b4 = self:byte(4); + --print('dword', b1, b2, b3, b4); + return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4; +end -function resolver:sub (len) -- - - - - - - - - - - - - - - - - - - - - - sub - len = len or 1 - local s = string.sub (self.packet, self.offset, self.offset + len - 1) - self.offset = self.offset + len - return s - end +function resolver:sub(len) -- - - - - - - - - - - - - - - - - - - - - - sub + len = len or 1; + local s = string.sub(self.packet, self.offset, self.offset + len - 1); + self.offset = self.offset + len; + return s; +end -function resolver:header (force) -- - - - - - - - - - - - - - - - - - header - - local id = self:word () - --print (string.format (':header id %x', id)) - if not self.active[id] and not force then return nil end - - local h = { id = id } - - local b1, b2 = self:byte (2) - - h.rd = b1 %2 - h.tc = b1 /2%2 - h.aa = b1 /4%2 - h.opcode = b1 /8%16 - h.qr = b1 /128 +function resolver:header(force) -- - - - - - - - - - - - - - - - - - header + local id = self:word(); + --print(string.format(':header id %x', id)); + if not self.active[id] and not force then return nil; end - h.rcode = b2 %16 - h.z = b2 /16%8 - h.ra = b2 /128 - - h.qdcount = self:word () - h.ancount = self:word () - h.nscount = self:word () - h.arcount = self:word () - - for k,v in pairs (h) do h[k] = v-v%1 end - - return h - end - - -function resolver:name () -- - - - - - - - - - - - - - - - - - - - - - name - local remember, pointers = nil, 0 - local len = self:byte () - local n = {} - while len > 0 do - if len >= 0xc0 then -- name is "compressed" - pointers = pointers + 1 - if pointers >= 20 then error ('dns error: 20 pointers') end - local offset = ((len-0xc0)*0x100) + self:byte () - remember = remember or self.offset - self.offset = offset + 1 -- +1 for lua - else -- name is not compressed - append (n, self:sub (len)..'.') - end - len = self:byte () - end - self.offset = remember or self.offset - return table.concat (n) - end + local h = { id = id }; + local b1, b2 = self:byte(2); -function resolver:question () -- - - - - - - - - - - - - - - - - - question - local q = {} - q.name = self:name () - q.type = dns.type[self:word ()] - q.class = dns.class[self:word ()] - return q - end + h.rd = b1 %2; + h.tc = b1 /2%2; + h.aa = b1 /4%2; + h.opcode = b1 /8%16; + h.qr = b1 /128; + h.rcode = b2 %16; + h.z = b2 /16%8; + h.ra = b2 /128; -function resolver:A (rr) -- - - - - - - - - - - - - - - - - - - - - - - - A - local b1, b2, b3, b4 = self:byte (4) - rr.a = string.format ('%i.%i.%i.%i', b1, b2, b3, b4) - end + h.qdcount = self:word(); + h.ancount = self:word(); + h.nscount = self:word(); + h.arcount = self:word(); + for k,v in pairs(h) do h[k] = v-v%1; end -function resolver:CNAME (rr) -- - - - - - - - - - - - - - - - - - - - CNAME - rr.cname = self:name () - end + return h; +end -function resolver:MX (rr) -- - - - - - - - - - - - - - - - - - - - - - - MX - rr.pref = self:word () - rr.mx = self:name () - end +function resolver:name() -- - - - - - - - - - - - - - - - - - - - - - name + local remember, pointers = nil, 0; + local len = self:byte(); + local n = {}; + if len == 0 then return "." end -- Root label + while len > 0 do + if len >= 0xc0 then -- name is "compressed" + pointers = pointers + 1; + if pointers >= 20 then error('dns error: 20 pointers'); end; + local offset = ((len-0xc0)*0x100) + self:byte(); + remember = remember or self.offset; + self.offset = offset + 1; -- +1 for lua + else -- name is not compressed + append(n, self:sub(len)..'.'); + end + len = self:byte(); + end + self.offset = remember or self.offset; + return table.concat(n); +end -function resolver:LOC_nibble_power () -- - - - - - - - - - LOC_nibble_power - local b = self:byte () - --print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10)) - return ((b-(b%0x10))/0x10) * (10^(b%0x10)) - end +function resolver:question() -- - - - - - - - - - - - - - - - - - question + local q = {}; + q.name = self:name(); + q.type = dns.type[self:word()]; + q.class = dns.class[self:word()]; + return q; +end -function resolver:LOC (rr) -- - - - - - - - - - - - - - - - - - - - - - LOC - rr.version = self:byte () - if rr.version == 0 then - rr.loc = rr.loc or {} - rr.loc.size = self:LOC_nibble_power () - rr.loc.horiz_pre = self:LOC_nibble_power () - rr.loc.vert_pre = self:LOC_nibble_power () - rr.loc.latitude = self:dword () - rr.loc.longitude = self:dword () - rr.loc.altitude = self:dword () - end end +function resolver:A(rr) -- - - - - - - - - - - - - - - - - - - - - - - - A + local b1, b2, b3, b4 = self:byte(4); + rr.a = string.format('%i.%i.%i.%i', b1, b2, b3, b4); +end +function resolver:AAAA(rr) + local addr = {}; + for i = 1, rr.rdlength, 2 do + local b1, b2 = self:byte(2); + table.insert(addr, ("%02x%02x"):format(b1, b2)); + end + addr = table.concat(addr, ":"):gsub("%f[%x]0+(%x)","%1"); + local zeros = {}; + for item in addr:gmatch(":[0:]+:") do + table.insert(zeros, item) + end + if #zeros == 0 then + rr.aaaa = addr; + return + elseif #zeros > 1 then + table.sort(zeros, function(a, b) return #a > #b end); + end + rr.aaaa = addr:gsub(zeros[1], "::", 1):gsub("^0::", "::"):gsub("::0$", "::"); +end -local function LOC_tostring_degrees (f, pos, neg) -- - - - - - - - - - - - - - f = f - 0x80000000 - if f < 0 then pos = neg f = -f end - local deg, min, msec - msec = f%60000 - f = (f-msec)/60000 - min = f%60 - deg = (f-min)/60 - return string.format ('%3d %2d %2.3f %s', deg, min, msec/1000, pos) - end - - -function resolver.LOC_tostring (rr) -- - - - - - - - - - - - - LOC_tostring - - local t = {} - - --[[ - for k,name in pairs { 'size', 'horiz_pre', 'vert_pre', - 'latitude', 'longitude', 'altitude' } do - append (t, string.format ('%4s%-10s: %12.0f\n', '', name, rr.loc[name])) - end - --]] - - append ( t, string.format ( - '%s %s %.2fm %.2fm %.2fm %.2fm', - LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'), - LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'), - (rr.loc.altitude - 10000000) / 100, - rr.loc.size / 100, - rr.loc.horiz_pre / 100, - rr.loc.vert_pre / 100 ) ) - - return table.concat (t) - end - - -function resolver:NS (rr) -- - - - - - - - - - - - - - - - - - - - - - - NS - rr.ns = self:name () - end - +function resolver:CNAME(rr) -- - - - - - - - - - - - - - - - - - - - CNAME + rr.cname = self:name(); +end -function resolver:SOA (rr) -- - - - - - - - - - - - - - - - - - - - - - SOA - end +function resolver:MX(rr) -- - - - - - - - - - - - - - - - - - - - - - - MX + rr.pref = self:word(); + rr.mx = self:name(); +end -function resolver:SRV (rr) -- - - - - - - - - - - - - - - - - - - - - - SRV - rr.srv = {} - rr.srv.priority = self:word () - rr.srv.weight = self:word () - rr.srv.port = self:word () - rr.srv.target = self:name () - end +function resolver:LOC_nibble_power() -- - - - - - - - - - LOC_nibble_power + local b = self:byte(); + --print('nibbles', ((b-(b%0x10))/0x10), (b%0x10)); + return ((b-(b%0x10))/0x10) * (10^(b%0x10)); +end -function SRV_tostring (rr) -- - - - - - - - - - - - - - - - - - SRV_tostring - local s = rr.srv - return string.format ( '%5d %5d %5d %s', - s.priority, s.weight, s.port, s.target ) - end +function resolver:LOC(rr) -- - - - - - - - - - - - - - - - - - - - - - LOC + rr.version = self:byte(); + if rr.version == 0 then + rr.loc = rr.loc or {}; + rr.loc.size = self:LOC_nibble_power(); + rr.loc.horiz_pre = self:LOC_nibble_power(); + rr.loc.vert_pre = self:LOC_nibble_power(); + rr.loc.latitude = self:dword(); + rr.loc.longitude = self:dword(); + rr.loc.altitude = self:dword(); + end +end -function resolver:TXT (rr) -- - - - - - - - - - - - - - - - - - - - - - TXT - rr.txt = self:sub (rr.rdlength) - end +local function LOC_tostring_degrees(f, pos, neg) -- - - - - - - - - - - - - + f = f - 0x80000000; + if f < 0 then pos = neg; f = -f; end + local deg, min, msec; + msec = f%60000; + f = (f-msec)/60000; + min = f%60; + deg = (f-min)/60; + return string.format('%3d %2d %2.3f %s', deg, min, msec/1000, pos); +end -function resolver:rr () -- - - - - - - - - - - - - - - - - - - - - - - - rr - local rr = {} - setmetatable (rr, rr_metatable) - rr.name = self:name (self) - rr.type = dns.type[self:word ()] or rr.type - rr.class = dns.class[self:word ()] or rr.class - rr.ttl = 0x10000*self:word () + self:word () - rr.rdlength = self:word () - if rr.ttl == 0 then -- pass - else rr.tod = self.time + rr.ttl end +function resolver.LOC_tostring(rr) -- - - - - - - - - - - - - LOC_tostring + local t = {}; - local remember = self.offset - local rr_parser = self[dns.type[rr.type]] - if rr_parser then rr_parser (self, rr) end - self.offset = remember - rr.rdata = self:sub (rr.rdlength) - return rr - end + --[[ + for k,name in pairs { 'size', 'horiz_pre', 'vert_pre', 'latitude', 'longitude', 'altitude' } do + append(t, string.format('%4s%-10s: %12.0f\n', '', name, rr.loc[name])); + end + --]] + + append(t, string.format( + '%s %s %.2fm %.2fm %.2fm %.2fm', + LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'), + LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'), + (rr.loc.altitude - 10000000) / 100, + rr.loc.size / 100, + rr.loc.horiz_pre / 100, + rr.loc.vert_pre / 100 + )); + + return table.concat(t); +end -function resolver:rrs (count) -- - - - - - - - - - - - - - - - - - - - - rrs - local rrs = {} - for i = 1,count do append (rrs, self:rr ()) end - return rrs - end +function resolver:NS(rr) -- - - - - - - - - - - - - - - - - - - - - - - NS + rr.ns = self:name(); +end -function resolver:decode (packet, force) -- - - - - - - - - - - - - - decode +function resolver:SOA(rr) -- - - - - - - - - - - - - - - - - - - - - - SOA +end - self.packet, self.offset = packet, 1 - local header = self:header (force) - if not header then return nil end - local response = { header = header } - response.question = {} - local offset = self.offset - for i = 1,response.header.qdcount do - append (response.question, self:question ()) end - response.question.raw = string.sub (self.packet, offset, self.offset - 1) +function resolver:SRV(rr) -- - - - - - - - - - - - - - - - - - - - - - SRV + rr.srv = {}; + rr.srv.priority = self:word(); + rr.srv.weight = self:word(); + rr.srv.port = self:word(); + rr.srv.target = self:name(); +end - if not force then - if not self.active[response.header.id] or - not self.active[response.header.id][response.question.raw] then - return nil end end +function resolver:PTR(rr) + rr.ptr = self:name(); +end - response.answer = self:rrs (response.header.ancount) - response.authority = self:rrs (response.header.nscount) - response.additional = self:rrs (response.header.arcount) +function resolver:TXT(rr) -- - - - - - - - - - - - - - - - - - - - - - TXT + rr.txt = self:sub (self:byte()); +end - return response - end +function resolver:rr() -- - - - - - - - - - - - - - - - - - - - - - - - rr + local rr = {}; + setmetatable(rr, rr_metatable); + rr.name = self:name(self); + rr.type = dns.type[self:word()] or rr.type; + rr.class = dns.class[self:word()] or rr.class; + rr.ttl = 0x10000*self:word() + self:word(); + rr.rdlength = self:word(); --- socket layer -------------------------------------------------- socket layer + if rr.ttl <= 0 then + rr.tod = self.time + 30; + else + rr.tod = self.time + rr.ttl; + end + + local remember = self.offset; + local rr_parser = self[dns.type[rr.type]]; + if rr_parser then rr_parser(self, rr); end + self.offset = remember; + rr.rdata = self:sub(rr.rdlength); + return rr; +end + + +function resolver:rrs (count) -- - - - - - - - - - - - - - - - - - - - - rrs + local rrs = {}; + for i = 1,count do append(rrs, self:rr()); end + return rrs; +end + + +function resolver:decode(packet, force) -- - - - - - - - - - - - - - decode + self.packet, self.offset = packet, 1; + local header = self:header(force); + if not header then return nil; end + local response = { header = header }; + + response.question = {}; + local offset = self.offset; + for i = 1,response.header.qdcount do + append(response.question, self:question()); + end + response.question.raw = string.sub(self.packet, offset, self.offset - 1); + + if not force then + if not self.active[response.header.id] or not self.active[response.header.id][response.question.raw] then + self.active[response.header.id] = nil; + return nil; + end + end + + response.answer = self:rrs(response.header.ancount); + response.authority = self:rrs(response.header.nscount); + response.additional = self:rrs(response.header.arcount); + + return response; +end -resolver.delays = { 1, 3, 11, 45 } +-- socket layer -------------------------------------------------- socket layer -function resolver:addnameserver (address) -- - - - - - - - - - addnameserver - self.server = self.server or {} - append (self.server, address) - end +resolver.delays = { 1, 3 }; -function resolver:setnameserver (address) -- - - - - - - - - - setnameserver - self.server = {} - self:addnameserver (address) - end +function resolver:addnameserver(address) -- - - - - - - - - - addnameserver + self.server = self.server or {}; + append(self.server, address); +end -function resolver:adddefaultnameservers () -- - - - - adddefaultnameservers - local resolv_conf = io.open("/etc/resolv.conf"); - if resolv_conf then - for line in resolv_conf:lines() do - local address = string.match (line, 'nameserver%s+(%d+%.%d+%.%d+%.%d+)') - if address then self:addnameserver (address) end - end - else -- FIXME correct for windows, using opendns nameservers for now - self:addnameserver ("208.67.222.222") - self:addnameserver ("208.67.220.220") - end +function resolver:setnameserver(address) -- - - - - - - - - - setnameserver + self.server = {}; + self:addnameserver(address); end -function resolver:getsocket (servernum) -- - - - - - - - - - - - - getsocket +function resolver:adddefaultnameservers() -- - - - - adddefaultnameservers + if is_windows then + if windows and windows.get_nameservers then + for _, server in ipairs(windows.get_nameservers()) do + self:addnameserver(server); + end + end + if not self.server or #self.server == 0 then + -- TODO log warning about no nameservers, adding opendns servers as fallback + self:addnameserver("208.67.222.222"); + self:addnameserver("208.67.220.220"); + end + else -- posix + local resolv_conf = io.open("/etc/resolv.conf"); + if resolv_conf then + for line in resolv_conf:lines() do + line = line:gsub("#.*$", "") + :match('^%s*nameserver%s+(.*)%s*$'); + if line then + line:gsub("%f[%d.](%d+%.%d+%.%d+%.%d+)%f[^%d.]", function (address) + self:addnameserver(address) + end); + end + end + end + if not self.server or #self.server == 0 then + -- TODO log warning about no nameservers, adding localhost as the default nameserver + self:addnameserver("127.0.0.1"); + end + end +end + - self.socket = self.socket or {} - self.socketset = self.socketset or {} +function resolver:getsocket(servernum) -- - - - - - - - - - - - - getsocket + self.socket = self.socket or {}; + self.socketset = self.socketset or {}; - local sock = self.socket[servernum] - if sock then return sock end + local sock = self.socket[servernum]; + if sock then return sock; end - sock = socket.udp () - if self.socket_wrapper then sock = self.socket_wrapper (sock) end - sock:settimeout (0) - -- todo: attempt to use a random port, fallback to 0 - sock:setsockname ('*', 0) - sock:setpeername (self.server[servernum], 53) - self.socket[servernum] = sock - self.socketset[sock] = sock - return sock - end + local err; + sock, err = socket.udp(); + if not sock then + return nil, err; + end + if self.socket_wrapper then sock = self.socket_wrapper(sock, self); end + sock:settimeout(0); + -- todo: attempt to use a random port, fallback to 0 + sock:setsockname('*', 0); + sock:setpeername(self.server[servernum], 53); + self.socket[servernum] = sock; + self.socketset[sock] = servernum; + return sock; +end +function resolver:voidsocket(sock) + if self.socket[sock] then + self.socketset[self.socket[sock]] = nil; + self.socket[sock] = nil; + elseif self.socketset[sock] then + self.socket[self.socketset[sock]] = nil; + self.socketset[sock] = nil; + end + sock:close(); +end -function resolver:socket_wrapper_set (func) -- - - - - - - socket_wrapper_set - self.socket_wrapper = func - end +function resolver:socket_wrapper_set(func) -- - - - - - - socket_wrapper_set + self.socket_wrapper = func; +end function resolver:closeall () -- - - - - - - - - - - - - - - - - - closeall - for i,sock in ipairs (self.socket) do self.socket[i]:close () end - self.socket = {} - end - + for i,sock in ipairs(self.socket) do + self.socket[i] = nil; + self.socketset[sock] = nil; + sock:close(); + end +end -function resolver:remember (rr, type) -- - - - - - - - - - - - - - remember - --print ('remember', type, rr.class, rr.type, rr.name) +function resolver:remember(rr, type) -- - - - - - - - - - - - - - remember + --print ('remember', type, rr.class, rr.type, rr.name) + local qname, qtype, qclass = standardize(rr.name, rr.type, rr.class); - if type ~= '*' then - type = rr.type - local all = get (self.cache, rr.class, '*', rr.name) - --print ('remember all', all) - if all then append (all, rr) end - end + if type ~= '*' then + type = qtype; + local all = get(self.cache, qclass, '*', qname); + --print('remember all', all); + if all then append(all, rr); end + end - self.cache = self.cache or setmetatable ({}, cache_metatable) - local rrs = get (self.cache, rr.class, type, rr.name) or - set (self.cache, rr.class, type, rr.name, setmetatable ({}, rrs_metatable)) - append (rrs, rr) + self.cache = self.cache or setmetatable({}, cache_metatable); + local rrs = get(self.cache, qclass, type, qname) or + set(self.cache, qclass, type, qname, setmetatable({}, rrs_metatable)); + append(rrs, rr); - if type == 'MX' then self.unsorted[rrs] = true end - end + if type == 'MX' then self.unsorted[rrs] = true; end +end -local function comp_mx (a, b) -- - - - - - - - - - - - - - - - - - - comp_mx - return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref) - end +local function comp_mx(a, b) -- - - - - - - - - - - - - - - - - - - comp_mx + return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref); +end function resolver:peek (qname, qtype, qclass) -- - - - - - - - - - - - peek - qname, qtype, qclass = standardize (qname, qtype, qclass) - local rrs = get (self.cache, qclass, qtype, qname) - if not rrs then return nil end - if prune (rrs, socket.gettime ()) and qtype == '*' or not next (rrs) then - set (self.cache, qclass, qtype, qname, nil) return nil end - if self.unsorted[rrs] then table.sort (rrs, comp_mx) end - return rrs - end - - -function resolver:purge (soft) -- - - - - - - - - - - - - - - - - - - purge - if soft == 'soft' then - self.time = socket.gettime () - for class,types in pairs (self.cache or {}) do - for type,names in pairs (types) do - for name,rrs in pairs (names) do - prune (rrs, self.time, 'soft') - end end end - else self.cache = {} end - end - - -function resolver:query (qname, qtype, qclass) -- - - - - - - - - - -- query - - qname, qtype, qclass = standardize (qname, qtype, qclass) - - if not self.server then self:adddefaultnameservers () end - - local question = encodeQuestion (qname, qtype, qclass) - local peek = self:peek (qname, qtype, qclass) - if peek then return peek end - - local header, id = encodeHeader () - --print ('query id', id, qclass, qtype, qname) - local o = { packet = header..question, - server = 1, - delay = 1, - retry = socket.gettime () + self.delays[1] } - self:getsocket (o.server):send (o.packet) - - -- remember the query - self.active[id] = self.active[id] or {} - self.active[id][question] = o - - -- remember which coroutine wants the answer - local co = coroutine.running () - if co then - set (self.wanted, qclass, qtype, qname, co, true) - --set (self.yielded, co, qclass, qtype, qname, true) - end -end - - - -function resolver:receive (rset) -- - - - - - - - - - - - - - - - - receive - - --print 'receive' print (self.socket) - self.time = socket.gettime () - rset = rset or self.socket - - local response - for i,sock in pairs (rset) do - - if self.socketset[sock] then - local packet = sock:receive () - if packet then - - response = self:decode (packet) - if response then - --print 'received response' - --self.print (response) - - for i,section in pairs { 'answer', 'authority', 'additional' } do - for j,rr in pairs (response[section]) do - self:remember (rr, response.question[1].type) end end - - -- retire the query - local queries = self.active[response.header.id] - if queries[response.question.raw] then - queries[response.question.raw] = nil end - if not next (queries) then self.active[response.header.id] = nil end - if not next (self.active) then self:closeall () end - - -- was the query on the wanted list? - local q = response.question - local cos = get (self.wanted, q.class, q.type, q.name) - if cos then - for co in pairs (cos) do - set (self.yielded, co, q.class, q.type, q.name, nil) - if coroutine.status(co) == "suspended" then coroutine.resume (co) end - end - set (self.wanted, q.class, q.type, q.name, nil) - end end end end end - - return response - end - - -function resolver:feed(sock, packet) - --print 'receive' print (self.socket) - self.time = socket.gettime () - - local response = self:decode (packet) - if response then - --print 'received response' - --self.print (response) - - for i,section in pairs { 'answer', 'authority', 'additional' } do - for j,rr in pairs (response[section]) do - self:remember (rr, response.question[1].type) - end - end - - -- retire the query - local queries = self.active[response.header.id] - if queries[response.question.raw] then - queries[response.question.raw] = nil - end - if not next (queries) then self.active[response.header.id] = nil end - if not next (self.active) then self:closeall () end - - -- was the query on the wanted list? - local q = response.question[1] - if q then - local cos = get (self.wanted, q.class, q.type, q.name) - if cos then - for co in pairs (cos) do - set (self.yielded, co, q.class, q.type, q.name, nil) - if coroutine.status(co) == "suspended" then coroutine.resume (co) end - end - set (self.wanted, q.class, q.type, q.name, nil) - end - end - end - - return response -end - -function resolver:cancel(data) - local cos = get (self.wanted, unpack(data, 1, 3)) - if cos then - cos[data[4]] = nil; + qname, qtype, qclass = standardize(qname, qtype, qclass); + local rrs = get(self.cache, qclass, qtype, qname); + if not rrs then return nil; end + if prune(rrs, socket.gettime()) and qtype == '*' or not next(rrs) then + set(self.cache, qclass, qtype, qname, nil); + return nil; end + if self.unsorted[rrs] then table.sort (rrs, comp_mx); end + return rrs; +end + + +function resolver:purge(soft) -- - - - - - - - - - - - - - - - - - - purge + if soft == 'soft' then + self.time = socket.gettime(); + for class,types in pairs(self.cache or {}) do + for type,names in pairs(types) do + for name,rrs in pairs(names) do + prune(rrs, self.time, 'soft') + end + end + end + else self.cache = setmetatable({}, cache_metatable); end end -function resolver:pulse () -- - - - - - - - - - - - - - - - - - - - - pulse - --print ':pulse' - while self:receive() do end - if not next (self.active) then return nil end +function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query + qname, qtype, qclass = standardize(qname, qtype, qclass) + + if not self.server then self:adddefaultnameservers(); end + + local question = encodeQuestion(qname, qtype, qclass); + local peek = self:peek (qname, qtype, qclass); + if peek then return peek; end + + local header, id = encodeHeader(); + --print ('query id', id, qclass, qtype, qname) + local o = { + packet = header..question, + server = self.best_server, + delay = 1, + retry = socket.gettime() + self.delays[1] + }; + + -- remember the query + self.active[id] = self.active[id] or {}; + self.active[id][question] = o; - self.time = socket.gettime () - for id,queries in pairs (self.active) do - for question,o in pairs (queries) do - if self.time >= o.retry then + -- remember which coroutine wants the answer + local co = coroutine.running(); + if co then + set(self.wanted, qclass, qtype, qname, co, true); + --set(self.yielded, co, qclass, qtype, qname, true); + end - o.server = o.server + 1 - if o.server > #self.server then - o.server = 1 - o.delay = o.delay + 1 - end + local conn, err = self:getsocket(o.server) + if not conn then + return nil, err; + end + conn:send (o.packet) + + if timer and self.timeout then + local num_servers = #self.server; + local i = 1; + timer.add_task(self.timeout, function () + if get(self.wanted, qclass, qtype, qname, co) then + if i < num_servers then + i = i + 1; + self:servfail(conn); + o.server = self.best_server; + conn, err = self:getsocket(o.server); + if conn then + conn:send(o.packet); + return self.timeout; + end + end + -- Tried everything, failed + self:cancel(qclass, qtype, qname, co, true); + end + end) + end + return true; +end - if o.delay > #self.delays then - --print ('timeout') - queries[question] = nil - if not next (queries) then self.active[id] = nil end - if not next (self.active) then return nil end - else - --print ('retry', o.server, o.delay) - local _a = self.socket[o.server]; - if _a then _a:send (o.packet) end - o.retry = self.time + self.delays[o.delay] - end end end end +function resolver:servfail(sock) + -- Resend all queries for this server + + local num = self.socketset[sock] + + -- Socket is dead now + self:voidsocket(sock); + + -- Find all requests to the down server, and retry on the next server + self.time = socket.gettime(); + for id,queries in pairs(self.active) do + for question,o in pairs(queries) do + if o.server == num then -- This request was to the broken server + o.server = o.server + 1 -- Use next server + if o.server > #self.server then + o.server = 1; + end + + o.retries = (o.retries or 0) + 1; + if o.retries >= #self.server then + --print('timeout'); + queries[question] = nil; + else + local _a = self:getsocket(o.server); + if _a then _a:send(o.packet); end + end + end + end + if next(queries) == nil then + self.active[id] = nil; + end + end - if next (self.active) then return true end - return nil - end + if num == self.best_server then + self.best_server = self.best_server + 1; + if self.best_server > #self.server then + -- Exhausted all servers, try first again + self.best_server = 1; + end + end +end +function resolver:settimeout(seconds) + self.timeout = seconds; +end -function resolver:lookup (qname, qtype, qclass) -- - - - - - - - - - lookup - self:query (qname, qtype, qclass) - while self:pulse () do socket.select (self.socket, nil, 4) end - --print (self.cache) - return self:peek (qname, qtype, qclass) - end +function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive + --print('receive'); print(self.socket); + self.time = socket.gettime(); + rset = rset or self.socket; + + local response; + for i,sock in pairs(rset) do + + if self.socketset[sock] then + local packet = sock:receive(); + if packet then + response = self:decode(packet); + if response and self.active[response.header.id] + and self.active[response.header.id][response.question.raw] then + --print('received response'); + --self.print(response); + + for j,rr in pairs(response.answer) do + if rr.name:sub(-#response.question[1].name, -1) == response.question[1].name then + self:remember(rr, response.question[1].type) + end + end + + -- retire the query + local queries = self.active[response.header.id]; + queries[response.question.raw] = nil; + + if not next(queries) then self.active[response.header.id] = nil; end + if not next(self.active) then self:closeall(); end + + -- was the query on the wanted list? + local q = response.question[1]; + local cos = get(self.wanted, q.class, q.type, q.name); + if cos then + for co in pairs(cos) do + set(self.yielded, co, q.class, q.type, q.name, nil); + if coroutine.status(co) == "suspended" then coroutine.resume(co); end + end + set(self.wanted, q.class, q.type, q.name, nil); + end + end + + end + end + end -function resolver:lookupex (handler, qname, qtype, qclass) -- - - - - - - - - - lookup - return self:peek (qname, qtype, qclass) or self:query (qname, qtype, qclass) - end + return response; +end +function resolver:feed(sock, packet, force) + --print('receive'); print(self.socket); + self.time = socket.gettime(); + + local response = self:decode(packet, force); + if response and self.active[response.header.id] + and self.active[response.header.id][response.question.raw] then + --print('received response'); + --self.print(response); + + for j,rr in pairs(response.answer) do + self:remember(rr, response.question[1].type); + end + + -- retire the query + local queries = self.active[response.header.id]; + queries[response.question.raw] = nil; + if not next(queries) then self.active[response.header.id] = nil; end + if not next(self.active) then self:closeall(); end + + -- was the query on the wanted list? + local q = response.question[1]; + if q then + local cos = get(self.wanted, q.class, q.type, q.name); + if cos then + for co in pairs(cos) do + set(self.yielded, co, q.class, q.type, q.name, nil); + if coroutine.status(co) == "suspended" then coroutine.resume(co); end + end + set(self.wanted, q.class, q.type, q.name, nil); + end + end + end + + return response; +end + +function resolver:cancel(qclass, qtype, qname, co, call_handler) + local cos = get(self.wanted, qclass, qtype, qname); + if cos then + if call_handler then + coroutine.resume(co); + end + cos[co] = nil; + end +end + +function resolver:pulse() -- - - - - - - - - - - - - - - - - - - - - pulse + --print(':pulse'); + while self:receive() do end + if not next(self.active) then return nil; end + + self.time = socket.gettime(); + for id,queries in pairs(self.active) do + for question,o in pairs(queries) do + if self.time >= o.retry then + + o.server = o.server + 1; + if o.server > #self.server then + o.server = 1; + o.delay = o.delay + 1; + end + + if o.delay > #self.delays then + --print('timeout'); + queries[question] = nil; + if not next(queries) then self.active[id] = nil; end + if not next(self.active) then return nil; end + else + --print('retry', o.server, o.delay); + local _a = self.socket[o.server]; + if _a then _a:send(o.packet); end + o.retry = self.time + self.delays[o.delay]; + end + end + end + end + + if next(self.active) then return true; end + return nil; +end + + +function resolver:lookup(qname, qtype, qclass) -- - - - - - - - - - lookup + self:query (qname, qtype, qclass) + while self:pulse() do + local recvt = {} + for i, s in ipairs(self.socket) do + recvt[i] = s + end + socket.select(recvt, nil, 4) + end + --print(self.cache); + return self:peek(qname, qtype, qclass); +end + +function resolver:lookupex(handler, qname, qtype, qclass) -- - - - - - - - - - lookup + return self:peek(qname, qtype, qclass) or self:query(qname, qtype, qclass); +end + +function resolver:tohostname(ip) + return dns.lookup(ip:gsub("(%d+)%.(%d+)%.(%d+)%.(%d+)", "%4.%3.%2.%1.in-addr.arpa."), "PTR"); +end + --print ---------------------------------------------------------------- print local hints = { -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints - qr = { [0]='query', 'response' }, - opcode = { [0]='query', 'inverse query', 'server status request' }, - aa = { [0]='non-authoritative', 'authoritative' }, - tc = { [0]='complete', 'truncated' }, - rd = { [0]='recursion not desired', 'recursion desired' }, - ra = { [0]='recursion not available', 'recursion available' }, - z = { [0]='(reserved)' }, - rcode = { [0]='no error', 'format error', 'server failure', 'name error', - 'not implemented' }, - - type = dns.type, - class = dns.class, } - - -local function hint (p, s) -- - - - - - - - - - - - - - - - - - - - - - hint - return (hints[s] and hints[s][p[s]]) or '' end - - -function resolver.print (response) -- - - - - - - - - - - - - resolver.print - - for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z', - 'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do - print ( string.format ('%-30s', 'header.'..s), - response.header[s], hint (response.header, s) ) - end - - for i,question in ipairs (response.question) do - print (string.format ('question[%i].name ', i), question.name) - print (string.format ('question[%i].type ', i), question.type) - print (string.format ('question[%i].class ', i), question.class) - end - - local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 } - local tmp - for s,s in pairs {'answer', 'authority', 'additional'} do - for i,rr in pairs (response[s]) do - for j,t in pairs { 'name', 'type', 'class', 'ttl', 'rdlength' } do - tmp = string.format ('%s[%i].%s', s, i, t) - print (string.format ('%-30s', tmp), rr[t], hint (rr, t)) - end - for j,t in pairs (rr) do - if not common[j] then - tmp = string.format ('%s[%i].%s', s, i, j) - print (string.format ('%-30s %s', tostring(tmp), tostring(t))) - end end end end end + qr = { [0]='query', 'response' }, + opcode = { [0]='query', 'inverse query', 'server status request' }, + aa = { [0]='non-authoritative', 'authoritative' }, + tc = { [0]='complete', 'truncated' }, + rd = { [0]='recursion not desired', 'recursion desired' }, + ra = { [0]='recursion not available', 'recursion available' }, + z = { [0]='(reserved)' }, + rcode = { [0]='no error', 'format error', 'server failure', 'name error', 'not implemented' }, + + type = dns.type, + class = dns.class +}; + + +local function hint(p, s) -- - - - - - - - - - - - - - - - - - - - - - hint + return (hints[s] and hints[s][p[s]]) or ''; +end --- module api ------------------------------------------------------ module api +function resolver.print(response) -- - - - - - - - - - - - - resolver.print + for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z', + 'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do + print( string.format('%-30s', 'header.'..s), response.header[s], hint(response.header, s) ); + end + for i,question in ipairs(response.question) do + print(string.format ('question[%i].name ', i), question.name); + print(string.format ('question[%i].type ', i), question.type); + print(string.format ('question[%i].class ', i), question.class); + end -local function resolve (func, ...) -- - - - - - - - - - - - - - resolver_get - dns._resolver = dns._resolver or dns.resolver () - return func (dns._resolver, ...) - end + local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 }; + local tmp; + for s,s in pairs({'answer', 'authority', 'additional'}) do + for i,rr in pairs(response[s]) do + for j,t in pairs({ 'name', 'type', 'class', 'ttl', 'rdlength' }) do + tmp = string.format('%s[%i].%s', s, i, t); + print(string.format('%-30s', tmp), rr[t], hint(rr, t)); + end + for j,t in pairs(rr) do + if not common[j] then + tmp = string.format('%s[%i].%s', s, i, j); + print(string.format('%-30s %s', tostring(tmp), tostring(t))); + end + end + end + end +end -function dns.resolver () -- - - - - - - - - - - - - - - - - - - - - resolver +-- module api ------------------------------------------------------ module api - -- this function seems to be redundant with resolver.new () - local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {} } - setmetatable (r, resolver) - setmetatable (r.cache, cache_metatable) - setmetatable (r.unsorted, { __mode = 'kv' }) - return r - end +function dns.resolver () -- - - - - - - - - - - - - - - - - - - - - resolver + -- this function seems to be redundant with resolver.new () + local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {}, best_server = 1 }; + setmetatable (r, resolver); + setmetatable (r.cache, cache_metatable); + setmetatable (r.unsorted, { __mode = 'kv' }); + return r; +end -function dns.lookup (...) -- - - - - - - - - - - - - - - - - - - - - lookup - return resolve (resolver.lookup, ...) end +local _resolver = dns.resolver(); +dns._resolver = _resolver; +function dns.lookup(...) -- - - - - - - - - - - - - - - - - - - - - lookup + return _resolver:lookup(...); +end -function dns.purge (...) -- - - - - - - - - - - - - - - - - - - - - - purge - return resolve (resolver.purge, ...) end +function dns.tohostname(...) + return _resolver:tohostname(...); +end -function dns.peek (...) -- - - - - - - - - - - - - - - - - - - - - - - peek - return resolve (resolver.peek, ...) end +function dns.purge(...) -- - - - - - - - - - - - - - - - - - - - - - purge + return _resolver:purge(...); +end +function dns.peek(...) -- - - - - - - - - - - - - - - - - - - - - - - peek + return _resolver:peek(...); +end -function dns.query (...) -- - - - - - - - - - - - - - - - - - - - - - query - return resolve (resolver.query, ...) end +function dns.query(...) -- - - - - - - - - - - - - - - - - - - - - - query + return _resolver:query(...); +end -function dns.feed (...) -- - - - - - - - - - - - - - - - - - - - - - feed - return resolve (resolver.feed, ...) end +function dns.feed(...) -- - - - - - - - - - - - - - - - - - - - - - - feed + return _resolver:feed(...); +end -function dns.cancel(...) -- - - - - - - - - - - - - - - - - - - - - - cancel - return resolve(resolver.cancel, ...) end +function dns.cancel(...) -- - - - - - - - - - - - - - - - - - - - - - cancel + return _resolver:cancel(...); +end -function dns:socket_wrapper_set (...) -- - - - - - - - - socket_wrapper_set - return resolve (resolver.socket_wrapper_set, ...) end +function dns.settimeout(...) + return _resolver:settimeout(...); +end +function dns.socket_wrapper_set(...) -- - - - - - - - - socket_wrapper_set + return _resolver:socket_wrapper_set(...); +end -return dns +return dns; diff --git a/net/http.lua b/net/http.lua index 9d2f9b96..3b783a41 100644 --- a/net/http.lua +++ b/net/http.lua @@ -1,124 +1,107 @@ -- Prosody IM --- Copyright (C) 2008-2009 Matthew Wild --- Copyright (C) 2008-2009 Waqas Hussain +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain -- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- - local socket = require "socket" -local mime = require "mime" +local b64 = require "util.encodings".base64.encode; local url = require "socket.url" +local httpstream_new = require "net.http.parser".new; +local util_http = require "util.http"; -local server = require "net.server" +local ssl_available = pcall(require, "ssl"); -local connlisteners_get = require "net.connlisteners".get; -local listener = connlisteners_get("httpclient") or error("No httpclient listener!"); +local server = require "net.server" local t_insert, t_concat = table.insert, table.concat; -local tonumber, tostring, pairs, xpcall, select, debug_traceback, char, format = - tonumber, tostring, pairs, xpcall, select, debug.traceback, string.char, string.format; +local pairs = pairs; +local tonumber, tostring, xpcall, select, traceback = + tonumber, tostring, xpcall, select, debug.traceback; local log = require "util.logger".init("http"); -local print = function () end module "http" -function urlencode(s) return s and (s:gsub("%W", function (c) return format("%%%02x", c:byte()); end)); end -function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end)); end +local requests = {}; -- Open requests + +local listener = { default_port = 80, default_mode = "*a" }; -local function expectbody(reqt, code) - if reqt.method == "HEAD" then return nil end - if code == 204 or code == 304 then return nil end - if code >= 100 and code < 200 then return nil end - return 1 +function listener.onconnect(conn) + local req = requests[conn]; + -- Send the request + local request_line = { req.method or "GET", " ", req.path, " HTTP/1.1\r\n" }; + if req.query then + t_insert(request_line, 4, "?"..req.query); + end + + conn:write(t_concat(request_line)); + local t = { [2] = ": ", [4] = "\r\n" }; + for k, v in pairs(req.headers) do + t[1], t[3] = k, v; + conn:write(t_concat(t)); + end + conn:write("\r\n"); + + if req.body then + conn:write(req.body); + end end -local function request_reader(request, data, startpos) - if not data then - if request.body then - log("debug", "Connection closed, but we have data, calling callback..."); - request.callback(t_concat(request.body), request.code, request); - elseif request.state ~= "completed" then - -- Error.. connection was closed prematurely - request.callback("connection-closed", 0, request); - end - destroy_request(request); - request.body = nil; - request.state = "completed"; +function listener.onincoming(conn, data) + local request = requests[conn]; + + if not request then + log("warn", "Received response from connection %s with no request attached!", tostring(conn)); return; end - if request.state == "body" and request.state ~= "completed" then - print("Reading body...") - if not request.body then request.body = {}; request.havebodylength, request.bodylength = 0, tonumber(request.responseheaders["content-length"]); end - if startpos then - data = data:sub(startpos, -1) - end - t_insert(request.body, data); - if request.bodylength then - request.havebodylength = request.havebodylength + #data; - if request.havebodylength >= request.bodylength then - -- We have the body - log("debug", "Have full body, calling callback"); - if request.callback then - request.callback(t_concat(request.body), request.code, request); - end - request.body = nil; - request.state = "completed"; - else - print("", "Have "..request.havebodylength.." bytes out of "..request.bodylength); - end - end - elseif request.state == "headers" then - print("Reading headers...") - local pos = startpos; - local headers = request.responseheaders or {}; - for line in data:sub(startpos, -1):gmatch("(.-)\r\n") do - startpos = startpos + #line + 2; - local k, v = line:match("(%S+): (.+)"); - if k and v then - headers[k:lower()] = v; - print("Header: "..k:lower().." = "..v); - elseif #line == 0 then - request.responseheaders = headers; - break; - else - print("Unhandled header line: "..line); + + if data and request.reader then + request:reader(data); + end +end + +function listener.ondisconnect(conn, err) + local request = requests[conn]; + if request and request.conn then + request:reader(nil, err); + end + requests[conn] = nil; +end + +local function request_reader(request, data, err) + if not request.parser then + local function error_cb(reason) + if request.callback then + request.callback(reason or "connection-closed", 0, request); + request.callback = nil; end - end - -- Reached the end of the headers - request.state = "body"; - if #data > startpos then - return request_reader(request, data, startpos); - end - elseif request.state == "status" then - print("Reading status...") - local http, code, text, linelen = data:match("^HTTP/(%S+) (%d+) (.-)\r\n()", startpos); - code = tonumber(code); - if not code then - return request.callback("invalid-status-line", 0, request); + destroy_request(request); end - request.code, request.responseversion = code, http; + if not data then + error_cb(err); + return; + end - if request.onlystatus or not expectbody(request, code) then + local function success_cb(r) if request.callback then - request.callback(nil, code, request); + request.callback(r.body, r.code, r, request); + request.callback = nil; end destroy_request(request); - return; end - - request.state = "headers"; - - if #data > linelen then - return request_reader(request, data, linelen); + local function options_cb() + return request; end + request.parser = httpstream_new(success_cb, error_cb, "client", options_cb); end + request.parser:feed(data); end -local function handleerr(err) log("error", "Traceback[http]: %s: %s", tostring(err), debug_traceback()); end +local function handleerr(err) log("error", "Traceback[http]: %s", traceback(tostring(err), 2)); end function request(u, ex, callback) local req = url.parse(u); @@ -131,79 +114,78 @@ function request(u, ex, callback) req.path = "/"; end - local custom_headers, body; - local default_headers = { ["Host"] = req.host, ["User-Agent"] = "Prosody XMPP Server" } + local method, headers, body; + headers = { + ["Host"] = req.host; + ["User-Agent"] = "Prosody XMPP Server"; + }; if req.userinfo then - default_headers["Authorization"] = "Basic "..mime.b64(req.userinfo); + headers["Authorization"] = "Basic "..b64(req.userinfo); end - + if ex then - custom_headers = ex.headers; req.onlystatus = ex.onlystatus; body = ex.body; if body then - req.method = "POST "; - default_headers["Content-Length"] = tostring(#body); - default_headers["Content-Type"] = "application/x-www-form-urlencoded"; + method = "POST"; + headers["Content-Length"] = tostring(#body); + headers["Content-Type"] = "application/x-www-form-urlencoded"; + end + if ex.method then method = ex.method; end + if ex.headers then + for k, v in pairs(ex.headers) do + headers[k] = v; + end end - if ex.method then req.method = ex.method; end - end - - req.handler, req.conn = server.wrapclient(socket.tcp(), req.host, req.port or 80, listener, "*a"); - req.write = req.handler.write; - req.conn:settimeout(0); - local ok, err = req.conn:connect(req.host, req.port or 80); - if not ok and err ~= "timeout" then - callback(nil, 0, req); - return nil, err; end - local request_line = { req.method or "GET", " ", req.path, " HTTP/1.1\r\n" }; + -- Attach to request object + req.method, req.headers, req.body = method, headers, body; - if req.query then - t_insert(request_line, 4, "?"); - t_insert(request_line, 5, req.query); - end - - req.write(t_concat(request_line)); - local t = { [2] = ": ", [4] = "\r\n" }; - if custom_headers then - for k, v in pairs(custom_headers) do - t[1], t[3] = k, v; - req.write(t_concat(t)); - default_headers[k] = nil; - end + local using_https = req.scheme == "https"; + if using_https and not ssl_available then + error("SSL not available, unable to contact https URL"); end + local port = tonumber(req.port) or (using_https and 443 or 80); - for k, v in pairs(default_headers) do - t[1], t[3] = k, v; - req.write(t_concat(t)); - default_headers[k] = nil; + -- Connect the socket, and wrap it with net.server + local conn = socket.tcp(); + conn:settimeout(10); + local ok, err = conn:connect(req.host, port); + if not ok and err ~= "timeout" then + callback(nil, 0, req); + return nil, err; end - req.write("\r\n"); - if body then - req.write(body); + local sslctx = false; + if using_https then + sslctx = ex and ex.sslctx or { mode = "client", protocol = "sslv23", options = { "no_sslv2" } }; end + + req.handler, req.conn = server.wrapclient(conn, req.host, port, listener, "*a", sslctx); + req.write = function (...) return req.handler:write(...); end - req.callback = function (content, code, request) log("debug", "Calling callback, status %s", code or "---"); return select(2, xpcall(function () return callback(content, code, request) end, handleerr)); end + req.callback = function (content, code, request, response) log("debug", "Calling callback, status %s", code or "---"); return select(2, xpcall(function () return callback(content, code, request, response) end, handleerr)); end req.reader = request_reader; req.state = "status"; - - listener.register_request(req.handler, req); + requests[req.handler] = req; return req; end function destroy_request(request) if request.conn then - request.handler.close() - listener.disconnect(request.conn, "closed"); + request.conn = nil; + request.handler:close() end end -_M.urlencode = urlencode; +local urlencode, urldecode = util_http.urlencode, util_http.urldecode; +local formencode, formdecode = util_http.formencode, util_http.formdecode; + +_M.urlencode, _M.urldecode = urlencode, urldecode; +_M.formencode, _M.formdecode = formencode, formdecode; return _M; diff --git a/net/http/codes.lua b/net/http/codes.lua new file mode 100644 index 00000000..0cadd079 --- /dev/null +++ b/net/http/codes.lua @@ -0,0 +1,67 @@ + +local response_codes = { + -- Source: http://www.iana.org/assignments/http-status-codes + -- s/^\(\d*\)\s*\(.*\S\)\s*\[RFC.*\]\s*$/^I["\1"] = "\2"; + [100] = "Continue"; + [101] = "Switching Protocols"; + [102] = "Processing"; + + [200] = "OK"; + [201] = "Created"; + [202] = "Accepted"; + [203] = "Non-Authoritative Information"; + [204] = "No Content"; + [205] = "Reset Content"; + [206] = "Partial Content"; + [207] = "Multi-Status"; + [208] = "Already Reported"; + [226] = "IM Used"; + + [300] = "Multiple Choices"; + [301] = "Moved Permanently"; + [302] = "Found"; + [303] = "See Other"; + [304] = "Not Modified"; + [305] = "Use Proxy"; + -- The 306 status code was used in a previous version of [RFC2616], is no longer used, and the code is reserved. + [307] = "Temporary Redirect"; + + [400] = "Bad Request"; + [401] = "Unauthorized"; + [402] = "Payment Required"; + [403] = "Forbidden"; + [404] = "Not Found"; + [405] = "Method Not Allowed"; + [406] = "Not Acceptable"; + [407] = "Proxy Authentication Required"; + [408] = "Request Timeout"; + [409] = "Conflict"; + [410] = "Gone"; + [411] = "Length Required"; + [412] = "Precondition Failed"; + [413] = "Request Entity Too Large"; + [414] = "Request-URI Too Long"; + [415] = "Unsupported Media Type"; + [416] = "Requested Range Not Satisfiable"; + [417] = "Expectation Failed"; + [418] = "I'm a teapot"; + [422] = "Unprocessable Entity"; + [423] = "Locked"; + [424] = "Failed Dependency"; + -- The 425 status code is reserved for the WebDAV advanced collections expired proposal [RFC2817] + [426] = "Upgrade Required"; + + [500] = "Internal Server Error"; + [501] = "Not Implemented"; + [502] = "Bad Gateway"; + [503] = "Service Unavailable"; + [504] = "Gateway Timeout"; + [505] = "HTTP Version Not Supported"; + [506] = "Variant Also Negotiates"; -- Experimental + [507] = "Insufficient Storage"; + [508] = "Loop Detected"; + [510] = "Not Extended"; +}; + +for k,v in pairs(response_codes) do response_codes[k] = k.." "..v; end +return setmetatable(response_codes, { __index = function(t, k) return k.." Unassigned"; end }) diff --git a/net/http/parser.lua b/net/http/parser.lua new file mode 100644 index 00000000..f9e6cea0 --- /dev/null +++ b/net/http/parser.lua @@ -0,0 +1,160 @@ +local tonumber = tonumber; +local assert = assert; +local url_parse = require "socket.url".parse; +local urldecode = require "util.http".urldecode; + +local function preprocess_path(path) + path = urldecode((path:gsub("//+", "/"))); + if path:sub(1,1) ~= "/" then + path = "/"..path; + end + local level = 0; + for component in path:gmatch("([^/]+)/") do + if component == ".." then + level = level - 1; + elseif component ~= "." then + level = level + 1; + end + if level < 0 then + return nil; + end + end + return path; +end + +local httpstream = {}; + +function httpstream.new(success_cb, error_cb, parser_type, options_cb) + local client = true; + if not parser_type or parser_type == "server" then client = false; else assert(parser_type == "client", "Invalid parser type"); end + local buf = ""; + local chunked, chunk_size, chunk_start; + local state = nil; + local packet; + local len; + local have_body; + local error; + return { + feed = function(self, data) + if error then return nil, "parse has failed"; end + if not data then -- EOF + if state and client and not len then -- reading client body until EOF + packet.body = buf; + success_cb(packet); + elseif buf ~= "" then -- unexpected EOF + error = true; return error_cb(); + end + return; + end + buf = buf..data; + while #buf > 0 do + if state == nil then -- read request + local index = buf:find("\r\n\r\n", nil, true); + if not index then return; end -- not enough data + local method, path, httpversion, status_code, reason_phrase; + local first_line; + local headers = {}; + for line in buf:sub(1,index+1):gmatch("([^\r\n]+)\r\n") do -- parse request + if first_line then + local key, val = line:match("^([^%s:]+): *(.*)$"); + if not key then error = true; return error_cb("invalid-header-line"); end -- TODO handle multi-line and invalid headers + key = key:lower(); + headers[key] = headers[key] and headers[key]..","..val or val; + else + first_line = line; + if client then + httpversion, status_code, reason_phrase = line:match("^HTTP/(1%.[01]) (%d%d%d) (.*)$"); + status_code = tonumber(status_code); + if not status_code then error = true; return error_cb("invalid-status-line"); end + have_body = not + ( (options_cb and options_cb().method == "HEAD") + or (status_code == 204 or status_code == 304 or status_code == 301) + or (status_code >= 100 and status_code < 200) ); + else + method, path, httpversion = line:match("^(%w+) (%S+) HTTP/(1%.[01])$"); + if not method then error = true; return error_cb("invalid-status-line"); end + end + end + end + if not first_line then error = true; return error_cb("invalid-status-line"); end + chunked = have_body and headers["transfer-encoding"] == "chunked"; + len = tonumber(headers["content-length"]); -- TODO check for invalid len + if client then + -- FIXME handle '100 Continue' response (by skipping it) + if not have_body then len = 0; end + packet = { + code = status_code; + httpversion = httpversion; + headers = headers; + body = have_body and "" or nil; + -- COMPAT the properties below are deprecated + responseversion = httpversion; + responseheaders = headers; + }; + else + local parsed_url; + if path:byte() == 47 then -- starts with / + local _path, _query = path:match("([^?]*).?(.*)"); + if _query == "" then _query = nil; end + parsed_url = { path = _path, query = _query }; + else + parsed_url = url_parse(path); + if not(parsed_url and parsed_url.path) then error = true; return error_cb("invalid-url"); end + end + path = preprocess_path(parsed_url.path); + headers.host = parsed_url.host or headers.host; + + len = len or 0; + packet = { + method = method; + url = parsed_url; + path = path; + httpversion = httpversion; + headers = headers; + body = nil; + }; + end + buf = buf:sub(index + 4); + state = true; + end + if state then -- read body + if client then + if chunked then + if not buf:find("\r\n", nil, true) then + return; + end -- not enough data + if not chunk_size then + chunk_size, chunk_start = buf:match("^(%x+)[^\r\n]*\r\n()"); + chunk_size = chunk_size and tonumber(chunk_size, 16); + if not chunk_size then error = true; return error_cb("invalid-chunk-size"); end + end + if chunk_size == 0 and buf:find("\r\n\r\n", chunk_start-2, true) then + state, chunk_size = nil, nil; + buf = buf:gsub("^.-\r\n\r\n", ""); -- This ensure extensions and trailers are stripped + success_cb(packet); + elseif #buf - chunk_start + 2 >= chunk_size then -- we have a chunk + packet.body = packet.body..buf:sub(chunk_start, chunk_start + (chunk_size-1)); + buf = buf:sub(chunk_start + chunk_size + 2); + chunk_size, chunk_start = nil, nil; + else -- Partial chunk remaining + break; + end + elseif len and #buf >= len then + packet.body, buf = buf:sub(1, len), buf:sub(len + 1); + state = nil; success_cb(packet); + else + break; + end + elseif #buf >= len then + packet.body, buf = buf:sub(1, len), buf:sub(len + 1); + state = nil; success_cb(packet); + else + break; + end + end + end + end; + }; +end + +return httpstream; diff --git a/net/http/server.lua b/net/http/server.lua new file mode 100644 index 00000000..dec7da19 --- /dev/null +++ b/net/http/server.lua @@ -0,0 +1,303 @@ + +local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat; +local parser_new = require "net.http.parser".new; +local events = require "util.events".new(); +local addserver = require "net.server".addserver; +local log = require "util.logger".init("http.server"); +local os_date = os.date; +local pairs = pairs; +local s_upper = string.upper; +local setmetatable = setmetatable; +local xpcall = xpcall; +local traceback = debug.traceback; +local tostring = tostring; +local codes = require "net.http.codes"; + +local _M = {}; + +local sessions = {}; +local listener = {}; +local hosts = {}; +local default_host; + +local function is_wildcard_event(event) + return event:sub(-2, -1) == "/*"; +end +local function is_wildcard_match(wildcard_event, event) + return wildcard_event:sub(1, -2) == event:sub(1, #wildcard_event-1); +end + +local recent_wildcard_events, max_cached_wildcard_events = {}, 10000; + +local event_map = events._event_map; +setmetatable(events._handlers, { + -- Called when firing an event that doesn't exist (but may match a wildcard handler) + __index = function (handlers, curr_event) + if is_wildcard_event(curr_event) then return; end -- Wildcard events cannot be fired + -- Find all handlers that could match this event, sort them + -- and then put the array into handlers[curr_event] (and return it) + local matching_handlers_set = {}; + local handlers_array = {}; + for event, handlers_set in pairs(event_map) do + if event == curr_event or + is_wildcard_event(event) and is_wildcard_match(event, curr_event) then + for handler, priority in pairs(handlers_set) do + matching_handlers_set[handler] = { (select(2, event:gsub("/", "%1"))), is_wildcard_event(event) and 0 or 1, priority }; + table.insert(handlers_array, handler); + end + end + end + if #handlers_array > 0 then + table.sort(handlers_array, function(b, a) + local a_score, b_score = matching_handlers_set[a], matching_handlers_set[b]; + for i = 1, #a_score do + if a_score[i] ~= b_score[i] then -- If equal, compare next score value + return a_score[i] < b_score[i]; + end + end + return false; + end); + else + handlers_array = false; + end + rawset(handlers, curr_event, handlers_array); + if not event_map[curr_event] then -- Only wildcard handlers match, if any + table.insert(recent_wildcard_events, curr_event); + if #recent_wildcard_events > max_cached_wildcard_events then + rawset(handlers, table.remove(recent_wildcard_events, 1), nil); + end + end + return handlers_array; + end; + __newindex = function (handlers, curr_event, handlers_array) + if handlers_array == nil + and is_wildcard_event(curr_event) then + -- Invalidate the indexes of all matching events + for event in pairs(handlers) do + if is_wildcard_match(curr_event, event) then + handlers[event] = nil; + end + end + end + rawset(handlers, curr_event, handlers_array); + end; +}); + +local handle_request; +local _1, _2, _3; +local function _handle_request() return handle_request(_1, _2, _3); end + +local last_err; +local function _traceback_handler(err) last_err = err; log("error", "Traceback[httpserver]: %s", traceback(tostring(err), 2)); end +events.add_handler("http-error", function (error) + return "Error processing request: "..codes[error.code]..". Check your error log for more information."; +end, -1); + +function listener.onconnect(conn) + local secure = conn:ssl() and true or nil; + local pending = {}; + local waiting = false; + local function process_next() + if waiting then log("debug", "can't process_next, waiting"); return; end + waiting = true; + while sessions[conn] and #pending > 0 do + local request = t_remove(pending); + --log("debug", "process_next: %s", request.path); + --handle_request(conn, request, process_next); + _1, _2, _3 = conn, request, process_next; + if not xpcall(_handle_request, _traceback_handler) then + conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = last_err })); + conn:close(); + end + end + --log("debug", "ready for more"); + waiting = false; + end + local function success_cb(request) + --log("debug", "success_cb: %s", request.path); + if waiting then + log("error", "http connection handler is not reentrant: %s", request.path); + assert(false, "http connection handler is not reentrant"); + end + request.secure = secure; + t_insert(pending, request); + process_next(); + end + local function error_cb(err) + log("debug", "error_cb: %s", err or "<nil>"); + -- FIXME don't close immediately, wait until we process current stuff + -- FIXME if err, send off a bad-request response + sessions[conn] = nil; + conn:close(); + end + sessions[conn] = parser_new(success_cb, error_cb); +end + +function listener.ondisconnect(conn) + local open_response = conn._http_open_response; + if open_response and open_response.on_destroy then + open_response.finished = true; + open_response:on_destroy(); + end + sessions[conn] = nil; +end + +function listener.onincoming(conn, data) + sessions[conn]:feed(data); +end + +local headerfix = setmetatable({}, { + __index = function(t, k) + local v = "\r\n"..k:gsub("_", "-"):gsub("%f[%w].", s_upper)..": "; + t[k] = v; + return v; + end +}); + +function _M.hijack_response(response, listener) + error("TODO"); +end +function handle_request(conn, request, finish_cb) + --log("debug", "handler: %s", request.path); + local headers = {}; + for k,v in pairs(request.headers) do headers[k:gsub("-", "_")] = v; end + request.headers = headers; + request.conn = conn; + + local date_header = os_date('!%a, %d %b %Y %H:%M:%S GMT'); -- FIXME use + local conn_header = request.headers.connection; + conn_header = conn_header and ","..conn_header:gsub("[ \t]", ""):lower().."," or "" + local httpversion = request.httpversion + local persistent = conn_header:find(",Keep-Alive,", 1, true) + or (httpversion == "1.1" and not conn_header:find(",close,", 1, true)); + + local response_conn_header; + if persistent then + response_conn_header = "Keep-Alive"; + else + response_conn_header = httpversion == "1.1" and "close" or nil + end + + local response = { + request = request; + status_code = 200; + headers = { date = date_header, connection = response_conn_header }; + persistent = persistent; + conn = conn; + send = _M.send_response; + finish_cb = finish_cb; + }; + conn._http_open_response = response; + + local host = (request.headers.host or ""):match("[^:]+"); + + -- Some sanity checking + local err_code, err; + if not request.path then + err_code, err = 400, "Invalid path"; + elseif not hosts[host] then + if hosts[default_host] then + host = default_host; + elseif host then + err_code, err = 404, "Unknown host: "..host; + else + err_code, err = 400, "Missing or invalid 'Host' header"; + end + end + + if err then + response.status_code = err_code; + response:send(events.fire_event("http-error", { code = err_code, message = err })); + return; + end + + local event = request.method.." "..host..request.path:match("[^?]*"); + local payload = { request = request, response = response }; + --log("debug", "Firing event: %s", event); + local result = events.fire_event(event, payload); + if result ~= nil then + if result ~= true then + local body; + local result_type = type(result); + if result_type == "number" then + response.status_code = result; + if result >= 400 then + body = events.fire_event("http-error", { code = result }); + end + elseif result_type == "string" then + body = result; + elseif result_type == "table" then + for k, v in pairs(result) do + if k ~= "headers" then + response[k] = v; + else + for header_name, header_value in pairs(v) do + response.headers[header_name] = header_value; + end + end + end + end + response:send(body); + end + return; + end + + -- if handler not called, return 404 + response.status_code = 404; + response:send(events.fire_event("http-error", { code = 404 })); +end +function _M.send_response(response, body) + if response.finished then return; end + response.finished = true; + response.conn._http_open_response = nil; + + local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]); + local headers = response.headers; + body = body or response.body or ""; + headers.content_length = #body; + + local output = { status_line }; + for k,v in pairs(headers) do + t_insert(output, headerfix[k]..v); + end + t_insert(output, "\r\n\r\n"); + t_insert(output, body); + + response.conn:write(t_concat(output)); + if response.on_destroy then + response:on_destroy(); + response.on_destroy = nil; + end + if response.persistent then + response:finish_cb(); + else + response.conn:close(); + end +end +function _M.add_handler(event, handler, priority) + events.add_handler(event, handler, priority); +end +function _M.remove_handler(event, handler) + events.remove_handler(event, handler); +end + +function _M.listen_on(port, interface, ssl) + addserver(interface or "*", port, listener, "*a", ssl); +end +function _M.add_host(host) + hosts[host] = true; +end +function _M.remove_host(host) + hosts[host] = nil; +end +function _M.set_default_host(host) + default_host = host; +end +function _M.fire_event(event, ...) + return events.fire_event(event, ...); +end + +_M.listener = listener; +_M.codes = codes; +_M._events = events; +return _M; diff --git a/net/httpclient_listener.lua b/net/httpclient_listener.lua deleted file mode 100644 index 69b7946b..00000000 --- a/net/httpclient_listener.lua +++ /dev/null @@ -1,44 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2009 Matthew Wild --- Copyright (C) 2008-2009 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - -local log = require "util.logger".init("httpclient_listener"); - -local connlisteners_register = require "net.connlisteners".register; - -local requests = {}; -- Open requests -local buffers = {}; -- Buffers of partial lines - -local httpclient = { default_port = 80, default_mode = "*a" }; - -function httpclient.listener(conn, data) - local request = requests[conn]; - - if not request then - log("warn", "Received response from connection %s with no request attached!", tostring(conn)); - return; - end - - if data and request.reader then - request:reader(data); - end -end - -function httpclient.disconnect(conn, err) - local request = requests[conn]; - if request then - request:reader(nil); - end - requests[conn] = nil; -end - -function httpclient.register_request(conn, req) - log("debug", "Attaching request %s to connection %s", tostring(req.id or req), tostring(conn)); - requests[conn] = req; -end - -connlisteners_register("httpclient", httpclient); diff --git a/net/httpserver.lua b/net/httpserver.lua index 57c8eede..7d574788 100644 --- a/net/httpserver.lua +++ b/net/httpserver.lua @@ -1,278 +1,15 @@ --- Prosody IM --- Copyright (C) 2008-2009 Matthew Wild --- Copyright (C) 2008-2009 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - -local socket = require "socket" -local server = require "net.server" -local url_parse = require "socket.url".parse; - -local connlisteners_start = require "net.connlisteners".start; -local connlisteners_get = require "net.connlisteners".get; -local listener; - -local t_insert, t_concat = table.insert, table.concat; -local s_match, s_gmatch = string.match, string.gmatch; -local tonumber, tostring, pairs, ipairs, type = tonumber, tostring, pairs, ipairs, type; - -local urlencode = function (s) return s and (s:gsub("%W", function (c) return string.format("%%%02x", c:byte()); end)); end - -local log = require "util.logger".init("httpserver"); - -local http_servers = {}; +-- COMPAT w/pre-0.9 +local log = require "util.logger".init("net.httpserver"); +local traceback = debug.traceback; module "httpserver" -local default_handler; - -local function expectbody(reqt) - return reqt.method == "POST"; -end - -local function send_response(request, response) - -- Write status line - local resp; - if response.body then - local body = tostring(response.body); - log("debug", "Sending response to %s", request.id); - resp = { "HTTP/1.0 ", response.status or "200 OK", "\r\n"}; - local h = response.headers; - if h then - for k, v in pairs(h) do - t_insert(resp, k); - t_insert(resp, ": "); - t_insert(resp, v); - t_insert(resp, "\r\n"); - end - end - if not (h and h["Content-Length"]) then - t_insert(resp, "Content-Length: "); - t_insert(resp, #body); - t_insert(resp, "\r\n"); - end - t_insert(resp, "\r\n"); - - if request.method ~= "HEAD" then - t_insert(resp, body); - end - else - -- Response we have is just a string (the body) - log("debug", "Sending response to %s: %s", request.id or "<none>", response or "<none>"); - - resp = { "HTTP/1.0 200 OK\r\n" }; - t_insert(resp, "Connection: close\r\n"); - t_insert(resp, "Content-Length: "); - t_insert(resp, #response); - t_insert(resp, "\r\n\r\n"); - - t_insert(resp, response); - end - request.write(t_concat(resp)); - if not request.stayopen then - request:destroy(); - end -end - -local function call_callback(request, err) - if request.handled then return; end - request.handled = true; - local callback = request.callback; - if not callback and request.path then - local path = request.url.path; - local base = path:match("^/([^/?]+)"); - if not base then - base = path:match("^http://[^/?]+/([^/?]+)"); - end - - callback = (request.server and request.server.handlers[base]) or default_handler; - if callback == default_handler then - log("debug", "Default callback for this request (base: "..tostring(base)..")") - end - end - if callback then - if err then - log("debug", "Request error: "..err); - if not callback(nil, err, request) then - destroy_request(request); - end - return; - end - - local response = callback(request.method, request.body and t_concat(request.body), request); - if response then - if response == true and not request.destroyed then - -- Keep connection open, we will reply later - log("debug", "Request %s left open, on_destroy is %s", request.id, tostring(request.on_destroy)); - elseif response ~= true then - -- Assume response - send_response(request, response); - destroy_request(request); - end - else - log("debug", "Request handler provided no response, destroying request..."); - -- No response, close connection - destroy_request(request); - end - end -end - -local function request_reader(request, data, startpos) - if not data then - if request.body then - call_callback(request); - else - -- Error.. connection was closed prematurely - call_callback(request, "connection-closed"); - end - -- Here we force a destroy... the connection is gone, so we can't reply later - destroy_request(request); - return; - end - if request.state == "body" then - log("debug", "Reading body...") - if not request.body then request.body = {}; request.havebodylength, request.bodylength = 0, tonumber(request.headers["content-length"]); end - if startpos then - data = data:sub(startpos, -1) - end - t_insert(request.body, data); - if request.bodylength then - request.havebodylength = request.havebodylength + #data; - if request.havebodylength >= request.bodylength then - -- We have the body - call_callback(request); - end - end - elseif request.state == "headers" then - log("debug", "Reading headers...") - local pos = startpos; - local headers = request.headers or {}; - for line in data:gmatch("(.-)\r\n") do - startpos = (startpos or 1) + #line + 2; - local k, v = line:match("(%S+): (.+)"); - if k and v then - headers[k:lower()] = v; --- log("debug", "Header: "..k:lower().." = "..v); - elseif #line == 0 then - request.headers = headers; - break; - else - log("debug", "Unhandled header line: "..line); - end - end - - if not expectbody(request) then - call_callback(request); - return; - end - - -- Reached the end of the headers - request.state = "body"; - if #data > startpos then - return request_reader(request, data:sub(startpos, -1)); - end - elseif request.state == "request" then - log("debug", "Reading request line...") - local method, path, http, linelen = data:match("^(%S+) (%S+) HTTP/(%S+)\r\n()", startpos); - if not method then - return call_callback(request, "invalid-status-line"); - end - - request.method, request.path, request.httpversion = method, path, http; - - request.url = url_parse(request.path); - - log("debug", method.." request for "..tostring(request.path) .. " on port "..request.handler.serverport()); - - if request.onlystatus then - if not call_callback(request) then - return; - end - end - - request.state = "headers"; - - if #data > linelen then - return request_reader(request, data:sub(linelen, -1)); - end - end -end - --- The default handler for requests -default_handler = function (method, body, request) - log("debug", method.." request for "..tostring(request.path) .. " on port "..request.handler.serverport()); - return { status = "404 Not Found", - headers = { ["Content-Type"] = "text/html" }, - body = "<html><head><title>Page Not Found</title></head><body>Not here :(</body></html>" }; -end - - -function new_request(handler) - return { handler = handler, conn = handler.socket, - write = handler.write, state = "request", - server = http_servers[handler.serverport()], - send = send_response, - destroy = destroy_request, - id = tostring{}:match("%x+$") - }; -end - -function destroy_request(request) - log("debug", "Destroying request %s", request.id); - listener = listener or connlisteners_get("httpserver"); - if not request.destroyed then - request.destroyed = true; - if request.on_destroy then - log("debug", "Request has destroy callback"); - request.on_destroy(request); - else - log("debug", "Request has no destroy callback"); - end - request.handler.close() - if request.conn then - listener.disconnect(request.conn, "closed"); - end - end -end - -function new(params) - local http_server = http_servers[params.port]; - if not http_server then - http_server = { handlers = {} }; - http_servers[params.port] = http_server; - -- We weren't already listening on this port, so start now - connlisteners_start("httpserver", params); - end - if params.base then - http_server.handlers[params.base] = params.handler; - end -end - -function new_from_config(ports, default_base, handle_request) - for _, options in ipairs(ports) do - local port, base, ssl, interface = 5280, default_base, false, nil; - if type(options) == "number" then - port = options; - elseif type(options) == "table" then - port, base, ssl, interface = options.port or 5280, options.path or default_base, options.ssl or false, options.interface; - elseif type(options) == "string" then - base = options; - end - - if ssl then - ssl.mode = "server"; - ssl.protocol = "sslv23"; - end - - new{ port = port, base = base, handler = handle_request, ssl = ssl, type = (ssl and "ssl") or "tcp" } - end +function fail() + log("error", "Attempt to use legacy HTTP API. For more info see http://prosody.im/doc/developers/legacy_http"); + log("error", "Legacy HTTP API usage, %s", traceback("", 2)); end -_M.request_reader = request_reader; -_M.send_response = send_response; -_M.urlencode = urlencode; +new, new_from_config = fail, fail; +set_default_handler = fail; return _M; diff --git a/net/httpserver_listener.lua b/net/httpserver_listener.lua deleted file mode 100644 index 455191fb..00000000 --- a/net/httpserver_listener.lua +++ /dev/null @@ -1,46 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2009 Matthew Wild --- Copyright (C) 2008-2009 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - - -local connlisteners_register = require "net.connlisteners".register; -local new_request = require "net.httpserver".new_request; -local request_reader = require "net.httpserver".request_reader; - -local requests = {}; -- Open requests - -local httpserver = { default_port = 80, default_mode = "*a" }; - -function httpserver.listener(conn, data) - local request = requests[conn]; - - if not request then - request = new_request(conn); - requests[conn] = request; - - -- If using HTTPS, request is secure - if conn.ssl() then - request.secure = true; - end - end - - if data then - request_reader(request, data); - end -end - -function httpserver.disconnect(conn, err) - local request = requests[conn]; - if request and not request.destroyed then - request.conn = nil; - request_reader(request, nil); - end - requests[conn] = nil; -end - -connlisteners_register("httpserver", httpserver); diff --git a/net/server.lua b/net/server.lua index 966006c1..375e7081 100644 --- a/net/server.lua +++ b/net/server.lua @@ -1,893 +1,84 @@ ---
--- server.lua by blastbeat of the luadch project
--- Re-used here under the MIT/X Consortium License
---
--- Modifications (C) 2008-2009 Matthew Wild, Waqas Hussain
---
-
--- // wrapping luadch stuff // --
-
-local use = function( what )
- return _G[ what ]
-end
-local clean = function( tbl )
- for i, k in pairs( tbl ) do
- tbl[ i ] = nil
- end
-end
-
-local log, table_concat = require ("util.logger").init("socket"), table.concat;
-local out_put = function (...) return log("debug", table_concat{...}); end
-local out_error = function (...) return log("warn", table_concat{...}); end
-local mem_free = collectgarbage
-
-----------------------------------// DECLARATION //--
-
---// constants //--
-
-local STAT_UNIT = 1 -- byte
-
---// lua functions //--
-
-local type = use "type"
-local pairs = use "pairs"
-local ipairs = use "ipairs"
-local tostring = use "tostring"
-local collectgarbage = use "collectgarbage"
-
---// lua libs //--
-
-local os = use "os"
-local table = use "table"
-local string = use "string"
-local coroutine = use "coroutine"
-
---// lua lib methods //--
-
-local os_time = os.time
-local os_difftime = os.difftime
-local table_concat = table.concat
-local table_remove = table.remove
-local string_len = string.len
-local string_sub = string.sub
-local coroutine_wrap = coroutine.wrap
-local coroutine_yield = coroutine.yield
-
---// extern libs //--
-
-local luasec = select( 2, pcall( require, "ssl" ) )
-local luasocket = require "socket"
-
---// extern lib methods //--
-
-local ssl_wrap = ( luasec and luasec.wrap )
-local socket_bind = luasocket.bind
-local socket_sleep = luasocket.sleep
-local socket_select = luasocket.select
-local ssl_newcontext = ( luasec and luasec.newcontext )
-
---// functions //--
-
-local id
-local loop
-local stats
-local idfalse
-local addtimer
-local closeall
-local addserver
-local getserver
-local wrapserver
-local getsettings
-local closesocket
-local removesocket
-local removeserver
-local changetimeout
-local wrapconnection
-local changesettings
-
---// tables //--
-
-local _server
-local _readlist
-local _timerlist
-local _sendlist
-local _socketlist
-local _closelist
-local _readtimes
-local _writetimes
-
---// simple data types //--
-
-local _
-local _readlistlen
-local _sendlistlen
-local _timerlistlen
-
-local _sendtraffic
-local _readtraffic
-
-local _selecttimeout
-local _sleeptime
-
-local _starttime
-local _currenttime
-
-local _maxsendlen
-local _maxreadlen
-
-local _checkinterval
-local _sendtimeout
-local _readtimeout
-
-local _cleanqueue
-
-local _timer
-
-local _maxclientsperserver
-
-----------------------------------// DEFINITION //--
-
-_server = { } -- key = port, value = table; list of listening servers
-_readlist = { } -- array with sockets to read from
-_sendlist = { } -- arrary with sockets to write to
-_timerlist = { } -- array of timer functions
-_socketlist = { } -- key = socket, value = wrapped socket (handlers)
-_readtimes = { } -- key = handler, value = timestamp of last data reading
-_writetimes = { } -- key = handler, value = timestamp of last data writing/sending
-_closelist = { } -- handlers to close
-
-_readlistlen = 0 -- length of readlist
-_sendlistlen = 0 -- length of sendlist
-_timerlistlen = 0 -- lenght of timerlist
-
-_sendtraffic = 0 -- some stats
-_readtraffic = 0
-
-_selecttimeout = 1 -- timeout of socket.select
-_sleeptime = 0 -- time to wait at the end of every loop
-
-_maxsendlen = 51000 * 1024 -- max len of send buffer
-_maxreadlen = 25000 * 1024 -- max len of read buffer
-
-_checkinterval = 1200000 -- interval in secs to check idle clients
-_sendtimeout = 60000 -- allowed send idle time in secs
-_readtimeout = 6 * 60 * 60 -- allowed read idle time in secs
-
-_cleanqueue = false -- clean bufferqueue after using
-
-_maxclientsperserver = 1000
-
-----------------------------------// PRIVATE //--
-
-wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxconnections, startssl ) -- this function wraps a server
-
- maxconnections = maxconnections or _maxclientsperserver
-
- local connections = 0
-
- local dispatch, disconnect = listeners.incoming or listeners.listener, listeners.disconnect
-
- local err
-
- local ssl = false
-
- if sslctx then
- ssl = true
- if not ssl_newcontext then
- out_error "luasec not found"
- ssl = false
- end
- if type( sslctx ) ~= "table" then
- out_error "server.lua: wrong server sslctx"
- ssl = false
- end
- local ctx;
- ctx, err = ssl_newcontext( sslctx )
- if not ctx then
- err = err or "wrong sslctx parameters"
- local file;
- file = err:match("^error loading (.-) %(");
- if file then
- if file == "private key" then
- file = sslctx.key or "your private key";
- elseif file == "certificate" then
- file = sslctx.certificate or "your certificate file";
- end
- local reason = err:match("%((.+)%)$") or "some reason";
- if reason == "Permission denied" then
- reason = "Check that the permissions allow Prosody to read this file.";
- elseif reason == "No such file or directory" then
- reason = "Check that the path is correct, and the file exists.";
- elseif reason == "system lib" then
- reason = "Previous error (see logs), or other system error.";
- else
- reason = "Reason: "..tostring(reason or "unknown"):lower();
- end
- log("error", "SSL/TLS: Failed to load %s: %s", file, reason);
- else
- log("error", "SSL/TLS: Error initialising for port %d: %s", serverport, err );
- end
- ssl = false
- end
- sslctx = ctx;
- end
- if not ssl then
- sslctx = false;
- if startssl then
- log("error", "Failed to listen on port %d due to SSL/TLS to SSL/TLS initialisation errors (see logs)", serverport )
- return nil, "Cannot start ssl, see log for details"
- end
- end
-
- local accept = socket.accept
-
- --// public methods of the object //--
-
- local handler = { }
-
- handler.shutdown = function( ) end
-
- handler.ssl = function( )
- return ssl
- end
- handler.remove = function( )
- connections = connections - 1
- end
- handler.close = function( )
- for _, handler in pairs( _socketlist ) do
- if handler.serverport == serverport then
- handler.disconnect( handler, "server closed" )
- handler.close( true )
- end
- end
- socket:close( )
- _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
- _readlistlen = removesocket( _readlist, socket, _readlistlen )
- _socketlist[ socket ] = nil
- handler = nil
- socket = nil
- mem_free( )
- out_put "server.lua: closed server handler and removed sockets from list"
- end
- handler.ip = function( )
- return ip
- end
- handler.serverport = function( )
- return serverport
- end
- handler.socket = function( )
- return socket
- end
- handler.readbuffer = function( )
- if connections > maxconnections then
- out_put( "server.lua: refused new client connection: server full" )
- return false
- end
- local client, err = accept( socket ) -- try to accept
- if client then
- local ip, clientport = client:getpeername( )
- client:settimeout( 0 )
- local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx, startssl ) -- wrap new client socket
- if err then -- error while wrapping ssl socket
- return false
- end
- connections = connections + 1
- out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))
- return dispatch( handler )
- elseif err then -- maybe timeout or something else
- out_put( "server.lua: error with new client connection: ", tostring(err) )
- return false
- end
- end
- return handler
-end
-
-wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, startssl ) -- this function wraps a client to a handler object
-
- socket:settimeout( 0 )
-
- --// local import of socket methods //--
-
- local send
- local receive
- local shutdown
-
- --// private closures of the object //--
-
- local ssl
-
- local dispatch = listeners.incoming or listeners.listener
- local disconnect = listeners.disconnect
-
- local bufferqueue = { } -- buffer array
- local bufferqueuelen = 0 -- end of buffer array
-
- local toclose
- local fatalerror
- local needtls
-
- local bufferlen = 0
-
- local noread = false
- local nosend = false
-
- local sendtraffic, readtraffic = 0, 0
-
- local maxsendlen = _maxsendlen
- local maxreadlen = _maxreadlen
-
- --// public methods of the object //--
-
- local handler = bufferqueue -- saves a table ^_^
-
- handler.dispatch = function( )
- return dispatch
- end
- handler.disconnect = function( )
- return disconnect
- end
- handler.setlistener = function( listeners )
- dispatch = listeners.incoming
- disconnect = listeners.disconnect
- end
- handler.getstats = function( )
- return readtraffic, sendtraffic
- end
- handler.ssl = function( )
- return ssl
- end
- handler.send = function( _, data, i, j )
- return send( socket, data, i, j )
- end
- handler.receive = function( pattern, prefix )
- return receive( socket, pattern, prefix )
- end
- handler.shutdown = function( pattern )
- return shutdown( socket, pattern )
- end
- handler.close = function( forced )
- if not handler then return true; end
- _readlistlen = removesocket( _readlist, socket, _readlistlen )
- _readtimes[ handler ] = nil
- if bufferqueuelen ~= 0 then
- if not ( forced or fatalerror ) then
- handler.sendbuffer( )
- if bufferqueuelen ~= 0 then -- try again...
- if handler then
- handler.write = nil -- ... but no further writing allowed
- end
- toclose = true
- return false
- end
- else
- send( socket, table_concat( bufferqueue, "", 1, bufferqueuelen ), 1, bufferlen ) -- forced send
- end
- end
- _ = shutdown and shutdown( socket )
- socket:close( )
- _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
- _socketlist[ socket ] = nil
- if handler then
- _writetimes[ handler ] = nil
- _closelist[ handler ] = nil
- handler = nil
- end
- socket = nil
- mem_free( )
- if server then
- server.remove( )
- end
- out_put "server.lua: closed client handler and removed socket from list"
- return true
- end
- handler.ip = function( )
- return ip
- end
- handler.serverport = function( )
- return serverport
- end
- handler.clientport = function( )
- return clientport
- end
- local write = function( data )
- bufferlen = bufferlen + string_len( data )
- if bufferlen > maxsendlen then
- _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle
- handler.write = idfalse -- dont write anymore
- return false
- elseif socket and not _sendlist[ socket ] then
- _sendlistlen = _sendlistlen + 1
- _sendlist[ _sendlistlen ] = socket
- _sendlist[ socket ] = _sendlistlen
- end
- bufferqueuelen = bufferqueuelen + 1
- bufferqueue[ bufferqueuelen ] = data
- if handler then
- _writetimes[ handler ] = _writetimes[ handler ] or _currenttime
- end
- return true
- end
- handler.write = write
- handler.bufferqueue = function( )
- return bufferqueue
- end
- handler.socket = function( )
- return socket
- end
- handler.pattern = function( new )
- pattern = new or pattern
- return pattern
- end
- handler.setsend = function ( newsend )
- send = newsend or send
- return send
- end
- handler.bufferlen = function( readlen, sendlen )
- maxsendlen = sendlen or maxsendlen
- maxreadlen = readlen or maxreadlen
- return maxreadlen, maxsendlen
- end
- handler.lock = function( switch )
- if switch == true then
- handler.write = idfalse
- local tmp = _sendlistlen
- _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
- _writetimes[ handler ] = nil
- if _sendlistlen ~= tmp then
- nosend = true
- end
- tmp = _readlistlen
- _readlistlen = removesocket( _readlist, socket, _readlistlen )
- _readtimes[ handler ] = nil
- if _readlistlen ~= tmp then
- noread = true
- end
- elseif switch == false then
- handler.write = write
- if noread then
- noread = false
- _readlistlen = _readlistlen + 1
- _readlist[ socket ] = _readlistlen
- _readlist[ _readlistlen ] = socket
- _readtimes[ handler ] = _currenttime
- end
- if nosend then
- nosend = false
- write( "" )
- end
- end
- return noread, nosend
- end
- local _readbuffer = function( ) -- this function reads data
- local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern"
- if not err or ( err == "timeout" or err == "wantread" ) then -- received something
- local buffer = buffer or part or ""
- local len = string_len( buffer )
- if len > maxreadlen then
- disconnect( handler, "receive buffer exceeded" )
- handler.close( true )
- return false
- end
- local count = len * STAT_UNIT
- readtraffic = readtraffic + count
- _readtraffic = _readtraffic + count
- _readtimes[ handler ] = _currenttime
- --out_put( "server.lua: read data '", buffer, "', error: ", err )
- return dispatch( handler, buffer, err )
- else -- connections was closed or fatal error
- out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " error: ", tostring(err) )
- fatalerror = true
- disconnect( handler, err )
- _ = handler and handler.close( )
- return false
- end
- end
- local _sendbuffer = function( ) -- this function sends data
- local buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )
- local succ, err, byte = send( socket, buffer, 1, bufferlen )
- local count = ( succ or byte or 0 ) * STAT_UNIT
- sendtraffic = sendtraffic + count
- _sendtraffic = _sendtraffic + count
- _ = _cleanqueue and clean( bufferqueue )
- --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) )
- if succ then -- sending succesful
- bufferqueuelen = 0
- bufferlen = 0
- _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist
- _ = needtls and handler.starttls(true)
- _writetimes[ handler ] = nil
- _ = toclose and handler.close( )
- return true
- elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write
- buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer
- bufferqueue[ 1 ] = buffer -- insert new buffer in queue
- bufferqueuelen = 1
- bufferlen = bufferlen - byte
- _writetimes[ handler ] = _currenttime
- return true
- else -- connection was closed during sending or fatal error
- out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " error: ", tostring(err) )
- fatalerror = true
- disconnect( handler, err )
- _ = handler and handler.close( )
- return false
- end
- end
-
- if sslctx then -- ssl?
- ssl = true
- local wrote
- local read
- local handshake = coroutine_wrap( function( client ) -- create handshake coroutine
- local err
- for i = 1, 10 do -- 10 handshake attemps
- _sendlistlen = ( wrote and removesocket( _sendlist, socket, _sendlistlen ) ) or _sendlistlen
- _readlistlen = ( read and removesocket( _readlist, socket, _readlistlen ) ) or _readlistlen
- read, wrote = nil, nil
- _, err = client:dohandshake( )
- if not err then
- out_put( "server.lua: ssl handshake done" )
- handler.readbuffer = _readbuffer -- when handshake is done, replace the handshake function with regular functions
- handler.sendbuffer = _sendbuffer
- -- return dispatch( handler )
- return true
- else
- out_put( "server.lua: error during ssl handshake: ", tostring(err) )
- if err == "wantwrite" and not wrote then
- _sendlistlen = _sendlistlen + 1
- _sendlist[ _sendlistlen ] = client
- wrote = true
- elseif err == "wantread" and not read then
- _readlistlen = _readlistlen + 1
- _readlist [ _readlistlen ] = client
- read = true
- else
- break;
- end
- --coroutine_yield( handler, nil, err ) -- handshake not finished
- coroutine_yield( )
- end
- end
- disconnect( handler, "ssl handshake failed" )
- _ = handler and handler.close( true ) -- forced disconnect
- return false -- handshake failed
- end
- )
- if startssl then -- ssl now?
- --out_put("server.lua: ", "starting ssl handshake")
- local err
- socket, err = ssl_wrap( socket, sslctx ) -- wrap socket
- if err then
- out_put( "server.lua: ssl error: ", tostring(err) )
- mem_free( )
- return nil, nil, err -- fatal error
- end
- socket:settimeout( 0 )
- handler.readbuffer = handshake
- handler.sendbuffer = handshake
- handshake( socket ) -- do handshake
- if not socket then
- return nil, nil, "ssl handshake failed";
- end
- else
- -- We're not automatically doing SSL, so we're not secure (yet)
- ssl = false
- handler.starttls = function( now )
- if not now then
- --out_put "server.lua: we need to do tls, but delaying until later"
- needtls = true
- return
- end
- --out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
- local oldsocket, err = socket
- socket, err = ssl_wrap( socket, sslctx ) -- wrap socket
- --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) )
- if err then
- out_put( "server.lua: error while starting tls on client: ", tostring(err) )
- return nil, err -- fatal error
- end
-
- socket:settimeout( 0 )
-
- -- add the new socket to our system
-
- send = socket.send
- receive = socket.receive
- shutdown = id
-
- _socketlist[ socket ] = handler
- _readlistlen = _readlistlen + 1
- _readlist[ _readlistlen ] = socket
- _readlist[ socket ] = _readlistlen
-
- -- remove traces of the old socket
-
- _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
- _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
- _socketlist[ oldsocket ] = nil
-
- handler.starttls = nil
- needtls = nil
-
- -- Secure now
- ssl = true
-
- handler.readbuffer = handshake
- handler.sendbuffer = handshake
- handshake( socket ) -- do handshake
- end
- handler.readbuffer = _readbuffer
- handler.sendbuffer = _sendbuffer
- end
- else -- normal connection
- ssl = false
- handler.readbuffer = _readbuffer
- handler.sendbuffer = _sendbuffer
- end
-
- send = socket.send
- receive = socket.receive
- shutdown = ( ssl and id ) or socket.shutdown
-
- _socketlist[ socket ] = handler
- _readlistlen = _readlistlen + 1
- _readlist[ _readlistlen ] = socket
- _readlist[ socket ] = _readlistlen
-
- return handler, socket
-end
-
-id = function( )
-end
-
-idfalse = function( )
- return false
-end
-
-removesocket = function( list, socket, len ) -- this function removes sockets from a list ( copied from copas )
- local pos = list[ socket ]
- if pos then
- list[ socket ] = nil
- local last = list[ len ]
- list[ len ] = nil
- if last ~= socket then
- list[ last ] = pos
- list[ pos ] = last
- end
- return len - 1
- end
- return len
-end
-
-closesocket = function( socket )
- _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
- _readlistlen = removesocket( _readlist, socket, _readlistlen )
- _socketlist[ socket ] = nil
- socket:close( )
- mem_free( )
-end
-
-----------------------------------// PUBLIC //--
-
-addserver = function( listeners, port, addr, pattern, sslctx, maxconnections, startssl ) -- this function provides a way for other scripts to reg a server
- local err
- --out_put("server.lua: autossl on ", port, " is ", startssl)
- if type( listeners ) ~= "table" then
- err = "invalid listener table"
- end
- if not type( port ) == "number" or not ( port >= 0 and port <= 65535 ) then
- err = "invalid port"
- elseif _server[ port ] then
- err = "listeners on port '" .. port .. "' already exist"
- elseif sslctx and not luasec then
- err = "luasec not found"
- end
- if err then
- out_error( "server.lua, port ", port, ": ", err )
- return nil, err
- end
- addr = addr or "*"
- local server, err = socket_bind( addr, port )
- if err then
- out_error( "server.lua, port ", port, ": ", err )
- return nil, err
- end
- local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, maxconnections, startssl ) -- wrap new server socket
- if not handler then
- server:close( )
- return nil, err
- end
- server:settimeout( 0 )
- _readlistlen = _readlistlen + 1
- _readlist[ _readlistlen ] = server
- _server[ port ] = handler
- _socketlist[ server ] = handler
- out_put( "server.lua: new server listener on '", addr, ":", port, "'" )
- return handler
-end
-
-getserver = function ( port )
- return _server[ port ];
-end
-
-removeserver = function( port )
- local handler = _server[ port ]
- if not handler then
- return nil, "no server found on port '" .. tostring( port ) "'"
- end
- handler.close( )
- _server[ port ] = nil
- return true
-end
-
-closeall = function( )
- for _, handler in pairs( _socketlist ) do
- handler.close( )
- _socketlist[ _ ] = nil
- end
- _readlistlen = 0
- _sendlistlen = 0
- _timerlistlen = 0
- _server = { }
- _readlist = { }
- _sendlist = { }
- _timerlist = { }
- _socketlist = { }
- mem_free( )
-end
-
-getsettings = function( )
- return _selecttimeout, _sleeptime, _maxsendlen, _maxreadlen, _checkinterval, _sendtimeout, _readtimeout, _cleanqueue, _maxclientsperserver
-end
-
-changesettings = function( new )
- if type( new ) ~= "table" then
- return nil, "invalid settings table"
- end
- _selecttimeout = tonumber( new.timeout ) or _selecttimeout
- _sleeptime = tonumber( new.sleeptime ) or _sleeptime
- _maxsendlen = tonumber( new.maxsendlen ) or _maxsendlen
- _maxreadlen = tonumber( new.maxreadlen ) or _maxreadlen
- _checkinterval = tonumber( new.checkinterval ) or _checkinterval
- _sendtimeout = tonumber( new.sendtimeout ) or _sendtimeout
- _readtimeout = tonumber( new.readtimeout ) or _readtimeout
- _cleanqueue = new.cleanqueue
- _maxclientsperserver = new._maxclientsperserver or _maxclientsperserver
- return true
-end
-
-addtimer = function( listener )
- if type( listener ) ~= "function" then
- return nil, "invalid listener function"
- end
- _timerlistlen = _timerlistlen + 1
- _timerlist[ _timerlistlen ] = listener
- return true
-end
-
-stats = function( )
- return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen
-end
-
-local dontstop = true; -- thinking about tomorrow, ...
-
-setquitting = function (quit)
- dontstop = not quit;
- return;
-end
-
-loop = function( ) -- this is the main loop of the program
- while dontstop do
- local read, write, err = socket_select( _readlist, _sendlist, _selecttimeout )
- for i, socket in ipairs( write ) do -- send data waiting in writequeues
- local handler = _socketlist[ socket ]
- if handler then
- handler.sendbuffer( )
- else
- closesocket( socket )
- out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen
- end
- end
- for i, socket in ipairs( read ) do -- receive data
- local handler = _socketlist[ socket ]
- if handler then
- handler.readbuffer( )
- else
- closesocket( socket )
- out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen
- end
- end
- for handler, err in pairs( _closelist ) do
- handler.disconnect( )( handler, err )
- handler.close( true ) -- forced disconnect
- end
- clean( _closelist )
- _currenttime = os_time( )
- if os_difftime( _currenttime - _timer ) >= 1 then
- for i = 1, _timerlistlen do
- _timerlist[ i ]( ) -- fire timers
- end
- _timer = _currenttime
- end
- socket_sleep( _sleeptime ) -- wait some time
- --collectgarbage( )
- end
- return "quitting"
-end
-
---// EXPERIMENTAL //--
-
-local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx, startssl )
- local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, startssl )
- _socketlist[ socket ] = handler
- _sendlistlen = _sendlistlen + 1
- _sendlist[ _sendlistlen ] = socket
- _sendlist[ socket ] = _sendlistlen
- return handler, socket
-end
-
-local addclient = function( address, port, listeners, pattern, sslctx, startssl )
- local client, err = luasocket.tcp( )
- if err then
- return nil, err
- end
- client:settimeout( 0 )
- _, err = client:connect( address, port )
- if err then -- try again
- local handler = wrapclient( client, address, port, listeners )
- else
- wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx, startssl )
- end
-end
-
---// EXPERIMENTAL //--
-
-----------------------------------// BEGIN //--
-
-use "setmetatable" ( _socketlist, { __mode = "k" } )
-use "setmetatable" ( _readtimes, { __mode = "k" } )
-use "setmetatable" ( _writetimes, { __mode = "k" } )
-
-_timer = os_time( )
-_starttime = os_time( )
-
-addtimer( function( )
- local difftime = os_difftime( _currenttime - _starttime )
- if difftime > _checkinterval then
- _starttime = _currenttime
- for handler, timestamp in pairs( _writetimes ) do
- if os_difftime( _currenttime - timestamp ) > _sendtimeout then
- --_writetimes[ handler ] = nil
- handler.disconnect( )( handler, "send timeout" )
- handler.close( true ) -- forced disconnect
- end
- end
- for handler, timestamp in pairs( _readtimes ) do
- if os_difftime( _currenttime - timestamp ) > _readtimeout then
- --_readtimes[ handler ] = nil
- handler.disconnect( )( handler, "read timeout" )
- handler.close( ) -- forced disconnect?
- end
- end
- end
- end
-)
-
-----------------------------------// PUBLIC INTERFACE //--
-
-return {
-
- addclient = addclient,
- wrapclient = wrapclient,
-
- loop = loop,
- stats = stats,
- closeall = closeall,
- addtimer = addtimer,
- addserver = addserver,
- getserver = getserver,
- getsettings = getsettings,
- setquitting = setquitting,
- removeserver = removeserver,
- changesettings = changesettings,
-}
+-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local use_luaevent = prosody and require "core.configmanager".get("*", "use_libevent"); + +if use_luaevent then + use_luaevent = pcall(require, "luaevent.core"); + if not use_luaevent then + log("error", "libevent not found, falling back to select()"); + end +end + +local server; + +if use_luaevent then + server = require "net.server_event"; + + -- Overwrite signal.signal() because we need to ask libevent to + -- handle them instead + local ok, signal = pcall(require, "util.signal"); + if ok and signal then + local _signal_signal = signal.signal; + function signal.signal(signal_id, handler) + if type(signal_id) == "string" then + signal_id = signal[signal_id:upper()]; + end + if type(signal_id) ~= "number" then + return false, "invalid-signal"; + end + return server.hook_signal(signal_id, handler); + end + end +else + use_luaevent = false; + server = require "net.server_select"; +end + +if prosody then + local config_get = require "core.configmanager".get; + local defaults = {}; + for k,v in pairs(server.cfg or server.getsettings()) do + defaults[k] = v; + end + local function load_config() + local settings = config_get("*", "network_settings") or {}; + if use_luaevent then + local event_settings = { + ACCEPT_DELAY = settings.event_accept_retry_interval; + ACCEPT_QUEUE = settings.tcp_backlog; + CLEAR_DELAY = settings.event_clear_interval; + CONNECT_TIMEOUT = settings.connect_timeout; + DEBUG = settings.debug; + HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout; + MAX_CONNECTIONS = settings.max_connections; + MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips; + MAX_READ_LENGTH = settings.max_receive_buffer_size; + MAX_SEND_LENGTH = settings.max_send_buffer_size; + READ_TIMEOUT = settings.read_timeout; + WRITE_TIMEOUT = settings.send_timeout; + }; + + for k,default in pairs(defaults) do + server.cfg[k] = event_settings[k] or default; + end + else + local select_settings = {}; + for k,default in pairs(defaults) do + select_settings[k] = settings[k] or default; + end + server.changesettings(select_settings); + end + end + load_config(); + prosody.events.add_handler("config-reloaded", load_config); +end + +-- require "net.server" shall now forever return this, +-- ie. server_select or server_event as chosen above. +return server; diff --git a/net/server_event.lua b/net/server_event.lua new file mode 100644 index 00000000..5eae95a9 --- /dev/null +++ b/net/server_event.lua @@ -0,0 +1,872 @@ +--[[ + + + server.lua based on lua/libevent by blastbeat + + notes: + -- when using luaevent, never register 2 or more EV_READ at one socket, same for EV_WRITE + -- you cant even register a new EV_READ/EV_WRITE callback inside another one + -- to do some of the above, use timeout events or something what will called from outside + -- dont let garbagecollect eventcallbacks, as long they are running + -- when using luasec, there are 4 cases of timeout errors: wantread or wantwrite during reading or writing + +--]] + +local SCRIPT_NAME = "server_event.lua" +local SCRIPT_VERSION = "0.05" +local SCRIPT_AUTHOR = "blastbeat" +local LAST_MODIFIED = "2009/11/20" + +local cfg = { + MAX_CONNECTIONS = 100000, -- max per server connections (use "ulimit -n" on *nix) + MAX_HANDSHAKE_ATTEMPTS= 1000, -- attempts to finish ssl handshake + HANDSHAKE_TIMEOUT = 60, -- timeout in seconds per handshake attempt + MAX_READ_LENGTH = 1024 * 1024 * 1024 * 1024, -- max bytes allowed to read from sockets + MAX_SEND_LENGTH = 1024 * 1024 * 1024 * 1024, -- max bytes size of write buffer (for writing on sockets) + ACCEPT_QUEUE = 128, -- might influence the length of the pending sockets queue + ACCEPT_DELAY = 10, -- seconds to wait until the next attempt of a full server to accept + READ_TIMEOUT = 60 * 60 * 6, -- timeout in seconds for read data from socket + WRITE_TIMEOUT = 180, -- timeout in seconds for write data on socket + CONNECT_TIMEOUT = 20, -- timeout in seconds for connection attempts + CLEAR_DELAY = 5, -- seconds to wait for clearing interface list (and calling ondisconnect listeners) + DEBUG = true, -- show debug messages +} + +local function use(x) return rawget(_G, x); end +local ipairs = use "ipairs" +local string = use "string" +local select = use "select" +local require = use "require" +local tostring = use "tostring" +local coroutine = use "coroutine" +local setmetatable = use "setmetatable" + +local t_insert = table.insert +local t_concat = table.concat + +local ssl = use "ssl" +local socket = use "socket" or require "socket" + +local log = require ("util.logger").init("socket") + +local function debug(...) + return log("debug", ("%s "):rep(select('#', ...)), ...) +end +local vdebug = debug; + +local bitor = ( function( ) -- thx Rici Lake + local hasbit = function( x, p ) + return x % ( p + p ) >= p + end + return function( x, y ) + local p = 1 + local z = 0 + local limit = x > y and x or y + while p <= limit do + if hasbit( x, p ) or hasbit( y, p ) then + z = z + p + end + p = p + p + end + return z + end +end )( ) + +local event = require "luaevent.core" +local base = event.new( ) +local EV_READ = event.EV_READ +local EV_WRITE = event.EV_WRITE +local EV_TIMEOUT = event.EV_TIMEOUT +local EV_SIGNAL = event.EV_SIGNAL + +local EV_READWRITE = bitor( EV_READ, EV_WRITE ) + +local interfacelist = ( function( ) -- holds the interfaces for sockets + local array = { } + local len = 0 + return function( method, arg ) + if "add" == method then + len = len + 1 + array[ len ] = arg + arg:_position( len ) + return len + elseif "delete" == method then + if len <= 0 then + return nil, "array is already empty" + end + local position = arg:_position() -- get position in array + if position ~= len then + local interface = array[ len ] -- get last interface + array[ position ] = interface -- copy it into free position + array[ len ] = nil -- free last position + interface:_position( position ) -- set new position in array + else -- free last position + array[ len ] = nil + end + len = len - 1 + return len + else + return array + end + end +end )( ) + +-- Client interface methods +local interface_mt +do + interface_mt = {}; interface_mt.__index = interface_mt; + + local addevent = base.addevent + local coroutine_wrap, coroutine_yield = coroutine.wrap,coroutine.yield + + -- Private methods + function interface_mt:_position(new_position) + self.position = new_position or self.position + return self.position; + end + function interface_mt:_close() + return self:_destroy(); + end + + function interface_mt:_start_connection(plainssl) -- should be called from addclient + local callback = function( event ) + if EV_TIMEOUT == event then -- timeout during connection + self.fatalerror = "connection timeout" + self:ontimeout() -- call timeout listener + self:_close() + debug( "new connection failed. id:", self.id, "error:", self.fatalerror ) + else + if plainssl and ssl then -- start ssl session + self:starttls(self._sslctx, true) + else -- normal connection + self:_start_session(true) + end + debug( "new connection established. id:", self.id ) + end + self.eventconnect = nil + return -1 + end + self.eventconnect = addevent( base, self.conn, EV_WRITE, callback, cfg.CONNECT_TIMEOUT ) + return true + end + function interface_mt:_start_session(call_onconnect) -- new session, for example after startssl + if self.type == "client" then + local callback = function( ) + self:_lock( false, false, false ) + --vdebug( "start listening on client socket with id:", self.id ) + self.eventread = addevent( base, self.conn, EV_READ, self.readcallback, cfg.READ_TIMEOUT ); -- register callback + if call_onconnect then + self:onconnect() + end + self.eventsession = nil + return -1 + end + self.eventsession = addevent( base, nil, EV_TIMEOUT, callback, 0 ) + else + self:_lock( false ) + --vdebug( "start listening on server socket with id:", self.id ) + self.eventread = addevent( base, self.conn, EV_READ, self.readcallback ) -- register callback + end + return true + end + function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed, therefore we have to close read/write events first + --vdebug( "starting ssl session with client id:", self.id ) + local _ + _ = self.eventread and self.eventread:close( ) -- close events; this must be called outside of the event callbacks! + _ = self.eventwrite and self.eventwrite:close( ) + self.eventread, self.eventwrite = nil, nil + local err + self.conn, err = ssl.wrap( self.conn, self._sslctx ) + if err then + self.fatalerror = err + self.conn = nil -- cannot be used anymore + if call_onconnect then + self.ondisconnect = nil -- dont call this when client isnt really connected + end + self:_close() + debug( "fatal error while ssl wrapping:", err ) + return false + end + self.conn:settimeout( 0 ) -- set non blocking + local handshakecallback = coroutine_wrap( + function( event ) + local _, err + local attempt = 0 + local maxattempt = cfg.MAX_HANDSHAKE_ATTEMPTS + while attempt < maxattempt do -- no endless loop + attempt = attempt + 1 + debug( "ssl handshake of client with id:"..tostring(self)..", attempt:"..attempt ) + if attempt > maxattempt then + self.fatalerror = "max handshake attempts exceeded" + elseif EV_TIMEOUT == event then + self.fatalerror = "timeout during handshake" + else + _, err = self.conn:dohandshake( ) + if not err then + self:_lock( false, false, false ) -- unlock the interface; sending, closing etc allowed + self.send = self.conn.send -- caching table lookups with new client object + self.receive = self.conn.receive + if not call_onconnect then -- trigger listener + self:onstatus("ssl-handshake-complete"); + end + self:_start_session( call_onconnect ) + debug( "ssl handshake done" ) + self.eventhandshake = nil + return -1 + end + if err == "wantwrite" then + event = EV_WRITE + elseif err == "wantread" then + event = EV_READ + else + debug( "ssl handshake error:", err ) + self.fatalerror = err + end + end + if self.fatalerror then + if call_onconnect then + self.ondisconnect = nil -- dont call this when client isnt really connected + end + self:_close() + debug( "handshake failed because:", self.fatalerror ) + self.eventhandshake = nil + return -1 + end + event = coroutine_yield( event, cfg.HANDSHAKE_TIMEOUT ) -- yield this monster... + end + end + ) + debug "starting handshake..." + self:_lock( false, true, true ) -- unlock read/write events, but keep interface locked + self.eventhandshake = addevent( base, self.conn, EV_READWRITE, handshakecallback, cfg.HANDSHAKE_TIMEOUT ) + return true + end + function interface_mt:_destroy() -- close this interface + events and call last listener + debug( "closing client with id:", self.id, self.fatalerror ) + self:_lock( true, true, true ) -- first of all, lock the interface to avoid further actions + local _ + _ = self.eventread and self.eventread:close( ) + if self.type == "client" then + _ = self.eventwrite and self.eventwrite:close( ) + _ = self.eventhandshake and self.eventhandshake:close( ) + _ = self.eventstarthandshake and self.eventstarthandshake:close( ) + _ = self.eventconnect and self.eventconnect:close( ) + _ = self.eventsession and self.eventsession:close( ) + _ = self.eventwritetimeout and self.eventwritetimeout:close( ) + _ = self.eventreadtimeout and self.eventreadtimeout:close( ) + _ = self.ondisconnect and self:ondisconnect( self.fatalerror ~= "client to close" and self.fatalerror) -- call ondisconnect listener (wont be the case if handshake failed on connect) + _ = self.conn and self.conn:close( ) -- close connection + _ = self._server and self._server:counter(-1); + self.eventread, self.eventwrite = nil, nil + self.eventstarthandshake, self.eventhandshake, self.eventclose = nil, nil, nil + self.readcallback, self.writecallback = nil, nil + else + self.conn:close( ) + self.eventread, self.eventclose = nil, nil + self.interface, self.readcallback = nil, nil + end + interfacelist( "delete", self ) + return true + end + + function interface_mt:_lock(nointerface, noreading, nowriting) -- lock or unlock this interface or events + self.nointerface, self.noreading, self.nowriting = nointerface, noreading, nowriting + return nointerface, noreading, nowriting + end + + --TODO: Deprecate + function interface_mt:lock_read(switch) + if switch then + return self:pause(); + else + return self:resume(); + end + end + + function interface_mt:pause() + return self:_lock(self.nointerface, true, self.nowriting); + end + + function interface_mt:resume() + self:_lock(self.nointerface, false, self.nowriting); + if not self.eventread then + self.eventread = addevent( base, self.conn, EV_READ, self.readcallback, cfg.READ_TIMEOUT ); -- register callback + end + end + + function interface_mt:counter(c) + if c then + self._connections = self._connections + c + end + return self._connections + end + + -- Public methods + function interface_mt:write(data) + if self.nowriting then return nil, "locked" end + --vdebug( "try to send data to client, id/data:", self.id, data ) + data = tostring( data ) + local len = #data + local total = len + self.writebufferlen + if total > cfg.MAX_SEND_LENGTH then -- check buffer length + local err = "send buffer exceeded" + debug( "error:", err ) -- to much, check your app + return nil, err + end + t_insert(self.writebuffer, data) -- new buffer + self.writebufferlen = total + if not self.eventwrite then -- register new write event + --vdebug( "register new write event" ) + self.eventwrite = addevent( base, self.conn, EV_WRITE, self.writecallback, cfg.WRITE_TIMEOUT ) + end + return true + end + function interface_mt:close() + if self.nointerface then return nil, "locked"; end + debug( "try to close client connection with id:", self.id ) + if self.type == "client" then + self.fatalerror = "client to close" + if self.eventwrite then -- wait for incomplete write request + self:_lock( true, true, false ) + debug "closing delayed until writebuffer is empty" + return nil, "writebuffer not empty, waiting" + else -- close now + self:_lock( true, true, true ) + self:_close() + return true + end + else + debug( "try to close server with id:", tostring(self.id)) + self.fatalerror = "server to close" + self:_lock( true ) + self:_close( 0 ) + return true + end + end + + function interface_mt:socket() + return self.conn + end + + function interface_mt:server() + return self._server or self; + end + + function interface_mt:port() + return self._port + end + + function interface_mt:serverport() + return self._serverport + end + + function interface_mt:ip() + return self._ip + end + + function interface_mt:ssl() + return self._usingssl + end + + function interface_mt:type() + return self._type or "client" + end + + function interface_mt:connections() + return self._connections + end + + function interface_mt:address() + return self.addr + end + + function interface_mt:set_sslctx(sslctx) + self._sslctx = sslctx; + if sslctx then + self.starttls = nil; -- use starttls() of interface_mt + else + self.starttls = false; -- prevent starttls() + end + end + + function interface_mt:set_mode(pattern) + if pattern then + self._pattern = pattern; + end + return self._pattern; + end + + function interface_mt:set_send(new_send) + -- No-op, we always use the underlying connection's send + end + + function interface_mt:starttls(sslctx, call_onconnect) + debug( "try to start ssl at client id:", self.id ) + local err + self._sslctx = sslctx; + if self._usingssl then -- startssl was already called + err = "ssl already active" + end + if err then + debug( "error:", err ) + return nil, err + end + self._usingssl = true + self.startsslcallback = function( ) -- we have to start the handshake outside of a read/write event + self.startsslcallback = nil + self:_start_ssl(call_onconnect); + self.eventstarthandshake = nil + return -1 + end + if not self.eventwrite then + self:_lock( true, true, true ) -- lock the interface, to not disturb the handshake + self.eventstarthandshake = addevent( base, nil, EV_TIMEOUT, self.startsslcallback, 0 ) -- add event to start handshake + else -- wait until writebuffer is empty + self:_lock( true, true, false ) + debug "ssl session delayed until writebuffer is empty..." + end + self.starttls = false; + return true + end + + function interface_mt:setoption(option, value) + if self.conn.setoption then + return self.conn:setoption(option, value); + end + return false, "setoption not implemented"; + end + + function interface_mt:setlistener(listener) + self.onconnect, self.ondisconnect, self.onincoming, self.ontimeout, self.onstatus + = listener.onconnect, listener.ondisconnect, listener.onincoming, listener.ontimeout, listener.onstatus; + end + + -- Stub handlers + function interface_mt:onconnect() + end + function interface_mt:onincoming() + end + function interface_mt:ondisconnect() + end + function interface_mt:ontimeout() + end + function interface_mt:ondrain() + end + function interface_mt:onstatus() + end +end + +-- End of client interface methods + +local handleclient; +do + local string_sub = string.sub -- caching table lookups + local addevent = base.addevent + local socket_gettime = socket.gettime + function handleclient( client, ip, port, server, pattern, listener, sslctx ) -- creates an client interface + --vdebug("creating client interfacce...") + local interface = { + type = "client"; + conn = client; + currenttime = socket_gettime( ); -- safe the origin + writebuffer = {}; -- writebuffer + writebufferlen = 0; -- length of writebuffer + send = client.send; -- caching table lookups + receive = client.receive; + onconnect = listener.onconnect; -- will be called when client disconnects + ondisconnect = listener.ondisconnect; -- will be called when client disconnects + onincoming = listener.onincoming; -- will be called when client sends data + ontimeout = listener.ontimeout; -- called when fatal socket timeout occurs + onstatus = listener.onstatus; -- called for status changes (e.g. of SSL/TLS) + eventread = false, eventwrite = false, eventclose = false, + eventhandshake = false, eventstarthandshake = false; -- event handler + eventconnect = false, eventsession = false; -- more event handler... + eventwritetimeout = false; -- even more event handler... + eventreadtimeout = false; + fatalerror = false; -- error message + writecallback = false; -- will be called on write events + readcallback = false; -- will be called on read events + nointerface = true; -- lock/unlock parameter of this interface + noreading = false, nowriting = false; -- locks of the read/writecallback + startsslcallback = false; -- starting handshake callback + position = false; -- position of client in interfacelist + + -- Properties + _ip = ip, _port = port, _server = server, _pattern = pattern, + _serverport = (server and server:port() or nil), + _sslctx = sslctx; -- parameters + _usingssl = false; -- client is using ssl; + } + if not ssl then interface.starttls = false; end + interface.id = tostring(interface):match("%x+$"); + interface.writecallback = function( event ) -- called on write events + --vdebug( "new client write event, id/ip/port:", interface, ip, port ) + if interface.nowriting or ( interface.fatalerror and ( "client to close" ~= interface.fatalerror ) ) then -- leave this event + --vdebug( "leaving this event because:", interface.nowriting or interface.fatalerror ) + interface.eventwrite = false + return -1 + end + if EV_TIMEOUT == event then -- took too long to write some data to socket -> disconnect + interface.fatalerror = "timeout during writing" + debug( "writing failed:", interface.fatalerror ) + interface:_close() + interface.eventwrite = false + return -1 + else -- can write :) + if interface._usingssl then -- handle luasec + if interface.eventreadtimeout then -- we have to read first + local ret = interface.readcallback( ) -- call readcallback + --vdebug( "tried to read in writecallback, result:", ret ) + end + if interface.eventwritetimeout then -- luasec only + interface.eventwritetimeout:close( ) -- first we have to close timeout event which where regged after a wantread error + interface.eventwritetimeout = false + end + end + interface.writebuffer = { t_concat(interface.writebuffer) } + local succ, err, byte = interface.conn:send( interface.writebuffer[1], 1, interface.writebufferlen ) + --vdebug( "write data:", interface.writebuffer, "error:", err, "part:", byte ) + if succ then -- writing succesful + interface.writebuffer[1] = nil + interface.writebufferlen = 0 + interface:ondrain(); + if interface.fatalerror then + debug "closing client after writing" + interface:_close() -- close interface if needed + elseif interface.startsslcallback then -- start ssl connection if needed + debug "starting ssl handshake after writing" + interface.eventstarthandshake = addevent( base, nil, EV_TIMEOUT, interface.startsslcallback, 0 ) + elseif interface.eventreadtimeout then + return EV_WRITE, EV_TIMEOUT + end + interface.eventwrite = nil + return -1 + elseif byte and (err == "timeout" or err == "wantwrite") then -- want write again + --vdebug( "writebuffer is not empty:", err ) + interface.writebuffer[1] = string_sub( interface.writebuffer[1], byte + 1, interface.writebufferlen ) -- new buffer + interface.writebufferlen = interface.writebufferlen - byte + if "wantread" == err then -- happens only with luasec + local callback = function( ) + interface:_close() + interface.eventwritetimeout = nil + return -1; + end + interface.eventwritetimeout = addevent( base, nil, EV_TIMEOUT, callback, cfg.WRITE_TIMEOUT ) -- reg a new timeout event + debug( "wantread during write attempt, reg it in readcallback but dont know what really happens next..." ) + -- hopefully this works with luasec; its simply not possible to use 2 different write events on a socket in luaevent + return -1 + end + return EV_WRITE, cfg.WRITE_TIMEOUT + else -- connection was closed during writing or fatal error + interface.fatalerror = err or "fatal error" + debug( "connection failed in write event:", interface.fatalerror ) + interface:_close() + interface.eventwrite = nil + return -1 + end + end + end + + interface.readcallback = function( event ) -- called on read events + --vdebug( "new client read event, id/ip/port:", tostring(interface.id), tostring(ip), tostring(port) ) + if interface.noreading or interface.fatalerror then -- leave this event + --vdebug( "leaving this event because:", tostring(interface.noreading or interface.fatalerror) ) + interface.eventread = nil + return -1 + end + if EV_TIMEOUT == event then -- took too long to get some data from client -> disconnect + interface.fatalerror = "timeout during receiving" + debug( "connection failed:", interface.fatalerror ) + interface:_close() + interface.eventread = nil + return -1 + else -- can read + if interface._usingssl then -- handle luasec + if interface.eventwritetimeout then -- ok, in the past writecallback was regged + local ret = interface.writecallback( ) -- call it + --vdebug( "tried to write in readcallback, result:", tostring(ret) ) + end + if interface.eventreadtimeout then + interface.eventreadtimeout:close( ) + interface.eventreadtimeout = nil + end + end + local buffer, err, part = interface.conn:receive( interface._pattern ) -- receive buffer with "pattern" + --vdebug( "read data:", tostring(buffer), "error:", tostring(err), "part:", tostring(part) ) + buffer = buffer or part + if buffer and #buffer > cfg.MAX_READ_LENGTH then -- check buffer length + interface.fatalerror = "receive buffer exceeded" + debug( "fatal error:", interface.fatalerror ) + interface:_close() + interface.eventread = nil + return -1 + end + if err and ( err ~= "timeout" and err ~= "wantread" ) then + if "wantwrite" == err then -- need to read on write event + if not interface.eventwrite then -- register new write event if needed + interface.eventwrite = addevent( base, interface.conn, EV_WRITE, interface.writecallback, cfg.WRITE_TIMEOUT ) + end + interface.eventreadtimeout = addevent( base, nil, EV_TIMEOUT, + function( ) + interface:_close() + end, cfg.READ_TIMEOUT + ) + debug( "wantwrite during read attempt, reg it in writecallback but dont know what really happens next..." ) + -- to be honest i dont know what happens next, if it is allowed to first read, the write etc... + else -- connection was closed or fatal error + interface.fatalerror = err + debug( "connection failed in read event:", interface.fatalerror ) + interface:_close() + interface.eventread = nil + return -1 + end + else + interface.onincoming( interface, buffer, err ) -- send new data to listener + end + if interface.noreading then + interface.eventread = nil; + return -1; + end + return EV_READ, cfg.READ_TIMEOUT + end + end + + client:settimeout( 0 ) -- set non blocking + setmetatable(interface, interface_mt) + interfacelist( "add", interface ) -- add to interfacelist + return interface + end +end + +local handleserver +do + function handleserver( server, addr, port, pattern, listener, sslctx ) -- creates an server interface + debug "creating server interface..." + local interface = { + _connections = 0; + + conn = server; + onconnect = listener.onconnect; -- will be called when new client connected + eventread = false; -- read event handler + eventclose = false; -- close event handler + readcallback = false; -- read event callback + fatalerror = false; -- error message + nointerface = true; -- lock/unlock parameter + + _ip = addr, _port = port, _pattern = pattern, + _sslctx = sslctx; + } + interface.id = tostring(interface):match("%x+$"); + interface.readcallback = function( event ) -- server handler, called on incoming connections + --vdebug( "server can accept, id/addr/port:", interface, addr, port ) + if interface.fatalerror then + --vdebug( "leaving this event because:", self.fatalerror ) + interface.eventread = nil + return -1 + end + local delay = cfg.ACCEPT_DELAY + if EV_TIMEOUT == event then + if interface._connections >= cfg.MAX_CONNECTIONS then -- check connection count + debug( "to many connections, seconds to wait for next accept:", delay ) + return EV_TIMEOUT, delay -- timeout... + else + return EV_READ -- accept again + end + end + --vdebug("max connection check ok, accepting...") + local client, err = server:accept() -- try to accept; TODO: check err + while client do + if interface._connections >= cfg.MAX_CONNECTIONS then + client:close( ) -- refuse connection + debug( "maximal connections reached, refuse client connection; accept delay:", delay ) + return EV_TIMEOUT, delay -- delay for next accept attempt + end + local client_ip, client_port = client:getpeername( ) + interface._connections = interface._connections + 1 -- increase connection count + local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx ) + --vdebug( "client id:", clientinterface, "startssl:", startssl ) + if ssl and sslctx then + clientinterface:starttls(sslctx, true) + else + clientinterface:_start_session( true ) + end + debug( "accepted incoming client connection from:", client_ip or "<unknown IP>", client_port or "<unknown port>", "to", port or "<unknown port>"); + + client, err = server:accept() -- try to accept again + end + return EV_READ + end + + server:settimeout( 0 ) + setmetatable(interface, interface_mt) + interfacelist( "add", interface ) + interface:_start_session() + return interface + end +end + +local addserver = ( function( ) + return function( addr, port, listener, pattern, sslcfg, startssl ) -- TODO: check arguments + --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslcfg or "nil", startssl or "nil") + local server, err = socket.bind( addr, port, cfg.ACCEPT_QUEUE ) -- create server socket + if not server then + debug( "creating server socket on "..addr.." port "..port.." failed:", err ) + return nil, err + end + local sslctx + if sslcfg then + if not ssl then + debug "fatal error: luasec not found" + return nil, "luasec not found" + end + sslctx, err = sslcfg + if err then + debug( "error while creating new ssl context for server socket:", err ) + return nil, err + end + end + local interface = handleserver( server, addr, port, pattern, listener, sslctx, startssl ) -- new server handler + debug( "new server created with id:", tostring(interface)) + return interface + end +end )( ) + +local addclient, wrapclient +do + function wrapclient( client, ip, port, listeners, pattern, sslctx ) + local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx ) + interface:_start_connection(sslctx) + return interface, client + --function handleclient( client, ip, port, server, pattern, listener, _, sslctx ) -- creates an client interface + end + + function addclient( addr, serverport, listener, pattern, localaddr, localport, sslcfg, startssl ) + local client, err = socket.tcp() -- creating new socket + if not client then + debug( "cannot create socket:", err ) + return nil, err + end + client:settimeout( 0 ) -- set nonblocking + if localaddr then + local res, err = client:bind( localaddr, localport, -1 ) + if not res then + debug( "cannot bind client:", err ) + return nil, err + end + end + local sslctx + if sslcfg then -- handle ssl/new context + if not ssl then + debug "need luasec, but not available" + return nil, "luasec not found" + end + sslctx, err = sslcfg + if err then + debug( "cannot create new ssl context:", err ) + return nil, err + end + end + local res, err = client:connect( addr, serverport ) -- connect + if res or ( err == "timeout" ) then + local ip, port = client:getsockname( ) + local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx, startssl ) + interface:_start_connection( startssl ) + debug( "new connection id:", interface.id ) + return interface, err + else + debug( "new connection failed:", err ) + return nil, err + end + end +end + + +local loop = function( ) -- starts the event loop + base:loop( ) + return "quitting"; +end + +local newevent = ( function( ) + local add = base.addevent + return function( ... ) + return add( base, ... ) + end +end )( ) + +local closeallservers = function( arg ) + for _, item in ipairs( interfacelist( ) ) do + if item.type == "server" then + item:close( arg ) + end + end +end + +local function setquitting(yes) + if yes then + -- Quit now + closeallservers(); + base:loopexit(); + end +end + +local function get_backend() + return base:method(); +end + +-- We need to hold onto the events to stop them +-- being garbage-collected +local signal_events = {}; -- [signal_num] -> event object +local function hook_signal(signal_num, handler) + local function _handler(event) + local ret = handler(); + if ret ~= false then -- Continue handling this signal? + return EV_SIGNAL; -- Yes + end + return -1; -- Close this event + end + signal_events[signal_num] = base:addevent(signal_num, EV_SIGNAL, _handler); + return signal_events[signal_num]; +end + +local function link(sender, receiver, buffersize) + local sender_locked; + + function receiver:ondrain() + if sender_locked then + sender:resume(); + sender_locked = nil; + end + end + + function sender:onincoming(data) + receiver:write(data); + if receiver.writebufferlen >= buffersize then + sender_locked = true; + sender:pause(); + end + end +end + +return { + + cfg = cfg, + base = base, + loop = loop, + link = link, + event = event, + event_base = base, + addevent = newevent, + addserver = addserver, + addclient = addclient, + wrapclient = wrapclient, + setquitting = setquitting, + closeall = closeallservers, + get_backend = get_backend, + hook_signal = hook_signal, + + __NAME = SCRIPT_NAME, + __DATE = LAST_MODIFIED, + __AUTHOR = SCRIPT_AUTHOR, + __VERSION = SCRIPT_VERSION, + +} diff --git a/net/server_select.lua b/net/server_select.lua new file mode 100644 index 00000000..7eb330a8 --- /dev/null +++ b/net/server_select.lua @@ -0,0 +1,984 @@ +-- +-- server.lua by blastbeat of the luadch project +-- Re-used here under the MIT/X Consortium License +-- +-- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain +-- + +-- // wrapping luadch stuff // -- + +local use = function( what ) + return _G[ what ] +end + +local log, table_concat = require ("util.logger").init("socket"), table.concat; +local out_put = function (...) return log("debug", table_concat{...}); end +local out_error = function (...) return log("warn", table_concat{...}); end + +----------------------------------// DECLARATION //-- + +--// constants //-- + +local STAT_UNIT = 1 -- byte + +--// lua functions //-- + +local type = use "type" +local pairs = use "pairs" +local ipairs = use "ipairs" +local tonumber = use "tonumber" +local tostring = use "tostring" + +--// lua libs //-- + +local os = use "os" +local table = use "table" +local string = use "string" +local coroutine = use "coroutine" + +--// lua lib methods //-- + +local os_difftime = os.difftime +local math_min = math.min +local math_huge = math.huge +local table_concat = table.concat +local string_sub = string.sub +local coroutine_wrap = coroutine.wrap +local coroutine_yield = coroutine.yield + +--// extern libs //-- + +local luasec = use "ssl" +local luasocket = use "socket" or require "socket" +local luasocket_gettime = luasocket.gettime + +--// extern lib methods //-- + +local ssl_wrap = ( luasec and luasec.wrap ) +local socket_bind = luasocket.bind +local socket_sleep = luasocket.sleep +local socket_select = luasocket.select + +--// functions //-- + +local id +local loop +local stats +local idfalse +local closeall +local addsocket +local addserver +local addtimer +local getserver +local wrapserver +local getsettings +local closesocket +local removesocket +local removeserver +local wrapconnection +local changesettings + +--// tables //-- + +local _server +local _readlist +local _timerlist +local _sendlist +local _socketlist +local _closelist +local _readtimes +local _writetimes + +--// simple data types //-- + +local _ +local _readlistlen +local _sendlistlen +local _timerlistlen + +local _sendtraffic +local _readtraffic + +local _selecttimeout +local _sleeptime +local _tcpbacklog + +local _starttime +local _currenttime + +local _maxsendlen +local _maxreadlen + +local _checkinterval +local _sendtimeout +local _readtimeout + +local _timer + +local _maxselectlen +local _maxfd + +local _maxsslhandshake + +----------------------------------// DEFINITION //-- + +_server = { } -- key = port, value = table; list of listening servers +_readlist = { } -- array with sockets to read from +_sendlist = { } -- arrary with sockets to write to +_timerlist = { } -- array of timer functions +_socketlist = { } -- key = socket, value = wrapped socket (handlers) +_readtimes = { } -- key = handler, value = timestamp of last data reading +_writetimes = { } -- key = handler, value = timestamp of last data writing/sending +_closelist = { } -- handlers to close + +_readlistlen = 0 -- length of readlist +_sendlistlen = 0 -- length of sendlist +_timerlistlen = 0 -- lenght of timerlist + +_sendtraffic = 0 -- some stats +_readtraffic = 0 + +_selecttimeout = 1 -- timeout of socket.select +_sleeptime = 0 -- time to wait at the end of every loop +_tcpbacklog = 128 -- some kind of hint to the OS + +_maxsendlen = 51000 * 1024 -- max len of send buffer +_maxreadlen = 25000 * 1024 -- max len of read buffer + +_checkinterval = 1200000 -- interval in secs to check idle clients +_sendtimeout = 60000 -- allowed send idle time in secs +_readtimeout = 6 * 60 * 60 -- allowed read idle time in secs + +local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows +_maxfd = luasocket._SETSIZE or (is_windows and math.huge) or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows +_maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows + +_maxsslhandshake = 30 -- max handshake round-trips + +----------------------------------// PRIVATE //-- + +wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd + + if socket:getfd() >= _maxfd then + out_error("server.lua: Disallowed FD number: "..socket:getfd()) + socket:close() + return nil, "fd-too-large" + end + + local connections = 0 + + local dispatch, disconnect = listeners.onconnect, listeners.ondisconnect + + local accept = socket.accept + + --// public methods of the object //-- + + local handler = { } + + handler.shutdown = function( ) end + + handler.ssl = function( ) + return sslctx ~= nil + end + handler.sslctx = function( ) + return sslctx + end + handler.remove = function( ) + connections = connections - 1 + if handler then + handler.resume( ) + end + end + handler.close = function() + socket:close( ) + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _server[ip..":"..serverport] = nil; + _socketlist[ socket ] = nil + handler = nil + socket = nil + --mem_free( ) + out_put "server.lua: closed server handler and removed sockets from list" + end + handler.pause = function( hard ) + if not handler.paused then + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + if hard then + _socketlist[ socket ] = nil + socket:close( ) + socket = nil; + end + handler.paused = true; + end + end + handler.resume = function( ) + if handler.paused then + if not socket then + socket = socket_bind( ip, serverport, _tcpbacklog ); + socket:settimeout( 0 ) + end + _readlistlen = addsocket(_readlist, socket, _readlistlen) + _socketlist[ socket ] = handler + handler.paused = false; + end + end + handler.ip = function( ) + return ip + end + handler.serverport = function( ) + return serverport + end + handler.socket = function( ) + return socket + end + handler.readbuffer = function( ) + if _readlistlen >= _maxselectlen or _sendlistlen >= _maxselectlen then + handler.pause( ) + out_put( "server.lua: refused new client connection: server full" ) + return false + end + local client, err = accept( socket ) -- try to accept + if client then + local ip, clientport = client:getpeername( ) + local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket + if err then -- error while wrapping ssl socket + return false + end + connections = connections + 1 + out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport)) + if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes + return dispatch( handler ); + end + return; + elseif err then -- maybe timeout or something else + out_put( "server.lua: error with new client connection: ", tostring(err) ) + return false + end + end + return handler +end + +wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object + + if socket:getfd() >= _maxfd then + out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent + socket:close( ) -- Should we send some kind of error here? + server.pause( ) + return nil, nil, "fd-too-large" + end + socket:settimeout( 0 ) + + --// local import of socket methods //-- + + local send + local receive + local shutdown + + --// private closures of the object //-- + + local ssl + + local dispatch = listeners.onincoming + local status = listeners.onstatus + local disconnect = listeners.ondisconnect + local drain = listeners.ondrain + + local bufferqueue = { } -- buffer array + local bufferqueuelen = 0 -- end of buffer array + + local toclose + local fatalerror + local needtls + + local bufferlen = 0 + + local noread = false + local nosend = false + + local sendtraffic, readtraffic = 0, 0 + + local maxsendlen = _maxsendlen + local maxreadlen = _maxreadlen + + --// public methods of the object //-- + + local handler = bufferqueue -- saves a table ^_^ + + handler.dispatch = function( ) + return dispatch + end + handler.disconnect = function( ) + return disconnect + end + handler.setlistener = function( self, listeners ) + dispatch = listeners.onincoming + disconnect = listeners.ondisconnect + status = listeners.onstatus + drain = listeners.ondrain + end + handler.getstats = function( ) + return readtraffic, sendtraffic + end + handler.ssl = function( ) + return ssl + end + handler.sslctx = function ( ) + return sslctx + end + handler.send = function( _, data, i, j ) + return send( socket, data, i, j ) + end + handler.receive = function( pattern, prefix ) + return receive( socket, pattern, prefix ) + end + handler.shutdown = function( pattern ) + return shutdown( socket, pattern ) + end + handler.setoption = function (self, option, value) + if socket.setoption then + return socket:setoption(option, value); + end + return false, "setoption not implemented"; + end + handler.force_close = function ( self, err ) + if bufferqueuelen ~= 0 then + out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport)) + bufferqueuelen = 0; + end + return self:close(err); + end + handler.close = function( self, err ) + if not handler then return true; end + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _readtimes[ handler ] = nil + if bufferqueuelen ~= 0 then + handler.sendbuffer() -- Try now to send any outstanding data + if bufferqueuelen ~= 0 then -- Still not empty, so we'll try again later + if handler then + handler.write = nil -- ... but no further writing allowed + end + toclose = true + return false + end + end + if socket then + _ = shutdown and shutdown( socket ) + socket:close( ) + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + _socketlist[ socket ] = nil + socket = nil + else + out_put "server.lua: socket already closed" + end + if handler then + _writetimes[ handler ] = nil + _closelist[ handler ] = nil + local _handler = handler; + handler = nil + if disconnect then + disconnect(_handler, err or false); + disconnect = nil + end + end + if server then + server.remove( ) + end + out_put "server.lua: closed client handler and removed socket from list" + return true + end + handler.ip = function( ) + return ip + end + handler.serverport = function( ) + return serverport + end + handler.clientport = function( ) + return clientport + end + local write = function( self, data ) + bufferlen = bufferlen + #data + if bufferlen > maxsendlen then + _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle + handler.write = idfalse -- dont write anymore + return false + elseif socket and not _sendlist[ socket ] then + _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) + end + bufferqueuelen = bufferqueuelen + 1 + bufferqueue[ bufferqueuelen ] = data + if handler then + _writetimes[ handler ] = _writetimes[ handler ] or _currenttime + end + return true + end + handler.write = write + handler.bufferqueue = function( self ) + return bufferqueue + end + handler.socket = function( self ) + return socket + end + handler.set_mode = function( self, new ) + pattern = new or pattern + return pattern + end + handler.set_send = function ( self, newsend ) + send = newsend or send + return send + end + handler.bufferlen = function( self, readlen, sendlen ) + maxsendlen = sendlen or maxsendlen + maxreadlen = readlen or maxreadlen + return bufferlen, maxreadlen, maxsendlen + end + --TODO: Deprecate + handler.lock_read = function (self, switch) + if switch == true then + local tmp = _readlistlen + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _readtimes[ handler ] = nil + if _readlistlen ~= tmp then + noread = true + end + elseif switch == false then + if noread then + noread = false + _readlistlen = addsocket(_readlist, socket, _readlistlen) + _readtimes[ handler ] = _currenttime + end + end + return noread + end + handler.pause = function (self) + return self:lock_read(true); + end + handler.resume = function (self) + return self:lock_read(false); + end + handler.lock = function( self, switch ) + handler.lock_read (switch) + if switch == true then + handler.write = idfalse + local tmp = _sendlistlen + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + _writetimes[ handler ] = nil + if _sendlistlen ~= tmp then + nosend = true + end + elseif switch == false then + handler.write = write + if nosend then + nosend = false + write( "" ) + end + end + return noread, nosend + end + local _readbuffer = function( ) -- this function reads data + local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern" + if not err or (err == "wantread" or err == "timeout") then -- received something + local buffer = buffer or part or "" + local len = #buffer + if len > maxreadlen then + handler:close( "receive buffer exceeded" ) + return false + end + local count = len * STAT_UNIT + readtraffic = readtraffic + count + _readtraffic = _readtraffic + count + _readtimes[ handler ] = _currenttime + --out_put( "server.lua: read data '", buffer:gsub("[^%w%p ]", "."), "', error: ", err ) + return dispatch( handler, buffer, err ) + else -- connections was closed or fatal error + out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) ) + fatalerror = true + _ = handler and handler:force_close( err ) + return false + end + end + local _sendbuffer = function( ) -- this function sends data + local succ, err, byte, buffer, count; + if socket then + buffer = table_concat( bufferqueue, "", 1, bufferqueuelen ) + succ, err, byte = send( socket, buffer, 1, bufferlen ) + count = ( succ or byte or 0 ) * STAT_UNIT + sendtraffic = sendtraffic + count + _sendtraffic = _sendtraffic + count + for i = bufferqueuelen,1,-1 do + bufferqueue[ i ] = nil + end + --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) ) + else + succ, err, count = false, "unexpected close", 0; + end + if succ then -- sending succesful + bufferqueuelen = 0 + bufferlen = 0 + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist + _writetimes[ handler ] = nil + if drain then + drain(handler) + end + _ = needtls and handler:starttls(nil) + _ = toclose and handler:force_close( ) + return true + elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write + buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer + bufferqueue[ 1 ] = buffer -- insert new buffer in queue + bufferqueuelen = 1 + bufferlen = bufferlen - byte + _writetimes[ handler ] = _currenttime + return true + else -- connection was closed during sending or fatal error + out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) ) + fatalerror = true + _ = handler and handler:force_close( err ) + return false + end + end + + -- Set the sslctx + local handshake; + function handler.set_sslctx(self, new_sslctx) + sslctx = new_sslctx; + local read, wrote + handshake = coroutine_wrap( function( client ) -- create handshake coroutine + local err + for i = 1, _maxsslhandshake do + _sendlistlen = ( wrote and removesocket( _sendlist, client, _sendlistlen ) ) or _sendlistlen + _readlistlen = ( read and removesocket( _readlist, client, _readlistlen ) ) or _readlistlen + read, wrote = nil, nil + _, err = client:dohandshake( ) + if not err then + out_put( "server.lua: ssl handshake done" ) + handler.readbuffer = _readbuffer -- when handshake is done, replace the handshake function with regular functions + handler.sendbuffer = _sendbuffer + _ = status and status( handler, "ssl-handshake-complete" ) + if self.autostart_ssl and listeners.onconnect then + listeners.onconnect(self); + end + _readlistlen = addsocket(_readlist, client, _readlistlen) + return true + else + if err == "wantwrite" then + _sendlistlen = addsocket(_sendlist, client, _sendlistlen) + wrote = true + elseif err == "wantread" then + _readlistlen = addsocket(_readlist, client, _readlistlen) + read = true + else + break; + end + err = nil; + coroutine_yield( ) -- handshake not finished + end + end + out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") ) + _ = handler and handler:force_close("ssl handshake failed") + return false, err -- handshake failed + end + ) + end + if luasec then + handler.starttls = function( self, _sslctx) + if _sslctx then + handler:set_sslctx(_sslctx); + end + if bufferqueuelen > 0 then + out_put "server.lua: we need to do tls, but delaying until send buffer empty" + needtls = true + return + end + out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) + local oldsocket, err = socket + socket, err = ssl_wrap( socket, sslctx ) -- wrap socket + if not socket then + out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) + return nil, err -- fatal error + end + + socket:settimeout( 0 ) + + -- add the new socket to our system + send = socket.send + receive = socket.receive + shutdown = id + _socketlist[ socket ] = handler + _readlistlen = addsocket(_readlist, socket, _readlistlen) + + -- remove traces of the old socket + _readlistlen = removesocket( _readlist, oldsocket, _readlistlen ) + _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen ) + _socketlist[ oldsocket ] = nil + + handler.starttls = nil + needtls = nil + + -- Secure now (if handshake fails connection will close) + ssl = true + + handler.readbuffer = handshake + handler.sendbuffer = handshake + return handshake( socket ) -- do handshake + end + end + + handler.readbuffer = _readbuffer + handler.sendbuffer = _sendbuffer + send = socket.send + receive = socket.receive + shutdown = ( ssl and id ) or socket.shutdown + + _socketlist[ socket ] = handler + _readlistlen = addsocket(_readlist, socket, _readlistlen) + + if sslctx and luasec then + out_put "server.lua: auto-starting ssl negotiation..." + handler.autostart_ssl = true; + local ok, err = handler:starttls(sslctx); + if ok == false then + return nil, nil, err + end + end + + return handler, socket +end + +id = function( ) +end + +idfalse = function( ) + return false +end + +addsocket = function( list, socket, len ) + if not list[ socket ] then + len = len + 1 + list[ len ] = socket + list[ socket ] = len + end + return len; +end + +removesocket = function( list, socket, len ) -- this function removes sockets from a list ( copied from copas ) + local pos = list[ socket ] + if pos then + list[ socket ] = nil + local last = list[ len ] + list[ len ] = nil + if last ~= socket then + list[ last ] = pos + list[ pos ] = last + end + return len - 1 + end + return len +end + +closesocket = function( socket ) + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _socketlist[ socket ] = nil + socket:close( ) + --mem_free( ) +end + +local function link(sender, receiver, buffersize) + local sender_locked; + local _sendbuffer = receiver.sendbuffer; + function receiver.sendbuffer() + _sendbuffer(); + if sender_locked and receiver.bufferlen() < buffersize then + sender:lock_read(false); -- Unlock now + sender_locked = nil; + end + end + + local _readbuffer = sender.readbuffer; + function sender.readbuffer() + _readbuffer(); + if not sender_locked and receiver.bufferlen() >= buffersize then + sender_locked = true; + sender:lock_read(true); + end + end +end + +----------------------------------// PUBLIC //-- + +addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server + local err + if type( listeners ) ~= "table" then + err = "invalid listener table" + end + if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then + err = "invalid port" + elseif _server[ addr..":"..port ] then + err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist" + elseif sslctx and not luasec then + err = "luasec not found" + end + if err then + out_error( "server.lua, [", addr, "]:", port, ": ", err ) + return nil, err + end + addr = addr or "*" + local server, err = socket_bind( addr, port, _tcpbacklog ) + if err then + out_error( "server.lua, [", addr, "]:", port, ": ", err ) + return nil, err + end + local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket + if not handler then + server:close( ) + return nil, err + end + server:settimeout( 0 ) + _readlistlen = addsocket(_readlist, server, _readlistlen) + _server[ addr..":"..port ] = handler + _socketlist[ server ] = handler + out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" ) + return handler +end + +getserver = function ( addr, port ) + return _server[ addr..":"..port ]; +end + +removeserver = function( addr, port ) + local handler = _server[ addr..":"..port ] + if not handler then + return nil, "no server found on '[" .. addr .. "]:" .. tostring( port ) .. "'" + end + handler:close( ) + _server[ addr..":"..port ] = nil + return true +end + +closeall = function( ) + for _, handler in pairs( _socketlist ) do + handler:close( ) + _socketlist[ _ ] = nil + end + _readlistlen = 0 + _sendlistlen = 0 + _timerlistlen = 0 + _server = { } + _readlist = { } + _sendlist = { } + _timerlist = { } + _socketlist = { } + --mem_free( ) +end + +getsettings = function( ) + return { + select_timeout = _selecttimeout; + select_sleep_time = _sleeptime; + tcp_backlog = _tcpbacklog; + max_send_buffer_size = _maxsendlen; + max_receive_buffer_size = _maxreadlen; + select_idle_check_interval = _checkinterval; + send_timeout = _sendtimeout; + read_timeout = _readtimeout; + max_connections = _maxselectlen; + max_ssl_handshake_roundtrips = _maxsslhandshake; + highest_allowed_fd = _maxfd; + } +end + +changesettings = function( new ) + if type( new ) ~= "table" then + return nil, "invalid settings table" + end + _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout + _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime + _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen + _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen + _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval + _tcpbacklog = tonumber( new.tcp_backlog ) or _tcpbacklog + _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout + _readtimeout = tonumber( new.read_timeout ) or _readtimeout + _maxselectlen = new.max_connections or _maxselectlen + _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake + _maxfd = new.highest_allowed_fd or _maxfd + return true +end + +addtimer = function( listener ) + if type( listener ) ~= "function" then + return nil, "invalid listener function" + end + _timerlistlen = _timerlistlen + 1 + _timerlist[ _timerlistlen ] = listener + return true +end + +stats = function( ) + return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen +end + +local quitting; + +local function setquitting(quit) + quitting = not not quit; +end + +loop = function(once) -- this is the main loop of the program + if quitting then return "quitting"; end + if once then quitting = "once"; end + local next_timer_time = math_huge; + repeat + local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) ) + for i, socket in ipairs( write ) do -- send data waiting in writequeues + local handler = _socketlist[ socket ] + if handler then + handler.sendbuffer( ) + else + closesocket( socket ) + out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen + end + end + for i, socket in ipairs( read ) do -- receive data + local handler = _socketlist[ socket ] + if handler then + handler.readbuffer( ) + else + closesocket( socket ) + out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen + end + end + for handler, err in pairs( _closelist ) do + handler.disconnect( )( handler, err ) + handler:force_close() -- forced disconnect + _closelist[ handler ] = nil; + end + _currenttime = luasocket_gettime( ) + + -- Check for socket timeouts + local difftime = os_difftime( _currenttime - _starttime ) + if difftime > _checkinterval then + _starttime = _currenttime + for handler, timestamp in pairs( _writetimes ) do + if os_difftime( _currenttime - timestamp ) > _sendtimeout then + --_writetimes[ handler ] = nil + handler.disconnect( )( handler, "send timeout" ) + handler:force_close() -- forced disconnect + end + end + for handler, timestamp in pairs( _readtimes ) do + if os_difftime( _currenttime - timestamp ) > _readtimeout then + --_readtimes[ handler ] = nil + handler.disconnect( )( handler, "read timeout" ) + handler:close( ) -- forced disconnect? + end + end + end + + -- Fire timers + if _currenttime - _timer >= math_min(next_timer_time, 1) then + next_timer_time = math_huge; + for i = 1, _timerlistlen do + local t = _timerlist[ i ]( _currenttime ) -- fire timers + if t then next_timer_time = math_min(next_timer_time, t); end + end + _timer = _currenttime + else + next_timer_time = next_timer_time - (_currenttime - _timer); + end + + -- wait some time (0 by default) + socket_sleep( _sleeptime ) + until quitting; + if once and quitting == "once" then quitting = nil; return; end + return "quitting" +end + +local function step() + return loop(true); +end + +local function get_backend() + return "select"; +end + +--// EXPERIMENTAL //-- + +local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx ) + local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) + if not handler then return nil, err end + _socketlist[ socket ] = handler + if not sslctx then + _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) + if listeners.onconnect then + -- When socket is writeable, call onconnect + local _sendbuffer = handler.sendbuffer; + handler.sendbuffer = function () + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ); + handler.sendbuffer = _sendbuffer; + listeners.onconnect(handler); + -- If there was data with the incoming packet, handle it now. + if #handler:bufferqueue() > 0 then + return _sendbuffer(); + end + end + end + end + return handler, socket +end + +local addclient = function( address, port, listeners, pattern, sslctx ) + local client, err = luasocket.tcp( ) + if err then + return nil, err + end + client:settimeout( 0 ) + _, err = client:connect( address, port ) + if err then -- try again + local handler = wrapclient( client, address, port, listeners ) + else + wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx ) + end +end + +--// EXPERIMENTAL //-- + +----------------------------------// BEGIN //-- + +use "setmetatable" ( _socketlist, { __mode = "k" } ) +use "setmetatable" ( _readtimes, { __mode = "k" } ) +use "setmetatable" ( _writetimes, { __mode = "k" } ) + +_timer = luasocket_gettime( ) +_starttime = luasocket_gettime( ) + +local function setlogger(new_logger) + local old_logger = log; + if new_logger then + log = new_logger; + end + return old_logger; +end + +----------------------------------// PUBLIC INTERFACE //-- + +return { + _addtimer = addtimer, + + addclient = addclient, + wrapclient = wrapclient, + + loop = loop, + link = link, + step = step, + stats = stats, + closeall = closeall, + addserver = addserver, + getserver = getserver, + setlogger = setlogger, + getsettings = getsettings, + setquitting = setquitting, + removeserver = removeserver, + get_backend = get_backend, + changesettings = changesettings, +} diff --git a/net/xmppclient_listener.lua b/net/xmppclient_listener.lua deleted file mode 100644 index dcc561f3..00000000 --- a/net/xmppclient_listener.lua +++ /dev/null @@ -1,152 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2009 Matthew Wild --- Copyright (C) 2008-2009 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - - -local logger = require "logger"; -local log = logger.init("xmppclient_listener"); -local lxp = require "lxp" -local init_xmlhandlers = require "core.xmlhandlers" -local sm_new_session = require "core.sessionmanager".new_session; - -local connlisteners_register = require "net.connlisteners".register; - -local t_insert = table.insert; -local t_concat = table.concat; -local t_concatall = function (t, sep) local tt = {}; for _, s in ipairs(t) do t_insert(tt, tostring(s)); end return t_concat(tt, sep); end -local m_random = math.random; -local format = string.format; -local sessionmanager = require "core.sessionmanager"; -local sm_new_session, sm_destroy_session = sessionmanager.new_session, sessionmanager.destroy_session; -local sm_streamopened = sessionmanager.streamopened; -local sm_streamclosed = sessionmanager.streamclosed; -local st = require "util.stanza"; - -local stream_callbacks = { stream_tag = "http://etherx.jabber.org/streams|stream", - default_ns = "jabber:client", - streamopened = sm_streamopened, streamclosed = sm_streamclosed, handlestanza = core_process_stanza }; - -function stream_callbacks.error(session, error, data) - if error == "no-stream" then - session.log("debug", "Invalid opening stream header"); - session:close("invalid-namespace"); - elseif session.close then - (session.log or log)("debug", "Client XML parse error: %s", tostring(error)); - session:close("xml-not-well-formed"); - end -end - -local function handleerr(err) log("error", "Traceback[c2s]: %s: %s", tostring(err), debug.traceback()); end -function stream_callbacks.handlestanza(a, b) - xpcall(function () core_process_stanza(a, b) end, handleerr); -end - -local sessions = {}; -local xmppclient = { default_port = 5222, default_mode = "*a" }; - --- These are session methods -- - -local function session_reset_stream(session) - -- Reset stream - local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "|"); - session.parser = parser; - - session.notopen = true; - - function session.data(conn, data) - local ok, err = parser:parse(data); - if ok then return; end - log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " ")); - session:close("xml-not-well-formed"); - end - - return true; -end - - -local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; -local default_stream_attr = { ["xmlns:stream"] = stream_callbacks.stream_tag:gsub("%|[^|]+$", ""), xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; -local function session_close(session, reason) - local log = session.log or log; - if session.conn then - if session.notopen then - session.send("<?xml version='1.0'?>"); - session.send(st.stanza("stream:stream", default_stream_attr):top_tag()); - end - if reason then - if type(reason) == "string" then -- assume stream error - log("info", "Disconnecting client, <stream:error> is: %s", reason); - session.send(st.stanza("stream:error"):tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' })); - elseif type(reason) == "table" then - if reason.condition then - local stanza = st.stanza("stream:error"):tag(reason.condition, stream_xmlns_attr):up(); - if reason.text then - stanza:tag("text", stream_xmlns_attr):text(reason.text):up(); - end - if reason.extra then - stanza:add_child(reason.extra); - end - log("info", "Disconnecting client, <stream:error> is: %s", tostring(stanza)); - session.send(stanza); - elseif reason.name then -- a stanza - log("info", "Disconnecting client, <stream:error> is: %s", tostring(reason)); - session.send(reason); - end - end - end - session.send("</stream:stream>"); - session.conn.close(); - xmppclient.disconnect(session.conn, (reason and (reason.text or reason.condition)) or reason or "session closed"); - end -end - - --- End of session methods -- - -function xmppclient.listener(conn, data) - local session = sessions[conn]; - if not session then - session = sm_new_session(conn); - sessions[conn] = session; - - -- Logging functions -- - - local conn_name = "c2s"..tostring(conn):match("[a-f0-9]+$"); - session.log = logger.init(conn_name); - - session.log("info", "Client connected"); - - -- Client is using legacy SSL (otherwise mod_tls sets this flag) - if conn.ssl() then - session.secure = true; - end - - session.reset_stream = session_reset_stream; - session.close = session_close; - - session_reset_stream(session); -- Initialise, ready for use - - session.dispatch_stanza = stream_callbacks.handlestanza; - end - if data then - session.data(conn, data); - end -end - -function xmppclient.disconnect(conn, err) - local session = sessions[conn]; - if session then - (session.log or log)("info", "Client disconnected: %s", err); - sm_destroy_session(session, err); - sessions[conn] = nil; - session = nil; - collectgarbage("collect"); - end -end - -connlisteners_register("xmppclient", xmppclient); diff --git a/net/xmppcomponent_listener.lua b/net/xmppcomponent_listener.lua deleted file mode 100644 index bee05967..00000000 --- a/net/xmppcomponent_listener.lua +++ /dev/null @@ -1,176 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2009 Matthew Wild --- Copyright (C) 2008-2009 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - -local hosts = _G.hosts; - -local t_concat = table.concat; - -local lxp = require "lxp"; -local logger = require "util.logger"; -local config = require "core.configmanager"; -local connlisteners = require "net.connlisteners"; -local cm_register_component = require "core.componentmanager".register_component; -local cm_deregister_component = require "core.componentmanager".deregister_component; -local uuid_gen = require "util.uuid".generate; -local sha1 = require "util.hashes".sha1; -local st = require "util.stanza"; -local init_xmlhandlers = require "core.xmlhandlers"; - -local sessions = {}; - -local log = logger.init("componentlistener"); - -local component_listener = { default_port = 5347; default_mode = "*a"; default_interface = config.get("*", "core", "component_interface") or "127.0.0.1" }; - -local xmlns_component = 'jabber:component:accept'; - ---- Callbacks/data for xmlhandlers to handle streams for us --- - -local stream_callbacks = { stream_tag = "http://etherx.jabber.org/streams|stream", default_ns = xmlns_component }; - -function stream_callbacks.error(session, error, data, data2) - log("warn", "Error processing component stream: "..tostring(error)); - if error == "no-stream" then - session:close("invalid-namespace"); - elseif error == "xml-parse-error" and data == "unexpected-element-close" then - session.log("warn", "Unexpected close of '%s' tag", data2); - session:close("xml-not-well-formed"); - else - session.log("warn", "External component %s XML parse error: %s", tostring(session.host), tostring(error)); - session:close("xml-not-well-formed"); - end -end - -function stream_callbacks.streamopened(session, attr) - if config.get(attr.to, "core", "component_module") ~= "component" then - -- Trying to act as a component domain which - -- hasn't been configured - session:close{ condition = "host-unknown", text = tostring(attr.to).." does not match any configured external components" }; - return; - end - - -- Store the original host (this is used for config, etc.) - session.user = attr.to; - -- Set the host for future reference - session.host = config.get(attr.to, "core", "component_address") or attr.to; - -- Note that we don't create the internal component - -- until after the external component auths successfully - - session.streamid = uuid_gen(); - session.notopen = nil; - - session.send(st.stanza("stream:stream", { xmlns=xmlns_component, - ["xmlns:stream"]='http://etherx.jabber.org/streams', id=session.streamid, from=session.host }):top_tag()); - -end - -function stream_callbacks.streamclosed(session) - session.send("</stream:stream>"); - session.notopen = true; -end - -local core_process_stanza = core_process_stanza; - -function stream_callbacks.handlestanza(session, stanza) - -- Namespaces are icky. - if not stanza.attr.xmlns and stanza.name == "handshake" then - stanza.attr.xmlns = xmlns_component; - end - return core_process_stanza(session, stanza); -end - ---- Closing a component connection -local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; -local default_stream_attr = { ["xmlns:stream"] = stream_callbacks.stream_tag:gsub("%|[^|]+$", ""), xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; -local function session_close(session, reason) - local log = session.log or log; - if session.conn then - if session.notopen then - session.send("<?xml version='1.0'?>"); - session.send(st.stanza("stream:stream", default_stream_attr):top_tag()); - end - if reason then - if type(reason) == "string" then -- assume stream error - log("info", "Disconnecting component, <stream:error> is: %s", reason); - session.send(st.stanza("stream:error"):tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' })); - elseif type(reason) == "table" then - if reason.condition then - local stanza = st.stanza("stream:error"):tag(reason.condition, stream_xmlns_attr):up(); - if reason.text then - stanza:tag("text", stream_xmlns_attr):text(reason.text):up(); - end - if reason.extra then - stanza:add_child(reason.extra); - end - log("info", "Disconnecting component, <stream:error> is: %s", tostring(stanza)); - session.send(stanza); - elseif reason.name then -- a stanza - log("info", "Disconnecting component, <stream:error> is: %s", tostring(reason)); - session.send(reason); - end - end - end - session.send("</stream:stream>"); - session.conn.close(); - component_listener.disconnect(session.conn, "stream error"); - end -end - ---- Component connlistener -function component_listener.listener(conn, data) - local session = sessions[conn]; - if not session then - local _send = conn.write; - session = { type = "component", conn = conn, send = function (data) return _send(tostring(data)); end }; - sessions[conn] = session; - - -- Logging functions -- - - local conn_name = "jcp"..tostring(conn):match("[a-f0-9]+$"); - session.log = logger.init(conn_name); - session.close = session_close; - - session.log("info", "Incoming Jabber component connection"); - - local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "|"); - session.parser = parser; - - session.notopen = true; - - function session.data(conn, data) - local ok, err = parser:parse(data); - if ok then return; end - session:close("xml-not-well-formed"); - end - - session.dispatch_stanza = stream_callbacks.handlestanza; - - end - if data then - session.data(conn, data); - end -end - -function component_listener.disconnect(conn, err) - local session = sessions[conn]; - if session then - (session.log or log)("info", "component disconnected: %s (%s)", tostring(session.host), tostring(err)); - if session.host then - log("debug", "Deregistering component"); - cm_deregister_component(session.host); - hosts[session.host].connected = nil; - end - sessions[conn] = nil; - for k in pairs(session) do session[k] = nil; end - session = nil; - collectgarbage("collect"); - end -end - -connlisteners.register('xmppcomponent', component_listener); diff --git a/net/xmppserver_listener.lua b/net/xmppserver_listener.lua deleted file mode 100644 index 1f27d841..00000000 --- a/net/xmppserver_listener.lua +++ /dev/null @@ -1,174 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2009 Matthew Wild --- Copyright (C) 2008-2009 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - - -local logger = require "logger"; -local log = logger.init("xmppserver_listener"); -local lxp = require "lxp" -local init_xmlhandlers = require "core.xmlhandlers" -local s2s_new_incoming = require "core.s2smanager".new_incoming; -local s2s_streamopened = require "core.s2smanager".streamopened; -local s2s_streamclosed = require "core.s2smanager".streamclosed; -local s2s_destroy_session = require "core.s2smanager".destroy_session; -local s2s_attempt_connect = require "core.s2smanager".attempt_connection; -local stream_callbacks = { stream_tag = "http://etherx.jabber.org/streams|stream", - default_ns = "jabber:server", - streamopened = s2s_streamopened, streamclosed = s2s_streamclosed, handlestanza = core_process_stanza }; - -function stream_callbacks.error(session, error, data) - if error == "no-stream" then - session:close("invalid-namespace"); - else - session.log("debug", "Server-to-server XML parse error: %s", tostring(error)); - session:close("xml-not-well-formed"); - end -end - -local function handleerr(err) log("error", "Traceback[s2s]: %s: %s", tostring(err), debug.traceback()); end -function stream_callbacks.handlestanza(a, b) - xpcall(function () core_process_stanza(a, b) end, handleerr); -end - -local connlisteners_register = require "net.connlisteners".register; - -local t_insert = table.insert; -local t_concat = table.concat; -local t_concatall = function (t, sep) local tt = {}; for _, s in ipairs(t) do t_insert(tt, tostring(s)); end return t_concat(tt, sep); end -local m_random = math.random; -local format = string.format; -local sessionmanager = require "core.sessionmanager"; -local sm_new_session, sm_destroy_session = sessionmanager.new_session, sessionmanager.destroy_session; -local st = require "util.stanza"; - -local sessions = {}; -local xmppserver = { default_port = 5269, default_mode = "*a" }; - --- These are session methods -- - -local function session_reset_stream(session) - -- Reset stream - local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "|"); - session.parser = parser; - - session.notopen = true; - - function session.data(conn, data) - local ok, err = parser:parse(data); - if ok then return; end - log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " ")); - session:close("xml-not-well-formed"); - end - - return true; -end - - -local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; -local default_stream_attr = { ["xmlns:stream"] = stream_callbacks.stream_tag:gsub("%|[^|]+$", ""), xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; -local function session_close(session, reason) - local log = session.log or log; - if session.conn then - if session.notopen then - session.sends2s("<?xml version='1.0'?>"); - session.sends2s(st.stanza("stream:stream", default_stream_attr):top_tag()); - end - if reason then - if type(reason) == "string" then -- assume stream error - log("info", "Disconnecting %s[%s], <stream:error> is: %s", session.host or "(unknown host)", session.type, reason); - session.sends2s(st.stanza("stream:error"):tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' })); - elseif type(reason) == "table" then - if reason.condition then - local stanza = st.stanza("stream:error"):tag(reason.condition, stream_xmlns_attr):up(); - if reason.text then - stanza:tag("text", stream_xmlns_attr):text(reason.text):up(); - end - if reason.extra then - stanza:add_child(reason.extra); - end - log("info", "Disconnecting %s[%s], <stream:error> is: %s", session.host or "(unknown host)", session.type, tostring(stanza)); - session.sends2s(stanza); - elseif reason.name then -- a stanza - log("info", "Disconnecting %s->%s[%s], <stream:error> is: %s", session.from_host or "(unknown host)", session.to_host or "(unknown host)", session.type, tostring(reason)); - session.sends2s(reason); - end - end - end - session.sends2s("</stream:stream>"); - session.conn.close(); - xmppserver.disconnect(session.conn, "stream error"); - end -end - - --- End of session methods -- - -function xmppserver.listener(conn, data) - local session = sessions[conn]; - if not session then - session = s2s_new_incoming(conn); - sessions[conn] = session; - - -- Logging functions -- - - - local conn_name = "s2sin"..tostring(conn):match("[a-f0-9]+$"); - session.log = logger.init(conn_name); - - session.log("info", "Incoming s2s connection"); - - session.reset_stream = session_reset_stream; - session.close = session_close; - - session_reset_stream(session); -- Initialise, ready for use - - session.dispatch_stanza = stream_callbacks.handlestanza; - end - if data then - session.data(conn, data); - end -end - -function xmppserver.disconnect(conn, err) - local session = sessions[conn]; - if session then - if err and err ~= "closed" and session.srv_hosts then - (session.log or log)("debug", "s2s connection closed unexpectedly"); - if s2s_attempt_connect(session, err) then - (session.log or log)("debug", "...so we're going to try again"); - return; -- Session lives for now - end - end - (session.log or log)("info", "s2s disconnected: %s->%s (%s)", tostring(session.from_host), tostring(session.to_host), tostring(err)); - s2s_destroy_session(session); - sessions[conn] = nil; - session = nil; - collectgarbage("collect"); - end -end - -function xmppserver.register_outgoing(conn, session) - session.direction = "outgoing"; - sessions[conn] = session; - - session.reset_stream = session_reset_stream; - session.close = session_close; - session_reset_stream(session); -- Initialise, ready for use - - --local function handleerr(err) print("Traceback:", err, debug.traceback()); end - --session.stanza_dispatch = function (stanza) return select(2, xpcall(function () return core_process_stanza(session, stanza); end, handleerr)); end -end - -connlisteners_register("xmppserver", xmppserver); - - --- We need to perform some initialisation when a connection is created --- We also need to perform that same initialisation at other points (SASL, TLS, ...) - --- ...and we need to handle data --- ...and record all sessions associated with connections |