diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/adns.lua | 91 | ||||
-rw-r--r-- | net/connlisteners.lua | 50 | ||||
-rw-r--r-- | net/dns.lua | 1079 | ||||
-rw-r--r-- | net/http.lua | 191 | ||||
-rw-r--r-- | net/http/codes.lua | 67 | ||||
-rw-r--r-- | net/http/parser.lua | 160 | ||||
-rw-r--r-- | net/http/server.lua | 303 | ||||
-rw-r--r-- | net/httpserver.lua | 15 | ||||
-rw-r--r-- | net/server.lua | 915 | ||||
-rw-r--r-- | net/server_event.lua | 872 | ||||
-rw-r--r-- | net/server_select.lua | 984 | ||||
-rw-r--r-- | net/xmppclient_listener.lua | 87 | ||||
-rw-r--r-- | net/xmppserver_listener.lua | 95 |
13 files changed, 3856 insertions, 1053 deletions
diff --git a/net/adns.lua b/net/adns.lua new file mode 100644 index 00000000..cd69a627 --- /dev/null +++ b/net/adns.lua @@ -0,0 +1,91 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local server = require "net.server"; +local dns = require "net.dns"; + +local log = require "util.logger".init("adns"); + +local t_insert, t_remove = table.insert, table.remove; +local coroutine, tostring, pcall = coroutine, tostring, pcall; + +local function dummy_send(sock, data, i, j) return (j-i)+1; end + +module "adns" + +function lookup(handler, qname, qtype, qclass) + return coroutine.wrap(function (peek) + if peek then + log("debug", "Records for %s already cached, using those...", qname); + handler(peek); + return; + end + log("debug", "Records for %s not in cache, sending query (%s)...", qname, tostring(coroutine.running())); + local ok, err = dns.query(qname, qtype, qclass); + if ok then + coroutine.yield({ qclass or "IN", qtype or "A", qname, coroutine.running()}); -- Wait for reply + log("debug", "Reply for %s (%s)", qname, tostring(coroutine.running())); + end + if ok then + ok, err = pcall(handler, dns.peek(qname, qtype, qclass)); + else + log("error", "Error sending DNS query: %s", err); + ok, err = pcall(handler, nil, err); + end + if not ok then + log("error", "Error in DNS response handler: %s", tostring(err)); + end + end)(dns.peek(qname, qtype, qclass)); +end + +function cancel(handle, call_handler, reason) + log("warn", "Cancelling DNS lookup for %s", tostring(handle[3])); + dns.cancel(handle[1], handle[2], handle[3], handle[4], call_handler); +end + +function new_async_socket(sock, resolver) + local peername = "<unknown>"; + local listener = {}; + local handler = {}; + function listener.onincoming(conn, data) + if data then + dns.feed(handler, data); + end + end + function listener.ondisconnect(conn, err) + if err then + log("warn", "DNS socket for %s disconnected: %s", peername, err); + local servers = resolver.server; + if resolver.socketset[conn] == resolver.best_server and resolver.best_server == #servers then + log("error", "Exhausted all %d configured DNS servers, next lookup will try %s again", #servers, servers[1]); + end + + resolver:servfail(conn); -- Let the magic commence + end + end + handler = server.wrapclient(sock, "dns", 53, listener); + if not handler then + log("warn", "handler is nil"); + end + + handler.settimeout = function () end + handler.setsockname = function (_, ...) return sock:setsockname(...); end + handler.setpeername = function (_, ...) peername = (...); local ret = sock:setpeername(...); _:set_send(dummy_send); return ret; end + handler.connect = function (_, ...) return sock:connect(...) end + --handler.send = function (_, data) _:write(data); return _.sendbuffer and _.sendbuffer(); end + handler.send = function (_, data) + local getpeername = sock.getpeername; + log("debug", "Sending DNS query to %s", (getpeername and getpeername(sock)) or "<unconnected>"); + return sock:send(data); + end + return handler; +end + +dns.socket_wrapper_set(new_async_socket); + +return _M; diff --git a/net/connlisteners.lua b/net/connlisteners.lua index 431d8717..99ddc720 100644 --- a/net/connlisteners.lua +++ b/net/connlisteners.lua @@ -1,45 +1,15 @@ +-- COMPAT w/pre-0.9 +local log = require "util.logger".init("net.connlisteners"); +local traceback = debug.traceback; -local server_add = require "net.server".add; -local log = require "util.logger".init("connlisteners"); +module "httpserver" -local dofile, pcall, error = - dofile, pcall, error - -module "connlisteners" - -local listeners = {}; - -function register(name, listener) - if listeners[name] and listeners[name] ~= listener then - log("warning", "Listener %s is already registered, not registering any more", name); - return false; - end - listeners[name] = listener; - log("info", "Registered connection listener %s", name); - return true; -end - -function deregister(name) - listeners[name] = nil; -end - -function get(name) - local h = listeners[name]; - if not h then - pcall(dofile, "net/"..name:gsub("[^%w%-]", "_").."_listener.lua"); - h = listeners[name]; - end - return h; +function fail() + log("error", "Attempt to use legacy connlisteners API. For more info see http://prosody.im/doc/developers/network"); + log("error", "Legacy connlisteners API usage, %s", traceback("", 2)); end -function start(name, udata) - local h = get(name); - if not h then - error("No such connection module: "..name, 0); - end - return server_add(h, - udata.port or h.default_port or error("Can't start listener "..name.." because no port was specified, and it has no default port", 0), - udata.interface or "*", udata.mode or h.default_mode or 1, udata.ssl ); -end +register, deregister = fail, fail; +get, start = fail, fail, epic_fail; -return _M;
\ No newline at end of file +return _M; diff --git a/net/dns.lua b/net/dns.lua new file mode 100644 index 00000000..c9c51fe8 --- /dev/null +++ b/net/dns.lua @@ -0,0 +1,1079 @@ +-- Prosody IM +-- This file is included with Prosody IM. It has modifications, +-- which are hereby placed in the public domain. + + +-- todo: quick (default) header generation +-- todo: nxdomain, error handling +-- todo: cache results of encodeName + + +-- reference: http://tools.ietf.org/html/rfc1035 +-- reference: http://tools.ietf.org/html/rfc1876 (LOC) + + +local socket = require "socket"; +local timer = require "util.timer"; + +local _, windows = pcall(require, "util.windows"); +local is_windows = (_ and windows) or os.getenv("WINDIR"); + +local coroutine, io, math, string, table = + coroutine, io, math, string, table; + +local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type= + ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type; + +local ztact = { -- public domain 20080404 lua@ztact.com + get = function(parent, ...) + local len = select('#', ...); + for i=1,len do + parent = parent[select(i, ...)]; + if parent == nil then break; end + end + return parent; + end; + set = function(parent, ...) + local len = select('#', ...); + local key, value = select(len-1, ...); + local cutpoint, cutkey; + + for i=1,len-2 do + local key = select (i, ...) + local child = parent[key] + + if value == nil then + if child == nil then + return; + elseif next(child, next(child)) then + cutpoint = nil; cutkey = nil; + elseif cutpoint == nil then + cutpoint = parent; cutkey = key; + end + elseif child == nil then + child = {}; + parent[key] = child; + end + parent = child + end + + if value == nil and cutpoint then + cutpoint[cutkey] = nil; + else + parent[key] = value; + return value; + end + end; +}; +local get, set = ztact.get, ztact.set; + +local default_timeout = 15; + +-------------------------------------------------- module dns +module('dns') +local dns = _M; + + +-- dns type & class codes ------------------------------ dns type & class codes + + +local append = table.insert + + +local function highbyte(i) -- - - - - - - - - - - - - - - - - - - highbyte + return (i-(i%0x100))/0x100; +end + + +local function augment (t) -- - - - - - - - - - - - - - - - - - - - augment + local a = {}; + for i,s in pairs(t) do + a[i] = s; + a[s] = s; + a[string.lower(s)] = s; + end + return a; +end + + +local function encode (t) -- - - - - - - - - - - - - - - - - - - - - encode + local code = {}; + for i,s in pairs(t) do + local word = string.char(highbyte(i), i%0x100); + code[i] = word; + code[s] = word; + code[string.lower(s)] = word; + end + return code; +end + + +dns.types = { + 'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS', + 'PTR', 'HINFO', 'MINFO', 'MX', 'TXT', + [ 28] = 'AAAA', [ 29] = 'LOC', [ 33] = 'SRV', + [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' }; + + +dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' }; + + +dns.type = augment (dns.types); +dns.class = augment (dns.classes); +dns.typecode = encode (dns.types); +dns.classcode = encode (dns.classes); + + + +local function standardize(qname, qtype, qclass) -- - - - - - - standardize + if string.byte(qname, -1) ~= 0x2E then qname = qname..'.'; end + qname = string.lower(qname); + return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN']; +end + + +local function prune(rrs, time, soft) -- - - - - - - - - - - - - - - prune + time = time or socket.gettime(); + for i,rr in pairs(rrs) do + if rr.tod then + -- rr.tod = rr.tod - 50 -- accelerated decripitude + rr.ttl = math.floor(rr.tod - time); + if rr.ttl <= 0 then + table.remove(rrs, i); + return prune(rrs, time, soft); -- Re-iterate + end + elseif soft == 'soft' then -- What is this? I forget! + assert(rr.ttl == 0); + rrs[i] = nil; + end + end +end + + +-- metatables & co. ------------------------------------------ metatables & co. + + +local resolver = {}; +resolver.__index = resolver; + +resolver.timeout = default_timeout; + +local function default_rr_tostring(rr) + local rr_val = rr.type and rr[rr.type:lower()]; + if type(rr_val) ~= "string" then + return "<UNKNOWN RDATA TYPE>"; + end + return rr_val; +end + +local special_tostrings = { + LOC = resolver.LOC_tostring; + MX = function (rr) + return string.format('%2i %s', rr.pref, rr.mx); + end; + SRV = function (rr) + local s = rr.srv; + return string.format('%5d %5d %5d %s', s.priority, s.weight, s.port, s.target); + end; +}; + +local rr_metatable = {}; -- - - - - - - - - - - - - - - - - - - rr_metatable +function rr_metatable.__tostring(rr) + local rr_string = (special_tostrings[rr.type] or default_rr_tostring)(rr); + return string.format('%2s %-5s %6i %-28s %s', rr.class, rr.type, rr.ttl, rr.name, rr_string); +end + + +local rrs_metatable = {}; -- - - - - - - - - - - - - - - - - - rrs_metatable +function rrs_metatable.__tostring(rrs) + local t = {}; + for i,rr in pairs(rrs) do + append(t, tostring(rr)..'\n'); + end + return table.concat(t); +end + + +local cache_metatable = {}; -- - - - - - - - - - - - - - - - cache_metatable +function cache_metatable.__tostring(cache) + local time = socket.gettime(); + local t = {}; + for class,types in pairs(cache) do + for type,names in pairs(types) do + for name,rrs in pairs(names) do + prune(rrs, time); + append(t, tostring(rrs)); + end + end + end + return table.concat(t); +end + + +function resolver:new() -- - - - - - - - - - - - - - - - - - - - - resolver + local r = { active = {}, cache = {}, unsorted = {} }; + setmetatable(r, resolver); + setmetatable(r.cache, cache_metatable); + setmetatable(r.unsorted, { __mode = 'kv' }); + return r; +end + + +-- packet layer -------------------------------------------------- packet layer + + +function dns.random(...) -- - - - - - - - - - - - - - - - - - - dns.random + math.randomseed(math.floor(10000*socket.gettime()) % 0x100000000); + dns.random = math.random; + return dns.random(...); +end + + +local function encodeHeader(o) -- - - - - - - - - - - - - - - encodeHeader + o = o or {}; + o.id = o.id or dns.random(0, 0xffff); -- 16b (random) id + + o.rd = o.rd or 1; -- 1b 1 recursion desired + o.tc = o.tc or 0; -- 1b 1 truncated response + o.aa = o.aa or 0; -- 1b 1 authoritative response + o.opcode = o.opcode or 0; -- 4b 0 query + -- 1 inverse query + -- 2 server status request + -- 3-15 reserved + o.qr = o.qr or 0; -- 1b 0 query, 1 response + + o.rcode = o.rcode or 0; -- 4b 0 no error + -- 1 format error + -- 2 server failure + -- 3 name error + -- 4 not implemented + -- 5 refused + -- 6-15 reserved + o.z = o.z or 0; -- 3b 0 resvered + o.ra = o.ra or 0; -- 1b 1 recursion available + + o.qdcount = o.qdcount or 1; -- 16b number of question RRs + o.ancount = o.ancount or 0; -- 16b number of answers RRs + o.nscount = o.nscount or 0; -- 16b number of nameservers RRs + o.arcount = o.arcount or 0; -- 16b number of additional RRs + + -- string.char() rounds, so prevent roundup with -0.4999 + local header = string.char( + highbyte(o.id), o.id %0x100, + o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr, + o.rcode + 16*o.z + 128*o.ra, + highbyte(o.qdcount), o.qdcount %0x100, + highbyte(o.ancount), o.ancount %0x100, + highbyte(o.nscount), o.nscount %0x100, + highbyte(o.arcount), o.arcount %0x100 + ); + + return header, o.id; +end + + +local function encodeName(name) -- - - - - - - - - - - - - - - - encodeName + local t = {}; + for part in string.gmatch(name, '[^.]+') do + append(t, string.char(string.len(part))); + append(t, part); + end + append(t, string.char(0)); + return table.concat(t); +end + + +local function encodeQuestion(qname, qtype, qclass) -- - - - encodeQuestion + qname = encodeName(qname); + qtype = dns.typecode[qtype or 'a']; + qclass = dns.classcode[qclass or 'in']; + return qname..qtype..qclass; +end + + +function resolver:byte(len) -- - - - - - - - - - - - - - - - - - - - - byte + len = len or 1; + local offset = self.offset; + local last = offset + len - 1; + if last > #self.packet then + error(string.format('out of bounds: %i>%i', last, #self.packet)); + end + self.offset = offset + len; + return string.byte(self.packet, offset, last); +end + + +function resolver:word() -- - - - - - - - - - - - - - - - - - - - - - word + local b1, b2 = self:byte(2); + return 0x100*b1 + b2; +end + + +function resolver:dword () -- - - - - - - - - - - - - - - - - - - - - dword + local b1, b2, b3, b4 = self:byte(4); + --print('dword', b1, b2, b3, b4); + return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4; +end + + +function resolver:sub(len) -- - - - - - - - - - - - - - - - - - - - - - sub + len = len or 1; + local s = string.sub(self.packet, self.offset, self.offset + len - 1); + self.offset = self.offset + len; + return s; +end + + +function resolver:header(force) -- - - - - - - - - - - - - - - - - - header + local id = self:word(); + --print(string.format(':header id %x', id)); + if not self.active[id] and not force then return nil; end + + local h = { id = id }; + + local b1, b2 = self:byte(2); + + h.rd = b1 %2; + h.tc = b1 /2%2; + h.aa = b1 /4%2; + h.opcode = b1 /8%16; + h.qr = b1 /128; + + h.rcode = b2 %16; + h.z = b2 /16%8; + h.ra = b2 /128; + + h.qdcount = self:word(); + h.ancount = self:word(); + h.nscount = self:word(); + h.arcount = self:word(); + + for k,v in pairs(h) do h[k] = v-v%1; end + + return h; +end + + +function resolver:name() -- - - - - - - - - - - - - - - - - - - - - - name + local remember, pointers = nil, 0; + local len = self:byte(); + local n = {}; + if len == 0 then return "." end -- Root label + while len > 0 do + if len >= 0xc0 then -- name is "compressed" + pointers = pointers + 1; + if pointers >= 20 then error('dns error: 20 pointers'); end; + local offset = ((len-0xc0)*0x100) + self:byte(); + remember = remember or self.offset; + self.offset = offset + 1; -- +1 for lua + else -- name is not compressed + append(n, self:sub(len)..'.'); + end + len = self:byte(); + end + self.offset = remember or self.offset; + return table.concat(n); +end + + +function resolver:question() -- - - - - - - - - - - - - - - - - - question + local q = {}; + q.name = self:name(); + q.type = dns.type[self:word()]; + q.class = dns.class[self:word()]; + return q; +end + + +function resolver:A(rr) -- - - - - - - - - - - - - - - - - - - - - - - - A + local b1, b2, b3, b4 = self:byte(4); + rr.a = string.format('%i.%i.%i.%i', b1, b2, b3, b4); +end + +function resolver:AAAA(rr) + local addr = {}; + for i = 1, rr.rdlength, 2 do + local b1, b2 = self:byte(2); + table.insert(addr, ("%02x%02x"):format(b1, b2)); + end + addr = table.concat(addr, ":"):gsub("%f[%x]0+(%x)","%1"); + local zeros = {}; + for item in addr:gmatch(":[0:]+:") do + table.insert(zeros, item) + end + if #zeros == 0 then + rr.aaaa = addr; + return + elseif #zeros > 1 then + table.sort(zeros, function(a, b) return #a > #b end); + end + rr.aaaa = addr:gsub(zeros[1], "::", 1):gsub("^0::", "::"):gsub("::0$", "::"); +end + +function resolver:CNAME(rr) -- - - - - - - - - - - - - - - - - - - - CNAME + rr.cname = self:name(); +end + + +function resolver:MX(rr) -- - - - - - - - - - - - - - - - - - - - - - - MX + rr.pref = self:word(); + rr.mx = self:name(); +end + + +function resolver:LOC_nibble_power() -- - - - - - - - - - LOC_nibble_power + local b = self:byte(); + --print('nibbles', ((b-(b%0x10))/0x10), (b%0x10)); + return ((b-(b%0x10))/0x10) * (10^(b%0x10)); +end + + +function resolver:LOC(rr) -- - - - - - - - - - - - - - - - - - - - - - LOC + rr.version = self:byte(); + if rr.version == 0 then + rr.loc = rr.loc or {}; + rr.loc.size = self:LOC_nibble_power(); + rr.loc.horiz_pre = self:LOC_nibble_power(); + rr.loc.vert_pre = self:LOC_nibble_power(); + rr.loc.latitude = self:dword(); + rr.loc.longitude = self:dword(); + rr.loc.altitude = self:dword(); + end +end + + +local function LOC_tostring_degrees(f, pos, neg) -- - - - - - - - - - - - - + f = f - 0x80000000; + if f < 0 then pos = neg; f = -f; end + local deg, min, msec; + msec = f%60000; + f = (f-msec)/60000; + min = f%60; + deg = (f-min)/60; + return string.format('%3d %2d %2.3f %s', deg, min, msec/1000, pos); +end + + +function resolver.LOC_tostring(rr) -- - - - - - - - - - - - - LOC_tostring + local t = {}; + + --[[ + for k,name in pairs { 'size', 'horiz_pre', 'vert_pre', 'latitude', 'longitude', 'altitude' } do + append(t, string.format('%4s%-10s: %12.0f\n', '', name, rr.loc[name])); + end + --]] + + append(t, string.format( + '%s %s %.2fm %.2fm %.2fm %.2fm', + LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'), + LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'), + (rr.loc.altitude - 10000000) / 100, + rr.loc.size / 100, + rr.loc.horiz_pre / 100, + rr.loc.vert_pre / 100 + )); + + return table.concat(t); +end + + +function resolver:NS(rr) -- - - - - - - - - - - - - - - - - - - - - - - NS + rr.ns = self:name(); +end + + +function resolver:SOA(rr) -- - - - - - - - - - - - - - - - - - - - - - SOA +end + + +function resolver:SRV(rr) -- - - - - - - - - - - - - - - - - - - - - - SRV + rr.srv = {}; + rr.srv.priority = self:word(); + rr.srv.weight = self:word(); + rr.srv.port = self:word(); + rr.srv.target = self:name(); +end + +function resolver:PTR(rr) + rr.ptr = self:name(); +end + +function resolver:TXT(rr) -- - - - - - - - - - - - - - - - - - - - - - TXT + rr.txt = self:sub (self:byte()); +end + + +function resolver:rr() -- - - - - - - - - - - - - - - - - - - - - - - - rr + local rr = {}; + setmetatable(rr, rr_metatable); + rr.name = self:name(self); + rr.type = dns.type[self:word()] or rr.type; + rr.class = dns.class[self:word()] or rr.class; + rr.ttl = 0x10000*self:word() + self:word(); + rr.rdlength = self:word(); + + if rr.ttl <= 0 then + rr.tod = self.time + 30; + else + rr.tod = self.time + rr.ttl; + end + + local remember = self.offset; + local rr_parser = self[dns.type[rr.type]]; + if rr_parser then rr_parser(self, rr); end + self.offset = remember; + rr.rdata = self:sub(rr.rdlength); + return rr; +end + + +function resolver:rrs (count) -- - - - - - - - - - - - - - - - - - - - - rrs + local rrs = {}; + for i = 1,count do append(rrs, self:rr()); end + return rrs; +end + + +function resolver:decode(packet, force) -- - - - - - - - - - - - - - decode + self.packet, self.offset = packet, 1; + local header = self:header(force); + if not header then return nil; end + local response = { header = header }; + + response.question = {}; + local offset = self.offset; + for i = 1,response.header.qdcount do + append(response.question, self:question()); + end + response.question.raw = string.sub(self.packet, offset, self.offset - 1); + + if not force then + if not self.active[response.header.id] or not self.active[response.header.id][response.question.raw] then + self.active[response.header.id] = nil; + return nil; + end + end + + response.answer = self:rrs(response.header.ancount); + response.authority = self:rrs(response.header.nscount); + response.additional = self:rrs(response.header.arcount); + + return response; +end + + +-- socket layer -------------------------------------------------- socket layer + + +resolver.delays = { 1, 3 }; + + +function resolver:addnameserver(address) -- - - - - - - - - - addnameserver + self.server = self.server or {}; + append(self.server, address); +end + + +function resolver:setnameserver(address) -- - - - - - - - - - setnameserver + self.server = {}; + self:addnameserver(address); +end + + +function resolver:adddefaultnameservers() -- - - - - adddefaultnameservers + if is_windows then + if windows and windows.get_nameservers then + for _, server in ipairs(windows.get_nameservers()) do + self:addnameserver(server); + end + end + if not self.server or #self.server == 0 then + -- TODO log warning about no nameservers, adding opendns servers as fallback + self:addnameserver("208.67.222.222"); + self:addnameserver("208.67.220.220"); + end + else -- posix + local resolv_conf = io.open("/etc/resolv.conf"); + if resolv_conf then + for line in resolv_conf:lines() do + line = line:gsub("#.*$", "") + :match('^%s*nameserver%s+(.*)%s*$'); + if line then + line:gsub("%f[%d.](%d+%.%d+%.%d+%.%d+)%f[^%d.]", function (address) + self:addnameserver(address) + end); + end + end + end + if not self.server or #self.server == 0 then + -- TODO log warning about no nameservers, adding localhost as the default nameserver + self:addnameserver("127.0.0.1"); + end + end +end + + +function resolver:getsocket(servernum) -- - - - - - - - - - - - - getsocket + self.socket = self.socket or {}; + self.socketset = self.socketset or {}; + + local sock = self.socket[servernum]; + if sock then return sock; end + + local err; + sock, err = socket.udp(); + if not sock then + return nil, err; + end + if self.socket_wrapper then sock = self.socket_wrapper(sock, self); end + sock:settimeout(0); + -- todo: attempt to use a random port, fallback to 0 + sock:setsockname('*', 0); + sock:setpeername(self.server[servernum], 53); + self.socket[servernum] = sock; + self.socketset[sock] = servernum; + return sock; +end + +function resolver:voidsocket(sock) + if self.socket[sock] then + self.socketset[self.socket[sock]] = nil; + self.socket[sock] = nil; + elseif self.socketset[sock] then + self.socket[self.socketset[sock]] = nil; + self.socketset[sock] = nil; + end + sock:close(); +end + +function resolver:socket_wrapper_set(func) -- - - - - - - socket_wrapper_set + self.socket_wrapper = func; +end + + +function resolver:closeall () -- - - - - - - - - - - - - - - - - - closeall + for i,sock in ipairs(self.socket) do + self.socket[i] = nil; + self.socketset[sock] = nil; + sock:close(); + end +end + + +function resolver:remember(rr, type) -- - - - - - - - - - - - - - remember + --print ('remember', type, rr.class, rr.type, rr.name) + local qname, qtype, qclass = standardize(rr.name, rr.type, rr.class); + + if type ~= '*' then + type = qtype; + local all = get(self.cache, qclass, '*', qname); + --print('remember all', all); + if all then append(all, rr); end + end + + self.cache = self.cache or setmetatable({}, cache_metatable); + local rrs = get(self.cache, qclass, type, qname) or + set(self.cache, qclass, type, qname, setmetatable({}, rrs_metatable)); + append(rrs, rr); + + if type == 'MX' then self.unsorted[rrs] = true; end +end + + +local function comp_mx(a, b) -- - - - - - - - - - - - - - - - - - - comp_mx + return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref); +end + + +function resolver:peek (qname, qtype, qclass) -- - - - - - - - - - - - peek + qname, qtype, qclass = standardize(qname, qtype, qclass); + local rrs = get(self.cache, qclass, qtype, qname); + if not rrs then return nil; end + if prune(rrs, socket.gettime()) and qtype == '*' or not next(rrs) then + set(self.cache, qclass, qtype, qname, nil); + return nil; + end + if self.unsorted[rrs] then table.sort (rrs, comp_mx); end + return rrs; +end + + +function resolver:purge(soft) -- - - - - - - - - - - - - - - - - - - purge + if soft == 'soft' then + self.time = socket.gettime(); + for class,types in pairs(self.cache or {}) do + for type,names in pairs(types) do + for name,rrs in pairs(names) do + prune(rrs, self.time, 'soft') + end + end + end + else self.cache = setmetatable({}, cache_metatable); end +end + + +function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query + qname, qtype, qclass = standardize(qname, qtype, qclass) + + if not self.server then self:adddefaultnameservers(); end + + local question = encodeQuestion(qname, qtype, qclass); + local peek = self:peek (qname, qtype, qclass); + if peek then return peek; end + + local header, id = encodeHeader(); + --print ('query id', id, qclass, qtype, qname) + local o = { + packet = header..question, + server = self.best_server, + delay = 1, + retry = socket.gettime() + self.delays[1] + }; + + -- remember the query + self.active[id] = self.active[id] or {}; + self.active[id][question] = o; + + -- remember which coroutine wants the answer + local co = coroutine.running(); + if co then + set(self.wanted, qclass, qtype, qname, co, true); + --set(self.yielded, co, qclass, qtype, qname, true); + end + + local conn, err = self:getsocket(o.server) + if not conn then + return nil, err; + end + conn:send (o.packet) + + if timer and self.timeout then + local num_servers = #self.server; + local i = 1; + timer.add_task(self.timeout, function () + if get(self.wanted, qclass, qtype, qname, co) then + if i < num_servers then + i = i + 1; + self:servfail(conn); + o.server = self.best_server; + conn, err = self:getsocket(o.server); + if conn then + conn:send(o.packet); + return self.timeout; + end + end + -- Tried everything, failed + self:cancel(qclass, qtype, qname, co, true); + end + end) + end + return true; +end + +function resolver:servfail(sock) + -- Resend all queries for this server + + local num = self.socketset[sock] + + -- Socket is dead now + self:voidsocket(sock); + + -- Find all requests to the down server, and retry on the next server + self.time = socket.gettime(); + for id,queries in pairs(self.active) do + for question,o in pairs(queries) do + if o.server == num then -- This request was to the broken server + o.server = o.server + 1 -- Use next server + if o.server > #self.server then + o.server = 1; + end + + o.retries = (o.retries or 0) + 1; + if o.retries >= #self.server then + --print('timeout'); + queries[question] = nil; + else + local _a = self:getsocket(o.server); + if _a then _a:send(o.packet); end + end + end + end + if next(queries) == nil then + self.active[id] = nil; + end + end + + if num == self.best_server then + self.best_server = self.best_server + 1; + if self.best_server > #self.server then + -- Exhausted all servers, try first again + self.best_server = 1; + end + end +end + +function resolver:settimeout(seconds) + self.timeout = seconds; +end + +function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive + --print('receive'); print(self.socket); + self.time = socket.gettime(); + rset = rset or self.socket; + + local response; + for i,sock in pairs(rset) do + + if self.socketset[sock] then + local packet = sock:receive(); + if packet then + response = self:decode(packet); + if response and self.active[response.header.id] + and self.active[response.header.id][response.question.raw] then + --print('received response'); + --self.print(response); + + for j,rr in pairs(response.answer) do + if rr.name:sub(-#response.question[1].name, -1) == response.question[1].name then + self:remember(rr, response.question[1].type) + end + end + + -- retire the query + local queries = self.active[response.header.id]; + queries[response.question.raw] = nil; + + if not next(queries) then self.active[response.header.id] = nil; end + if not next(self.active) then self:closeall(); end + + -- was the query on the wanted list? + local q = response.question[1]; + local cos = get(self.wanted, q.class, q.type, q.name); + if cos then + for co in pairs(cos) do + set(self.yielded, co, q.class, q.type, q.name, nil); + if coroutine.status(co) == "suspended" then coroutine.resume(co); end + end + set(self.wanted, q.class, q.type, q.name, nil); + end + end + + end + end + end + + return response; +end + + +function resolver:feed(sock, packet, force) + --print('receive'); print(self.socket); + self.time = socket.gettime(); + + local response = self:decode(packet, force); + if response and self.active[response.header.id] + and self.active[response.header.id][response.question.raw] then + --print('received response'); + --self.print(response); + + for j,rr in pairs(response.answer) do + self:remember(rr, response.question[1].type); + end + + -- retire the query + local queries = self.active[response.header.id]; + queries[response.question.raw] = nil; + if not next(queries) then self.active[response.header.id] = nil; end + if not next(self.active) then self:closeall(); end + + -- was the query on the wanted list? + local q = response.question[1]; + if q then + local cos = get(self.wanted, q.class, q.type, q.name); + if cos then + for co in pairs(cos) do + set(self.yielded, co, q.class, q.type, q.name, nil); + if coroutine.status(co) == "suspended" then coroutine.resume(co); end + end + set(self.wanted, q.class, q.type, q.name, nil); + end + end + end + + return response; +end + +function resolver:cancel(qclass, qtype, qname, co, call_handler) + local cos = get(self.wanted, qclass, qtype, qname); + if cos then + if call_handler then + coroutine.resume(co); + end + cos[co] = nil; + end +end + +function resolver:pulse() -- - - - - - - - - - - - - - - - - - - - - pulse + --print(':pulse'); + while self:receive() do end + if not next(self.active) then return nil; end + + self.time = socket.gettime(); + for id,queries in pairs(self.active) do + for question,o in pairs(queries) do + if self.time >= o.retry then + + o.server = o.server + 1; + if o.server > #self.server then + o.server = 1; + o.delay = o.delay + 1; + end + + if o.delay > #self.delays then + --print('timeout'); + queries[question] = nil; + if not next(queries) then self.active[id] = nil; end + if not next(self.active) then return nil; end + else + --print('retry', o.server, o.delay); + local _a = self.socket[o.server]; + if _a then _a:send(o.packet); end + o.retry = self.time + self.delays[o.delay]; + end + end + end + end + + if next(self.active) then return true; end + return nil; +end + + +function resolver:lookup(qname, qtype, qclass) -- - - - - - - - - - lookup + self:query (qname, qtype, qclass) + while self:pulse() do + local recvt = {} + for i, s in ipairs(self.socket) do + recvt[i] = s + end + socket.select(recvt, nil, 4) + end + --print(self.cache); + return self:peek(qname, qtype, qclass); +end + +function resolver:lookupex(handler, qname, qtype, qclass) -- - - - - - - - - - lookup + return self:peek(qname, qtype, qclass) or self:query(qname, qtype, qclass); +end + +function resolver:tohostname(ip) + return dns.lookup(ip:gsub("(%d+)%.(%d+)%.(%d+)%.(%d+)", "%4.%3.%2.%1.in-addr.arpa."), "PTR"); +end + +--print ---------------------------------------------------------------- print + + +local hints = { -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints + qr = { [0]='query', 'response' }, + opcode = { [0]='query', 'inverse query', 'server status request' }, + aa = { [0]='non-authoritative', 'authoritative' }, + tc = { [0]='complete', 'truncated' }, + rd = { [0]='recursion not desired', 'recursion desired' }, + ra = { [0]='recursion not available', 'recursion available' }, + z = { [0]='(reserved)' }, + rcode = { [0]='no error', 'format error', 'server failure', 'name error', 'not implemented' }, + + type = dns.type, + class = dns.class +}; + + +local function hint(p, s) -- - - - - - - - - - - - - - - - - - - - - - hint + return (hints[s] and hints[s][p[s]]) or ''; +end + + +function resolver.print(response) -- - - - - - - - - - - - - resolver.print + for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z', + 'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do + print( string.format('%-30s', 'header.'..s), response.header[s], hint(response.header, s) ); + end + + for i,question in ipairs(response.question) do + print(string.format ('question[%i].name ', i), question.name); + print(string.format ('question[%i].type ', i), question.type); + print(string.format ('question[%i].class ', i), question.class); + end + + local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 }; + local tmp; + for s,s in pairs({'answer', 'authority', 'additional'}) do + for i,rr in pairs(response[s]) do + for j,t in pairs({ 'name', 'type', 'class', 'ttl', 'rdlength' }) do + tmp = string.format('%s[%i].%s', s, i, t); + print(string.format('%-30s', tmp), rr[t], hint(rr, t)); + end + for j,t in pairs(rr) do + if not common[j] then + tmp = string.format('%s[%i].%s', s, i, j); + print(string.format('%-30s %s', tostring(tmp), tostring(t))); + end + end + end + end +end + + +-- module api ------------------------------------------------------ module api + + +function dns.resolver () -- - - - - - - - - - - - - - - - - - - - - resolver + -- this function seems to be redundant with resolver.new () + + local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {}, best_server = 1 }; + setmetatable (r, resolver); + setmetatable (r.cache, cache_metatable); + setmetatable (r.unsorted, { __mode = 'kv' }); + return r; +end + +local _resolver = dns.resolver(); +dns._resolver = _resolver; + +function dns.lookup(...) -- - - - - - - - - - - - - - - - - - - - - lookup + return _resolver:lookup(...); +end + +function dns.tohostname(...) + return _resolver:tohostname(...); +end + +function dns.purge(...) -- - - - - - - - - - - - - - - - - - - - - - purge + return _resolver:purge(...); +end + +function dns.peek(...) -- - - - - - - - - - - - - - - - - - - - - - - peek + return _resolver:peek(...); +end + +function dns.query(...) -- - - - - - - - - - - - - - - - - - - - - - query + return _resolver:query(...); +end + +function dns.feed(...) -- - - - - - - - - - - - - - - - - - - - - - - feed + return _resolver:feed(...); +end + +function dns.cancel(...) -- - - - - - - - - - - - - - - - - - - - - - cancel + return _resolver:cancel(...); +end + +function dns.settimeout(...) + return _resolver:settimeout(...); +end + +function dns.socket_wrapper_set(...) -- - - - - - - - - socket_wrapper_set + return _resolver:socket_wrapper_set(...); +end + +return dns; diff --git a/net/http.lua b/net/http.lua new file mode 100644 index 00000000..3b783a41 --- /dev/null +++ b/net/http.lua @@ -0,0 +1,191 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local socket = require "socket" +local b64 = require "util.encodings".base64.encode; +local url = require "socket.url" +local httpstream_new = require "net.http.parser".new; +local util_http = require "util.http"; + +local ssl_available = pcall(require, "ssl"); + +local server = require "net.server" + +local t_insert, t_concat = table.insert, table.concat; +local pairs = pairs; +local tonumber, tostring, xpcall, select, traceback = + tonumber, tostring, xpcall, select, debug.traceback; + +local log = require "util.logger".init("http"); + +module "http" + +local requests = {}; -- Open requests + +local listener = { default_port = 80, default_mode = "*a" }; + +function listener.onconnect(conn) + local req = requests[conn]; + -- Send the request + local request_line = { req.method or "GET", " ", req.path, " HTTP/1.1\r\n" }; + if req.query then + t_insert(request_line, 4, "?"..req.query); + end + + conn:write(t_concat(request_line)); + local t = { [2] = ": ", [4] = "\r\n" }; + for k, v in pairs(req.headers) do + t[1], t[3] = k, v; + conn:write(t_concat(t)); + end + conn:write("\r\n"); + + if req.body then + conn:write(req.body); + end +end + +function listener.onincoming(conn, data) + local request = requests[conn]; + + if not request then + log("warn", "Received response from connection %s with no request attached!", tostring(conn)); + return; + end + + if data and request.reader then + request:reader(data); + end +end + +function listener.ondisconnect(conn, err) + local request = requests[conn]; + if request and request.conn then + request:reader(nil, err); + end + requests[conn] = nil; +end + +local function request_reader(request, data, err) + if not request.parser then + local function error_cb(reason) + if request.callback then + request.callback(reason or "connection-closed", 0, request); + request.callback = nil; + end + destroy_request(request); + end + + if not data then + error_cb(err); + return; + end + + local function success_cb(r) + if request.callback then + request.callback(r.body, r.code, r, request); + request.callback = nil; + end + destroy_request(request); + end + local function options_cb() + return request; + end + request.parser = httpstream_new(success_cb, error_cb, "client", options_cb); + end + request.parser:feed(data); +end + +local function handleerr(err) log("error", "Traceback[http]: %s", traceback(tostring(err), 2)); end +function request(u, ex, callback) + local req = url.parse(u); + + if not (req and req.host) then + callback(nil, 0, req); + return nil, "invalid-url"; + end + + if not req.path then + req.path = "/"; + end + + local method, headers, body; + + headers = { + ["Host"] = req.host; + ["User-Agent"] = "Prosody XMPP Server"; + }; + + if req.userinfo then + headers["Authorization"] = "Basic "..b64(req.userinfo); + end + + if ex then + req.onlystatus = ex.onlystatus; + body = ex.body; + if body then + method = "POST"; + headers["Content-Length"] = tostring(#body); + headers["Content-Type"] = "application/x-www-form-urlencoded"; + end + if ex.method then method = ex.method; end + if ex.headers then + for k, v in pairs(ex.headers) do + headers[k] = v; + end + end + end + + -- Attach to request object + req.method, req.headers, req.body = method, headers, body; + + local using_https = req.scheme == "https"; + if using_https and not ssl_available then + error("SSL not available, unable to contact https URL"); + end + local port = tonumber(req.port) or (using_https and 443 or 80); + + -- Connect the socket, and wrap it with net.server + local conn = socket.tcp(); + conn:settimeout(10); + local ok, err = conn:connect(req.host, port); + if not ok and err ~= "timeout" then + callback(nil, 0, req); + return nil, err; + end + + local sslctx = false; + if using_https then + sslctx = ex and ex.sslctx or { mode = "client", protocol = "sslv23", options = { "no_sslv2" } }; + end + + req.handler, req.conn = server.wrapclient(conn, req.host, port, listener, "*a", sslctx); + req.write = function (...) return req.handler:write(...); end + + req.callback = function (content, code, request, response) log("debug", "Calling callback, status %s", code or "---"); return select(2, xpcall(function () return callback(content, code, request, response) end, handleerr)); end + req.reader = request_reader; + req.state = "status"; + + requests[req.handler] = req; + return req; +end + +function destroy_request(request) + if request.conn then + request.conn = nil; + request.handler:close() + end +end + +local urlencode, urldecode = util_http.urlencode, util_http.urldecode; +local formencode, formdecode = util_http.formencode, util_http.formdecode; + +_M.urlencode, _M.urldecode = urlencode, urldecode; +_M.formencode, _M.formdecode = formencode, formdecode; + +return _M; diff --git a/net/http/codes.lua b/net/http/codes.lua new file mode 100644 index 00000000..0cadd079 --- /dev/null +++ b/net/http/codes.lua @@ -0,0 +1,67 @@ + +local response_codes = { + -- Source: http://www.iana.org/assignments/http-status-codes + -- s/^\(\d*\)\s*\(.*\S\)\s*\[RFC.*\]\s*$/^I["\1"] = "\2"; + [100] = "Continue"; + [101] = "Switching Protocols"; + [102] = "Processing"; + + [200] = "OK"; + [201] = "Created"; + [202] = "Accepted"; + [203] = "Non-Authoritative Information"; + [204] = "No Content"; + [205] = "Reset Content"; + [206] = "Partial Content"; + [207] = "Multi-Status"; + [208] = "Already Reported"; + [226] = "IM Used"; + + [300] = "Multiple Choices"; + [301] = "Moved Permanently"; + [302] = "Found"; + [303] = "See Other"; + [304] = "Not Modified"; + [305] = "Use Proxy"; + -- The 306 status code was used in a previous version of [RFC2616], is no longer used, and the code is reserved. + [307] = "Temporary Redirect"; + + [400] = "Bad Request"; + [401] = "Unauthorized"; + [402] = "Payment Required"; + [403] = "Forbidden"; + [404] = "Not Found"; + [405] = "Method Not Allowed"; + [406] = "Not Acceptable"; + [407] = "Proxy Authentication Required"; + [408] = "Request Timeout"; + [409] = "Conflict"; + [410] = "Gone"; + [411] = "Length Required"; + [412] = "Precondition Failed"; + [413] = "Request Entity Too Large"; + [414] = "Request-URI Too Long"; + [415] = "Unsupported Media Type"; + [416] = "Requested Range Not Satisfiable"; + [417] = "Expectation Failed"; + [418] = "I'm a teapot"; + [422] = "Unprocessable Entity"; + [423] = "Locked"; + [424] = "Failed Dependency"; + -- The 425 status code is reserved for the WebDAV advanced collections expired proposal [RFC2817] + [426] = "Upgrade Required"; + + [500] = "Internal Server Error"; + [501] = "Not Implemented"; + [502] = "Bad Gateway"; + [503] = "Service Unavailable"; + [504] = "Gateway Timeout"; + [505] = "HTTP Version Not Supported"; + [506] = "Variant Also Negotiates"; -- Experimental + [507] = "Insufficient Storage"; + [508] = "Loop Detected"; + [510] = "Not Extended"; +}; + +for k,v in pairs(response_codes) do response_codes[k] = k.." "..v; end +return setmetatable(response_codes, { __index = function(t, k) return k.." Unassigned"; end }) diff --git a/net/http/parser.lua b/net/http/parser.lua new file mode 100644 index 00000000..f9e6cea0 --- /dev/null +++ b/net/http/parser.lua @@ -0,0 +1,160 @@ +local tonumber = tonumber; +local assert = assert; +local url_parse = require "socket.url".parse; +local urldecode = require "util.http".urldecode; + +local function preprocess_path(path) + path = urldecode((path:gsub("//+", "/"))); + if path:sub(1,1) ~= "/" then + path = "/"..path; + end + local level = 0; + for component in path:gmatch("([^/]+)/") do + if component == ".." then + level = level - 1; + elseif component ~= "." then + level = level + 1; + end + if level < 0 then + return nil; + end + end + return path; +end + +local httpstream = {}; + +function httpstream.new(success_cb, error_cb, parser_type, options_cb) + local client = true; + if not parser_type or parser_type == "server" then client = false; else assert(parser_type == "client", "Invalid parser type"); end + local buf = ""; + local chunked, chunk_size, chunk_start; + local state = nil; + local packet; + local len; + local have_body; + local error; + return { + feed = function(self, data) + if error then return nil, "parse has failed"; end + if not data then -- EOF + if state and client and not len then -- reading client body until EOF + packet.body = buf; + success_cb(packet); + elseif buf ~= "" then -- unexpected EOF + error = true; return error_cb(); + end + return; + end + buf = buf..data; + while #buf > 0 do + if state == nil then -- read request + local index = buf:find("\r\n\r\n", nil, true); + if not index then return; end -- not enough data + local method, path, httpversion, status_code, reason_phrase; + local first_line; + local headers = {}; + for line in buf:sub(1,index+1):gmatch("([^\r\n]+)\r\n") do -- parse request + if first_line then + local key, val = line:match("^([^%s:]+): *(.*)$"); + if not key then error = true; return error_cb("invalid-header-line"); end -- TODO handle multi-line and invalid headers + key = key:lower(); + headers[key] = headers[key] and headers[key]..","..val or val; + else + first_line = line; + if client then + httpversion, status_code, reason_phrase = line:match("^HTTP/(1%.[01]) (%d%d%d) (.*)$"); + status_code = tonumber(status_code); + if not status_code then error = true; return error_cb("invalid-status-line"); end + have_body = not + ( (options_cb and options_cb().method == "HEAD") + or (status_code == 204 or status_code == 304 or status_code == 301) + or (status_code >= 100 and status_code < 200) ); + else + method, path, httpversion = line:match("^(%w+) (%S+) HTTP/(1%.[01])$"); + if not method then error = true; return error_cb("invalid-status-line"); end + end + end + end + if not first_line then error = true; return error_cb("invalid-status-line"); end + chunked = have_body and headers["transfer-encoding"] == "chunked"; + len = tonumber(headers["content-length"]); -- TODO check for invalid len + if client then + -- FIXME handle '100 Continue' response (by skipping it) + if not have_body then len = 0; end + packet = { + code = status_code; + httpversion = httpversion; + headers = headers; + body = have_body and "" or nil; + -- COMPAT the properties below are deprecated + responseversion = httpversion; + responseheaders = headers; + }; + else + local parsed_url; + if path:byte() == 47 then -- starts with / + local _path, _query = path:match("([^?]*).?(.*)"); + if _query == "" then _query = nil; end + parsed_url = { path = _path, query = _query }; + else + parsed_url = url_parse(path); + if not(parsed_url and parsed_url.path) then error = true; return error_cb("invalid-url"); end + end + path = preprocess_path(parsed_url.path); + headers.host = parsed_url.host or headers.host; + + len = len or 0; + packet = { + method = method; + url = parsed_url; + path = path; + httpversion = httpversion; + headers = headers; + body = nil; + }; + end + buf = buf:sub(index + 4); + state = true; + end + if state then -- read body + if client then + if chunked then + if not buf:find("\r\n", nil, true) then + return; + end -- not enough data + if not chunk_size then + chunk_size, chunk_start = buf:match("^(%x+)[^\r\n]*\r\n()"); + chunk_size = chunk_size and tonumber(chunk_size, 16); + if not chunk_size then error = true; return error_cb("invalid-chunk-size"); end + end + if chunk_size == 0 and buf:find("\r\n\r\n", chunk_start-2, true) then + state, chunk_size = nil, nil; + buf = buf:gsub("^.-\r\n\r\n", ""); -- This ensure extensions and trailers are stripped + success_cb(packet); + elseif #buf - chunk_start + 2 >= chunk_size then -- we have a chunk + packet.body = packet.body..buf:sub(chunk_start, chunk_start + (chunk_size-1)); + buf = buf:sub(chunk_start + chunk_size + 2); + chunk_size, chunk_start = nil, nil; + else -- Partial chunk remaining + break; + end + elseif len and #buf >= len then + packet.body, buf = buf:sub(1, len), buf:sub(len + 1); + state = nil; success_cb(packet); + else + break; + end + elseif #buf >= len then + packet.body, buf = buf:sub(1, len), buf:sub(len + 1); + state = nil; success_cb(packet); + else + break; + end + end + end + end; + }; +end + +return httpstream; diff --git a/net/http/server.lua b/net/http/server.lua new file mode 100644 index 00000000..dec7da19 --- /dev/null +++ b/net/http/server.lua @@ -0,0 +1,303 @@ + +local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat; +local parser_new = require "net.http.parser".new; +local events = require "util.events".new(); +local addserver = require "net.server".addserver; +local log = require "util.logger".init("http.server"); +local os_date = os.date; +local pairs = pairs; +local s_upper = string.upper; +local setmetatable = setmetatable; +local xpcall = xpcall; +local traceback = debug.traceback; +local tostring = tostring; +local codes = require "net.http.codes"; + +local _M = {}; + +local sessions = {}; +local listener = {}; +local hosts = {}; +local default_host; + +local function is_wildcard_event(event) + return event:sub(-2, -1) == "/*"; +end +local function is_wildcard_match(wildcard_event, event) + return wildcard_event:sub(1, -2) == event:sub(1, #wildcard_event-1); +end + +local recent_wildcard_events, max_cached_wildcard_events = {}, 10000; + +local event_map = events._event_map; +setmetatable(events._handlers, { + -- Called when firing an event that doesn't exist (but may match a wildcard handler) + __index = function (handlers, curr_event) + if is_wildcard_event(curr_event) then return; end -- Wildcard events cannot be fired + -- Find all handlers that could match this event, sort them + -- and then put the array into handlers[curr_event] (and return it) + local matching_handlers_set = {}; + local handlers_array = {}; + for event, handlers_set in pairs(event_map) do + if event == curr_event or + is_wildcard_event(event) and is_wildcard_match(event, curr_event) then + for handler, priority in pairs(handlers_set) do + matching_handlers_set[handler] = { (select(2, event:gsub("/", "%1"))), is_wildcard_event(event) and 0 or 1, priority }; + table.insert(handlers_array, handler); + end + end + end + if #handlers_array > 0 then + table.sort(handlers_array, function(b, a) + local a_score, b_score = matching_handlers_set[a], matching_handlers_set[b]; + for i = 1, #a_score do + if a_score[i] ~= b_score[i] then -- If equal, compare next score value + return a_score[i] < b_score[i]; + end + end + return false; + end); + else + handlers_array = false; + end + rawset(handlers, curr_event, handlers_array); + if not event_map[curr_event] then -- Only wildcard handlers match, if any + table.insert(recent_wildcard_events, curr_event); + if #recent_wildcard_events > max_cached_wildcard_events then + rawset(handlers, table.remove(recent_wildcard_events, 1), nil); + end + end + return handlers_array; + end; + __newindex = function (handlers, curr_event, handlers_array) + if handlers_array == nil + and is_wildcard_event(curr_event) then + -- Invalidate the indexes of all matching events + for event in pairs(handlers) do + if is_wildcard_match(curr_event, event) then + handlers[event] = nil; + end + end + end + rawset(handlers, curr_event, handlers_array); + end; +}); + +local handle_request; +local _1, _2, _3; +local function _handle_request() return handle_request(_1, _2, _3); end + +local last_err; +local function _traceback_handler(err) last_err = err; log("error", "Traceback[httpserver]: %s", traceback(tostring(err), 2)); end +events.add_handler("http-error", function (error) + return "Error processing request: "..codes[error.code]..". Check your error log for more information."; +end, -1); + +function listener.onconnect(conn) + local secure = conn:ssl() and true or nil; + local pending = {}; + local waiting = false; + local function process_next() + if waiting then log("debug", "can't process_next, waiting"); return; end + waiting = true; + while sessions[conn] and #pending > 0 do + local request = t_remove(pending); + --log("debug", "process_next: %s", request.path); + --handle_request(conn, request, process_next); + _1, _2, _3 = conn, request, process_next; + if not xpcall(_handle_request, _traceback_handler) then + conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = last_err })); + conn:close(); + end + end + --log("debug", "ready for more"); + waiting = false; + end + local function success_cb(request) + --log("debug", "success_cb: %s", request.path); + if waiting then + log("error", "http connection handler is not reentrant: %s", request.path); + assert(false, "http connection handler is not reentrant"); + end + request.secure = secure; + t_insert(pending, request); + process_next(); + end + local function error_cb(err) + log("debug", "error_cb: %s", err or "<nil>"); + -- FIXME don't close immediately, wait until we process current stuff + -- FIXME if err, send off a bad-request response + sessions[conn] = nil; + conn:close(); + end + sessions[conn] = parser_new(success_cb, error_cb); +end + +function listener.ondisconnect(conn) + local open_response = conn._http_open_response; + if open_response and open_response.on_destroy then + open_response.finished = true; + open_response:on_destroy(); + end + sessions[conn] = nil; +end + +function listener.onincoming(conn, data) + sessions[conn]:feed(data); +end + +local headerfix = setmetatable({}, { + __index = function(t, k) + local v = "\r\n"..k:gsub("_", "-"):gsub("%f[%w].", s_upper)..": "; + t[k] = v; + return v; + end +}); + +function _M.hijack_response(response, listener) + error("TODO"); +end +function handle_request(conn, request, finish_cb) + --log("debug", "handler: %s", request.path); + local headers = {}; + for k,v in pairs(request.headers) do headers[k:gsub("-", "_")] = v; end + request.headers = headers; + request.conn = conn; + + local date_header = os_date('!%a, %d %b %Y %H:%M:%S GMT'); -- FIXME use + local conn_header = request.headers.connection; + conn_header = conn_header and ","..conn_header:gsub("[ \t]", ""):lower().."," or "" + local httpversion = request.httpversion + local persistent = conn_header:find(",Keep-Alive,", 1, true) + or (httpversion == "1.1" and not conn_header:find(",close,", 1, true)); + + local response_conn_header; + if persistent then + response_conn_header = "Keep-Alive"; + else + response_conn_header = httpversion == "1.1" and "close" or nil + end + + local response = { + request = request; + status_code = 200; + headers = { date = date_header, connection = response_conn_header }; + persistent = persistent; + conn = conn; + send = _M.send_response; + finish_cb = finish_cb; + }; + conn._http_open_response = response; + + local host = (request.headers.host or ""):match("[^:]+"); + + -- Some sanity checking + local err_code, err; + if not request.path then + err_code, err = 400, "Invalid path"; + elseif not hosts[host] then + if hosts[default_host] then + host = default_host; + elseif host then + err_code, err = 404, "Unknown host: "..host; + else + err_code, err = 400, "Missing or invalid 'Host' header"; + end + end + + if err then + response.status_code = err_code; + response:send(events.fire_event("http-error", { code = err_code, message = err })); + return; + end + + local event = request.method.." "..host..request.path:match("[^?]*"); + local payload = { request = request, response = response }; + --log("debug", "Firing event: %s", event); + local result = events.fire_event(event, payload); + if result ~= nil then + if result ~= true then + local body; + local result_type = type(result); + if result_type == "number" then + response.status_code = result; + if result >= 400 then + body = events.fire_event("http-error", { code = result }); + end + elseif result_type == "string" then + body = result; + elseif result_type == "table" then + for k, v in pairs(result) do + if k ~= "headers" then + response[k] = v; + else + for header_name, header_value in pairs(v) do + response.headers[header_name] = header_value; + end + end + end + end + response:send(body); + end + return; + end + + -- if handler not called, return 404 + response.status_code = 404; + response:send(events.fire_event("http-error", { code = 404 })); +end +function _M.send_response(response, body) + if response.finished then return; end + response.finished = true; + response.conn._http_open_response = nil; + + local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]); + local headers = response.headers; + body = body or response.body or ""; + headers.content_length = #body; + + local output = { status_line }; + for k,v in pairs(headers) do + t_insert(output, headerfix[k]..v); + end + t_insert(output, "\r\n\r\n"); + t_insert(output, body); + + response.conn:write(t_concat(output)); + if response.on_destroy then + response:on_destroy(); + response.on_destroy = nil; + end + if response.persistent then + response:finish_cb(); + else + response.conn:close(); + end +end +function _M.add_handler(event, handler, priority) + events.add_handler(event, handler, priority); +end +function _M.remove_handler(event, handler) + events.remove_handler(event, handler); +end + +function _M.listen_on(port, interface, ssl) + addserver(interface or "*", port, listener, "*a", ssl); +end +function _M.add_host(host) + hosts[host] = true; +end +function _M.remove_host(host) + hosts[host] = nil; +end +function _M.set_default_host(host) + default_host = host; +end +function _M.fire_event(event, ...) + return events.fire_event(event, ...); +end + +_M.listener = listener; +_M.codes = codes; +_M._events = events; +return _M; diff --git a/net/httpserver.lua b/net/httpserver.lua new file mode 100644 index 00000000..7d574788 --- /dev/null +++ b/net/httpserver.lua @@ -0,0 +1,15 @@ +-- COMPAT w/pre-0.9 +local log = require "util.logger".init("net.httpserver"); +local traceback = debug.traceback; + +module "httpserver" + +function fail() + log("error", "Attempt to use legacy HTTP API. For more info see http://prosody.im/doc/developers/legacy_http"); + log("error", "Legacy HTTP API usage, %s", traceback("", 2)); +end + +new, new_from_config = fail, fail; +set_default_handler = fail; + +return _M; diff --git a/net/server.lua b/net/server.lua index 40cc6dc8..375e7081 100644 --- a/net/server.lua +++ b/net/server.lua @@ -1,831 +1,84 @@ ---[[
-
- server.lua by blastbeat of the luadch project
-
- re-used here under the MIT/X Consortium License
-
- - this script contains the server loop of the program
- - other scripts can reg a server here
-
-]]--
-
-----------------------------------// DECLARATION //--
-
---// constants //--
-
-local STAT_UNIT = 1 / ( 1024 * 1024 ) -- mb
-
---// lua functions //--
-
-local function use( what ) return _G[ what ] end
-
-local type = use "type"
-local pairs = use "pairs"
-local ipairs = use "ipairs"
-local tostring = use "tostring"
-local collectgarbage = use "collectgarbage"
-
---// lua libs //--
-
-local table = use "table"
-local coroutine = use "coroutine"
-
---// lua lib methods //--
-
-local table_concat = table.concat
-local table_remove = table.remove
-local string_sub = use'string'.sub
-local coroutine_wrap = coroutine.wrap
-local coroutine_yield = coroutine.yield
-local print = print;
-local out_put = function () end --print;
-local out_error = print;
-
---// extern libs //--
-
-local luasec = select(2, pcall(require, "ssl"))
-local luasocket = require "socket"
-
---// extern lib methods //--
-
-local ssl_wrap = ( luasec and luasec.wrap )
-local socket_bind = luasocket.bind
-local socket_select = luasocket.select
-local ssl_newcontext = ( luasec and luasec.newcontext )
-
---// functions //--
-
-local loop
-local stats
-local addtimer
-local closeall
-local addserver
-local firetimer
-local closesocket
-local removesocket
-local wrapserver
-local wraptcpclient
-local wrapsslclient
-
---// tables //--
-
-local listener
-local readlist
-local writelist
-local socketlist
-local timelistener
-
---// simple data types //--
-
-local _
-local readlen = 0 -- length of readlist
-local writelen = 0 -- lenght of writelist
-
-local sendstat= 0
-local receivestat = 0
-
-----------------------------------// DEFINITION //--
-
-listener = { } -- key = port, value = table
-readlist = { } -- array with sockets to read from
-writelist = { } -- arrary with sockets to write to
-socketlist = { } -- key = socket, value = wrapped socket
-timelistener = { }
-
-stats = function( )
- return receivestat, sendstat
-end
-
-wrapserver = function( listener, socket, ip, serverport, mode, sslctx ) -- this function wraps a server
-
- local dispatch, disconnect = listener.listener, listener.disconnect -- dangerous
-
- local wrapclient, err
-
- if sslctx then
- if not ssl_newcontext then
- return nil, "luasec not found"
- end
- if type( sslctx ) ~= "table" then
- out_error "server.lua: wrong server sslctx"
- return nil, "wrong server sslctx"
- end
- sslctx, err = ssl_newcontext( sslctx )
- if not sslctx then
- err = err or "wrong sslctx parameters"
- out_error( "server.lua: ", err )
- return nil, err
- end
- wrapclient = wrapsslclient
- wrapclient = wraptlsclient
- else
- wrapclient = wraptcpclient
- end
-
- local accept = socket.accept
- local close = socket.close
-
- --// public methods of the object //--
-
- local handler = { }
-
- handler.shutdown = function( ) end
-
- --[[handler.listener = function( data, err )
- return ondata( handler, data, err )
- end]]
- handler.ssl = function( )
- return sslctx and true or false
- end
- handler.close = function( closed )
- _ = not closed and close( socket )
- writelen = removesocket( writelist, socket, writelen )
- readlen = removesocket( readlist, socket, readlen )
- socketlist[ socket ] = nil
- handler = nil
- end
- handler.ip = function( )
- return ip
- end
- handler.serverport = function( )
- return serverport
- end
- handler.socket = function( )
- return socket
- end
- handler.receivedata = function( )
- local client, err = accept( socket ) -- try to accept
- if client then
- local ip, clientport = client:getpeername( )
- client:settimeout( 0 )
- local handler, client, err = wrapclient( listener, client, ip, serverport, clientport, mode, sslctx ) -- wrap new client socket
- if err then -- error while wrapping ssl socket
- return false
- end
- out_put( "server.lua: accepted new client connection from ", ip, ":", clientport )
- return dispatch( handler )
- elseif err then -- maybe timeout or something else
- out_put( "server.lua: error with new client connection: ", err )
- return false
- end
- end
- return handler
-end
-
-wrapsslclient = function( listener, socket, ip, serverport, clientport, mode, sslctx ) -- this function wraps a ssl cleint
-
- local dispatch, disconnect = listener.listener, listener.disconnect
-
- --// transform socket to ssl object //--
-
- local err
- socket, err = ssl_wrap( socket, sslctx ) -- wrap socket
- if err then
- out_put( "server.lua: ssl error: ", err )
- return nil, nil, err -- fatal error
- end
- socket:settimeout( 0 )
-
- --// private closures of the object //--
-
- local writequeue = { } -- buffer for messages to send
-
- local eol -- end of buffer
-
- local sstat, rstat = 0, 0
-
- --// local import of socket methods //--
-
- local send = socket.send
- local receive = socket.receive
- local close = socket.close
- --local shutdown = socket.shutdown
-
- --// public methods of the object //--
-
- local handler = { }
-
- handler.getstats = function( )
- return rstat, sstat
- end
-
- handler.listener = function( data, err )
- return listener( handler, data, err )
- end
- handler.ssl = function( )
- return true
- end
- handler.send = function( _, data, i, j )
- return send( socket, data, i, j )
- end
- handler.receive = function( pattern, prefix )
- return receive( socket, pattern, prefix )
- end
- handler.shutdown = function( pattern )
- --return shutdown( socket, pattern )
- end
- handler.close = function( closed )
- close( socket )
- writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen
- readlen = removesocket( readlist, socket, readlen )
- socketlist[ socket ] = nil
- out_put "server.lua: closed handler and removed socket from list"
- end
- handler.ip = function( )
- return ip
- end
- handler.serverport = function( )
- return serverport
- end
- handler.clientport = function( )
- return clientport
- end
-
- handler.write = function( data )
- if not eol then
- writelen = writelen + 1
- writelist[ writelen ] = socket
- eol = 0
- end
- eol = eol + 1
- writequeue[ eol ] = data
- end
- handler.writequeue = function( )
- return writequeue
- end
- handler.socket = function( )
- return socket
- end
- handler.mode = function( )
- return mode
- end
- handler._receivedata = function( )
- local data, err, part = receive( socket, mode ) -- receive data in "mode"
- if not err or ( err == "timeout" or err == "wantread" ) then -- received something
- local data = data or part or ""
- local count = #data * STAT_UNIT
- rstat = rstat + count
- receivestat = receivestat + count
- out_put( "server.lua: read data '", data, "', error: ", err )
- return dispatch( handler, data, err )
- else -- connections was closed or fatal error
- out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )
- handler.close( )
- disconnect( handler, err )
- writequeue = nil
- handler = nil
- return false
- end
- end
- handler._dispatchdata = function( ) -- this function writes data to handlers
- local buffer = table_concat( writequeue, "", 1, eol )
- local succ, err, byte = send( socket, buffer )
- local count = ( succ or 0 ) * STAT_UNIT
- sstat = sstat + count
- sendstat = sendstat + count
- out_put( "server.lua: sended '", buffer, "', bytes: ", succ, ", error: ", err, ", part: ", byte, ", to: ", ip, ":", clientport )
- if succ then -- sending succesful
- --writequeue = { }
- eol = nil
- writelen = removesocket( writelist, socket, writelen ) -- delete socket from writelist
- return true
- elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write
- buffer = string_sub( buffer, byte + 1, -1 ) -- new buffer
- writequeue[ 1 ] = buffer -- insert new buffer in queue
- eol = 1
- return true
- else -- connection was closed during sending or fatal error
- out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )
- handler.close( )
- disconnect( handler, err )
- writequeue = nil
- handler = nil
- return false
- end
- end
-
- -- // COMPAT // --
-
- handler.getIp = handler.ip
- handler.getPort = handler.clientport
-
- --// handshake //--
-
- local wrote
-
- handler.handshake = coroutine_wrap( function( client )
- local err
- for i = 1, 10 do -- 10 handshake attemps
- _, err = client:dohandshake( )
- if not err then
- out_put( "server.lua: ssl handshake done" )
- writelen = ( wrote and removesocket( writelist, socket, writelen ) ) or writelen
- handler.receivedata = handler._receivedata -- when handshake is done, replace the handshake function with regular functions
- handler.dispatchdata = handler._dispatchdata
- return dispatch( handler )
- else
- out_put( "server.lua: error during ssl handshake: ", err )
- if err == "wantwrite" then
- if wrote == nil then
- writelen = writelen + 1
- writelist[ writelen ] = client
- wrote = true
- end
- end
- coroutine_yield( handler, nil, err ) -- handshake not finished
- end
- end
- _ = err ~= "closed" and close( socket )
- handler.close( )
- disconnect( handler, err )
- writequeue = nil
- handler = nil
- return false -- handshake failed
- end
- )
- handler.receivedata = handler.handshake
- handler.dispatchdata = handler.handshake
-
- handler.handshake( socket ) -- do handshake
-
- socketlist[ socket ] = handler
- readlen = readlen + 1
- readlist[ readlen ] = socket
-
- return handler, socket
-end
-
-wraptlsclient = function( listener, socket, ip, serverport, clientport, mode, sslctx ) -- this function wraps a tls cleint
-
- local dispatch, disconnect = listener.listener, listener.disconnect
-
- --// transform socket to ssl object //--
-
- local err
-
- socket:settimeout( 0 )
-
- --// private closures of the object //--
-
- local writequeue = { } -- buffer for messages to send
-
- local eol -- end of buffer
-
- local sstat, rstat = 0, 0
-
- --// local import of socket methods //--
-
- local send = socket.send
- local receive = socket.receive
- local close = socket.close
- --local shutdown = socket.shutdown
-
- --// public methods of the object //--
-
- local handler = { }
-
- handler.getstats = function( )
- return rstat, sstat
- end
-
- handler.listener = function( data, err )
- return listener( handler, data, err )
- end
- handler.ssl = function( )
- return false
- end
- handler.send = function( _, data, i, j )
- return send( socket, data, i, j )
- end
- handler.receive = function( pattern, prefix )
- return receive( socket, pattern, prefix )
- end
- handler.shutdown = function( pattern )
- --return shutdown( socket, pattern )
- end
- handler.close = function( closed )
- close( socket )
- writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen
- readlen = removesocket( readlist, socket, readlen )
- socketlist[ socket ] = nil
- out_put "server.lua: closed handler and removed socket from list"
- end
- handler.ip = function( )
- return ip
- end
- handler.serverport = function( )
- return serverport
- end
- handler.clientport = function( )
- return clientport
- end
-
- handler.write = function( data )
- if not eol then
- writelen = writelen + 1
- writelist[ writelen ] = socket
- eol = 0
- end
- eol = eol + 1
- writequeue[ eol ] = data
- end
- handler.writequeue = function( )
- return writequeue
- end
- handler.socket = function( )
- return socket
- end
- handler.mode = function( )
- return mode
- end
- handler._receivedata = function( )
- local data, err, part = receive( socket, mode ) -- receive data in "mode"
- if not err or ( err == "timeout" or err == "wantread" ) then -- received something
- local data = data or part or ""
- local count = #data * STAT_UNIT
- rstat = rstat + count
- receivestat = receivestat + count
- --out_put( "server.lua: read data '", data, "', error: ", err )
- return dispatch( handler, data, err )
- else -- connections was closed or fatal error
- out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )
- handler.close( )
- disconnect( handler, err )
- writequeue = nil
- handler = nil
- return false
- end
- end
- handler._dispatchdata = function( ) -- this function writes data to handlers
- local buffer = table_concat( writequeue, "", 1, eol )
- local succ, err, byte = send( socket, buffer )
- local count = ( succ or 0 ) * STAT_UNIT
- sstat = sstat + count
- sendstat = sendstat + count
- out_put( "server.lua: sended '", buffer, "', bytes: ", succ, ", error: ", err, ", part: ", byte, ", to: ", ip, ":", clientport )
- if succ then -- sending succesful
- --writequeue = { }
- eol = nil
- writelen = removesocket( writelist, socket, writelen ) -- delete socket from writelist
- if handler.need_tls then
- out_put("server.lua: connection is ready for tls handshake");
- handler.starttls(true);
- if handler.need_tls then
- out_put("server.lua: uh-oh... we still want tls, something must be wrong");
- end
- end
- return true
- elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write
- buffer = string_sub( buffer, byte + 1, -1 ) -- new buffer
- writequeue[ 1 ] = buffer -- insert new buffer in queue
- eol = 1
- return true
- else -- connection was closed during sending or fatal error
- out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )
- handler.close( )
- disconnect( handler, err )
- writequeue = nil
- handler = nil
- return false
- end
- end
-
- handler.receivedata, handler.dispatchdata = handler._receivedata, handler._dispatchdata;
- -- // COMPAT // --
-
- handler.getIp = handler.ip
- handler.getPort = handler.clientport
-
- --// handshake //--
-
- local wrote, read
-
- handler.starttls = function (now)
- if not now then out_put("server.lua: we need to do tls, but delaying until later"); handler.need_tls = true; return; end
- out_put( "server.lua: attempting to start tls on "..tostring(socket) )
- socket, err = ssl_wrap( socket, sslctx ) -- wrap socket
- out_put("sslwrapped socket is "..tostring(socket));
- if err then
- out_put( "server.lua: ssl error: ", err )
- return nil, nil, err -- fatal error
- end
- socket:settimeout( 1 )
- send = socket.send
- receive = socket.receive
- close = socket.close
- handler.ssl = function( )
- return true
- end
- handler.send = function( _, data, i, j )
- return send( socket, data, i, j )
- end
- handler.receive = function( pattern, prefix )
- return receive( socket, pattern, prefix )
- end
-
- handler.handshake = coroutine_wrap( function( client )
- local err
- for i = 1, 10 do -- 10 handshake attemps
- _, err = client:dohandshake( )
- if not err then
- out_put( "server.lua: ssl handshake done" )
- writelen = ( wrote and removesocket( writelist, socket, writelen ) ) or writelen
- handler.receivedata = handler._receivedata -- when handshake is done, replace the handshake function with regular functions
- handler.dispatchdata = handler._dispatchdata
- handler.need_tls = nil
- socketlist[ client ] = handler
- readlen = readlen + 1
- readlist[ readlen ] = client
- return true;
- else
- out_put( "server.lua: error during ssl handshake: ", err )
- if err == "wantwrite" then
- if wrote == nil then
- writelen = writelen + 1
- writelist[ writelen ] = client
- wrote = true
- end
- end
- coroutine_yield( handler, nil, err ) -- handshake not finished
- end
- end
- _ = err ~= "closed" and close( socket )
- handler.close( )
- disconnect( handler, err )
- writequeue = nil
- handler = nil
- return false -- handshake failed
- end
- )
- handler.receivedata = handler.handshake
- handler.dispatchdata = handler.handshake
-
- handler.handshake( socket ) -- do handshake
- end
- socketlist[ socket ] = handler
- readlen = readlen + 1
- readlist[ readlen ] = socket
-
- return handler, socket
-end
-
-wraptcpclient = function( listener, socket, ip, serverport, clientport, mode ) -- this function wraps a socket
-
- local dispatch, disconnect = listener.listener, listener.disconnect
-
- --// private closures of the object //--
-
- local writequeue = { } -- list for messages to send
-
- local eol
-
- local rstat, sstat = 0, 0
-
- --// local import of socket methods //--
-
- local send = socket.send
- local receive = socket.receive
- local close = socket.close
- local shutdown = socket.shutdown
-
- --// public methods of the object //--
-
- local handler = { }
-
- handler.getstats = function( )
- return rstat, sstat
- end
-
- handler.listener = function( data, err )
- return listener( handler, data, err )
- end
- handler.ssl = function( )
- return false
- end
- handler.send = function( _, data, i, j )
- return send( socket, data, i, j )
- end
- handler.receive = function( pattern, prefix )
- return receive( socket, pattern, prefix )
- end
- handler.shutdown = function( pattern )
- return shutdown( socket, pattern )
- end
- handler.close = function( closed )
- _ = not closed and shutdown( socket )
- _ = not closed and close( socket )
- writelen = ( eol and removesocket( writelist, socket, writelen ) ) or writelen
- readlen = removesocket( readlist, socket, readlen )
- socketlist[ socket ] = nil
- out_put "server.lua: closed handler and removed socket from list"
- end
- handler.ip = function( )
- return ip
- end
- handler.serverport = function( )
- return serverport
- end
- handler.clientport = function( )
- return clientport
- end
- handler.write = function( data )
- if not eol then
- writelen = writelen + 1
- writelist[ writelen ] = socket
- eol = 0
- end
- eol = eol + 1
- writequeue[ eol ] = data
- end
- handler.writequeue = function( )
- return writequeue
- end
- handler.socket = function( )
- return socket
- end
- handler.mode = function( )
- return mode
- end
-
- handler.receivedata = function( )
- local data, err, part = receive( socket, mode ) -- receive data in "mode"
- if not err or ( err == "timeout" or err == "wantread" ) then -- received something
- local data = data or part or ""
- local count = #data * STAT_UNIT
- rstat = rstat + count
- receivestat = receivestat + count
- out_put( "server.lua: read data '", data, "', error: ", err )
- return dispatch( handler, data, err )
- else -- connections was closed or fatal error
- out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )
- handler.close( )
- disconnect( handler, err )
- writequeue = nil
- handler = nil
- return false
- end
- end
-
- handler.dispatchdata = function( ) -- this function writes data to handlers
- local buffer = table_concat( writequeue, "", 1, eol )
- local succ, err, byte = send( socket, buffer )
- local count = ( succ or 0 ) * STAT_UNIT
- sstat = sstat + count
- sendstat = sendstat + count
- out_put( "server.lua: sended '", buffer, "', bytes: ", succ, ", error: ", err, ", part: ", byte, ", to: ", ip, ":", clientport )
- if succ then -- sending succesful
- --writequeue = { }
- eol = nil
- writelen = removesocket( writelist, socket, writelen ) -- delete socket from writelist
- return true
- elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write
- buffer = string_sub( buffer, byte + 1, -1 ) -- new buffer
- writequeue[ 1 ] = buffer -- insert new buffer in queue
- eol = 1
- return true
- else -- connection was closed during sending or fatal error
- out_put( "server.lua: client ", ip, ":", clientport, " error: ", err )
- handler.close( )
- disconnect( handler, err )
- writequeue = nil
- handler = nil
- return false
- end
- end
-
- -- // COMPAT // --
-
- handler.getIp = handler.ip
- handler.getPort = handler.clientport
-
- socketlist[ socket ] = handler
- readlen = readlen + 1
- readlist[ readlen ] = socket
-
- return handler, socket
-end
-
-addtimer = function( listener )
- timelistener[ #timelistener + 1 ] = listener
-end
-
-firetimer = function( listener )
- for i, listener in ipairs( timelistener ) do
- listener( )
- end
-end
-
-addserver = function( listeners, port, addr, mode, sslctx ) -- this function provides a way for other scripts to reg a server
- local err
- if type( listeners ) ~= "table" then
- err = "invalid listener table"
- else
- for name, func in pairs( listeners ) do
- if type( func ) ~= "function" then
- --err = "invalid listener function"
- break
- end
- end
- end
- if not type( port ) == "number" or not ( port >= 0 and port <= 65535 ) then
- err = "invalid port"
- elseif listener[ port ] then
- err= "listeners on port '" .. port .. "' already exist"
- elseif sslctx and not luasec then
- err = "luasec not found"
- end
- if err then
- out_error( "server.lua: ", err )
- return nil, err
- end
- addr = addr or "*"
- local server, err = socket_bind( addr, port )
- if err then
- out_error( "server.lua: ", err )
- return nil, err
- end
- local handler, err = wrapserver( listeners, server, addr, port, mode, sslctx ) -- wrap new server socket
- if not handler then
- server:close( )
- return nil, err
- end
- server:settimeout( 0 )
- readlen = readlen + 1
- readlist[ readlen ] = server
- listener[ port ] = listeners
- socketlist[ server ] = handler
- out_put( "server.lua: new server listener on ", addr, ":", port )
- return true
-end
-
-removesocket = function( tbl, socket, len ) -- this function removes sockets from a list
- for i, target in ipairs( tbl ) do
- if target == socket then
- len = len - 1
- table_remove( tbl, i )
- return len
- end
- end
- return len
-end
-
-closeall = function( )
- for sock, handler in pairs( socketlist ) do
- handler.shutdown( )
- handler.close( )
- socketlist[ sock ] = nil
- end
- writelist, readlist, socketlist = { }, { }, { }
-end
-
-closesocket = function( socket )
- writelen = removesocket( writelist, socket, writelen )
- readlen = removesocket( readlist, socket, readlen )
- socketlist[ socket ] = nil
- socket:close( )
-end
-
-loop = function( ) -- this is the main loop of the program
- --signal_set( "hub", "run" )
- repeat
- --[[print(readlen, writelen)
- for _, s in ipairs(readlist) do print("R:", tostring(s)) end
- for _, s in ipairs(writelist) do print("W:", tostring(s)) end
- out_put("select()"..os.time())]]
- local read, write, err = socket_select( readlist, writelist, 1 ) -- 1 sec timeout, nice for timers
- for i, socket in ipairs( write ) do -- send data waiting in writequeues
- local handler = socketlist[ socket ]
- if handler then
- handler.dispatchdata( )
- else
- closesocket( socket )
- out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen
- end
- end
- for i, socket in ipairs( read ) do -- receive data
- local handler = socketlist[ socket ]
- if handler then
- handler.receivedata( )
- else
- closesocket( socket )
- out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen
- end
- end
- firetimer( )
- until false
- return
-end
-
-----------------------------------// BEGIN //--
-
-----------------------------------// PUBLIC INTERFACE //--
-
-return {
-
- add = addserver,
- loop = loop,
- stats = stats,
- closeall = closeall,
- addtimer = addtimer,
- wraptlsclient = wraptlsclient,
-}
+-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local use_luaevent = prosody and require "core.configmanager".get("*", "use_libevent"); + +if use_luaevent then + use_luaevent = pcall(require, "luaevent.core"); + if not use_luaevent then + log("error", "libevent not found, falling back to select()"); + end +end + +local server; + +if use_luaevent then + server = require "net.server_event"; + + -- Overwrite signal.signal() because we need to ask libevent to + -- handle them instead + local ok, signal = pcall(require, "util.signal"); + if ok and signal then + local _signal_signal = signal.signal; + function signal.signal(signal_id, handler) + if type(signal_id) == "string" then + signal_id = signal[signal_id:upper()]; + end + if type(signal_id) ~= "number" then + return false, "invalid-signal"; + end + return server.hook_signal(signal_id, handler); + end + end +else + use_luaevent = false; + server = require "net.server_select"; +end + +if prosody then + local config_get = require "core.configmanager".get; + local defaults = {}; + for k,v in pairs(server.cfg or server.getsettings()) do + defaults[k] = v; + end + local function load_config() + local settings = config_get("*", "network_settings") or {}; + if use_luaevent then + local event_settings = { + ACCEPT_DELAY = settings.event_accept_retry_interval; + ACCEPT_QUEUE = settings.tcp_backlog; + CLEAR_DELAY = settings.event_clear_interval; + CONNECT_TIMEOUT = settings.connect_timeout; + DEBUG = settings.debug; + HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout; + MAX_CONNECTIONS = settings.max_connections; + MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips; + MAX_READ_LENGTH = settings.max_receive_buffer_size; + MAX_SEND_LENGTH = settings.max_send_buffer_size; + READ_TIMEOUT = settings.read_timeout; + WRITE_TIMEOUT = settings.send_timeout; + }; + + for k,default in pairs(defaults) do + server.cfg[k] = event_settings[k] or default; + end + else + local select_settings = {}; + for k,default in pairs(defaults) do + select_settings[k] = settings[k] or default; + end + server.changesettings(select_settings); + end + end + load_config(); + prosody.events.add_handler("config-reloaded", load_config); +end + +-- require "net.server" shall now forever return this, +-- ie. server_select or server_event as chosen above. +return server; diff --git a/net/server_event.lua b/net/server_event.lua new file mode 100644 index 00000000..5eae95a9 --- /dev/null +++ b/net/server_event.lua @@ -0,0 +1,872 @@ +--[[ + + + server.lua based on lua/libevent by blastbeat + + notes: + -- when using luaevent, never register 2 or more EV_READ at one socket, same for EV_WRITE + -- you cant even register a new EV_READ/EV_WRITE callback inside another one + -- to do some of the above, use timeout events or something what will called from outside + -- dont let garbagecollect eventcallbacks, as long they are running + -- when using luasec, there are 4 cases of timeout errors: wantread or wantwrite during reading or writing + +--]] + +local SCRIPT_NAME = "server_event.lua" +local SCRIPT_VERSION = "0.05" +local SCRIPT_AUTHOR = "blastbeat" +local LAST_MODIFIED = "2009/11/20" + +local cfg = { + MAX_CONNECTIONS = 100000, -- max per server connections (use "ulimit -n" on *nix) + MAX_HANDSHAKE_ATTEMPTS= 1000, -- attempts to finish ssl handshake + HANDSHAKE_TIMEOUT = 60, -- timeout in seconds per handshake attempt + MAX_READ_LENGTH = 1024 * 1024 * 1024 * 1024, -- max bytes allowed to read from sockets + MAX_SEND_LENGTH = 1024 * 1024 * 1024 * 1024, -- max bytes size of write buffer (for writing on sockets) + ACCEPT_QUEUE = 128, -- might influence the length of the pending sockets queue + ACCEPT_DELAY = 10, -- seconds to wait until the next attempt of a full server to accept + READ_TIMEOUT = 60 * 60 * 6, -- timeout in seconds for read data from socket + WRITE_TIMEOUT = 180, -- timeout in seconds for write data on socket + CONNECT_TIMEOUT = 20, -- timeout in seconds for connection attempts + CLEAR_DELAY = 5, -- seconds to wait for clearing interface list (and calling ondisconnect listeners) + DEBUG = true, -- show debug messages +} + +local function use(x) return rawget(_G, x); end +local ipairs = use "ipairs" +local string = use "string" +local select = use "select" +local require = use "require" +local tostring = use "tostring" +local coroutine = use "coroutine" +local setmetatable = use "setmetatable" + +local t_insert = table.insert +local t_concat = table.concat + +local ssl = use "ssl" +local socket = use "socket" or require "socket" + +local log = require ("util.logger").init("socket") + +local function debug(...) + return log("debug", ("%s "):rep(select('#', ...)), ...) +end +local vdebug = debug; + +local bitor = ( function( ) -- thx Rici Lake + local hasbit = function( x, p ) + return x % ( p + p ) >= p + end + return function( x, y ) + local p = 1 + local z = 0 + local limit = x > y and x or y + while p <= limit do + if hasbit( x, p ) or hasbit( y, p ) then + z = z + p + end + p = p + p + end + return z + end +end )( ) + +local event = require "luaevent.core" +local base = event.new( ) +local EV_READ = event.EV_READ +local EV_WRITE = event.EV_WRITE +local EV_TIMEOUT = event.EV_TIMEOUT +local EV_SIGNAL = event.EV_SIGNAL + +local EV_READWRITE = bitor( EV_READ, EV_WRITE ) + +local interfacelist = ( function( ) -- holds the interfaces for sockets + local array = { } + local len = 0 + return function( method, arg ) + if "add" == method then + len = len + 1 + array[ len ] = arg + arg:_position( len ) + return len + elseif "delete" == method then + if len <= 0 then + return nil, "array is already empty" + end + local position = arg:_position() -- get position in array + if position ~= len then + local interface = array[ len ] -- get last interface + array[ position ] = interface -- copy it into free position + array[ len ] = nil -- free last position + interface:_position( position ) -- set new position in array + else -- free last position + array[ len ] = nil + end + len = len - 1 + return len + else + return array + end + end +end )( ) + +-- Client interface methods +local interface_mt +do + interface_mt = {}; interface_mt.__index = interface_mt; + + local addevent = base.addevent + local coroutine_wrap, coroutine_yield = coroutine.wrap,coroutine.yield + + -- Private methods + function interface_mt:_position(new_position) + self.position = new_position or self.position + return self.position; + end + function interface_mt:_close() + return self:_destroy(); + end + + function interface_mt:_start_connection(plainssl) -- should be called from addclient + local callback = function( event ) + if EV_TIMEOUT == event then -- timeout during connection + self.fatalerror = "connection timeout" + self:ontimeout() -- call timeout listener + self:_close() + debug( "new connection failed. id:", self.id, "error:", self.fatalerror ) + else + if plainssl and ssl then -- start ssl session + self:starttls(self._sslctx, true) + else -- normal connection + self:_start_session(true) + end + debug( "new connection established. id:", self.id ) + end + self.eventconnect = nil + return -1 + end + self.eventconnect = addevent( base, self.conn, EV_WRITE, callback, cfg.CONNECT_TIMEOUT ) + return true + end + function interface_mt:_start_session(call_onconnect) -- new session, for example after startssl + if self.type == "client" then + local callback = function( ) + self:_lock( false, false, false ) + --vdebug( "start listening on client socket with id:", self.id ) + self.eventread = addevent( base, self.conn, EV_READ, self.readcallback, cfg.READ_TIMEOUT ); -- register callback + if call_onconnect then + self:onconnect() + end + self.eventsession = nil + return -1 + end + self.eventsession = addevent( base, nil, EV_TIMEOUT, callback, 0 ) + else + self:_lock( false ) + --vdebug( "start listening on server socket with id:", self.id ) + self.eventread = addevent( base, self.conn, EV_READ, self.readcallback ) -- register callback + end + return true + end + function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed, therefore we have to close read/write events first + --vdebug( "starting ssl session with client id:", self.id ) + local _ + _ = self.eventread and self.eventread:close( ) -- close events; this must be called outside of the event callbacks! + _ = self.eventwrite and self.eventwrite:close( ) + self.eventread, self.eventwrite = nil, nil + local err + self.conn, err = ssl.wrap( self.conn, self._sslctx ) + if err then + self.fatalerror = err + self.conn = nil -- cannot be used anymore + if call_onconnect then + self.ondisconnect = nil -- dont call this when client isnt really connected + end + self:_close() + debug( "fatal error while ssl wrapping:", err ) + return false + end + self.conn:settimeout( 0 ) -- set non blocking + local handshakecallback = coroutine_wrap( + function( event ) + local _, err + local attempt = 0 + local maxattempt = cfg.MAX_HANDSHAKE_ATTEMPTS + while attempt < maxattempt do -- no endless loop + attempt = attempt + 1 + debug( "ssl handshake of client with id:"..tostring(self)..", attempt:"..attempt ) + if attempt > maxattempt then + self.fatalerror = "max handshake attempts exceeded" + elseif EV_TIMEOUT == event then + self.fatalerror = "timeout during handshake" + else + _, err = self.conn:dohandshake( ) + if not err then + self:_lock( false, false, false ) -- unlock the interface; sending, closing etc allowed + self.send = self.conn.send -- caching table lookups with new client object + self.receive = self.conn.receive + if not call_onconnect then -- trigger listener + self:onstatus("ssl-handshake-complete"); + end + self:_start_session( call_onconnect ) + debug( "ssl handshake done" ) + self.eventhandshake = nil + return -1 + end + if err == "wantwrite" then + event = EV_WRITE + elseif err == "wantread" then + event = EV_READ + else + debug( "ssl handshake error:", err ) + self.fatalerror = err + end + end + if self.fatalerror then + if call_onconnect then + self.ondisconnect = nil -- dont call this when client isnt really connected + end + self:_close() + debug( "handshake failed because:", self.fatalerror ) + self.eventhandshake = nil + return -1 + end + event = coroutine_yield( event, cfg.HANDSHAKE_TIMEOUT ) -- yield this monster... + end + end + ) + debug "starting handshake..." + self:_lock( false, true, true ) -- unlock read/write events, but keep interface locked + self.eventhandshake = addevent( base, self.conn, EV_READWRITE, handshakecallback, cfg.HANDSHAKE_TIMEOUT ) + return true + end + function interface_mt:_destroy() -- close this interface + events and call last listener + debug( "closing client with id:", self.id, self.fatalerror ) + self:_lock( true, true, true ) -- first of all, lock the interface to avoid further actions + local _ + _ = self.eventread and self.eventread:close( ) + if self.type == "client" then + _ = self.eventwrite and self.eventwrite:close( ) + _ = self.eventhandshake and self.eventhandshake:close( ) + _ = self.eventstarthandshake and self.eventstarthandshake:close( ) + _ = self.eventconnect and self.eventconnect:close( ) + _ = self.eventsession and self.eventsession:close( ) + _ = self.eventwritetimeout and self.eventwritetimeout:close( ) + _ = self.eventreadtimeout and self.eventreadtimeout:close( ) + _ = self.ondisconnect and self:ondisconnect( self.fatalerror ~= "client to close" and self.fatalerror) -- call ondisconnect listener (wont be the case if handshake failed on connect) + _ = self.conn and self.conn:close( ) -- close connection + _ = self._server and self._server:counter(-1); + self.eventread, self.eventwrite = nil, nil + self.eventstarthandshake, self.eventhandshake, self.eventclose = nil, nil, nil + self.readcallback, self.writecallback = nil, nil + else + self.conn:close( ) + self.eventread, self.eventclose = nil, nil + self.interface, self.readcallback = nil, nil + end + interfacelist( "delete", self ) + return true + end + + function interface_mt:_lock(nointerface, noreading, nowriting) -- lock or unlock this interface or events + self.nointerface, self.noreading, self.nowriting = nointerface, noreading, nowriting + return nointerface, noreading, nowriting + end + + --TODO: Deprecate + function interface_mt:lock_read(switch) + if switch then + return self:pause(); + else + return self:resume(); + end + end + + function interface_mt:pause() + return self:_lock(self.nointerface, true, self.nowriting); + end + + function interface_mt:resume() + self:_lock(self.nointerface, false, self.nowriting); + if not self.eventread then + self.eventread = addevent( base, self.conn, EV_READ, self.readcallback, cfg.READ_TIMEOUT ); -- register callback + end + end + + function interface_mt:counter(c) + if c then + self._connections = self._connections + c + end + return self._connections + end + + -- Public methods + function interface_mt:write(data) + if self.nowriting then return nil, "locked" end + --vdebug( "try to send data to client, id/data:", self.id, data ) + data = tostring( data ) + local len = #data + local total = len + self.writebufferlen + if total > cfg.MAX_SEND_LENGTH then -- check buffer length + local err = "send buffer exceeded" + debug( "error:", err ) -- to much, check your app + return nil, err + end + t_insert(self.writebuffer, data) -- new buffer + self.writebufferlen = total + if not self.eventwrite then -- register new write event + --vdebug( "register new write event" ) + self.eventwrite = addevent( base, self.conn, EV_WRITE, self.writecallback, cfg.WRITE_TIMEOUT ) + end + return true + end + function interface_mt:close() + if self.nointerface then return nil, "locked"; end + debug( "try to close client connection with id:", self.id ) + if self.type == "client" then + self.fatalerror = "client to close" + if self.eventwrite then -- wait for incomplete write request + self:_lock( true, true, false ) + debug "closing delayed until writebuffer is empty" + return nil, "writebuffer not empty, waiting" + else -- close now + self:_lock( true, true, true ) + self:_close() + return true + end + else + debug( "try to close server with id:", tostring(self.id)) + self.fatalerror = "server to close" + self:_lock( true ) + self:_close( 0 ) + return true + end + end + + function interface_mt:socket() + return self.conn + end + + function interface_mt:server() + return self._server or self; + end + + function interface_mt:port() + return self._port + end + + function interface_mt:serverport() + return self._serverport + end + + function interface_mt:ip() + return self._ip + end + + function interface_mt:ssl() + return self._usingssl + end + + function interface_mt:type() + return self._type or "client" + end + + function interface_mt:connections() + return self._connections + end + + function interface_mt:address() + return self.addr + end + + function interface_mt:set_sslctx(sslctx) + self._sslctx = sslctx; + if sslctx then + self.starttls = nil; -- use starttls() of interface_mt + else + self.starttls = false; -- prevent starttls() + end + end + + function interface_mt:set_mode(pattern) + if pattern then + self._pattern = pattern; + end + return self._pattern; + end + + function interface_mt:set_send(new_send) + -- No-op, we always use the underlying connection's send + end + + function interface_mt:starttls(sslctx, call_onconnect) + debug( "try to start ssl at client id:", self.id ) + local err + self._sslctx = sslctx; + if self._usingssl then -- startssl was already called + err = "ssl already active" + end + if err then + debug( "error:", err ) + return nil, err + end + self._usingssl = true + self.startsslcallback = function( ) -- we have to start the handshake outside of a read/write event + self.startsslcallback = nil + self:_start_ssl(call_onconnect); + self.eventstarthandshake = nil + return -1 + end + if not self.eventwrite then + self:_lock( true, true, true ) -- lock the interface, to not disturb the handshake + self.eventstarthandshake = addevent( base, nil, EV_TIMEOUT, self.startsslcallback, 0 ) -- add event to start handshake + else -- wait until writebuffer is empty + self:_lock( true, true, false ) + debug "ssl session delayed until writebuffer is empty..." + end + self.starttls = false; + return true + end + + function interface_mt:setoption(option, value) + if self.conn.setoption then + return self.conn:setoption(option, value); + end + return false, "setoption not implemented"; + end + + function interface_mt:setlistener(listener) + self.onconnect, self.ondisconnect, self.onincoming, self.ontimeout, self.onstatus + = listener.onconnect, listener.ondisconnect, listener.onincoming, listener.ontimeout, listener.onstatus; + end + + -- Stub handlers + function interface_mt:onconnect() + end + function interface_mt:onincoming() + end + function interface_mt:ondisconnect() + end + function interface_mt:ontimeout() + end + function interface_mt:ondrain() + end + function interface_mt:onstatus() + end +end + +-- End of client interface methods + +local handleclient; +do + local string_sub = string.sub -- caching table lookups + local addevent = base.addevent + local socket_gettime = socket.gettime + function handleclient( client, ip, port, server, pattern, listener, sslctx ) -- creates an client interface + --vdebug("creating client interfacce...") + local interface = { + type = "client"; + conn = client; + currenttime = socket_gettime( ); -- safe the origin + writebuffer = {}; -- writebuffer + writebufferlen = 0; -- length of writebuffer + send = client.send; -- caching table lookups + receive = client.receive; + onconnect = listener.onconnect; -- will be called when client disconnects + ondisconnect = listener.ondisconnect; -- will be called when client disconnects + onincoming = listener.onincoming; -- will be called when client sends data + ontimeout = listener.ontimeout; -- called when fatal socket timeout occurs + onstatus = listener.onstatus; -- called for status changes (e.g. of SSL/TLS) + eventread = false, eventwrite = false, eventclose = false, + eventhandshake = false, eventstarthandshake = false; -- event handler + eventconnect = false, eventsession = false; -- more event handler... + eventwritetimeout = false; -- even more event handler... + eventreadtimeout = false; + fatalerror = false; -- error message + writecallback = false; -- will be called on write events + readcallback = false; -- will be called on read events + nointerface = true; -- lock/unlock parameter of this interface + noreading = false, nowriting = false; -- locks of the read/writecallback + startsslcallback = false; -- starting handshake callback + position = false; -- position of client in interfacelist + + -- Properties + _ip = ip, _port = port, _server = server, _pattern = pattern, + _serverport = (server and server:port() or nil), + _sslctx = sslctx; -- parameters + _usingssl = false; -- client is using ssl; + } + if not ssl then interface.starttls = false; end + interface.id = tostring(interface):match("%x+$"); + interface.writecallback = function( event ) -- called on write events + --vdebug( "new client write event, id/ip/port:", interface, ip, port ) + if interface.nowriting or ( interface.fatalerror and ( "client to close" ~= interface.fatalerror ) ) then -- leave this event + --vdebug( "leaving this event because:", interface.nowriting or interface.fatalerror ) + interface.eventwrite = false + return -1 + end + if EV_TIMEOUT == event then -- took too long to write some data to socket -> disconnect + interface.fatalerror = "timeout during writing" + debug( "writing failed:", interface.fatalerror ) + interface:_close() + interface.eventwrite = false + return -1 + else -- can write :) + if interface._usingssl then -- handle luasec + if interface.eventreadtimeout then -- we have to read first + local ret = interface.readcallback( ) -- call readcallback + --vdebug( "tried to read in writecallback, result:", ret ) + end + if interface.eventwritetimeout then -- luasec only + interface.eventwritetimeout:close( ) -- first we have to close timeout event which where regged after a wantread error + interface.eventwritetimeout = false + end + end + interface.writebuffer = { t_concat(interface.writebuffer) } + local succ, err, byte = interface.conn:send( interface.writebuffer[1], 1, interface.writebufferlen ) + --vdebug( "write data:", interface.writebuffer, "error:", err, "part:", byte ) + if succ then -- writing succesful + interface.writebuffer[1] = nil + interface.writebufferlen = 0 + interface:ondrain(); + if interface.fatalerror then + debug "closing client after writing" + interface:_close() -- close interface if needed + elseif interface.startsslcallback then -- start ssl connection if needed + debug "starting ssl handshake after writing" + interface.eventstarthandshake = addevent( base, nil, EV_TIMEOUT, interface.startsslcallback, 0 ) + elseif interface.eventreadtimeout then + return EV_WRITE, EV_TIMEOUT + end + interface.eventwrite = nil + return -1 + elseif byte and (err == "timeout" or err == "wantwrite") then -- want write again + --vdebug( "writebuffer is not empty:", err ) + interface.writebuffer[1] = string_sub( interface.writebuffer[1], byte + 1, interface.writebufferlen ) -- new buffer + interface.writebufferlen = interface.writebufferlen - byte + if "wantread" == err then -- happens only with luasec + local callback = function( ) + interface:_close() + interface.eventwritetimeout = nil + return -1; + end + interface.eventwritetimeout = addevent( base, nil, EV_TIMEOUT, callback, cfg.WRITE_TIMEOUT ) -- reg a new timeout event + debug( "wantread during write attempt, reg it in readcallback but dont know what really happens next..." ) + -- hopefully this works with luasec; its simply not possible to use 2 different write events on a socket in luaevent + return -1 + end + return EV_WRITE, cfg.WRITE_TIMEOUT + else -- connection was closed during writing or fatal error + interface.fatalerror = err or "fatal error" + debug( "connection failed in write event:", interface.fatalerror ) + interface:_close() + interface.eventwrite = nil + return -1 + end + end + end + + interface.readcallback = function( event ) -- called on read events + --vdebug( "new client read event, id/ip/port:", tostring(interface.id), tostring(ip), tostring(port) ) + if interface.noreading or interface.fatalerror then -- leave this event + --vdebug( "leaving this event because:", tostring(interface.noreading or interface.fatalerror) ) + interface.eventread = nil + return -1 + end + if EV_TIMEOUT == event then -- took too long to get some data from client -> disconnect + interface.fatalerror = "timeout during receiving" + debug( "connection failed:", interface.fatalerror ) + interface:_close() + interface.eventread = nil + return -1 + else -- can read + if interface._usingssl then -- handle luasec + if interface.eventwritetimeout then -- ok, in the past writecallback was regged + local ret = interface.writecallback( ) -- call it + --vdebug( "tried to write in readcallback, result:", tostring(ret) ) + end + if interface.eventreadtimeout then + interface.eventreadtimeout:close( ) + interface.eventreadtimeout = nil + end + end + local buffer, err, part = interface.conn:receive( interface._pattern ) -- receive buffer with "pattern" + --vdebug( "read data:", tostring(buffer), "error:", tostring(err), "part:", tostring(part) ) + buffer = buffer or part + if buffer and #buffer > cfg.MAX_READ_LENGTH then -- check buffer length + interface.fatalerror = "receive buffer exceeded" + debug( "fatal error:", interface.fatalerror ) + interface:_close() + interface.eventread = nil + return -1 + end + if err and ( err ~= "timeout" and err ~= "wantread" ) then + if "wantwrite" == err then -- need to read on write event + if not interface.eventwrite then -- register new write event if needed + interface.eventwrite = addevent( base, interface.conn, EV_WRITE, interface.writecallback, cfg.WRITE_TIMEOUT ) + end + interface.eventreadtimeout = addevent( base, nil, EV_TIMEOUT, + function( ) + interface:_close() + end, cfg.READ_TIMEOUT + ) + debug( "wantwrite during read attempt, reg it in writecallback but dont know what really happens next..." ) + -- to be honest i dont know what happens next, if it is allowed to first read, the write etc... + else -- connection was closed or fatal error + interface.fatalerror = err + debug( "connection failed in read event:", interface.fatalerror ) + interface:_close() + interface.eventread = nil + return -1 + end + else + interface.onincoming( interface, buffer, err ) -- send new data to listener + end + if interface.noreading then + interface.eventread = nil; + return -1; + end + return EV_READ, cfg.READ_TIMEOUT + end + end + + client:settimeout( 0 ) -- set non blocking + setmetatable(interface, interface_mt) + interfacelist( "add", interface ) -- add to interfacelist + return interface + end +end + +local handleserver +do + function handleserver( server, addr, port, pattern, listener, sslctx ) -- creates an server interface + debug "creating server interface..." + local interface = { + _connections = 0; + + conn = server; + onconnect = listener.onconnect; -- will be called when new client connected + eventread = false; -- read event handler + eventclose = false; -- close event handler + readcallback = false; -- read event callback + fatalerror = false; -- error message + nointerface = true; -- lock/unlock parameter + + _ip = addr, _port = port, _pattern = pattern, + _sslctx = sslctx; + } + interface.id = tostring(interface):match("%x+$"); + interface.readcallback = function( event ) -- server handler, called on incoming connections + --vdebug( "server can accept, id/addr/port:", interface, addr, port ) + if interface.fatalerror then + --vdebug( "leaving this event because:", self.fatalerror ) + interface.eventread = nil + return -1 + end + local delay = cfg.ACCEPT_DELAY + if EV_TIMEOUT == event then + if interface._connections >= cfg.MAX_CONNECTIONS then -- check connection count + debug( "to many connections, seconds to wait for next accept:", delay ) + return EV_TIMEOUT, delay -- timeout... + else + return EV_READ -- accept again + end + end + --vdebug("max connection check ok, accepting...") + local client, err = server:accept() -- try to accept; TODO: check err + while client do + if interface._connections >= cfg.MAX_CONNECTIONS then + client:close( ) -- refuse connection + debug( "maximal connections reached, refuse client connection; accept delay:", delay ) + return EV_TIMEOUT, delay -- delay for next accept attempt + end + local client_ip, client_port = client:getpeername( ) + interface._connections = interface._connections + 1 -- increase connection count + local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx ) + --vdebug( "client id:", clientinterface, "startssl:", startssl ) + if ssl and sslctx then + clientinterface:starttls(sslctx, true) + else + clientinterface:_start_session( true ) + end + debug( "accepted incoming client connection from:", client_ip or "<unknown IP>", client_port or "<unknown port>", "to", port or "<unknown port>"); + + client, err = server:accept() -- try to accept again + end + return EV_READ + end + + server:settimeout( 0 ) + setmetatable(interface, interface_mt) + interfacelist( "add", interface ) + interface:_start_session() + return interface + end +end + +local addserver = ( function( ) + return function( addr, port, listener, pattern, sslcfg, startssl ) -- TODO: check arguments + --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslcfg or "nil", startssl or "nil") + local server, err = socket.bind( addr, port, cfg.ACCEPT_QUEUE ) -- create server socket + if not server then + debug( "creating server socket on "..addr.." port "..port.." failed:", err ) + return nil, err + end + local sslctx + if sslcfg then + if not ssl then + debug "fatal error: luasec not found" + return nil, "luasec not found" + end + sslctx, err = sslcfg + if err then + debug( "error while creating new ssl context for server socket:", err ) + return nil, err + end + end + local interface = handleserver( server, addr, port, pattern, listener, sslctx, startssl ) -- new server handler + debug( "new server created with id:", tostring(interface)) + return interface + end +end )( ) + +local addclient, wrapclient +do + function wrapclient( client, ip, port, listeners, pattern, sslctx ) + local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx ) + interface:_start_connection(sslctx) + return interface, client + --function handleclient( client, ip, port, server, pattern, listener, _, sslctx ) -- creates an client interface + end + + function addclient( addr, serverport, listener, pattern, localaddr, localport, sslcfg, startssl ) + local client, err = socket.tcp() -- creating new socket + if not client then + debug( "cannot create socket:", err ) + return nil, err + end + client:settimeout( 0 ) -- set nonblocking + if localaddr then + local res, err = client:bind( localaddr, localport, -1 ) + if not res then + debug( "cannot bind client:", err ) + return nil, err + end + end + local sslctx + if sslcfg then -- handle ssl/new context + if not ssl then + debug "need luasec, but not available" + return nil, "luasec not found" + end + sslctx, err = sslcfg + if err then + debug( "cannot create new ssl context:", err ) + return nil, err + end + end + local res, err = client:connect( addr, serverport ) -- connect + if res or ( err == "timeout" ) then + local ip, port = client:getsockname( ) + local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx, startssl ) + interface:_start_connection( startssl ) + debug( "new connection id:", interface.id ) + return interface, err + else + debug( "new connection failed:", err ) + return nil, err + end + end +end + + +local loop = function( ) -- starts the event loop + base:loop( ) + return "quitting"; +end + +local newevent = ( function( ) + local add = base.addevent + return function( ... ) + return add( base, ... ) + end +end )( ) + +local closeallservers = function( arg ) + for _, item in ipairs( interfacelist( ) ) do + if item.type == "server" then + item:close( arg ) + end + end +end + +local function setquitting(yes) + if yes then + -- Quit now + closeallservers(); + base:loopexit(); + end +end + +local function get_backend() + return base:method(); +end + +-- We need to hold onto the events to stop them +-- being garbage-collected +local signal_events = {}; -- [signal_num] -> event object +local function hook_signal(signal_num, handler) + local function _handler(event) + local ret = handler(); + if ret ~= false then -- Continue handling this signal? + return EV_SIGNAL; -- Yes + end + return -1; -- Close this event + end + signal_events[signal_num] = base:addevent(signal_num, EV_SIGNAL, _handler); + return signal_events[signal_num]; +end + +local function link(sender, receiver, buffersize) + local sender_locked; + + function receiver:ondrain() + if sender_locked then + sender:resume(); + sender_locked = nil; + end + end + + function sender:onincoming(data) + receiver:write(data); + if receiver.writebufferlen >= buffersize then + sender_locked = true; + sender:pause(); + end + end +end + +return { + + cfg = cfg, + base = base, + loop = loop, + link = link, + event = event, + event_base = base, + addevent = newevent, + addserver = addserver, + addclient = addclient, + wrapclient = wrapclient, + setquitting = setquitting, + closeall = closeallservers, + get_backend = get_backend, + hook_signal = hook_signal, + + __NAME = SCRIPT_NAME, + __DATE = LAST_MODIFIED, + __AUTHOR = SCRIPT_AUTHOR, + __VERSION = SCRIPT_VERSION, + +} diff --git a/net/server_select.lua b/net/server_select.lua new file mode 100644 index 00000000..7eb330a8 --- /dev/null +++ b/net/server_select.lua @@ -0,0 +1,984 @@ +-- +-- server.lua by blastbeat of the luadch project +-- Re-used here under the MIT/X Consortium License +-- +-- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain +-- + +-- // wrapping luadch stuff // -- + +local use = function( what ) + return _G[ what ] +end + +local log, table_concat = require ("util.logger").init("socket"), table.concat; +local out_put = function (...) return log("debug", table_concat{...}); end +local out_error = function (...) return log("warn", table_concat{...}); end + +----------------------------------// DECLARATION //-- + +--// constants //-- + +local STAT_UNIT = 1 -- byte + +--// lua functions //-- + +local type = use "type" +local pairs = use "pairs" +local ipairs = use "ipairs" +local tonumber = use "tonumber" +local tostring = use "tostring" + +--// lua libs //-- + +local os = use "os" +local table = use "table" +local string = use "string" +local coroutine = use "coroutine" + +--// lua lib methods //-- + +local os_difftime = os.difftime +local math_min = math.min +local math_huge = math.huge +local table_concat = table.concat +local string_sub = string.sub +local coroutine_wrap = coroutine.wrap +local coroutine_yield = coroutine.yield + +--// extern libs //-- + +local luasec = use "ssl" +local luasocket = use "socket" or require "socket" +local luasocket_gettime = luasocket.gettime + +--// extern lib methods //-- + +local ssl_wrap = ( luasec and luasec.wrap ) +local socket_bind = luasocket.bind +local socket_sleep = luasocket.sleep +local socket_select = luasocket.select + +--// functions //-- + +local id +local loop +local stats +local idfalse +local closeall +local addsocket +local addserver +local addtimer +local getserver +local wrapserver +local getsettings +local closesocket +local removesocket +local removeserver +local wrapconnection +local changesettings + +--// tables //-- + +local _server +local _readlist +local _timerlist +local _sendlist +local _socketlist +local _closelist +local _readtimes +local _writetimes + +--// simple data types //-- + +local _ +local _readlistlen +local _sendlistlen +local _timerlistlen + +local _sendtraffic +local _readtraffic + +local _selecttimeout +local _sleeptime +local _tcpbacklog + +local _starttime +local _currenttime + +local _maxsendlen +local _maxreadlen + +local _checkinterval +local _sendtimeout +local _readtimeout + +local _timer + +local _maxselectlen +local _maxfd + +local _maxsslhandshake + +----------------------------------// DEFINITION //-- + +_server = { } -- key = port, value = table; list of listening servers +_readlist = { } -- array with sockets to read from +_sendlist = { } -- arrary with sockets to write to +_timerlist = { } -- array of timer functions +_socketlist = { } -- key = socket, value = wrapped socket (handlers) +_readtimes = { } -- key = handler, value = timestamp of last data reading +_writetimes = { } -- key = handler, value = timestamp of last data writing/sending +_closelist = { } -- handlers to close + +_readlistlen = 0 -- length of readlist +_sendlistlen = 0 -- length of sendlist +_timerlistlen = 0 -- lenght of timerlist + +_sendtraffic = 0 -- some stats +_readtraffic = 0 + +_selecttimeout = 1 -- timeout of socket.select +_sleeptime = 0 -- time to wait at the end of every loop +_tcpbacklog = 128 -- some kind of hint to the OS + +_maxsendlen = 51000 * 1024 -- max len of send buffer +_maxreadlen = 25000 * 1024 -- max len of read buffer + +_checkinterval = 1200000 -- interval in secs to check idle clients +_sendtimeout = 60000 -- allowed send idle time in secs +_readtimeout = 6 * 60 * 60 -- allowed read idle time in secs + +local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows +_maxfd = luasocket._SETSIZE or (is_windows and math.huge) or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows +_maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows + +_maxsslhandshake = 30 -- max handshake round-trips + +----------------------------------// PRIVATE //-- + +wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd + + if socket:getfd() >= _maxfd then + out_error("server.lua: Disallowed FD number: "..socket:getfd()) + socket:close() + return nil, "fd-too-large" + end + + local connections = 0 + + local dispatch, disconnect = listeners.onconnect, listeners.ondisconnect + + local accept = socket.accept + + --// public methods of the object //-- + + local handler = { } + + handler.shutdown = function( ) end + + handler.ssl = function( ) + return sslctx ~= nil + end + handler.sslctx = function( ) + return sslctx + end + handler.remove = function( ) + connections = connections - 1 + if handler then + handler.resume( ) + end + end + handler.close = function() + socket:close( ) + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _server[ip..":"..serverport] = nil; + _socketlist[ socket ] = nil + handler = nil + socket = nil + --mem_free( ) + out_put "server.lua: closed server handler and removed sockets from list" + end + handler.pause = function( hard ) + if not handler.paused then + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + if hard then + _socketlist[ socket ] = nil + socket:close( ) + socket = nil; + end + handler.paused = true; + end + end + handler.resume = function( ) + if handler.paused then + if not socket then + socket = socket_bind( ip, serverport, _tcpbacklog ); + socket:settimeout( 0 ) + end + _readlistlen = addsocket(_readlist, socket, _readlistlen) + _socketlist[ socket ] = handler + handler.paused = false; + end + end + handler.ip = function( ) + return ip + end + handler.serverport = function( ) + return serverport + end + handler.socket = function( ) + return socket + end + handler.readbuffer = function( ) + if _readlistlen >= _maxselectlen or _sendlistlen >= _maxselectlen then + handler.pause( ) + out_put( "server.lua: refused new client connection: server full" ) + return false + end + local client, err = accept( socket ) -- try to accept + if client then + local ip, clientport = client:getpeername( ) + local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket + if err then -- error while wrapping ssl socket + return false + end + connections = connections + 1 + out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport)) + if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes + return dispatch( handler ); + end + return; + elseif err then -- maybe timeout or something else + out_put( "server.lua: error with new client connection: ", tostring(err) ) + return false + end + end + return handler +end + +wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object + + if socket:getfd() >= _maxfd then + out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent + socket:close( ) -- Should we send some kind of error here? + server.pause( ) + return nil, nil, "fd-too-large" + end + socket:settimeout( 0 ) + + --// local import of socket methods //-- + + local send + local receive + local shutdown + + --// private closures of the object //-- + + local ssl + + local dispatch = listeners.onincoming + local status = listeners.onstatus + local disconnect = listeners.ondisconnect + local drain = listeners.ondrain + + local bufferqueue = { } -- buffer array + local bufferqueuelen = 0 -- end of buffer array + + local toclose + local fatalerror + local needtls + + local bufferlen = 0 + + local noread = false + local nosend = false + + local sendtraffic, readtraffic = 0, 0 + + local maxsendlen = _maxsendlen + local maxreadlen = _maxreadlen + + --// public methods of the object //-- + + local handler = bufferqueue -- saves a table ^_^ + + handler.dispatch = function( ) + return dispatch + end + handler.disconnect = function( ) + return disconnect + end + handler.setlistener = function( self, listeners ) + dispatch = listeners.onincoming + disconnect = listeners.ondisconnect + status = listeners.onstatus + drain = listeners.ondrain + end + handler.getstats = function( ) + return readtraffic, sendtraffic + end + handler.ssl = function( ) + return ssl + end + handler.sslctx = function ( ) + return sslctx + end + handler.send = function( _, data, i, j ) + return send( socket, data, i, j ) + end + handler.receive = function( pattern, prefix ) + return receive( socket, pattern, prefix ) + end + handler.shutdown = function( pattern ) + return shutdown( socket, pattern ) + end + handler.setoption = function (self, option, value) + if socket.setoption then + return socket:setoption(option, value); + end + return false, "setoption not implemented"; + end + handler.force_close = function ( self, err ) + if bufferqueuelen ~= 0 then + out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport)) + bufferqueuelen = 0; + end + return self:close(err); + end + handler.close = function( self, err ) + if not handler then return true; end + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _readtimes[ handler ] = nil + if bufferqueuelen ~= 0 then + handler.sendbuffer() -- Try now to send any outstanding data + if bufferqueuelen ~= 0 then -- Still not empty, so we'll try again later + if handler then + handler.write = nil -- ... but no further writing allowed + end + toclose = true + return false + end + end + if socket then + _ = shutdown and shutdown( socket ) + socket:close( ) + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + _socketlist[ socket ] = nil + socket = nil + else + out_put "server.lua: socket already closed" + end + if handler then + _writetimes[ handler ] = nil + _closelist[ handler ] = nil + local _handler = handler; + handler = nil + if disconnect then + disconnect(_handler, err or false); + disconnect = nil + end + end + if server then + server.remove( ) + end + out_put "server.lua: closed client handler and removed socket from list" + return true + end + handler.ip = function( ) + return ip + end + handler.serverport = function( ) + return serverport + end + handler.clientport = function( ) + return clientport + end + local write = function( self, data ) + bufferlen = bufferlen + #data + if bufferlen > maxsendlen then + _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle + handler.write = idfalse -- dont write anymore + return false + elseif socket and not _sendlist[ socket ] then + _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) + end + bufferqueuelen = bufferqueuelen + 1 + bufferqueue[ bufferqueuelen ] = data + if handler then + _writetimes[ handler ] = _writetimes[ handler ] or _currenttime + end + return true + end + handler.write = write + handler.bufferqueue = function( self ) + return bufferqueue + end + handler.socket = function( self ) + return socket + end + handler.set_mode = function( self, new ) + pattern = new or pattern + return pattern + end + handler.set_send = function ( self, newsend ) + send = newsend or send + return send + end + handler.bufferlen = function( self, readlen, sendlen ) + maxsendlen = sendlen or maxsendlen + maxreadlen = readlen or maxreadlen + return bufferlen, maxreadlen, maxsendlen + end + --TODO: Deprecate + handler.lock_read = function (self, switch) + if switch == true then + local tmp = _readlistlen + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _readtimes[ handler ] = nil + if _readlistlen ~= tmp then + noread = true + end + elseif switch == false then + if noread then + noread = false + _readlistlen = addsocket(_readlist, socket, _readlistlen) + _readtimes[ handler ] = _currenttime + end + end + return noread + end + handler.pause = function (self) + return self:lock_read(true); + end + handler.resume = function (self) + return self:lock_read(false); + end + handler.lock = function( self, switch ) + handler.lock_read (switch) + if switch == true then + handler.write = idfalse + local tmp = _sendlistlen + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + _writetimes[ handler ] = nil + if _sendlistlen ~= tmp then + nosend = true + end + elseif switch == false then + handler.write = write + if nosend then + nosend = false + write( "" ) + end + end + return noread, nosend + end + local _readbuffer = function( ) -- this function reads data + local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern" + if not err or (err == "wantread" or err == "timeout") then -- received something + local buffer = buffer or part or "" + local len = #buffer + if len > maxreadlen then + handler:close( "receive buffer exceeded" ) + return false + end + local count = len * STAT_UNIT + readtraffic = readtraffic + count + _readtraffic = _readtraffic + count + _readtimes[ handler ] = _currenttime + --out_put( "server.lua: read data '", buffer:gsub("[^%w%p ]", "."), "', error: ", err ) + return dispatch( handler, buffer, err ) + else -- connections was closed or fatal error + out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) ) + fatalerror = true + _ = handler and handler:force_close( err ) + return false + end + end + local _sendbuffer = function( ) -- this function sends data + local succ, err, byte, buffer, count; + if socket then + buffer = table_concat( bufferqueue, "", 1, bufferqueuelen ) + succ, err, byte = send( socket, buffer, 1, bufferlen ) + count = ( succ or byte or 0 ) * STAT_UNIT + sendtraffic = sendtraffic + count + _sendtraffic = _sendtraffic + count + for i = bufferqueuelen,1,-1 do + bufferqueue[ i ] = nil + end + --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) ) + else + succ, err, count = false, "unexpected close", 0; + end + if succ then -- sending succesful + bufferqueuelen = 0 + bufferlen = 0 + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist + _writetimes[ handler ] = nil + if drain then + drain(handler) + end + _ = needtls and handler:starttls(nil) + _ = toclose and handler:force_close( ) + return true + elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write + buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer + bufferqueue[ 1 ] = buffer -- insert new buffer in queue + bufferqueuelen = 1 + bufferlen = bufferlen - byte + _writetimes[ handler ] = _currenttime + return true + else -- connection was closed during sending or fatal error + out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) ) + fatalerror = true + _ = handler and handler:force_close( err ) + return false + end + end + + -- Set the sslctx + local handshake; + function handler.set_sslctx(self, new_sslctx) + sslctx = new_sslctx; + local read, wrote + handshake = coroutine_wrap( function( client ) -- create handshake coroutine + local err + for i = 1, _maxsslhandshake do + _sendlistlen = ( wrote and removesocket( _sendlist, client, _sendlistlen ) ) or _sendlistlen + _readlistlen = ( read and removesocket( _readlist, client, _readlistlen ) ) or _readlistlen + read, wrote = nil, nil + _, err = client:dohandshake( ) + if not err then + out_put( "server.lua: ssl handshake done" ) + handler.readbuffer = _readbuffer -- when handshake is done, replace the handshake function with regular functions + handler.sendbuffer = _sendbuffer + _ = status and status( handler, "ssl-handshake-complete" ) + if self.autostart_ssl and listeners.onconnect then + listeners.onconnect(self); + end + _readlistlen = addsocket(_readlist, client, _readlistlen) + return true + else + if err == "wantwrite" then + _sendlistlen = addsocket(_sendlist, client, _sendlistlen) + wrote = true + elseif err == "wantread" then + _readlistlen = addsocket(_readlist, client, _readlistlen) + read = true + else + break; + end + err = nil; + coroutine_yield( ) -- handshake not finished + end + end + out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") ) + _ = handler and handler:force_close("ssl handshake failed") + return false, err -- handshake failed + end + ) + end + if luasec then + handler.starttls = function( self, _sslctx) + if _sslctx then + handler:set_sslctx(_sslctx); + end + if bufferqueuelen > 0 then + out_put "server.lua: we need to do tls, but delaying until send buffer empty" + needtls = true + return + end + out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) + local oldsocket, err = socket + socket, err = ssl_wrap( socket, sslctx ) -- wrap socket + if not socket then + out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) + return nil, err -- fatal error + end + + socket:settimeout( 0 ) + + -- add the new socket to our system + send = socket.send + receive = socket.receive + shutdown = id + _socketlist[ socket ] = handler + _readlistlen = addsocket(_readlist, socket, _readlistlen) + + -- remove traces of the old socket + _readlistlen = removesocket( _readlist, oldsocket, _readlistlen ) + _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen ) + _socketlist[ oldsocket ] = nil + + handler.starttls = nil + needtls = nil + + -- Secure now (if handshake fails connection will close) + ssl = true + + handler.readbuffer = handshake + handler.sendbuffer = handshake + return handshake( socket ) -- do handshake + end + end + + handler.readbuffer = _readbuffer + handler.sendbuffer = _sendbuffer + send = socket.send + receive = socket.receive + shutdown = ( ssl and id ) or socket.shutdown + + _socketlist[ socket ] = handler + _readlistlen = addsocket(_readlist, socket, _readlistlen) + + if sslctx and luasec then + out_put "server.lua: auto-starting ssl negotiation..." + handler.autostart_ssl = true; + local ok, err = handler:starttls(sslctx); + if ok == false then + return nil, nil, err + end + end + + return handler, socket +end + +id = function( ) +end + +idfalse = function( ) + return false +end + +addsocket = function( list, socket, len ) + if not list[ socket ] then + len = len + 1 + list[ len ] = socket + list[ socket ] = len + end + return len; +end + +removesocket = function( list, socket, len ) -- this function removes sockets from a list ( copied from copas ) + local pos = list[ socket ] + if pos then + list[ socket ] = nil + local last = list[ len ] + list[ len ] = nil + if last ~= socket then + list[ last ] = pos + list[ pos ] = last + end + return len - 1 + end + return len +end + +closesocket = function( socket ) + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _socketlist[ socket ] = nil + socket:close( ) + --mem_free( ) +end + +local function link(sender, receiver, buffersize) + local sender_locked; + local _sendbuffer = receiver.sendbuffer; + function receiver.sendbuffer() + _sendbuffer(); + if sender_locked and receiver.bufferlen() < buffersize then + sender:lock_read(false); -- Unlock now + sender_locked = nil; + end + end + + local _readbuffer = sender.readbuffer; + function sender.readbuffer() + _readbuffer(); + if not sender_locked and receiver.bufferlen() >= buffersize then + sender_locked = true; + sender:lock_read(true); + end + end +end + +----------------------------------// PUBLIC //-- + +addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server + local err + if type( listeners ) ~= "table" then + err = "invalid listener table" + end + if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then + err = "invalid port" + elseif _server[ addr..":"..port ] then + err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist" + elseif sslctx and not luasec then + err = "luasec not found" + end + if err then + out_error( "server.lua, [", addr, "]:", port, ": ", err ) + return nil, err + end + addr = addr or "*" + local server, err = socket_bind( addr, port, _tcpbacklog ) + if err then + out_error( "server.lua, [", addr, "]:", port, ": ", err ) + return nil, err + end + local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket + if not handler then + server:close( ) + return nil, err + end + server:settimeout( 0 ) + _readlistlen = addsocket(_readlist, server, _readlistlen) + _server[ addr..":"..port ] = handler + _socketlist[ server ] = handler + out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" ) + return handler +end + +getserver = function ( addr, port ) + return _server[ addr..":"..port ]; +end + +removeserver = function( addr, port ) + local handler = _server[ addr..":"..port ] + if not handler then + return nil, "no server found on '[" .. addr .. "]:" .. tostring( port ) .. "'" + end + handler:close( ) + _server[ addr..":"..port ] = nil + return true +end + +closeall = function( ) + for _, handler in pairs( _socketlist ) do + handler:close( ) + _socketlist[ _ ] = nil + end + _readlistlen = 0 + _sendlistlen = 0 + _timerlistlen = 0 + _server = { } + _readlist = { } + _sendlist = { } + _timerlist = { } + _socketlist = { } + --mem_free( ) +end + +getsettings = function( ) + return { + select_timeout = _selecttimeout; + select_sleep_time = _sleeptime; + tcp_backlog = _tcpbacklog; + max_send_buffer_size = _maxsendlen; + max_receive_buffer_size = _maxreadlen; + select_idle_check_interval = _checkinterval; + send_timeout = _sendtimeout; + read_timeout = _readtimeout; + max_connections = _maxselectlen; + max_ssl_handshake_roundtrips = _maxsslhandshake; + highest_allowed_fd = _maxfd; + } +end + +changesettings = function( new ) + if type( new ) ~= "table" then + return nil, "invalid settings table" + end + _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout + _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime + _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen + _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen + _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval + _tcpbacklog = tonumber( new.tcp_backlog ) or _tcpbacklog + _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout + _readtimeout = tonumber( new.read_timeout ) or _readtimeout + _maxselectlen = new.max_connections or _maxselectlen + _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake + _maxfd = new.highest_allowed_fd or _maxfd + return true +end + +addtimer = function( listener ) + if type( listener ) ~= "function" then + return nil, "invalid listener function" + end + _timerlistlen = _timerlistlen + 1 + _timerlist[ _timerlistlen ] = listener + return true +end + +stats = function( ) + return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen +end + +local quitting; + +local function setquitting(quit) + quitting = not not quit; +end + +loop = function(once) -- this is the main loop of the program + if quitting then return "quitting"; end + if once then quitting = "once"; end + local next_timer_time = math_huge; + repeat + local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) ) + for i, socket in ipairs( write ) do -- send data waiting in writequeues + local handler = _socketlist[ socket ] + if handler then + handler.sendbuffer( ) + else + closesocket( socket ) + out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen + end + end + for i, socket in ipairs( read ) do -- receive data + local handler = _socketlist[ socket ] + if handler then + handler.readbuffer( ) + else + closesocket( socket ) + out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen + end + end + for handler, err in pairs( _closelist ) do + handler.disconnect( )( handler, err ) + handler:force_close() -- forced disconnect + _closelist[ handler ] = nil; + end + _currenttime = luasocket_gettime( ) + + -- Check for socket timeouts + local difftime = os_difftime( _currenttime - _starttime ) + if difftime > _checkinterval then + _starttime = _currenttime + for handler, timestamp in pairs( _writetimes ) do + if os_difftime( _currenttime - timestamp ) > _sendtimeout then + --_writetimes[ handler ] = nil + handler.disconnect( )( handler, "send timeout" ) + handler:force_close() -- forced disconnect + end + end + for handler, timestamp in pairs( _readtimes ) do + if os_difftime( _currenttime - timestamp ) > _readtimeout then + --_readtimes[ handler ] = nil + handler.disconnect( )( handler, "read timeout" ) + handler:close( ) -- forced disconnect? + end + end + end + + -- Fire timers + if _currenttime - _timer >= math_min(next_timer_time, 1) then + next_timer_time = math_huge; + for i = 1, _timerlistlen do + local t = _timerlist[ i ]( _currenttime ) -- fire timers + if t then next_timer_time = math_min(next_timer_time, t); end + end + _timer = _currenttime + else + next_timer_time = next_timer_time - (_currenttime - _timer); + end + + -- wait some time (0 by default) + socket_sleep( _sleeptime ) + until quitting; + if once and quitting == "once" then quitting = nil; return; end + return "quitting" +end + +local function step() + return loop(true); +end + +local function get_backend() + return "select"; +end + +--// EXPERIMENTAL //-- + +local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx ) + local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) + if not handler then return nil, err end + _socketlist[ socket ] = handler + if not sslctx then + _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) + if listeners.onconnect then + -- When socket is writeable, call onconnect + local _sendbuffer = handler.sendbuffer; + handler.sendbuffer = function () + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ); + handler.sendbuffer = _sendbuffer; + listeners.onconnect(handler); + -- If there was data with the incoming packet, handle it now. + if #handler:bufferqueue() > 0 then + return _sendbuffer(); + end + end + end + end + return handler, socket +end + +local addclient = function( address, port, listeners, pattern, sslctx ) + local client, err = luasocket.tcp( ) + if err then + return nil, err + end + client:settimeout( 0 ) + _, err = client:connect( address, port ) + if err then -- try again + local handler = wrapclient( client, address, port, listeners ) + else + wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx ) + end +end + +--// EXPERIMENTAL //-- + +----------------------------------// BEGIN //-- + +use "setmetatable" ( _socketlist, { __mode = "k" } ) +use "setmetatable" ( _readtimes, { __mode = "k" } ) +use "setmetatable" ( _writetimes, { __mode = "k" } ) + +_timer = luasocket_gettime( ) +_starttime = luasocket_gettime( ) + +local function setlogger(new_logger) + local old_logger = log; + if new_logger then + log = new_logger; + end + return old_logger; +end + +----------------------------------// PUBLIC INTERFACE //-- + +return { + _addtimer = addtimer, + + addclient = addclient, + wrapclient = wrapclient, + + loop = loop, + link = link, + step = step, + stats = stats, + closeall = closeall, + addserver = addserver, + getserver = getserver, + setlogger = setlogger, + getsettings = getsettings, + setquitting = setquitting, + removeserver = removeserver, + get_backend = get_backend, + changesettings = changesettings, +} diff --git a/net/xmppclient_listener.lua b/net/xmppclient_listener.lua deleted file mode 100644 index 068557c8..00000000 --- a/net/xmppclient_listener.lua +++ /dev/null @@ -1,87 +0,0 @@ - -local logger = require "logger"; -local lxp = require "lxp" -local init_xmlhandlers = require "core.xmlhandlers" -local sm_new_session = require "core.sessionmanager".new_session; - -local connlisteners_register = require "net.connlisteners".register; - -local t_insert = table.insert; -local t_concat = table.concat; -local t_concatall = function (t, sep) local tt = {}; for _, s in ipairs(t) do t_insert(tt, tostring(s)); end return t_concat(tt, sep); end -local m_random = math.random; -local format = string.format; -local sm_new_session, sm_destroy_session = sessionmanager.new_session, sessionmanager.destroy_session; --import("core.sessionmanager", "new_session", "destroy_session"); -local sm_streamopened = sessionmanager.streamopened; -local st = stanza; - -local sessions = {}; -local xmppclient = { default_port = 5222 }; - --- These are session methods -- - -local function session_reset_stream(session) - -- Reset stream - local parser = lxp.new(init_xmlhandlers(session, sm_streamopened), "|"); - session.parser = parser; - - session.notopen = true; - - function session.data(conn, data) - parser:parse(data); - end - return true; -end - --- End of session methods -- - -function xmppclient.listener(conn, data) - local session = sessions[conn]; - if not session then - session = sm_new_session(conn); - sessions[conn] = session; - - -- Logging functions -- - - local mainlog, log = log; - do - local conn_name = tostring(conn):match("[a-f0-9]+$"); - log = logger.init(conn_name); - end - local print = function (...) log("info", t_concatall({...}, "\t")); end - session.log = log; - - print("Client connected"); - - session.reset_stream = session_reset_stream; - - session_reset_stream(session); -- Initialise, ready for use - - -- TODO: Below function should be session,stanza - and xmlhandlers should use :method() notation to call, - -- this will avoid the useless indirection we have atm - -- (I'm on a mission, no time to fix now) - session.stanza_dispatch = function (stanza) return core_process_stanza(session, stanza); end - - end - if data then - session.data(conn, data); - end -end - -function xmppclient.disconnect(conn) - local session = sessions[conn]; - if session then - if session.last_presence and session.last_presence.attr.type ~= "unavailable" then - local pres = st.presence{ type = "unavailable" }; - if err == "closed" then err = "connection closed"; end - pres:tag("status"):text("Disconnected: "..err); - session.stanza_dispatch(pres); - end - sm_destroy_session(session); - sessions[conn] = nil; - session = nil; - collectgarbage("collect"); - end -end - -connlisteners_register("xmppclient", xmppclient); diff --git a/net/xmppserver_listener.lua b/net/xmppserver_listener.lua deleted file mode 100644 index 7c5b0d9c..00000000 --- a/net/xmppserver_listener.lua +++ /dev/null @@ -1,95 +0,0 @@ - -local logger = require "logger"; -local lxp = require "lxp" -local init_xmlhandlers = require "core.xmlhandlers" -local sm_new_session = require "core.sessionmanager".new_session; -local s2s_new_incoming = require "core.s2smanager".new_incoming; -local s2s_streamopened = require "core.s2smanager".streamopened; - -local connlisteners_register = require "net.connlisteners".register; - -local t_insert = table.insert; -local t_concat = table.concat; -local t_concatall = function (t, sep) local tt = {}; for _, s in ipairs(t) do t_insert(tt, tostring(s)); end return t_concat(tt, sep); end -local m_random = math.random; -local format = string.format; -local sm_new_session, sm_destroy_session = sessionmanager.new_session, sessionmanager.destroy_session; --import("core.sessionmanager", "new_session", "destroy_session"); -local st = stanza; - -local sessions = {}; -local xmppserver = { default_port = 5269 }; - --- These are session methods -- - -local function session_reset_stream(session) - -- Reset stream - local parser = lxp.new(init_xmlhandlers(session, s2s_streamopened), "|"); - session.parser = parser; - - session.notopen = true; - - function session.data(conn, data) - parser:parse(data); - end - return true; -end - --- End of session methods -- - -function xmppserver.listener(conn, data) - local session = sessions[conn]; - if not session then - session = s2s_new_incoming(conn); - sessions[conn] = session; - - -- Logging functions -- - - local mainlog, log = log; - do - local conn_name = "s2sin"..tostring(conn):match("[a-f0-9]+$"); - log = logger.init(conn_name); - end - local print = function (...) log("info", t_concatall({...}, "\t")); end - session.log = log; - - print("Incoming s2s connection"); - - session.reset_stream = session_reset_stream; - - session_reset_stream(session); -- Initialise, ready for use - - -- FIXME: Below function should be session,stanza - and xmlhandlers should use :method() notation to call, - -- this will avoid the useless indirection we have atm - -- (I'm on a mission, no time to fix now) - session.stanza_dispatch = function (stanza) return core_process_stanza(session, stanza); end - - end - if data then - session.data(conn, data); - end -end - -function xmppserver.disconnect(conn) -end - -function xmppserver.register_outgoing(conn, session) - session.direction = "outgoing"; - sessions[conn] = session; - - session.reset_stream = session_reset_stream; - session_reset_stream(session); -- Initialise, ready for use - - -- FIXME: Below function should be session,stanza - and xmlhandlers should use :method() notation to call, - -- this will avoid the useless indirection we have atm - -- (I'm on a mission, no time to fix now) - session.stanza_dispatch = function (stanza) return core_process_stanza(session, stanza); end -end - -connlisteners_register("xmppserver", xmppserver); - - --- We need to perform some initialisation when a connection is created --- We also need to perform that same initialisation at other points (SASL, TLS, ...) - --- ...and we need to handle data --- ...and record all sessions associated with connections
\ No newline at end of file |