aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/adns.lua78
-rw-r--r--net/connlisteners.lua68
-rw-r--r--net/dns.lua1574
-rw-r--r--net/http.lua246
-rw-r--r--net/http/codes.lua67
-rw-r--r--net/http/parser.lua160
-rw-r--r--net/http/server.lua303
-rw-r--r--net/httpclient_listener.lua44
-rw-r--r--net/httpserver.lua279
-rw-r--r--net/httpserver_listener.lua46
-rw-r--r--net/server.lua977
-rw-r--r--net/server_event.lua872
-rw-r--r--net/server_select.lua984
-rw-r--r--net/xmppclient_listener.lua152
-rw-r--r--net/xmppcomponent_listener.lua176
-rw-r--r--net/xmppserver_listener.lua174
16 files changed, 3549 insertions, 2651 deletions
diff --git a/net/adns.lua b/net/adns.lua
index 34ef5d77..cd69a627 100644
--- a/net/adns.lua
+++ b/net/adns.lua
@@ -1,6 +1,6 @@
-- Prosody IM
--- Copyright (C) 2008-2009 Matthew Wild
--- Copyright (C) 2008-2009 Waqas Hussain
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
@@ -11,8 +11,11 @@ local dns = require "net.dns";
local log = require "util.logger".init("adns");
+local t_insert, t_remove = table.insert, table.remove;
local coroutine, tostring, pcall = coroutine, tostring, pcall;
+local function dummy_send(sock, data, i, j) return (j-i)+1; end
+
module "adns"
function lookup(handler, qname, qtype, qclass)
@@ -23,41 +26,66 @@ function lookup(handler, qname, qtype, qclass)
return;
end
log("debug", "Records for %s not in cache, sending query (%s)...", qname, tostring(coroutine.running()));
- dns.query(qname, qtype, qclass);
- coroutine.yield({ qclass or "IN", qtype or "A", qname, coroutine.running()}); -- Wait for reply
- log("debug", "Reply for %s (%s)", qname, tostring(coroutine.running()));
- local ok, err = pcall(handler, dns.peek(qname, qtype, qclass));
+ local ok, err = dns.query(qname, qtype, qclass);
+ if ok then
+ coroutine.yield({ qclass or "IN", qtype or "A", qname, coroutine.running()}); -- Wait for reply
+ log("debug", "Reply for %s (%s)", qname, tostring(coroutine.running()));
+ end
+ if ok then
+ ok, err = pcall(handler, dns.peek(qname, qtype, qclass));
+ else
+ log("error", "Error sending DNS query: %s", err);
+ ok, err = pcall(handler, nil, err);
+ end
if not ok then
- log("debug", "Error in DNS response handler: %s", tostring(err));
+ log("error", "Error in DNS response handler: %s", tostring(err));
end
end)(dns.peek(qname, qtype, qclass));
end
-function cancel(handle, call_handler)
+function cancel(handle, call_handler, reason)
log("warn", "Cancelling DNS lookup for %s", tostring(handle[3]));
- dns.cancel(handle);
- if call_handler then
- coroutine.resume(handle[4]);
- end
+ dns.cancel(handle[1], handle[2], handle[3], handle[4], call_handler);
end
-function new_async_socket(sock)
- local newconn = {};
+function new_async_socket(sock, resolver)
+ local peername = "<unknown>";
local listener = {};
- function listener.incoming(conn, data)
- dns.feed(sock, data);
+ local handler = {};
+ function listener.onincoming(conn, data)
+ if data then
+ dns.feed(handler, data);
+ end
+ end
+ function listener.ondisconnect(conn, err)
+ if err then
+ log("warn", "DNS socket for %s disconnected: %s", peername, err);
+ local servers = resolver.server;
+ if resolver.socketset[conn] == resolver.best_server and resolver.best_server == #servers then
+ log("error", "Exhausted all %d configured DNS servers, next lookup will try %s again", #servers, servers[1]);
+ end
+
+ resolver:servfail(conn); -- Let the magic commence
+ end
+ end
+ handler = server.wrapclient(sock, "dns", 53, listener);
+ if not handler then
+ log("warn", "handler is nil");
end
- function listener.disconnect()
+
+ handler.settimeout = function () end
+ handler.setsockname = function (_, ...) return sock:setsockname(...); end
+ handler.setpeername = function (_, ...) peername = (...); local ret = sock:setpeername(...); _:set_send(dummy_send); return ret; end
+ handler.connect = function (_, ...) return sock:connect(...) end
+ --handler.send = function (_, data) _:write(data); return _.sendbuffer and _.sendbuffer(); end
+ handler.send = function (_, data)
+ local getpeername = sock.getpeername;
+ log("debug", "Sending DNS query to %s", (getpeername and getpeername(sock)) or "<unconnected>");
+ return sock:send(data);
end
- newconn.handler, newconn._socket = server.wrapclient(sock, "dns", 53, listener);
- newconn.handler.settimeout = function () end
- newconn.handler.setsockname = function (_, ...) return sock:setsockname(...); end
- newconn.handler.setpeername = function (_, ...) local ret = sock:setpeername(...); _.setsend(sock.send); return ret; end
- newconn.handler.connect = function (_, ...) return sock:connect(...) end
- newconn.handler.send = function (_, data) _.write(data); return _.sendbuffer(); end
- return newconn.handler;
+ return handler;
end
-dns:socket_wrapper_set(new_async_socket);
+dns.socket_wrapper_set(new_async_socket);
return _M;
diff --git a/net/connlisteners.lua b/net/connlisteners.lua
index ebb3cc18..99ddc720 100644
--- a/net/connlisteners.lua
+++ b/net/connlisteners.lua
@@ -1,65 +1,15 @@
--- Prosody IM
--- Copyright (C) 2008-2009 Matthew Wild
--- Copyright (C) 2008-2009 Waqas Hussain
---
--- This project is MIT/X11 licensed. Please see the
--- COPYING file in the source package for more information.
---
+-- COMPAT w/pre-0.9
+local log = require "util.logger".init("net.connlisteners");
+local traceback = debug.traceback;
+module "httpserver"
-
-local listeners_dir = (CFG_SOURCEDIR or ".").."/net/";
-local server = require "net.server";
-local log = require "util.logger".init("connlisteners");
-
-local dofile, pcall, error =
- dofile, pcall, error
-
-module "connlisteners"
-
-local listeners = {};
-
-function register(name, listener)
- if listeners[name] and listeners[name] ~= listener then
- log("debug", "Listener %s is already registered, not registering any more", name);
- return false;
- end
- listeners[name] = listener;
- log("debug", "Registered connection listener %s", name);
- return true;
+function fail()
+ log("error", "Attempt to use legacy connlisteners API. For more info see http://prosody.im/doc/developers/network");
+ log("error", "Legacy connlisteners API usage, %s", traceback("", 2));
end
-function deregister(name)
- listeners[name] = nil;
-end
-
-function get(name)
- local h = listeners[name];
- if not h then
- local ok, ret = pcall(dofile, listeners_dir..name:gsub("[^%w%-]", "_").."_listener.lua");
- if not ok then return nil, ret; end
- h = listeners[name];
- end
- return h;
-end
-
-function start(name, udata)
- local h, err = get(name);
- if not h then
- error("No such connection module: "..name.. (err and (" ("..err..")") or ""), 0);
- end
-
- if udata then
- if (udata.type == "ssl" or udata.type == "tls") and not udata.ssl then
- error("No SSL context supplied for a "..tostring(udata.type):upper().." connection!", 0);
- elseif udata.ssl and udata.type == "tcp" then
- error("SSL context supplied for a TCP connection!", 0);
- end
- end
-
- return server.addserver(h,
- (udata and udata.port) or h.default_port or error("Can't start listener "..name.." because no port was specified, and it has no default port", 0),
- (udata and udata.interface) or h.default_interface or "*", (udata and udata.mode) or h.default_mode or 1, (udata and udata.ssl) or nil, 99999999, udata and udata.type == "ssl");
-end
+register, deregister = fail, fail;
+get, start = fail, fail, epic_fail;
return _M;
diff --git a/net/dns.lua b/net/dns.lua
index 48c08218..c9c51fe8 100644
--- a/net/dns.lua
+++ b/net/dns.lua
@@ -2,8 +2,6 @@
-- This file is included with Prosody IM. It has modifications,
-- which are hereby placed in the public domain.
--- public domain 20080404 lua@ztact.com
-
-- todo: quick (default) header generation
-- todo: nxdomain, error handling
@@ -14,21 +12,65 @@
-- reference: http://tools.ietf.org/html/rfc1876 (LOC)
-require 'socket'
-local ztact = require 'util.ztact'
-local require = require
-
-local coroutine, io, math, socket, string, table =
- coroutine, io, math, socket, string, table
-
-local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack =
- ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack
-
-local get, set = ztact.get, ztact.set
-
+local socket = require "socket";
+local timer = require "util.timer";
+
+local _, windows = pcall(require, "util.windows");
+local is_windows = (_ and windows) or os.getenv("WINDIR");
+
+local coroutine, io, math, string, table =
+ coroutine, io, math, string, table;
+
+local ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type=
+ ipairs, next, pairs, print, setmetatable, tostring, assert, error, unpack, select, type;
+
+local ztact = { -- public domain 20080404 lua@ztact.com
+ get = function(parent, ...)
+ local len = select('#', ...);
+ for i=1,len do
+ parent = parent[select(i, ...)];
+ if parent == nil then break; end
+ end
+ return parent;
+ end;
+ set = function(parent, ...)
+ local len = select('#', ...);
+ local key, value = select(len-1, ...);
+ local cutpoint, cutkey;
+
+ for i=1,len-2 do
+ local key = select (i, ...)
+ local child = parent[key]
+
+ if value == nil then
+ if child == nil then
+ return;
+ elseif next(child, next(child)) then
+ cutpoint = nil; cutkey = nil;
+ elseif cutpoint == nil then
+ cutpoint = parent; cutkey = key;
+ end
+ elseif child == nil then
+ child = {};
+ parent[key] = child;
+ end
+ parent = child
+ end
+
+ if value == nil and cutpoint then
+ cutpoint[cutkey] = nil;
+ else
+ parent[key] = value;
+ return value;
+ end
+ end;
+};
+local get, set = ztact.get, ztact.set;
+
+local default_timeout = 15;
-------------------------------------------------- module dns
-module ('dns')
+module('dns')
local dns = _M;
@@ -38,826 +80,1000 @@ local dns = _M;
local append = table.insert
-local function highbyte (i) -- - - - - - - - - - - - - - - - - - - highbyte
- return (i-(i%0x100))/0x100
- end
+local function highbyte(i) -- - - - - - - - - - - - - - - - - - - highbyte
+ return (i-(i%0x100))/0x100;
+end
local function augment (t) -- - - - - - - - - - - - - - - - - - - - augment
- local a = {}
- for i,s in pairs (t) do a[i] = s a[s] = s a[string.lower (s)] = s end
- return a
- end
+ local a = {};
+ for i,s in pairs(t) do
+ a[i] = s;
+ a[s] = s;
+ a[string.lower(s)] = s;
+ end
+ return a;
+end
local function encode (t) -- - - - - - - - - - - - - - - - - - - - - encode
- local code = {}
- for i,s in pairs (t) do
- local word = string.char (highbyte (i), i %0x100)
- code[i] = word
- code[s] = word
- code[string.lower (s)] = word
- end
- return code
- end
+ local code = {};
+ for i,s in pairs(t) do
+ local word = string.char(highbyte(i), i%0x100);
+ code[i] = word;
+ code[s] = word;
+ code[string.lower(s)] = word;
+ end
+ return code;
+end
dns.types = {
- 'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS',
- 'PTR', 'HINFO', 'MINFO', 'MX', 'TXT',
- [ 28] = 'AAAA', [ 29] = 'LOC', [ 33] = 'SRV',
- [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' }
+ 'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS',
+ 'PTR', 'HINFO', 'MINFO', 'MX', 'TXT',
+ [ 28] = 'AAAA', [ 29] = 'LOC', [ 33] = 'SRV',
+ [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' };
-dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' }
+dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' };
-dns.type = augment (dns.types)
-dns.class = augment (dns.classes)
-dns.typecode = encode (dns.types)
-dns.classcode = encode (dns.classes)
+dns.type = augment (dns.types);
+dns.class = augment (dns.classes);
+dns.typecode = encode (dns.types);
+dns.classcode = encode (dns.classes);
-local function standardize (qname, qtype, qclass) -- - - - - - - standardize
- if string.byte (qname, -1) ~= 0x2E then qname = qname..'.' end
- qname = string.lower (qname)
- return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN']
- end
-
-
-local function prune (rrs, time, soft) -- - - - - - - - - - - - - - - prune
-
- time = time or socket.gettime ()
- for i,rr in pairs (rrs) do
+local function standardize(qname, qtype, qclass) -- - - - - - - standardize
+ if string.byte(qname, -1) ~= 0x2E then qname = qname..'.'; end
+ qname = string.lower(qname);
+ return qname, dns.type[qtype or 'A'], dns.class[qclass or 'IN'];
+end
- if rr.tod then
- -- rr.tod = rr.tod - 50 -- accelerated decripitude
- rr.ttl = math.floor (rr.tod - time)
- if rr.ttl <= 0 then rrs[i] = nil end
- elseif soft == 'soft' then -- What is this? I forget!
- assert (rr.ttl == 0)
- rrs[i] = nil
- end end end
+local function prune(rrs, time, soft) -- - - - - - - - - - - - - - - prune
+ time = time or socket.gettime();
+ for i,rr in pairs(rrs) do
+ if rr.tod then
+ -- rr.tod = rr.tod - 50 -- accelerated decripitude
+ rr.ttl = math.floor(rr.tod - time);
+ if rr.ttl <= 0 then
+ table.remove(rrs, i);
+ return prune(rrs, time, soft); -- Re-iterate
+ end
+ elseif soft == 'soft' then -- What is this? I forget!
+ assert(rr.ttl == 0);
+ rrs[i] = nil;
+ end
+ end
+end
-- metatables & co. ------------------------------------------ metatables & co.
-local resolver = {}
-resolver.__index = resolver
-
+local resolver = {};
+resolver.__index = resolver;
-local SRV_tostring
+resolver.timeout = default_timeout;
+local function default_rr_tostring(rr)
+ local rr_val = rr.type and rr[rr.type:lower()];
+ if type(rr_val) ~= "string" then
+ return "<UNKNOWN RDATA TYPE>";
+ end
+ return rr_val;
+end
-local rr_metatable = {} -- - - - - - - - - - - - - - - - - - - rr_metatable
-function rr_metatable.__tostring (rr)
- local s0 = string.format (
- '%2s %-5s %6i %-28s', rr.class, rr.type, rr.ttl, rr.name )
- local s1 = ''
- if rr.type == 'A' then s1 = ' '..rr.a
- elseif rr.type == 'MX' then
- s1 = string.format (' %2i %s', rr.pref, rr.mx)
- elseif rr.type == 'CNAME' then s1 = ' '..rr.cname
- elseif rr.type == 'LOC' then s1 = ' '..resolver.LOC_tostring (rr)
- elseif rr.type == 'NS' then s1 = ' '..rr.ns
- elseif rr.type == 'SRV' then s1 = ' '..SRV_tostring (rr)
- elseif rr.type == 'TXT' then s1 = ' '..rr.txt
- else s1 = ' <UNKNOWN RDATA TYPE>' end
- return s0..s1
- end
+local special_tostrings = {
+ LOC = resolver.LOC_tostring;
+ MX = function (rr)
+ return string.format('%2i %s', rr.pref, rr.mx);
+ end;
+ SRV = function (rr)
+ local s = rr.srv;
+ return string.format('%5d %5d %5d %s', s.priority, s.weight, s.port, s.target);
+ end;
+};
+
+local rr_metatable = {}; -- - - - - - - - - - - - - - - - - - - rr_metatable
+function rr_metatable.__tostring(rr)
+ local rr_string = (special_tostrings[rr.type] or default_rr_tostring)(rr);
+ return string.format('%2s %-5s %6i %-28s %s', rr.class, rr.type, rr.ttl, rr.name, rr_string);
+end
-local rrs_metatable = {} -- - - - - - - - - - - - - - - - - - rrs_metatable
-function rrs_metatable.__tostring (rrs)
- local t = {}
- for i,rr in pairs (rrs) do append (t, tostring (rr)..'\n') end
- return table.concat (t)
- end
+local rrs_metatable = {}; -- - - - - - - - - - - - - - - - - - rrs_metatable
+function rrs_metatable.__tostring(rrs)
+ local t = {};
+ for i,rr in pairs(rrs) do
+ append(t, tostring(rr)..'\n');
+ end
+ return table.concat(t);
+end
-local cache_metatable = {} -- - - - - - - - - - - - - - - - cache_metatable
-function cache_metatable.__tostring (cache)
- local time = socket.gettime ()
- local t = {}
- for class,types in pairs (cache) do
- for type,names in pairs (types) do
- for name,rrs in pairs (names) do
- prune (rrs, time)
- append (t, tostring (rrs)) end end end
- return table.concat (t)
- end
+local cache_metatable = {}; -- - - - - - - - - - - - - - - - cache_metatable
+function cache_metatable.__tostring(cache)
+ local time = socket.gettime();
+ local t = {};
+ for class,types in pairs(cache) do
+ for type,names in pairs(types) do
+ for name,rrs in pairs(names) do
+ prune(rrs, time);
+ append(t, tostring(rrs));
+ end
+ end
+ end
+ return table.concat(t);
+end
-function resolver:new () -- - - - - - - - - - - - - - - - - - - - - resolver
- local r = { active = {}, cache = {}, unsorted = {} }
- setmetatable (r, resolver)
- setmetatable (r.cache, cache_metatable)
- setmetatable (r.unsorted, { __mode = 'kv' })
- return r
- end
+function resolver:new() -- - - - - - - - - - - - - - - - - - - - - resolver
+ local r = { active = {}, cache = {}, unsorted = {} };
+ setmetatable(r, resolver);
+ setmetatable(r.cache, cache_metatable);
+ setmetatable(r.unsorted, { __mode = 'kv' });
+ return r;
+end
-- packet layer -------------------------------------------------- packet layer
-function dns.random (...) -- - - - - - - - - - - - - - - - - - - dns.random
- math.randomseed (10000*socket.gettime ())
- dns.random = math.random
- return dns.random (...)
- end
-
-
-local function encodeHeader (o) -- - - - - - - - - - - - - - - encodeHeader
+function dns.random(...) -- - - - - - - - - - - - - - - - - - - dns.random
+ math.randomseed(math.floor(10000*socket.gettime()) % 0x100000000);
+ dns.random = math.random;
+ return dns.random(...);
+end
- o = o or {}
- o.id = o.id or -- 16b (random) id
- dns.random (0, 0xffff)
+local function encodeHeader(o) -- - - - - - - - - - - - - - - encodeHeader
+ o = o or {};
+ o.id = o.id or dns.random(0, 0xffff); -- 16b (random) id
- o.rd = o.rd or 1 -- 1b 1 recursion desired
- o.tc = o.tc or 0 -- 1b 1 truncated response
- o.aa = o.aa or 0 -- 1b 1 authoritative response
- o.opcode = o.opcode or 0 -- 4b 0 query
- -- 1 inverse query
+ o.rd = o.rd or 1; -- 1b 1 recursion desired
+ o.tc = o.tc or 0; -- 1b 1 truncated response
+ o.aa = o.aa or 0; -- 1b 1 authoritative response
+ o.opcode = o.opcode or 0; -- 4b 0 query
+ -- 1 inverse query
-- 2 server status request
-- 3-15 reserved
- o.qr = o.qr or 0 -- 1b 0 query, 1 response
+ o.qr = o.qr or 0; -- 1b 0 query, 1 response
- o.rcode = o.rcode or 0 -- 4b 0 no error
+ o.rcode = o.rcode or 0; -- 4b 0 no error
-- 1 format error
-- 2 server failure
-- 3 name error
-- 4 not implemented
-- 5 refused
-- 6-15 reserved
- o.z = o.z or 0 -- 3b 0 resvered
- o.ra = o.ra or 0 -- 1b 1 recursion available
-
- o.qdcount = o.qdcount or 1 -- 16b number of question RRs
- o.ancount = o.ancount or 0 -- 16b number of answers RRs
- o.nscount = o.nscount or 0 -- 16b number of nameservers RRs
- o.arcount = o.arcount or 0 -- 16b number of additional RRs
-
- -- string.char() rounds, so prevent roundup with -0.4999
- local header = string.char (
- highbyte (o.id), o.id %0x100,
- o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr,
- o.rcode + 16*o.z + 128*o.ra,
- highbyte (o.qdcount), o.qdcount %0x100,
- highbyte (o.ancount), o.ancount %0x100,
- highbyte (o.nscount), o.nscount %0x100,
- highbyte (o.arcount), o.arcount %0x100 )
-
- return header, o.id
- end
-
-
-local function encodeName (name) -- - - - - - - - - - - - - - - - encodeName
- local t = {}
- for part in string.gmatch (name, '[^.]+') do
- append (t, string.char (string.len (part)))
- append (t, part)
- end
- append (t, string.char (0))
- return table.concat (t)
- end
-
-
-local function encodeQuestion (qname, qtype, qclass) -- - - - encodeQuestion
- qname = encodeName (qname)
- qtype = dns.typecode[qtype or 'a']
- qclass = dns.classcode[qclass or 'in']
- return qname..qtype..qclass;
- end
-
-
-function resolver:byte (len) -- - - - - - - - - - - - - - - - - - - - - byte
- len = len or 1
- local offset = self.offset
- local last = offset + len - 1
- if last > #self.packet then
- error (string.format ('out of bounds: %i>%i', last, #self.packet)) end
- self.offset = offset + len
- return string.byte (self.packet, offset, last)
- end
-
-
-function resolver:word () -- - - - - - - - - - - - - - - - - - - - - - word
- local b1, b2 = self:byte (2)
- return 0x100*b1 + b2
- end
+ o.z = o.z or 0; -- 3b 0 resvered
+ o.ra = o.ra or 0; -- 1b 1 recursion available
+
+ o.qdcount = o.qdcount or 1; -- 16b number of question RRs
+ o.ancount = o.ancount or 0; -- 16b number of answers RRs
+ o.nscount = o.nscount or 0; -- 16b number of nameservers RRs
+ o.arcount = o.arcount or 0; -- 16b number of additional RRs
+
+ -- string.char() rounds, so prevent roundup with -0.4999
+ local header = string.char(
+ highbyte(o.id), o.id %0x100,
+ o.rd + 2*o.tc + 4*o.aa + 8*o.opcode + 128*o.qr,
+ o.rcode + 16*o.z + 128*o.ra,
+ highbyte(o.qdcount), o.qdcount %0x100,
+ highbyte(o.ancount), o.ancount %0x100,
+ highbyte(o.nscount), o.nscount %0x100,
+ highbyte(o.arcount), o.arcount %0x100
+ );
+
+ return header, o.id;
+end
+
+
+local function encodeName(name) -- - - - - - - - - - - - - - - - encodeName
+ local t = {};
+ for part in string.gmatch(name, '[^.]+') do
+ append(t, string.char(string.len(part)));
+ append(t, part);
+ end
+ append(t, string.char(0));
+ return table.concat(t);
+end
+
+
+local function encodeQuestion(qname, qtype, qclass) -- - - - encodeQuestion
+ qname = encodeName(qname);
+ qtype = dns.typecode[qtype or 'a'];
+ qclass = dns.classcode[qclass or 'in'];
+ return qname..qtype..qclass;
+end
+
+
+function resolver:byte(len) -- - - - - - - - - - - - - - - - - - - - - byte
+ len = len or 1;
+ local offset = self.offset;
+ local last = offset + len - 1;
+ if last > #self.packet then
+ error(string.format('out of bounds: %i>%i', last, #self.packet));
+ end
+ self.offset = offset + len;
+ return string.byte(self.packet, offset, last);
+end
+
+
+function resolver:word() -- - - - - - - - - - - - - - - - - - - - - - word
+ local b1, b2 = self:byte(2);
+ return 0x100*b1 + b2;
+end
function resolver:dword () -- - - - - - - - - - - - - - - - - - - - - dword
- local b1, b2, b3, b4 = self:byte (4)
- --print ('dword', b1, b2, b3, b4)
- return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4
- end
+ local b1, b2, b3, b4 = self:byte(4);
+ --print('dword', b1, b2, b3, b4);
+ return 0x1000000*b1 + 0x10000*b2 + 0x100*b3 + b4;
+end
-function resolver:sub (len) -- - - - - - - - - - - - - - - - - - - - - - sub
- len = len or 1
- local s = string.sub (self.packet, self.offset, self.offset + len - 1)
- self.offset = self.offset + len
- return s
- end
+function resolver:sub(len) -- - - - - - - - - - - - - - - - - - - - - - sub
+ len = len or 1;
+ local s = string.sub(self.packet, self.offset, self.offset + len - 1);
+ self.offset = self.offset + len;
+ return s;
+end
-function resolver:header (force) -- - - - - - - - - - - - - - - - - - header
-
- local id = self:word ()
- --print (string.format (':header id %x', id))
- if not self.active[id] and not force then return nil end
-
- local h = { id = id }
-
- local b1, b2 = self:byte (2)
-
- h.rd = b1 %2
- h.tc = b1 /2%2
- h.aa = b1 /4%2
- h.opcode = b1 /8%16
- h.qr = b1 /128
+function resolver:header(force) -- - - - - - - - - - - - - - - - - - header
+ local id = self:word();
+ --print(string.format(':header id %x', id));
+ if not self.active[id] and not force then return nil; end
- h.rcode = b2 %16
- h.z = b2 /16%8
- h.ra = b2 /128
-
- h.qdcount = self:word ()
- h.ancount = self:word ()
- h.nscount = self:word ()
- h.arcount = self:word ()
-
- for k,v in pairs (h) do h[k] = v-v%1 end
-
- return h
- end
-
-
-function resolver:name () -- - - - - - - - - - - - - - - - - - - - - - name
- local remember, pointers = nil, 0
- local len = self:byte ()
- local n = {}
- while len > 0 do
- if len >= 0xc0 then -- name is "compressed"
- pointers = pointers + 1
- if pointers >= 20 then error ('dns error: 20 pointers') end
- local offset = ((len-0xc0)*0x100) + self:byte ()
- remember = remember or self.offset
- self.offset = offset + 1 -- +1 for lua
- else -- name is not compressed
- append (n, self:sub (len)..'.')
- end
- len = self:byte ()
- end
- self.offset = remember or self.offset
- return table.concat (n)
- end
+ local h = { id = id };
+ local b1, b2 = self:byte(2);
-function resolver:question () -- - - - - - - - - - - - - - - - - - question
- local q = {}
- q.name = self:name ()
- q.type = dns.type[self:word ()]
- q.class = dns.class[self:word ()]
- return q
- end
+ h.rd = b1 %2;
+ h.tc = b1 /2%2;
+ h.aa = b1 /4%2;
+ h.opcode = b1 /8%16;
+ h.qr = b1 /128;
+ h.rcode = b2 %16;
+ h.z = b2 /16%8;
+ h.ra = b2 /128;
-function resolver:A (rr) -- - - - - - - - - - - - - - - - - - - - - - - - A
- local b1, b2, b3, b4 = self:byte (4)
- rr.a = string.format ('%i.%i.%i.%i', b1, b2, b3, b4)
- end
+ h.qdcount = self:word();
+ h.ancount = self:word();
+ h.nscount = self:word();
+ h.arcount = self:word();
+ for k,v in pairs(h) do h[k] = v-v%1; end
-function resolver:CNAME (rr) -- - - - - - - - - - - - - - - - - - - - CNAME
- rr.cname = self:name ()
- end
+ return h;
+end
-function resolver:MX (rr) -- - - - - - - - - - - - - - - - - - - - - - - MX
- rr.pref = self:word ()
- rr.mx = self:name ()
- end
+function resolver:name() -- - - - - - - - - - - - - - - - - - - - - - name
+ local remember, pointers = nil, 0;
+ local len = self:byte();
+ local n = {};
+ if len == 0 then return "." end -- Root label
+ while len > 0 do
+ if len >= 0xc0 then -- name is "compressed"
+ pointers = pointers + 1;
+ if pointers >= 20 then error('dns error: 20 pointers'); end;
+ local offset = ((len-0xc0)*0x100) + self:byte();
+ remember = remember or self.offset;
+ self.offset = offset + 1; -- +1 for lua
+ else -- name is not compressed
+ append(n, self:sub(len)..'.');
+ end
+ len = self:byte();
+ end
+ self.offset = remember or self.offset;
+ return table.concat(n);
+end
-function resolver:LOC_nibble_power () -- - - - - - - - - - LOC_nibble_power
- local b = self:byte ()
- --print ('nibbles', ((b-(b%0x10))/0x10), (b%0x10))
- return ((b-(b%0x10))/0x10) * (10^(b%0x10))
- end
+function resolver:question() -- - - - - - - - - - - - - - - - - - question
+ local q = {};
+ q.name = self:name();
+ q.type = dns.type[self:word()];
+ q.class = dns.class[self:word()];
+ return q;
+end
-function resolver:LOC (rr) -- - - - - - - - - - - - - - - - - - - - - - LOC
- rr.version = self:byte ()
- if rr.version == 0 then
- rr.loc = rr.loc or {}
- rr.loc.size = self:LOC_nibble_power ()
- rr.loc.horiz_pre = self:LOC_nibble_power ()
- rr.loc.vert_pre = self:LOC_nibble_power ()
- rr.loc.latitude = self:dword ()
- rr.loc.longitude = self:dword ()
- rr.loc.altitude = self:dword ()
- end end
+function resolver:A(rr) -- - - - - - - - - - - - - - - - - - - - - - - - A
+ local b1, b2, b3, b4 = self:byte(4);
+ rr.a = string.format('%i.%i.%i.%i', b1, b2, b3, b4);
+end
+function resolver:AAAA(rr)
+ local addr = {};
+ for i = 1, rr.rdlength, 2 do
+ local b1, b2 = self:byte(2);
+ table.insert(addr, ("%02x%02x"):format(b1, b2));
+ end
+ addr = table.concat(addr, ":"):gsub("%f[%x]0+(%x)","%1");
+ local zeros = {};
+ for item in addr:gmatch(":[0:]+:") do
+ table.insert(zeros, item)
+ end
+ if #zeros == 0 then
+ rr.aaaa = addr;
+ return
+ elseif #zeros > 1 then
+ table.sort(zeros, function(a, b) return #a > #b end);
+ end
+ rr.aaaa = addr:gsub(zeros[1], "::", 1):gsub("^0::", "::"):gsub("::0$", "::");
+end
-local function LOC_tostring_degrees (f, pos, neg) -- - - - - - - - - - - - -
- f = f - 0x80000000
- if f < 0 then pos = neg f = -f end
- local deg, min, msec
- msec = f%60000
- f = (f-msec)/60000
- min = f%60
- deg = (f-min)/60
- return string.format ('%3d %2d %2.3f %s', deg, min, msec/1000, pos)
- end
-
-
-function resolver.LOC_tostring (rr) -- - - - - - - - - - - - - LOC_tostring
-
- local t = {}
-
- --[[
- for k,name in pairs { 'size', 'horiz_pre', 'vert_pre',
- 'latitude', 'longitude', 'altitude' } do
- append (t, string.format ('%4s%-10s: %12.0f\n', '', name, rr.loc[name]))
- end
- --]]
-
- append ( t, string.format (
- '%s %s %.2fm %.2fm %.2fm %.2fm',
- LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'),
- LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'),
- (rr.loc.altitude - 10000000) / 100,
- rr.loc.size / 100,
- rr.loc.horiz_pre / 100,
- rr.loc.vert_pre / 100 ) )
-
- return table.concat (t)
- end
-
-
-function resolver:NS (rr) -- - - - - - - - - - - - - - - - - - - - - - - NS
- rr.ns = self:name ()
- end
-
+function resolver:CNAME(rr) -- - - - - - - - - - - - - - - - - - - - CNAME
+ rr.cname = self:name();
+end
-function resolver:SOA (rr) -- - - - - - - - - - - - - - - - - - - - - - SOA
- end
+function resolver:MX(rr) -- - - - - - - - - - - - - - - - - - - - - - - MX
+ rr.pref = self:word();
+ rr.mx = self:name();
+end
-function resolver:SRV (rr) -- - - - - - - - - - - - - - - - - - - - - - SRV
- rr.srv = {}
- rr.srv.priority = self:word ()
- rr.srv.weight = self:word ()
- rr.srv.port = self:word ()
- rr.srv.target = self:name ()
- end
+function resolver:LOC_nibble_power() -- - - - - - - - - - LOC_nibble_power
+ local b = self:byte();
+ --print('nibbles', ((b-(b%0x10))/0x10), (b%0x10));
+ return ((b-(b%0x10))/0x10) * (10^(b%0x10));
+end
-function SRV_tostring (rr) -- - - - - - - - - - - - - - - - - - SRV_tostring
- local s = rr.srv
- return string.format ( '%5d %5d %5d %s',
- s.priority, s.weight, s.port, s.target )
- end
+function resolver:LOC(rr) -- - - - - - - - - - - - - - - - - - - - - - LOC
+ rr.version = self:byte();
+ if rr.version == 0 then
+ rr.loc = rr.loc or {};
+ rr.loc.size = self:LOC_nibble_power();
+ rr.loc.horiz_pre = self:LOC_nibble_power();
+ rr.loc.vert_pre = self:LOC_nibble_power();
+ rr.loc.latitude = self:dword();
+ rr.loc.longitude = self:dword();
+ rr.loc.altitude = self:dword();
+ end
+end
-function resolver:TXT (rr) -- - - - - - - - - - - - - - - - - - - - - - TXT
- rr.txt = self:sub (rr.rdlength)
- end
+local function LOC_tostring_degrees(f, pos, neg) -- - - - - - - - - - - - -
+ f = f - 0x80000000;
+ if f < 0 then pos = neg; f = -f; end
+ local deg, min, msec;
+ msec = f%60000;
+ f = (f-msec)/60000;
+ min = f%60;
+ deg = (f-min)/60;
+ return string.format('%3d %2d %2.3f %s', deg, min, msec/1000, pos);
+end
-function resolver:rr () -- - - - - - - - - - - - - - - - - - - - - - - - rr
- local rr = {}
- setmetatable (rr, rr_metatable)
- rr.name = self:name (self)
- rr.type = dns.type[self:word ()] or rr.type
- rr.class = dns.class[self:word ()] or rr.class
- rr.ttl = 0x10000*self:word () + self:word ()
- rr.rdlength = self:word ()
- if rr.ttl == 0 then -- pass
- else rr.tod = self.time + rr.ttl end
+function resolver.LOC_tostring(rr) -- - - - - - - - - - - - - LOC_tostring
+ local t = {};
- local remember = self.offset
- local rr_parser = self[dns.type[rr.type]]
- if rr_parser then rr_parser (self, rr) end
- self.offset = remember
- rr.rdata = self:sub (rr.rdlength)
- return rr
- end
+ --[[
+ for k,name in pairs { 'size', 'horiz_pre', 'vert_pre', 'latitude', 'longitude', 'altitude' } do
+ append(t, string.format('%4s%-10s: %12.0f\n', '', name, rr.loc[name]));
+ end
+ --]]
+
+ append(t, string.format(
+ '%s %s %.2fm %.2fm %.2fm %.2fm',
+ LOC_tostring_degrees (rr.loc.latitude, 'N', 'S'),
+ LOC_tostring_degrees (rr.loc.longitude, 'E', 'W'),
+ (rr.loc.altitude - 10000000) / 100,
+ rr.loc.size / 100,
+ rr.loc.horiz_pre / 100,
+ rr.loc.vert_pre / 100
+ ));
+
+ return table.concat(t);
+end
-function resolver:rrs (count) -- - - - - - - - - - - - - - - - - - - - - rrs
- local rrs = {}
- for i = 1,count do append (rrs, self:rr ()) end
- return rrs
- end
+function resolver:NS(rr) -- - - - - - - - - - - - - - - - - - - - - - - NS
+ rr.ns = self:name();
+end
-function resolver:decode (packet, force) -- - - - - - - - - - - - - - decode
+function resolver:SOA(rr) -- - - - - - - - - - - - - - - - - - - - - - SOA
+end
- self.packet, self.offset = packet, 1
- local header = self:header (force)
- if not header then return nil end
- local response = { header = header }
- response.question = {}
- local offset = self.offset
- for i = 1,response.header.qdcount do
- append (response.question, self:question ()) end
- response.question.raw = string.sub (self.packet, offset, self.offset - 1)
+function resolver:SRV(rr) -- - - - - - - - - - - - - - - - - - - - - - SRV
+ rr.srv = {};
+ rr.srv.priority = self:word();
+ rr.srv.weight = self:word();
+ rr.srv.port = self:word();
+ rr.srv.target = self:name();
+end
- if not force then
- if not self.active[response.header.id] or
- not self.active[response.header.id][response.question.raw] then
- return nil end end
+function resolver:PTR(rr)
+ rr.ptr = self:name();
+end
- response.answer = self:rrs (response.header.ancount)
- response.authority = self:rrs (response.header.nscount)
- response.additional = self:rrs (response.header.arcount)
+function resolver:TXT(rr) -- - - - - - - - - - - - - - - - - - - - - - TXT
+ rr.txt = self:sub (self:byte());
+end
- return response
- end
+function resolver:rr() -- - - - - - - - - - - - - - - - - - - - - - - - rr
+ local rr = {};
+ setmetatable(rr, rr_metatable);
+ rr.name = self:name(self);
+ rr.type = dns.type[self:word()] or rr.type;
+ rr.class = dns.class[self:word()] or rr.class;
+ rr.ttl = 0x10000*self:word() + self:word();
+ rr.rdlength = self:word();
--- socket layer -------------------------------------------------- socket layer
+ if rr.ttl <= 0 then
+ rr.tod = self.time + 30;
+ else
+ rr.tod = self.time + rr.ttl;
+ end
+
+ local remember = self.offset;
+ local rr_parser = self[dns.type[rr.type]];
+ if rr_parser then rr_parser(self, rr); end
+ self.offset = remember;
+ rr.rdata = self:sub(rr.rdlength);
+ return rr;
+end
+
+
+function resolver:rrs (count) -- - - - - - - - - - - - - - - - - - - - - rrs
+ local rrs = {};
+ for i = 1,count do append(rrs, self:rr()); end
+ return rrs;
+end
+
+
+function resolver:decode(packet, force) -- - - - - - - - - - - - - - decode
+ self.packet, self.offset = packet, 1;
+ local header = self:header(force);
+ if not header then return nil; end
+ local response = { header = header };
+
+ response.question = {};
+ local offset = self.offset;
+ for i = 1,response.header.qdcount do
+ append(response.question, self:question());
+ end
+ response.question.raw = string.sub(self.packet, offset, self.offset - 1);
+
+ if not force then
+ if not self.active[response.header.id] or not self.active[response.header.id][response.question.raw] then
+ self.active[response.header.id] = nil;
+ return nil;
+ end
+ end
+
+ response.answer = self:rrs(response.header.ancount);
+ response.authority = self:rrs(response.header.nscount);
+ response.additional = self:rrs(response.header.arcount);
+
+ return response;
+end
-resolver.delays = { 1, 3, 11, 45 }
+-- socket layer -------------------------------------------------- socket layer
-function resolver:addnameserver (address) -- - - - - - - - - - addnameserver
- self.server = self.server or {}
- append (self.server, address)
- end
+resolver.delays = { 1, 3 };
-function resolver:setnameserver (address) -- - - - - - - - - - setnameserver
- self.server = {}
- self:addnameserver (address)
- end
+function resolver:addnameserver(address) -- - - - - - - - - - addnameserver
+ self.server = self.server or {};
+ append(self.server, address);
+end
-function resolver:adddefaultnameservers () -- - - - - adddefaultnameservers
- local resolv_conf = io.open("/etc/resolv.conf");
- if resolv_conf then
- for line in resolv_conf:lines() do
- local address = string.match (line, 'nameserver%s+(%d+%.%d+%.%d+%.%d+)')
- if address then self:addnameserver (address) end
- end
- else -- FIXME correct for windows, using opendns nameservers for now
- self:addnameserver ("208.67.222.222")
- self:addnameserver ("208.67.220.220")
- end
+function resolver:setnameserver(address) -- - - - - - - - - - setnameserver
+ self.server = {};
+ self:addnameserver(address);
end
-function resolver:getsocket (servernum) -- - - - - - - - - - - - - getsocket
+function resolver:adddefaultnameservers() -- - - - - adddefaultnameservers
+ if is_windows then
+ if windows and windows.get_nameservers then
+ for _, server in ipairs(windows.get_nameservers()) do
+ self:addnameserver(server);
+ end
+ end
+ if not self.server or #self.server == 0 then
+ -- TODO log warning about no nameservers, adding opendns servers as fallback
+ self:addnameserver("208.67.222.222");
+ self:addnameserver("208.67.220.220");
+ end
+ else -- posix
+ local resolv_conf = io.open("/etc/resolv.conf");
+ if resolv_conf then
+ for line in resolv_conf:lines() do
+ line = line:gsub("#.*$", "")
+ :match('^%s*nameserver%s+(.*)%s*$');
+ if line then
+ line:gsub("%f[%d.](%d+%.%d+%.%d+%.%d+)%f[^%d.]", function (address)
+ self:addnameserver(address)
+ end);
+ end
+ end
+ end
+ if not self.server or #self.server == 0 then
+ -- TODO log warning about no nameservers, adding localhost as the default nameserver
+ self:addnameserver("127.0.0.1");
+ end
+ end
+end
+
- self.socket = self.socket or {}
- self.socketset = self.socketset or {}
+function resolver:getsocket(servernum) -- - - - - - - - - - - - - getsocket
+ self.socket = self.socket or {};
+ self.socketset = self.socketset or {};
- local sock = self.socket[servernum]
- if sock then return sock end
+ local sock = self.socket[servernum];
+ if sock then return sock; end
- sock = socket.udp ()
- if self.socket_wrapper then sock = self.socket_wrapper (sock) end
- sock:settimeout (0)
- -- todo: attempt to use a random port, fallback to 0
- sock:setsockname ('*', 0)
- sock:setpeername (self.server[servernum], 53)
- self.socket[servernum] = sock
- self.socketset[sock] = sock
- return sock
- end
+ local err;
+ sock, err = socket.udp();
+ if not sock then
+ return nil, err;
+ end
+ if self.socket_wrapper then sock = self.socket_wrapper(sock, self); end
+ sock:settimeout(0);
+ -- todo: attempt to use a random port, fallback to 0
+ sock:setsockname('*', 0);
+ sock:setpeername(self.server[servernum], 53);
+ self.socket[servernum] = sock;
+ self.socketset[sock] = servernum;
+ return sock;
+end
+function resolver:voidsocket(sock)
+ if self.socket[sock] then
+ self.socketset[self.socket[sock]] = nil;
+ self.socket[sock] = nil;
+ elseif self.socketset[sock] then
+ self.socket[self.socketset[sock]] = nil;
+ self.socketset[sock] = nil;
+ end
+ sock:close();
+end
-function resolver:socket_wrapper_set (func) -- - - - - - - socket_wrapper_set
- self.socket_wrapper = func
- end
+function resolver:socket_wrapper_set(func) -- - - - - - - socket_wrapper_set
+ self.socket_wrapper = func;
+end
function resolver:closeall () -- - - - - - - - - - - - - - - - - - closeall
- for i,sock in ipairs (self.socket) do self.socket[i]:close () end
- self.socket = {}
- end
-
+ for i,sock in ipairs(self.socket) do
+ self.socket[i] = nil;
+ self.socketset[sock] = nil;
+ sock:close();
+ end
+end
-function resolver:remember (rr, type) -- - - - - - - - - - - - - - remember
- --print ('remember', type, rr.class, rr.type, rr.name)
+function resolver:remember(rr, type) -- - - - - - - - - - - - - - remember
+ --print ('remember', type, rr.class, rr.type, rr.name)
+ local qname, qtype, qclass = standardize(rr.name, rr.type, rr.class);
- if type ~= '*' then
- type = rr.type
- local all = get (self.cache, rr.class, '*', rr.name)
- --print ('remember all', all)
- if all then append (all, rr) end
- end
+ if type ~= '*' then
+ type = qtype;
+ local all = get(self.cache, qclass, '*', qname);
+ --print('remember all', all);
+ if all then append(all, rr); end
+ end
- self.cache = self.cache or setmetatable ({}, cache_metatable)
- local rrs = get (self.cache, rr.class, type, rr.name) or
- set (self.cache, rr.class, type, rr.name, setmetatable ({}, rrs_metatable))
- append (rrs, rr)
+ self.cache = self.cache or setmetatable({}, cache_metatable);
+ local rrs = get(self.cache, qclass, type, qname) or
+ set(self.cache, qclass, type, qname, setmetatable({}, rrs_metatable));
+ append(rrs, rr);
- if type == 'MX' then self.unsorted[rrs] = true end
- end
+ if type == 'MX' then self.unsorted[rrs] = true; end
+end
-local function comp_mx (a, b) -- - - - - - - - - - - - - - - - - - - comp_mx
- return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref)
- end
+local function comp_mx(a, b) -- - - - - - - - - - - - - - - - - - - comp_mx
+ return (a.pref == b.pref) and (a.mx < b.mx) or (a.pref < b.pref);
+end
function resolver:peek (qname, qtype, qclass) -- - - - - - - - - - - - peek
- qname, qtype, qclass = standardize (qname, qtype, qclass)
- local rrs = get (self.cache, qclass, qtype, qname)
- if not rrs then return nil end
- if prune (rrs, socket.gettime ()) and qtype == '*' or not next (rrs) then
- set (self.cache, qclass, qtype, qname, nil) return nil end
- if self.unsorted[rrs] then table.sort (rrs, comp_mx) end
- return rrs
- end
-
-
-function resolver:purge (soft) -- - - - - - - - - - - - - - - - - - - purge
- if soft == 'soft' then
- self.time = socket.gettime ()
- for class,types in pairs (self.cache or {}) do
- for type,names in pairs (types) do
- for name,rrs in pairs (names) do
- prune (rrs, self.time, 'soft')
- end end end
- else self.cache = {} end
- end
-
-
-function resolver:query (qname, qtype, qclass) -- - - - - - - - - - -- query
-
- qname, qtype, qclass = standardize (qname, qtype, qclass)
-
- if not self.server then self:adddefaultnameservers () end
-
- local question = encodeQuestion (qname, qtype, qclass)
- local peek = self:peek (qname, qtype, qclass)
- if peek then return peek end
-
- local header, id = encodeHeader ()
- --print ('query id', id, qclass, qtype, qname)
- local o = { packet = header..question,
- server = 1,
- delay = 1,
- retry = socket.gettime () + self.delays[1] }
- self:getsocket (o.server):send (o.packet)
-
- -- remember the query
- self.active[id] = self.active[id] or {}
- self.active[id][question] = o
-
- -- remember which coroutine wants the answer
- local co = coroutine.running ()
- if co then
- set (self.wanted, qclass, qtype, qname, co, true)
- --set (self.yielded, co, qclass, qtype, qname, true)
- end
-end
-
-
-
-function resolver:receive (rset) -- - - - - - - - - - - - - - - - - receive
-
- --print 'receive' print (self.socket)
- self.time = socket.gettime ()
- rset = rset or self.socket
-
- local response
- for i,sock in pairs (rset) do
-
- if self.socketset[sock] then
- local packet = sock:receive ()
- if packet then
-
- response = self:decode (packet)
- if response then
- --print 'received response'
- --self.print (response)
-
- for i,section in pairs { 'answer', 'authority', 'additional' } do
- for j,rr in pairs (response[section]) do
- self:remember (rr, response.question[1].type) end end
-
- -- retire the query
- local queries = self.active[response.header.id]
- if queries[response.question.raw] then
- queries[response.question.raw] = nil end
- if not next (queries) then self.active[response.header.id] = nil end
- if not next (self.active) then self:closeall () end
-
- -- was the query on the wanted list?
- local q = response.question
- local cos = get (self.wanted, q.class, q.type, q.name)
- if cos then
- for co in pairs (cos) do
- set (self.yielded, co, q.class, q.type, q.name, nil)
- if coroutine.status(co) == "suspended" then coroutine.resume (co) end
- end
- set (self.wanted, q.class, q.type, q.name, nil)
- end end end end end
-
- return response
- end
-
-
-function resolver:feed(sock, packet)
- --print 'receive' print (self.socket)
- self.time = socket.gettime ()
-
- local response = self:decode (packet)
- if response then
- --print 'received response'
- --self.print (response)
-
- for i,section in pairs { 'answer', 'authority', 'additional' } do
- for j,rr in pairs (response[section]) do
- self:remember (rr, response.question[1].type)
- end
- end
-
- -- retire the query
- local queries = self.active[response.header.id]
- if queries[response.question.raw] then
- queries[response.question.raw] = nil
- end
- if not next (queries) then self.active[response.header.id] = nil end
- if not next (self.active) then self:closeall () end
-
- -- was the query on the wanted list?
- local q = response.question[1]
- if q then
- local cos = get (self.wanted, q.class, q.type, q.name)
- if cos then
- for co in pairs (cos) do
- set (self.yielded, co, q.class, q.type, q.name, nil)
- if coroutine.status(co) == "suspended" then coroutine.resume (co) end
- end
- set (self.wanted, q.class, q.type, q.name, nil)
- end
- end
- end
-
- return response
-end
-
-function resolver:cancel(data)
- local cos = get (self.wanted, unpack(data, 1, 3))
- if cos then
- cos[data[4]] = nil;
+ qname, qtype, qclass = standardize(qname, qtype, qclass);
+ local rrs = get(self.cache, qclass, qtype, qname);
+ if not rrs then return nil; end
+ if prune(rrs, socket.gettime()) and qtype == '*' or not next(rrs) then
+ set(self.cache, qclass, qtype, qname, nil);
+ return nil;
end
+ if self.unsorted[rrs] then table.sort (rrs, comp_mx); end
+ return rrs;
+end
+
+
+function resolver:purge(soft) -- - - - - - - - - - - - - - - - - - - purge
+ if soft == 'soft' then
+ self.time = socket.gettime();
+ for class,types in pairs(self.cache or {}) do
+ for type,names in pairs(types) do
+ for name,rrs in pairs(names) do
+ prune(rrs, self.time, 'soft')
+ end
+ end
+ end
+ else self.cache = setmetatable({}, cache_metatable); end
end
-function resolver:pulse () -- - - - - - - - - - - - - - - - - - - - - pulse
- --print ':pulse'
- while self:receive() do end
- if not next (self.active) then return nil end
+function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query
+ qname, qtype, qclass = standardize(qname, qtype, qclass)
+
+ if not self.server then self:adddefaultnameservers(); end
+
+ local question = encodeQuestion(qname, qtype, qclass);
+ local peek = self:peek (qname, qtype, qclass);
+ if peek then return peek; end
+
+ local header, id = encodeHeader();
+ --print ('query id', id, qclass, qtype, qname)
+ local o = {
+ packet = header..question,
+ server = self.best_server,
+ delay = 1,
+ retry = socket.gettime() + self.delays[1]
+ };
+
+ -- remember the query
+ self.active[id] = self.active[id] or {};
+ self.active[id][question] = o;
- self.time = socket.gettime ()
- for id,queries in pairs (self.active) do
- for question,o in pairs (queries) do
- if self.time >= o.retry then
+ -- remember which coroutine wants the answer
+ local co = coroutine.running();
+ if co then
+ set(self.wanted, qclass, qtype, qname, co, true);
+ --set(self.yielded, co, qclass, qtype, qname, true);
+ end
- o.server = o.server + 1
- if o.server > #self.server then
- o.server = 1
- o.delay = o.delay + 1
- end
+ local conn, err = self:getsocket(o.server)
+ if not conn then
+ return nil, err;
+ end
+ conn:send (o.packet)
+
+ if timer and self.timeout then
+ local num_servers = #self.server;
+ local i = 1;
+ timer.add_task(self.timeout, function ()
+ if get(self.wanted, qclass, qtype, qname, co) then
+ if i < num_servers then
+ i = i + 1;
+ self:servfail(conn);
+ o.server = self.best_server;
+ conn, err = self:getsocket(o.server);
+ if conn then
+ conn:send(o.packet);
+ return self.timeout;
+ end
+ end
+ -- Tried everything, failed
+ self:cancel(qclass, qtype, qname, co, true);
+ end
+ end)
+ end
+ return true;
+end
- if o.delay > #self.delays then
- --print ('timeout')
- queries[question] = nil
- if not next (queries) then self.active[id] = nil end
- if not next (self.active) then return nil end
- else
- --print ('retry', o.server, o.delay)
- local _a = self.socket[o.server];
- if _a then _a:send (o.packet) end
- o.retry = self.time + self.delays[o.delay]
- end end end end
+function resolver:servfail(sock)
+ -- Resend all queries for this server
+
+ local num = self.socketset[sock]
+
+ -- Socket is dead now
+ self:voidsocket(sock);
+
+ -- Find all requests to the down server, and retry on the next server
+ self.time = socket.gettime();
+ for id,queries in pairs(self.active) do
+ for question,o in pairs(queries) do
+ if o.server == num then -- This request was to the broken server
+ o.server = o.server + 1 -- Use next server
+ if o.server > #self.server then
+ o.server = 1;
+ end
+
+ o.retries = (o.retries or 0) + 1;
+ if o.retries >= #self.server then
+ --print('timeout');
+ queries[question] = nil;
+ else
+ local _a = self:getsocket(o.server);
+ if _a then _a:send(o.packet); end
+ end
+ end
+ end
+ if next(queries) == nil then
+ self.active[id] = nil;
+ end
+ end
- if next (self.active) then return true end
- return nil
- end
+ if num == self.best_server then
+ self.best_server = self.best_server + 1;
+ if self.best_server > #self.server then
+ -- Exhausted all servers, try first again
+ self.best_server = 1;
+ end
+ end
+end
+function resolver:settimeout(seconds)
+ self.timeout = seconds;
+end
-function resolver:lookup (qname, qtype, qclass) -- - - - - - - - - - lookup
- self:query (qname, qtype, qclass)
- while self:pulse () do socket.select (self.socket, nil, 4) end
- --print (self.cache)
- return self:peek (qname, qtype, qclass)
- end
+function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive
+ --print('receive'); print(self.socket);
+ self.time = socket.gettime();
+ rset = rset or self.socket;
+
+ local response;
+ for i,sock in pairs(rset) do
+
+ if self.socketset[sock] then
+ local packet = sock:receive();
+ if packet then
+ response = self:decode(packet);
+ if response and self.active[response.header.id]
+ and self.active[response.header.id][response.question.raw] then
+ --print('received response');
+ --self.print(response);
+
+ for j,rr in pairs(response.answer) do
+ if rr.name:sub(-#response.question[1].name, -1) == response.question[1].name then
+ self:remember(rr, response.question[1].type)
+ end
+ end
+
+ -- retire the query
+ local queries = self.active[response.header.id];
+ queries[response.question.raw] = nil;
+
+ if not next(queries) then self.active[response.header.id] = nil; end
+ if not next(self.active) then self:closeall(); end
+
+ -- was the query on the wanted list?
+ local q = response.question[1];
+ local cos = get(self.wanted, q.class, q.type, q.name);
+ if cos then
+ for co in pairs(cos) do
+ set(self.yielded, co, q.class, q.type, q.name, nil);
+ if coroutine.status(co) == "suspended" then coroutine.resume(co); end
+ end
+ set(self.wanted, q.class, q.type, q.name, nil);
+ end
+ end
+
+ end
+ end
+ end
-function resolver:lookupex (handler, qname, qtype, qclass) -- - - - - - - - - - lookup
- return self:peek (qname, qtype, qclass) or self:query (qname, qtype, qclass)
- end
+ return response;
+end
+function resolver:feed(sock, packet, force)
+ --print('receive'); print(self.socket);
+ self.time = socket.gettime();
+
+ local response = self:decode(packet, force);
+ if response and self.active[response.header.id]
+ and self.active[response.header.id][response.question.raw] then
+ --print('received response');
+ --self.print(response);
+
+ for j,rr in pairs(response.answer) do
+ self:remember(rr, response.question[1].type);
+ end
+
+ -- retire the query
+ local queries = self.active[response.header.id];
+ queries[response.question.raw] = nil;
+ if not next(queries) then self.active[response.header.id] = nil; end
+ if not next(self.active) then self:closeall(); end
+
+ -- was the query on the wanted list?
+ local q = response.question[1];
+ if q then
+ local cos = get(self.wanted, q.class, q.type, q.name);
+ if cos then
+ for co in pairs(cos) do
+ set(self.yielded, co, q.class, q.type, q.name, nil);
+ if coroutine.status(co) == "suspended" then coroutine.resume(co); end
+ end
+ set(self.wanted, q.class, q.type, q.name, nil);
+ end
+ end
+ end
+
+ return response;
+end
+
+function resolver:cancel(qclass, qtype, qname, co, call_handler)
+ local cos = get(self.wanted, qclass, qtype, qname);
+ if cos then
+ if call_handler then
+ coroutine.resume(co);
+ end
+ cos[co] = nil;
+ end
+end
+
+function resolver:pulse() -- - - - - - - - - - - - - - - - - - - - - pulse
+ --print(':pulse');
+ while self:receive() do end
+ if not next(self.active) then return nil; end
+
+ self.time = socket.gettime();
+ for id,queries in pairs(self.active) do
+ for question,o in pairs(queries) do
+ if self.time >= o.retry then
+
+ o.server = o.server + 1;
+ if o.server > #self.server then
+ o.server = 1;
+ o.delay = o.delay + 1;
+ end
+
+ if o.delay > #self.delays then
+ --print('timeout');
+ queries[question] = nil;
+ if not next(queries) then self.active[id] = nil; end
+ if not next(self.active) then return nil; end
+ else
+ --print('retry', o.server, o.delay);
+ local _a = self.socket[o.server];
+ if _a then _a:send(o.packet); end
+ o.retry = self.time + self.delays[o.delay];
+ end
+ end
+ end
+ end
+
+ if next(self.active) then return true; end
+ return nil;
+end
+
+
+function resolver:lookup(qname, qtype, qclass) -- - - - - - - - - - lookup
+ self:query (qname, qtype, qclass)
+ while self:pulse() do
+ local recvt = {}
+ for i, s in ipairs(self.socket) do
+ recvt[i] = s
+ end
+ socket.select(recvt, nil, 4)
+ end
+ --print(self.cache);
+ return self:peek(qname, qtype, qclass);
+end
+
+function resolver:lookupex(handler, qname, qtype, qclass) -- - - - - - - - - - lookup
+ return self:peek(qname, qtype, qclass) or self:query(qname, qtype, qclass);
+end
+
+function resolver:tohostname(ip)
+ return dns.lookup(ip:gsub("(%d+)%.(%d+)%.(%d+)%.(%d+)", "%4.%3.%2.%1.in-addr.arpa."), "PTR");
+end
+
--print ---------------------------------------------------------------- print
local hints = { -- - - - - - - - - - - - - - - - - - - - - - - - - - - hints
- qr = { [0]='query', 'response' },
- opcode = { [0]='query', 'inverse query', 'server status request' },
- aa = { [0]='non-authoritative', 'authoritative' },
- tc = { [0]='complete', 'truncated' },
- rd = { [0]='recursion not desired', 'recursion desired' },
- ra = { [0]='recursion not available', 'recursion available' },
- z = { [0]='(reserved)' },
- rcode = { [0]='no error', 'format error', 'server failure', 'name error',
- 'not implemented' },
-
- type = dns.type,
- class = dns.class, }
-
-
-local function hint (p, s) -- - - - - - - - - - - - - - - - - - - - - - hint
- return (hints[s] and hints[s][p[s]]) or '' end
-
-
-function resolver.print (response) -- - - - - - - - - - - - - resolver.print
-
- for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z',
- 'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do
- print ( string.format ('%-30s', 'header.'..s),
- response.header[s], hint (response.header, s) )
- end
-
- for i,question in ipairs (response.question) do
- print (string.format ('question[%i].name ', i), question.name)
- print (string.format ('question[%i].type ', i), question.type)
- print (string.format ('question[%i].class ', i), question.class)
- end
-
- local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 }
- local tmp
- for s,s in pairs {'answer', 'authority', 'additional'} do
- for i,rr in pairs (response[s]) do
- for j,t in pairs { 'name', 'type', 'class', 'ttl', 'rdlength' } do
- tmp = string.format ('%s[%i].%s', s, i, t)
- print (string.format ('%-30s', tmp), rr[t], hint (rr, t))
- end
- for j,t in pairs (rr) do
- if not common[j] then
- tmp = string.format ('%s[%i].%s', s, i, j)
- print (string.format ('%-30s %s', tostring(tmp), tostring(t)))
- end end end end end
+ qr = { [0]='query', 'response' },
+ opcode = { [0]='query', 'inverse query', 'server status request' },
+ aa = { [0]='non-authoritative', 'authoritative' },
+ tc = { [0]='complete', 'truncated' },
+ rd = { [0]='recursion not desired', 'recursion desired' },
+ ra = { [0]='recursion not available', 'recursion available' },
+ z = { [0]='(reserved)' },
+ rcode = { [0]='no error', 'format error', 'server failure', 'name error', 'not implemented' },
+
+ type = dns.type,
+ class = dns.class
+};
+
+
+local function hint(p, s) -- - - - - - - - - - - - - - - - - - - - - - hint
+ return (hints[s] and hints[s][p[s]]) or '';
+end
--- module api ------------------------------------------------------ module api
+function resolver.print(response) -- - - - - - - - - - - - - resolver.print
+ for s,s in pairs { 'id', 'qr', 'opcode', 'aa', 'tc', 'rd', 'ra', 'z',
+ 'rcode', 'qdcount', 'ancount', 'nscount', 'arcount' } do
+ print( string.format('%-30s', 'header.'..s), response.header[s], hint(response.header, s) );
+ end
+ for i,question in ipairs(response.question) do
+ print(string.format ('question[%i].name ', i), question.name);
+ print(string.format ('question[%i].type ', i), question.type);
+ print(string.format ('question[%i].class ', i), question.class);
+ end
-local function resolve (func, ...) -- - - - - - - - - - - - - - resolver_get
- dns._resolver = dns._resolver or dns.resolver ()
- return func (dns._resolver, ...)
- end
+ local common = { name=1, type=1, class=1, ttl=1, rdlength=1, rdata=1 };
+ local tmp;
+ for s,s in pairs({'answer', 'authority', 'additional'}) do
+ for i,rr in pairs(response[s]) do
+ for j,t in pairs({ 'name', 'type', 'class', 'ttl', 'rdlength' }) do
+ tmp = string.format('%s[%i].%s', s, i, t);
+ print(string.format('%-30s', tmp), rr[t], hint(rr, t));
+ end
+ for j,t in pairs(rr) do
+ if not common[j] then
+ tmp = string.format('%s[%i].%s', s, i, j);
+ print(string.format('%-30s %s', tostring(tmp), tostring(t)));
+ end
+ end
+ end
+ end
+end
-function dns.resolver () -- - - - - - - - - - - - - - - - - - - - - resolver
+-- module api ------------------------------------------------------ module api
- -- this function seems to be redundant with resolver.new ()
- local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {} }
- setmetatable (r, resolver)
- setmetatable (r.cache, cache_metatable)
- setmetatable (r.unsorted, { __mode = 'kv' })
- return r
- end
+function dns.resolver () -- - - - - - - - - - - - - - - - - - - - - resolver
+ -- this function seems to be redundant with resolver.new ()
+ local r = { active = {}, cache = {}, unsorted = {}, wanted = {}, yielded = {}, best_server = 1 };
+ setmetatable (r, resolver);
+ setmetatable (r.cache, cache_metatable);
+ setmetatable (r.unsorted, { __mode = 'kv' });
+ return r;
+end
-function dns.lookup (...) -- - - - - - - - - - - - - - - - - - - - - lookup
- return resolve (resolver.lookup, ...) end
+local _resolver = dns.resolver();
+dns._resolver = _resolver;
+function dns.lookup(...) -- - - - - - - - - - - - - - - - - - - - - lookup
+ return _resolver:lookup(...);
+end
-function dns.purge (...) -- - - - - - - - - - - - - - - - - - - - - - purge
- return resolve (resolver.purge, ...) end
+function dns.tohostname(...)
+ return _resolver:tohostname(...);
+end
-function dns.peek (...) -- - - - - - - - - - - - - - - - - - - - - - - peek
- return resolve (resolver.peek, ...) end
+function dns.purge(...) -- - - - - - - - - - - - - - - - - - - - - - purge
+ return _resolver:purge(...);
+end
+function dns.peek(...) -- - - - - - - - - - - - - - - - - - - - - - - peek
+ return _resolver:peek(...);
+end
-function dns.query (...) -- - - - - - - - - - - - - - - - - - - - - - query
- return resolve (resolver.query, ...) end
+function dns.query(...) -- - - - - - - - - - - - - - - - - - - - - - query
+ return _resolver:query(...);
+end
-function dns.feed (...) -- - - - - - - - - - - - - - - - - - - - - - feed
- return resolve (resolver.feed, ...) end
+function dns.feed(...) -- - - - - - - - - - - - - - - - - - - - - - - feed
+ return _resolver:feed(...);
+end
-function dns.cancel(...) -- - - - - - - - - - - - - - - - - - - - - - cancel
- return resolve(resolver.cancel, ...) end
+function dns.cancel(...) -- - - - - - - - - - - - - - - - - - - - - - cancel
+ return _resolver:cancel(...);
+end
-function dns:socket_wrapper_set (...) -- - - - - - - - - socket_wrapper_set
- return resolve (resolver.socket_wrapper_set, ...) end
+function dns.settimeout(...)
+ return _resolver:settimeout(...);
+end
+function dns.socket_wrapper_set(...) -- - - - - - - - - socket_wrapper_set
+ return _resolver:socket_wrapper_set(...);
+end
-return dns
+return dns;
diff --git a/net/http.lua b/net/http.lua
index 9d2f9b96..3b783a41 100644
--- a/net/http.lua
+++ b/net/http.lua
@@ -1,124 +1,107 @@
-- Prosody IM
--- Copyright (C) 2008-2009 Matthew Wild
--- Copyright (C) 2008-2009 Waqas Hussain
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
--
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
-
local socket = require "socket"
-local mime = require "mime"
+local b64 = require "util.encodings".base64.encode;
local url = require "socket.url"
+local httpstream_new = require "net.http.parser".new;
+local util_http = require "util.http";
-local server = require "net.server"
+local ssl_available = pcall(require, "ssl");
-local connlisteners_get = require "net.connlisteners".get;
-local listener = connlisteners_get("httpclient") or error("No httpclient listener!");
+local server = require "net.server"
local t_insert, t_concat = table.insert, table.concat;
-local tonumber, tostring, pairs, xpcall, select, debug_traceback, char, format =
- tonumber, tostring, pairs, xpcall, select, debug.traceback, string.char, string.format;
+local pairs = pairs;
+local tonumber, tostring, xpcall, select, traceback =
+ tonumber, tostring, xpcall, select, debug.traceback;
local log = require "util.logger".init("http");
-local print = function () end
module "http"
-function urlencode(s) return s and (s:gsub("%W", function (c) return format("%%%02x", c:byte()); end)); end
-function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end)); end
+local requests = {}; -- Open requests
+
+local listener = { default_port = 80, default_mode = "*a" };
-local function expectbody(reqt, code)
- if reqt.method == "HEAD" then return nil end
- if code == 204 or code == 304 then return nil end
- if code >= 100 and code < 200 then return nil end
- return 1
+function listener.onconnect(conn)
+ local req = requests[conn];
+ -- Send the request
+ local request_line = { req.method or "GET", " ", req.path, " HTTP/1.1\r\n" };
+ if req.query then
+ t_insert(request_line, 4, "?"..req.query);
+ end
+
+ conn:write(t_concat(request_line));
+ local t = { [2] = ": ", [4] = "\r\n" };
+ for k, v in pairs(req.headers) do
+ t[1], t[3] = k, v;
+ conn:write(t_concat(t));
+ end
+ conn:write("\r\n");
+
+ if req.body then
+ conn:write(req.body);
+ end
end
-local function request_reader(request, data, startpos)
- if not data then
- if request.body then
- log("debug", "Connection closed, but we have data, calling callback...");
- request.callback(t_concat(request.body), request.code, request);
- elseif request.state ~= "completed" then
- -- Error.. connection was closed prematurely
- request.callback("connection-closed", 0, request);
- end
- destroy_request(request);
- request.body = nil;
- request.state = "completed";
+function listener.onincoming(conn, data)
+ local request = requests[conn];
+
+ if not request then
+ log("warn", "Received response from connection %s with no request attached!", tostring(conn));
return;
end
- if request.state == "body" and request.state ~= "completed" then
- print("Reading body...")
- if not request.body then request.body = {}; request.havebodylength, request.bodylength = 0, tonumber(request.responseheaders["content-length"]); end
- if startpos then
- data = data:sub(startpos, -1)
- end
- t_insert(request.body, data);
- if request.bodylength then
- request.havebodylength = request.havebodylength + #data;
- if request.havebodylength >= request.bodylength then
- -- We have the body
- log("debug", "Have full body, calling callback");
- if request.callback then
- request.callback(t_concat(request.body), request.code, request);
- end
- request.body = nil;
- request.state = "completed";
- else
- print("", "Have "..request.havebodylength.." bytes out of "..request.bodylength);
- end
- end
- elseif request.state == "headers" then
- print("Reading headers...")
- local pos = startpos;
- local headers = request.responseheaders or {};
- for line in data:sub(startpos, -1):gmatch("(.-)\r\n") do
- startpos = startpos + #line + 2;
- local k, v = line:match("(%S+): (.+)");
- if k and v then
- headers[k:lower()] = v;
- print("Header: "..k:lower().." = "..v);
- elseif #line == 0 then
- request.responseheaders = headers;
- break;
- else
- print("Unhandled header line: "..line);
+
+ if data and request.reader then
+ request:reader(data);
+ end
+end
+
+function listener.ondisconnect(conn, err)
+ local request = requests[conn];
+ if request and request.conn then
+ request:reader(nil, err);
+ end
+ requests[conn] = nil;
+end
+
+local function request_reader(request, data, err)
+ if not request.parser then
+ local function error_cb(reason)
+ if request.callback then
+ request.callback(reason or "connection-closed", 0, request);
+ request.callback = nil;
end
- end
- -- Reached the end of the headers
- request.state = "body";
- if #data > startpos then
- return request_reader(request, data, startpos);
- end
- elseif request.state == "status" then
- print("Reading status...")
- local http, code, text, linelen = data:match("^HTTP/(%S+) (%d+) (.-)\r\n()", startpos);
- code = tonumber(code);
- if not code then
- return request.callback("invalid-status-line", 0, request);
+ destroy_request(request);
end
- request.code, request.responseversion = code, http;
+ if not data then
+ error_cb(err);
+ return;
+ end
- if request.onlystatus or not expectbody(request, code) then
+ local function success_cb(r)
if request.callback then
- request.callback(nil, code, request);
+ request.callback(r.body, r.code, r, request);
+ request.callback = nil;
end
destroy_request(request);
- return;
end
-
- request.state = "headers";
-
- if #data > linelen then
- return request_reader(request, data, linelen);
+ local function options_cb()
+ return request;
end
+ request.parser = httpstream_new(success_cb, error_cb, "client", options_cb);
end
+ request.parser:feed(data);
end
-local function handleerr(err) log("error", "Traceback[http]: %s: %s", tostring(err), debug_traceback()); end
+local function handleerr(err) log("error", "Traceback[http]: %s", traceback(tostring(err), 2)); end
function request(u, ex, callback)
local req = url.parse(u);
@@ -131,79 +114,78 @@ function request(u, ex, callback)
req.path = "/";
end
- local custom_headers, body;
- local default_headers = { ["Host"] = req.host, ["User-Agent"] = "Prosody XMPP Server" }
+ local method, headers, body;
+ headers = {
+ ["Host"] = req.host;
+ ["User-Agent"] = "Prosody XMPP Server";
+ };
if req.userinfo then
- default_headers["Authorization"] = "Basic "..mime.b64(req.userinfo);
+ headers["Authorization"] = "Basic "..b64(req.userinfo);
end
-
+
if ex then
- custom_headers = ex.headers;
req.onlystatus = ex.onlystatus;
body = ex.body;
if body then
- req.method = "POST ";
- default_headers["Content-Length"] = tostring(#body);
- default_headers["Content-Type"] = "application/x-www-form-urlencoded";
+ method = "POST";
+ headers["Content-Length"] = tostring(#body);
+ headers["Content-Type"] = "application/x-www-form-urlencoded";
+ end
+ if ex.method then method = ex.method; end
+ if ex.headers then
+ for k, v in pairs(ex.headers) do
+ headers[k] = v;
+ end
end
- if ex.method then req.method = ex.method; end
- end
-
- req.handler, req.conn = server.wrapclient(socket.tcp(), req.host, req.port or 80, listener, "*a");
- req.write = req.handler.write;
- req.conn:settimeout(0);
- local ok, err = req.conn:connect(req.host, req.port or 80);
- if not ok and err ~= "timeout" then
- callback(nil, 0, req);
- return nil, err;
end
- local request_line = { req.method or "GET", " ", req.path, " HTTP/1.1\r\n" };
+ -- Attach to request object
+ req.method, req.headers, req.body = method, headers, body;
- if req.query then
- t_insert(request_line, 4, "?");
- t_insert(request_line, 5, req.query);
- end
-
- req.write(t_concat(request_line));
- local t = { [2] = ": ", [4] = "\r\n" };
- if custom_headers then
- for k, v in pairs(custom_headers) do
- t[1], t[3] = k, v;
- req.write(t_concat(t));
- default_headers[k] = nil;
- end
+ local using_https = req.scheme == "https";
+ if using_https and not ssl_available then
+ error("SSL not available, unable to contact https URL");
end
+ local port = tonumber(req.port) or (using_https and 443 or 80);
- for k, v in pairs(default_headers) do
- t[1], t[3] = k, v;
- req.write(t_concat(t));
- default_headers[k] = nil;
+ -- Connect the socket, and wrap it with net.server
+ local conn = socket.tcp();
+ conn:settimeout(10);
+ local ok, err = conn:connect(req.host, port);
+ if not ok and err ~= "timeout" then
+ callback(nil, 0, req);
+ return nil, err;
end
- req.write("\r\n");
- if body then
- req.write(body);
+ local sslctx = false;
+ if using_https then
+ sslctx = ex and ex.sslctx or { mode = "client", protocol = "sslv23", options = { "no_sslv2" } };
end
+
+ req.handler, req.conn = server.wrapclient(conn, req.host, port, listener, "*a", sslctx);
+ req.write = function (...) return req.handler:write(...); end
- req.callback = function (content, code, request) log("debug", "Calling callback, status %s", code or "---"); return select(2, xpcall(function () return callback(content, code, request) end, handleerr)); end
+ req.callback = function (content, code, request, response) log("debug", "Calling callback, status %s", code or "---"); return select(2, xpcall(function () return callback(content, code, request, response) end, handleerr)); end
req.reader = request_reader;
req.state = "status";
-
- listener.register_request(req.handler, req);
+ requests[req.handler] = req;
return req;
end
function destroy_request(request)
if request.conn then
- request.handler.close()
- listener.disconnect(request.conn, "closed");
+ request.conn = nil;
+ request.handler:close()
end
end
-_M.urlencode = urlencode;
+local urlencode, urldecode = util_http.urlencode, util_http.urldecode;
+local formencode, formdecode = util_http.formencode, util_http.formdecode;
+
+_M.urlencode, _M.urldecode = urlencode, urldecode;
+_M.formencode, _M.formdecode = formencode, formdecode;
return _M;
diff --git a/net/http/codes.lua b/net/http/codes.lua
new file mode 100644
index 00000000..0cadd079
--- /dev/null
+++ b/net/http/codes.lua
@@ -0,0 +1,67 @@
+
+local response_codes = {
+ -- Source: http://www.iana.org/assignments/http-status-codes
+ -- s/^\(\d*\)\s*\(.*\S\)\s*\[RFC.*\]\s*$/^I["\1"] = "\2";
+ [100] = "Continue";
+ [101] = "Switching Protocols";
+ [102] = "Processing";
+
+ [200] = "OK";
+ [201] = "Created";
+ [202] = "Accepted";
+ [203] = "Non-Authoritative Information";
+ [204] = "No Content";
+ [205] = "Reset Content";
+ [206] = "Partial Content";
+ [207] = "Multi-Status";
+ [208] = "Already Reported";
+ [226] = "IM Used";
+
+ [300] = "Multiple Choices";
+ [301] = "Moved Permanently";
+ [302] = "Found";
+ [303] = "See Other";
+ [304] = "Not Modified";
+ [305] = "Use Proxy";
+ -- The 306 status code was used in a previous version of [RFC2616], is no longer used, and the code is reserved.
+ [307] = "Temporary Redirect";
+
+ [400] = "Bad Request";
+ [401] = "Unauthorized";
+ [402] = "Payment Required";
+ [403] = "Forbidden";
+ [404] = "Not Found";
+ [405] = "Method Not Allowed";
+ [406] = "Not Acceptable";
+ [407] = "Proxy Authentication Required";
+ [408] = "Request Timeout";
+ [409] = "Conflict";
+ [410] = "Gone";
+ [411] = "Length Required";
+ [412] = "Precondition Failed";
+ [413] = "Request Entity Too Large";
+ [414] = "Request-URI Too Long";
+ [415] = "Unsupported Media Type";
+ [416] = "Requested Range Not Satisfiable";
+ [417] = "Expectation Failed";
+ [418] = "I'm a teapot";
+ [422] = "Unprocessable Entity";
+ [423] = "Locked";
+ [424] = "Failed Dependency";
+ -- The 425 status code is reserved for the WebDAV advanced collections expired proposal [RFC2817]
+ [426] = "Upgrade Required";
+
+ [500] = "Internal Server Error";
+ [501] = "Not Implemented";
+ [502] = "Bad Gateway";
+ [503] = "Service Unavailable";
+ [504] = "Gateway Timeout";
+ [505] = "HTTP Version Not Supported";
+ [506] = "Variant Also Negotiates"; -- Experimental
+ [507] = "Insufficient Storage";
+ [508] = "Loop Detected";
+ [510] = "Not Extended";
+};
+
+for k,v in pairs(response_codes) do response_codes[k] = k.." "..v; end
+return setmetatable(response_codes, { __index = function(t, k) return k.." Unassigned"; end })
diff --git a/net/http/parser.lua b/net/http/parser.lua
new file mode 100644
index 00000000..f9e6cea0
--- /dev/null
+++ b/net/http/parser.lua
@@ -0,0 +1,160 @@
+local tonumber = tonumber;
+local assert = assert;
+local url_parse = require "socket.url".parse;
+local urldecode = require "util.http".urldecode;
+
+local function preprocess_path(path)
+ path = urldecode((path:gsub("//+", "/")));
+ if path:sub(1,1) ~= "/" then
+ path = "/"..path;
+ end
+ local level = 0;
+ for component in path:gmatch("([^/]+)/") do
+ if component == ".." then
+ level = level - 1;
+ elseif component ~= "." then
+ level = level + 1;
+ end
+ if level < 0 then
+ return nil;
+ end
+ end
+ return path;
+end
+
+local httpstream = {};
+
+function httpstream.new(success_cb, error_cb, parser_type, options_cb)
+ local client = true;
+ if not parser_type or parser_type == "server" then client = false; else assert(parser_type == "client", "Invalid parser type"); end
+ local buf = "";
+ local chunked, chunk_size, chunk_start;
+ local state = nil;
+ local packet;
+ local len;
+ local have_body;
+ local error;
+ return {
+ feed = function(self, data)
+ if error then return nil, "parse has failed"; end
+ if not data then -- EOF
+ if state and client and not len then -- reading client body until EOF
+ packet.body = buf;
+ success_cb(packet);
+ elseif buf ~= "" then -- unexpected EOF
+ error = true; return error_cb();
+ end
+ return;
+ end
+ buf = buf..data;
+ while #buf > 0 do
+ if state == nil then -- read request
+ local index = buf:find("\r\n\r\n", nil, true);
+ if not index then return; end -- not enough data
+ local method, path, httpversion, status_code, reason_phrase;
+ local first_line;
+ local headers = {};
+ for line in buf:sub(1,index+1):gmatch("([^\r\n]+)\r\n") do -- parse request
+ if first_line then
+ local key, val = line:match("^([^%s:]+): *(.*)$");
+ if not key then error = true; return error_cb("invalid-header-line"); end -- TODO handle multi-line and invalid headers
+ key = key:lower();
+ headers[key] = headers[key] and headers[key]..","..val or val;
+ else
+ first_line = line;
+ if client then
+ httpversion, status_code, reason_phrase = line:match("^HTTP/(1%.[01]) (%d%d%d) (.*)$");
+ status_code = tonumber(status_code);
+ if not status_code then error = true; return error_cb("invalid-status-line"); end
+ have_body = not
+ ( (options_cb and options_cb().method == "HEAD")
+ or (status_code == 204 or status_code == 304 or status_code == 301)
+ or (status_code >= 100 and status_code < 200) );
+ else
+ method, path, httpversion = line:match("^(%w+) (%S+) HTTP/(1%.[01])$");
+ if not method then error = true; return error_cb("invalid-status-line"); end
+ end
+ end
+ end
+ if not first_line then error = true; return error_cb("invalid-status-line"); end
+ chunked = have_body and headers["transfer-encoding"] == "chunked";
+ len = tonumber(headers["content-length"]); -- TODO check for invalid len
+ if client then
+ -- FIXME handle '100 Continue' response (by skipping it)
+ if not have_body then len = 0; end
+ packet = {
+ code = status_code;
+ httpversion = httpversion;
+ headers = headers;
+ body = have_body and "" or nil;
+ -- COMPAT the properties below are deprecated
+ responseversion = httpversion;
+ responseheaders = headers;
+ };
+ else
+ local parsed_url;
+ if path:byte() == 47 then -- starts with /
+ local _path, _query = path:match("([^?]*).?(.*)");
+ if _query == "" then _query = nil; end
+ parsed_url = { path = _path, query = _query };
+ else
+ parsed_url = url_parse(path);
+ if not(parsed_url and parsed_url.path) then error = true; return error_cb("invalid-url"); end
+ end
+ path = preprocess_path(parsed_url.path);
+ headers.host = parsed_url.host or headers.host;
+
+ len = len or 0;
+ packet = {
+ method = method;
+ url = parsed_url;
+ path = path;
+ httpversion = httpversion;
+ headers = headers;
+ body = nil;
+ };
+ end
+ buf = buf:sub(index + 4);
+ state = true;
+ end
+ if state then -- read body
+ if client then
+ if chunked then
+ if not buf:find("\r\n", nil, true) then
+ return;
+ end -- not enough data
+ if not chunk_size then
+ chunk_size, chunk_start = buf:match("^(%x+)[^\r\n]*\r\n()");
+ chunk_size = chunk_size and tonumber(chunk_size, 16);
+ if not chunk_size then error = true; return error_cb("invalid-chunk-size"); end
+ end
+ if chunk_size == 0 and buf:find("\r\n\r\n", chunk_start-2, true) then
+ state, chunk_size = nil, nil;
+ buf = buf:gsub("^.-\r\n\r\n", ""); -- This ensure extensions and trailers are stripped
+ success_cb(packet);
+ elseif #buf - chunk_start + 2 >= chunk_size then -- we have a chunk
+ packet.body = packet.body..buf:sub(chunk_start, chunk_start + (chunk_size-1));
+ buf = buf:sub(chunk_start + chunk_size + 2);
+ chunk_size, chunk_start = nil, nil;
+ else -- Partial chunk remaining
+ break;
+ end
+ elseif len and #buf >= len then
+ packet.body, buf = buf:sub(1, len), buf:sub(len + 1);
+ state = nil; success_cb(packet);
+ else
+ break;
+ end
+ elseif #buf >= len then
+ packet.body, buf = buf:sub(1, len), buf:sub(len + 1);
+ state = nil; success_cb(packet);
+ else
+ break;
+ end
+ end
+ end
+ end;
+ };
+end
+
+return httpstream;
diff --git a/net/http/server.lua b/net/http/server.lua
new file mode 100644
index 00000000..dec7da19
--- /dev/null
+++ b/net/http/server.lua
@@ -0,0 +1,303 @@
+
+local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat;
+local parser_new = require "net.http.parser".new;
+local events = require "util.events".new();
+local addserver = require "net.server".addserver;
+local log = require "util.logger".init("http.server");
+local os_date = os.date;
+local pairs = pairs;
+local s_upper = string.upper;
+local setmetatable = setmetatable;
+local xpcall = xpcall;
+local traceback = debug.traceback;
+local tostring = tostring;
+local codes = require "net.http.codes";
+
+local _M = {};
+
+local sessions = {};
+local listener = {};
+local hosts = {};
+local default_host;
+
+local function is_wildcard_event(event)
+ return event:sub(-2, -1) == "/*";
+end
+local function is_wildcard_match(wildcard_event, event)
+ return wildcard_event:sub(1, -2) == event:sub(1, #wildcard_event-1);
+end
+
+local recent_wildcard_events, max_cached_wildcard_events = {}, 10000;
+
+local event_map = events._event_map;
+setmetatable(events._handlers, {
+ -- Called when firing an event that doesn't exist (but may match a wildcard handler)
+ __index = function (handlers, curr_event)
+ if is_wildcard_event(curr_event) then return; end -- Wildcard events cannot be fired
+ -- Find all handlers that could match this event, sort them
+ -- and then put the array into handlers[curr_event] (and return it)
+ local matching_handlers_set = {};
+ local handlers_array = {};
+ for event, handlers_set in pairs(event_map) do
+ if event == curr_event or
+ is_wildcard_event(event) and is_wildcard_match(event, curr_event) then
+ for handler, priority in pairs(handlers_set) do
+ matching_handlers_set[handler] = { (select(2, event:gsub("/", "%1"))), is_wildcard_event(event) and 0 or 1, priority };
+ table.insert(handlers_array, handler);
+ end
+ end
+ end
+ if #handlers_array > 0 then
+ table.sort(handlers_array, function(b, a)
+ local a_score, b_score = matching_handlers_set[a], matching_handlers_set[b];
+ for i = 1, #a_score do
+ if a_score[i] ~= b_score[i] then -- If equal, compare next score value
+ return a_score[i] < b_score[i];
+ end
+ end
+ return false;
+ end);
+ else
+ handlers_array = false;
+ end
+ rawset(handlers, curr_event, handlers_array);
+ if not event_map[curr_event] then -- Only wildcard handlers match, if any
+ table.insert(recent_wildcard_events, curr_event);
+ if #recent_wildcard_events > max_cached_wildcard_events then
+ rawset(handlers, table.remove(recent_wildcard_events, 1), nil);
+ end
+ end
+ return handlers_array;
+ end;
+ __newindex = function (handlers, curr_event, handlers_array)
+ if handlers_array == nil
+ and is_wildcard_event(curr_event) then
+ -- Invalidate the indexes of all matching events
+ for event in pairs(handlers) do
+ if is_wildcard_match(curr_event, event) then
+ handlers[event] = nil;
+ end
+ end
+ end
+ rawset(handlers, curr_event, handlers_array);
+ end;
+});
+
+local handle_request;
+local _1, _2, _3;
+local function _handle_request() return handle_request(_1, _2, _3); end
+
+local last_err;
+local function _traceback_handler(err) last_err = err; log("error", "Traceback[httpserver]: %s", traceback(tostring(err), 2)); end
+events.add_handler("http-error", function (error)
+ return "Error processing request: "..codes[error.code]..". Check your error log for more information.";
+end, -1);
+
+function listener.onconnect(conn)
+ local secure = conn:ssl() and true or nil;
+ local pending = {};
+ local waiting = false;
+ local function process_next()
+ if waiting then log("debug", "can't process_next, waiting"); return; end
+ waiting = true;
+ while sessions[conn] and #pending > 0 do
+ local request = t_remove(pending);
+ --log("debug", "process_next: %s", request.path);
+ --handle_request(conn, request, process_next);
+ _1, _2, _3 = conn, request, process_next;
+ if not xpcall(_handle_request, _traceback_handler) then
+ conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = last_err }));
+ conn:close();
+ end
+ end
+ --log("debug", "ready for more");
+ waiting = false;
+ end
+ local function success_cb(request)
+ --log("debug", "success_cb: %s", request.path);
+ if waiting then
+ log("error", "http connection handler is not reentrant: %s", request.path);
+ assert(false, "http connection handler is not reentrant");
+ end
+ request.secure = secure;
+ t_insert(pending, request);
+ process_next();
+ end
+ local function error_cb(err)
+ log("debug", "error_cb: %s", err or "<nil>");
+ -- FIXME don't close immediately, wait until we process current stuff
+ -- FIXME if err, send off a bad-request response
+ sessions[conn] = nil;
+ conn:close();
+ end
+ sessions[conn] = parser_new(success_cb, error_cb);
+end
+
+function listener.ondisconnect(conn)
+ local open_response = conn._http_open_response;
+ if open_response and open_response.on_destroy then
+ open_response.finished = true;
+ open_response:on_destroy();
+ end
+ sessions[conn] = nil;
+end
+
+function listener.onincoming(conn, data)
+ sessions[conn]:feed(data);
+end
+
+local headerfix = setmetatable({}, {
+ __index = function(t, k)
+ local v = "\r\n"..k:gsub("_", "-"):gsub("%f[%w].", s_upper)..": ";
+ t[k] = v;
+ return v;
+ end
+});
+
+function _M.hijack_response(response, listener)
+ error("TODO");
+end
+function handle_request(conn, request, finish_cb)
+ --log("debug", "handler: %s", request.path);
+ local headers = {};
+ for k,v in pairs(request.headers) do headers[k:gsub("-", "_")] = v; end
+ request.headers = headers;
+ request.conn = conn;
+
+ local date_header = os_date('!%a, %d %b %Y %H:%M:%S GMT'); -- FIXME use
+ local conn_header = request.headers.connection;
+ conn_header = conn_header and ","..conn_header:gsub("[ \t]", ""):lower().."," or ""
+ local httpversion = request.httpversion
+ local persistent = conn_header:find(",Keep-Alive,", 1, true)
+ or (httpversion == "1.1" and not conn_header:find(",close,", 1, true));
+
+ local response_conn_header;
+ if persistent then
+ response_conn_header = "Keep-Alive";
+ else
+ response_conn_header = httpversion == "1.1" and "close" or nil
+ end
+
+ local response = {
+ request = request;
+ status_code = 200;
+ headers = { date = date_header, connection = response_conn_header };
+ persistent = persistent;
+ conn = conn;
+ send = _M.send_response;
+ finish_cb = finish_cb;
+ };
+ conn._http_open_response = response;
+
+ local host = (request.headers.host or ""):match("[^:]+");
+
+ -- Some sanity checking
+ local err_code, err;
+ if not request.path then
+ err_code, err = 400, "Invalid path";
+ elseif not hosts[host] then
+ if hosts[default_host] then
+ host = default_host;
+ elseif host then
+ err_code, err = 404, "Unknown host: "..host;
+ else
+ err_code, err = 400, "Missing or invalid 'Host' header";
+ end
+ end
+
+ if err then
+ response.status_code = err_code;
+ response:send(events.fire_event("http-error", { code = err_code, message = err }));
+ return;
+ end
+
+ local event = request.method.." "..host..request.path:match("[^?]*");
+ local payload = { request = request, response = response };
+ --log("debug", "Firing event: %s", event);
+ local result = events.fire_event(event, payload);
+ if result ~= nil then
+ if result ~= true then
+ local body;
+ local result_type = type(result);
+ if result_type == "number" then
+ response.status_code = result;
+ if result >= 400 then
+ body = events.fire_event("http-error", { code = result });
+ end
+ elseif result_type == "string" then
+ body = result;
+ elseif result_type == "table" then
+ for k, v in pairs(result) do
+ if k ~= "headers" then
+ response[k] = v;
+ else
+ for header_name, header_value in pairs(v) do
+ response.headers[header_name] = header_value;
+ end
+ end
+ end
+ end
+ response:send(body);
+ end
+ return;
+ end
+
+ -- if handler not called, return 404
+ response.status_code = 404;
+ response:send(events.fire_event("http-error", { code = 404 }));
+end
+function _M.send_response(response, body)
+ if response.finished then return; end
+ response.finished = true;
+ response.conn._http_open_response = nil;
+
+ local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]);
+ local headers = response.headers;
+ body = body or response.body or "";
+ headers.content_length = #body;
+
+ local output = { status_line };
+ for k,v in pairs(headers) do
+ t_insert(output, headerfix[k]..v);
+ end
+ t_insert(output, "\r\n\r\n");
+ t_insert(output, body);
+
+ response.conn:write(t_concat(output));
+ if response.on_destroy then
+ response:on_destroy();
+ response.on_destroy = nil;
+ end
+ if response.persistent then
+ response:finish_cb();
+ else
+ response.conn:close();
+ end
+end
+function _M.add_handler(event, handler, priority)
+ events.add_handler(event, handler, priority);
+end
+function _M.remove_handler(event, handler)
+ events.remove_handler(event, handler);
+end
+
+function _M.listen_on(port, interface, ssl)
+ addserver(interface or "*", port, listener, "*a", ssl);
+end
+function _M.add_host(host)
+ hosts[host] = true;
+end
+function _M.remove_host(host)
+ hosts[host] = nil;
+end
+function _M.set_default_host(host)
+ default_host = host;
+end
+function _M.fire_event(event, ...)
+ return events.fire_event(event, ...);
+end
+
+_M.listener = listener;
+_M.codes = codes;
+_M._events = events;
+return _M;
diff --git a/net/httpclient_listener.lua b/net/httpclient_listener.lua
deleted file mode 100644
index 69b7946b..00000000
--- a/net/httpclient_listener.lua
+++ /dev/null
@@ -1,44 +0,0 @@
--- Prosody IM
--- Copyright (C) 2008-2009 Matthew Wild
--- Copyright (C) 2008-2009 Waqas Hussain
---
--- This project is MIT/X11 licensed. Please see the
--- COPYING file in the source package for more information.
---
-
-local log = require "util.logger".init("httpclient_listener");
-
-local connlisteners_register = require "net.connlisteners".register;
-
-local requests = {}; -- Open requests
-local buffers = {}; -- Buffers of partial lines
-
-local httpclient = { default_port = 80, default_mode = "*a" };
-
-function httpclient.listener(conn, data)
- local request = requests[conn];
-
- if not request then
- log("warn", "Received response from connection %s with no request attached!", tostring(conn));
- return;
- end
-
- if data and request.reader then
- request:reader(data);
- end
-end
-
-function httpclient.disconnect(conn, err)
- local request = requests[conn];
- if request then
- request:reader(nil);
- end
- requests[conn] = nil;
-end
-
-function httpclient.register_request(conn, req)
- log("debug", "Attaching request %s to connection %s", tostring(req.id or req), tostring(conn));
- requests[conn] = req;
-end
-
-connlisteners_register("httpclient", httpclient);
diff --git a/net/httpserver.lua b/net/httpserver.lua
index 57c8eede..7d574788 100644
--- a/net/httpserver.lua
+++ b/net/httpserver.lua
@@ -1,278 +1,15 @@
--- Prosody IM
--- Copyright (C) 2008-2009 Matthew Wild
--- Copyright (C) 2008-2009 Waqas Hussain
---
--- This project is MIT/X11 licensed. Please see the
--- COPYING file in the source package for more information.
---
-
-
-local socket = require "socket"
-local server = require "net.server"
-local url_parse = require "socket.url".parse;
-
-local connlisteners_start = require "net.connlisteners".start;
-local connlisteners_get = require "net.connlisteners".get;
-local listener;
-
-local t_insert, t_concat = table.insert, table.concat;
-local s_match, s_gmatch = string.match, string.gmatch;
-local tonumber, tostring, pairs, ipairs, type = tonumber, tostring, pairs, ipairs, type;
-
-local urlencode = function (s) return s and (s:gsub("%W", function (c) return string.format("%%%02x", c:byte()); end)); end
-
-local log = require "util.logger".init("httpserver");
-
-local http_servers = {};
+-- COMPAT w/pre-0.9
+local log = require "util.logger".init("net.httpserver");
+local traceback = debug.traceback;
module "httpserver"
-local default_handler;
-
-local function expectbody(reqt)
- return reqt.method == "POST";
-end
-
-local function send_response(request, response)
- -- Write status line
- local resp;
- if response.body then
- local body = tostring(response.body);
- log("debug", "Sending response to %s", request.id);
- resp = { "HTTP/1.0 ", response.status or "200 OK", "\r\n"};
- local h = response.headers;
- if h then
- for k, v in pairs(h) do
- t_insert(resp, k);
- t_insert(resp, ": ");
- t_insert(resp, v);
- t_insert(resp, "\r\n");
- end
- end
- if not (h and h["Content-Length"]) then
- t_insert(resp, "Content-Length: ");
- t_insert(resp, #body);
- t_insert(resp, "\r\n");
- end
- t_insert(resp, "\r\n");
-
- if request.method ~= "HEAD" then
- t_insert(resp, body);
- end
- else
- -- Response we have is just a string (the body)
- log("debug", "Sending response to %s: %s", request.id or "<none>", response or "<none>");
-
- resp = { "HTTP/1.0 200 OK\r\n" };
- t_insert(resp, "Connection: close\r\n");
- t_insert(resp, "Content-Length: ");
- t_insert(resp, #response);
- t_insert(resp, "\r\n\r\n");
-
- t_insert(resp, response);
- end
- request.write(t_concat(resp));
- if not request.stayopen then
- request:destroy();
- end
-end
-
-local function call_callback(request, err)
- if request.handled then return; end
- request.handled = true;
- local callback = request.callback;
- if not callback and request.path then
- local path = request.url.path;
- local base = path:match("^/([^/?]+)");
- if not base then
- base = path:match("^http://[^/?]+/([^/?]+)");
- end
-
- callback = (request.server and request.server.handlers[base]) or default_handler;
- if callback == default_handler then
- log("debug", "Default callback for this request (base: "..tostring(base)..")")
- end
- end
- if callback then
- if err then
- log("debug", "Request error: "..err);
- if not callback(nil, err, request) then
- destroy_request(request);
- end
- return;
- end
-
- local response = callback(request.method, request.body and t_concat(request.body), request);
- if response then
- if response == true and not request.destroyed then
- -- Keep connection open, we will reply later
- log("debug", "Request %s left open, on_destroy is %s", request.id, tostring(request.on_destroy));
- elseif response ~= true then
- -- Assume response
- send_response(request, response);
- destroy_request(request);
- end
- else
- log("debug", "Request handler provided no response, destroying request...");
- -- No response, close connection
- destroy_request(request);
- end
- end
-end
-
-local function request_reader(request, data, startpos)
- if not data then
- if request.body then
- call_callback(request);
- else
- -- Error.. connection was closed prematurely
- call_callback(request, "connection-closed");
- end
- -- Here we force a destroy... the connection is gone, so we can't reply later
- destroy_request(request);
- return;
- end
- if request.state == "body" then
- log("debug", "Reading body...")
- if not request.body then request.body = {}; request.havebodylength, request.bodylength = 0, tonumber(request.headers["content-length"]); end
- if startpos then
- data = data:sub(startpos, -1)
- end
- t_insert(request.body, data);
- if request.bodylength then
- request.havebodylength = request.havebodylength + #data;
- if request.havebodylength >= request.bodylength then
- -- We have the body
- call_callback(request);
- end
- end
- elseif request.state == "headers" then
- log("debug", "Reading headers...")
- local pos = startpos;
- local headers = request.headers or {};
- for line in data:gmatch("(.-)\r\n") do
- startpos = (startpos or 1) + #line + 2;
- local k, v = line:match("(%S+): (.+)");
- if k and v then
- headers[k:lower()] = v;
--- log("debug", "Header: "..k:lower().." = "..v);
- elseif #line == 0 then
- request.headers = headers;
- break;
- else
- log("debug", "Unhandled header line: "..line);
- end
- end
-
- if not expectbody(request) then
- call_callback(request);
- return;
- end
-
- -- Reached the end of the headers
- request.state = "body";
- if #data > startpos then
- return request_reader(request, data:sub(startpos, -1));
- end
- elseif request.state == "request" then
- log("debug", "Reading request line...")
- local method, path, http, linelen = data:match("^(%S+) (%S+) HTTP/(%S+)\r\n()", startpos);
- if not method then
- return call_callback(request, "invalid-status-line");
- end
-
- request.method, request.path, request.httpversion = method, path, http;
-
- request.url = url_parse(request.path);
-
- log("debug", method.." request for "..tostring(request.path) .. " on port "..request.handler.serverport());
-
- if request.onlystatus then
- if not call_callback(request) then
- return;
- end
- end
-
- request.state = "headers";
-
- if #data > linelen then
- return request_reader(request, data:sub(linelen, -1));
- end
- end
-end
-
--- The default handler for requests
-default_handler = function (method, body, request)
- log("debug", method.." request for "..tostring(request.path) .. " on port "..request.handler.serverport());
- return { status = "404 Not Found",
- headers = { ["Content-Type"] = "text/html" },
- body = "<html><head><title>Page Not Found</title></head><body>Not here :(</body></html>" };
-end
-
-
-function new_request(handler)
- return { handler = handler, conn = handler.socket,
- write = handler.write, state = "request",
- server = http_servers[handler.serverport()],
- send = send_response,
- destroy = destroy_request,
- id = tostring{}:match("%x+$")
- };
-end
-
-function destroy_request(request)
- log("debug", "Destroying request %s", request.id);
- listener = listener or connlisteners_get("httpserver");
- if not request.destroyed then
- request.destroyed = true;
- if request.on_destroy then
- log("debug", "Request has destroy callback");
- request.on_destroy(request);
- else
- log("debug", "Request has no destroy callback");
- end
- request.handler.close()
- if request.conn then
- listener.disconnect(request.conn, "closed");
- end
- end
-end
-
-function new(params)
- local http_server = http_servers[params.port];
- if not http_server then
- http_server = { handlers = {} };
- http_servers[params.port] = http_server;
- -- We weren't already listening on this port, so start now
- connlisteners_start("httpserver", params);
- end
- if params.base then
- http_server.handlers[params.base] = params.handler;
- end
-end
-
-function new_from_config(ports, default_base, handle_request)
- for _, options in ipairs(ports) do
- local port, base, ssl, interface = 5280, default_base, false, nil;
- if type(options) == "number" then
- port = options;
- elseif type(options) == "table" then
- port, base, ssl, interface = options.port or 5280, options.path or default_base, options.ssl or false, options.interface;
- elseif type(options) == "string" then
- base = options;
- end
-
- if ssl then
- ssl.mode = "server";
- ssl.protocol = "sslv23";
- end
-
- new{ port = port, base = base, handler = handle_request, ssl = ssl, type = (ssl and "ssl") or "tcp" }
- end
+function fail()
+ log("error", "Attempt to use legacy HTTP API. For more info see http://prosody.im/doc/developers/legacy_http");
+ log("error", "Legacy HTTP API usage, %s", traceback("", 2));
end
-_M.request_reader = request_reader;
-_M.send_response = send_response;
-_M.urlencode = urlencode;
+new, new_from_config = fail, fail;
+set_default_handler = fail;
return _M;
diff --git a/net/httpserver_listener.lua b/net/httpserver_listener.lua
deleted file mode 100644
index 455191fb..00000000
--- a/net/httpserver_listener.lua
+++ /dev/null
@@ -1,46 +0,0 @@
--- Prosody IM
--- Copyright (C) 2008-2009 Matthew Wild
--- Copyright (C) 2008-2009 Waqas Hussain
---
--- This project is MIT/X11 licensed. Please see the
--- COPYING file in the source package for more information.
---
-
-
-
-local connlisteners_register = require "net.connlisteners".register;
-local new_request = require "net.httpserver".new_request;
-local request_reader = require "net.httpserver".request_reader;
-
-local requests = {}; -- Open requests
-
-local httpserver = { default_port = 80, default_mode = "*a" };
-
-function httpserver.listener(conn, data)
- local request = requests[conn];
-
- if not request then
- request = new_request(conn);
- requests[conn] = request;
-
- -- If using HTTPS, request is secure
- if conn.ssl() then
- request.secure = true;
- end
- end
-
- if data then
- request_reader(request, data);
- end
-end
-
-function httpserver.disconnect(conn, err)
- local request = requests[conn];
- if request and not request.destroyed then
- request.conn = nil;
- request_reader(request, nil);
- end
- requests[conn] = nil;
-end
-
-connlisteners_register("httpserver", httpserver);
diff --git a/net/server.lua b/net/server.lua
index 966006c1..375e7081 100644
--- a/net/server.lua
+++ b/net/server.lua
@@ -1,893 +1,84 @@
---
--- server.lua by blastbeat of the luadch project
--- Re-used here under the MIT/X Consortium License
---
--- Modifications (C) 2008-2009 Matthew Wild, Waqas Hussain
---
-
--- // wrapping luadch stuff // --
-
-local use = function( what )
- return _G[ what ]
-end
-local clean = function( tbl )
- for i, k in pairs( tbl ) do
- tbl[ i ] = nil
- end
-end
-
-local log, table_concat = require ("util.logger").init("socket"), table.concat;
-local out_put = function (...) return log("debug", table_concat{...}); end
-local out_error = function (...) return log("warn", table_concat{...}); end
-local mem_free = collectgarbage
-
-----------------------------------// DECLARATION //--
-
---// constants //--
-
-local STAT_UNIT = 1 -- byte
-
---// lua functions //--
-
-local type = use "type"
-local pairs = use "pairs"
-local ipairs = use "ipairs"
-local tostring = use "tostring"
-local collectgarbage = use "collectgarbage"
-
---// lua libs //--
-
-local os = use "os"
-local table = use "table"
-local string = use "string"
-local coroutine = use "coroutine"
-
---// lua lib methods //--
-
-local os_time = os.time
-local os_difftime = os.difftime
-local table_concat = table.concat
-local table_remove = table.remove
-local string_len = string.len
-local string_sub = string.sub
-local coroutine_wrap = coroutine.wrap
-local coroutine_yield = coroutine.yield
-
---// extern libs //--
-
-local luasec = select( 2, pcall( require, "ssl" ) )
-local luasocket = require "socket"
-
---// extern lib methods //--
-
-local ssl_wrap = ( luasec and luasec.wrap )
-local socket_bind = luasocket.bind
-local socket_sleep = luasocket.sleep
-local socket_select = luasocket.select
-local ssl_newcontext = ( luasec and luasec.newcontext )
-
---// functions //--
-
-local id
-local loop
-local stats
-local idfalse
-local addtimer
-local closeall
-local addserver
-local getserver
-local wrapserver
-local getsettings
-local closesocket
-local removesocket
-local removeserver
-local changetimeout
-local wrapconnection
-local changesettings
-
---// tables //--
-
-local _server
-local _readlist
-local _timerlist
-local _sendlist
-local _socketlist
-local _closelist
-local _readtimes
-local _writetimes
-
---// simple data types //--
-
-local _
-local _readlistlen
-local _sendlistlen
-local _timerlistlen
-
-local _sendtraffic
-local _readtraffic
-
-local _selecttimeout
-local _sleeptime
-
-local _starttime
-local _currenttime
-
-local _maxsendlen
-local _maxreadlen
-
-local _checkinterval
-local _sendtimeout
-local _readtimeout
-
-local _cleanqueue
-
-local _timer
-
-local _maxclientsperserver
-
-----------------------------------// DEFINITION //--
-
-_server = { } -- key = port, value = table; list of listening servers
-_readlist = { } -- array with sockets to read from
-_sendlist = { } -- arrary with sockets to write to
-_timerlist = { } -- array of timer functions
-_socketlist = { } -- key = socket, value = wrapped socket (handlers)
-_readtimes = { } -- key = handler, value = timestamp of last data reading
-_writetimes = { } -- key = handler, value = timestamp of last data writing/sending
-_closelist = { } -- handlers to close
-
-_readlistlen = 0 -- length of readlist
-_sendlistlen = 0 -- length of sendlist
-_timerlistlen = 0 -- lenght of timerlist
-
-_sendtraffic = 0 -- some stats
-_readtraffic = 0
-
-_selecttimeout = 1 -- timeout of socket.select
-_sleeptime = 0 -- time to wait at the end of every loop
-
-_maxsendlen = 51000 * 1024 -- max len of send buffer
-_maxreadlen = 25000 * 1024 -- max len of read buffer
-
-_checkinterval = 1200000 -- interval in secs to check idle clients
-_sendtimeout = 60000 -- allowed send idle time in secs
-_readtimeout = 6 * 60 * 60 -- allowed read idle time in secs
-
-_cleanqueue = false -- clean bufferqueue after using
-
-_maxclientsperserver = 1000
-
-----------------------------------// PRIVATE //--
-
-wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxconnections, startssl ) -- this function wraps a server
-
- maxconnections = maxconnections or _maxclientsperserver
-
- local connections = 0
-
- local dispatch, disconnect = listeners.incoming or listeners.listener, listeners.disconnect
-
- local err
-
- local ssl = false
-
- if sslctx then
- ssl = true
- if not ssl_newcontext then
- out_error "luasec not found"
- ssl = false
- end
- if type( sslctx ) ~= "table" then
- out_error "server.lua: wrong server sslctx"
- ssl = false
- end
- local ctx;
- ctx, err = ssl_newcontext( sslctx )
- if not ctx then
- err = err or "wrong sslctx parameters"
- local file;
- file = err:match("^error loading (.-) %(");
- if file then
- if file == "private key" then
- file = sslctx.key or "your private key";
- elseif file == "certificate" then
- file = sslctx.certificate or "your certificate file";
- end
- local reason = err:match("%((.+)%)$") or "some reason";
- if reason == "Permission denied" then
- reason = "Check that the permissions allow Prosody to read this file.";
- elseif reason == "No such file or directory" then
- reason = "Check that the path is correct, and the file exists.";
- elseif reason == "system lib" then
- reason = "Previous error (see logs), or other system error.";
- else
- reason = "Reason: "..tostring(reason or "unknown"):lower();
- end
- log("error", "SSL/TLS: Failed to load %s: %s", file, reason);
- else
- log("error", "SSL/TLS: Error initialising for port %d: %s", serverport, err );
- end
- ssl = false
- end
- sslctx = ctx;
- end
- if not ssl then
- sslctx = false;
- if startssl then
- log("error", "Failed to listen on port %d due to SSL/TLS to SSL/TLS initialisation errors (see logs)", serverport )
- return nil, "Cannot start ssl, see log for details"
- end
- end
-
- local accept = socket.accept
-
- --// public methods of the object //--
-
- local handler = { }
-
- handler.shutdown = function( ) end
-
- handler.ssl = function( )
- return ssl
- end
- handler.remove = function( )
- connections = connections - 1
- end
- handler.close = function( )
- for _, handler in pairs( _socketlist ) do
- if handler.serverport == serverport then
- handler.disconnect( handler, "server closed" )
- handler.close( true )
- end
- end
- socket:close( )
- _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
- _readlistlen = removesocket( _readlist, socket, _readlistlen )
- _socketlist[ socket ] = nil
- handler = nil
- socket = nil
- mem_free( )
- out_put "server.lua: closed server handler and removed sockets from list"
- end
- handler.ip = function( )
- return ip
- end
- handler.serverport = function( )
- return serverport
- end
- handler.socket = function( )
- return socket
- end
- handler.readbuffer = function( )
- if connections > maxconnections then
- out_put( "server.lua: refused new client connection: server full" )
- return false
- end
- local client, err = accept( socket ) -- try to accept
- if client then
- local ip, clientport = client:getpeername( )
- client:settimeout( 0 )
- local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx, startssl ) -- wrap new client socket
- if err then -- error while wrapping ssl socket
- return false
- end
- connections = connections + 1
- out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))
- return dispatch( handler )
- elseif err then -- maybe timeout or something else
- out_put( "server.lua: error with new client connection: ", tostring(err) )
- return false
- end
- end
- return handler
-end
-
-wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, startssl ) -- this function wraps a client to a handler object
-
- socket:settimeout( 0 )
-
- --// local import of socket methods //--
-
- local send
- local receive
- local shutdown
-
- --// private closures of the object //--
-
- local ssl
-
- local dispatch = listeners.incoming or listeners.listener
- local disconnect = listeners.disconnect
-
- local bufferqueue = { } -- buffer array
- local bufferqueuelen = 0 -- end of buffer array
-
- local toclose
- local fatalerror
- local needtls
-
- local bufferlen = 0
-
- local noread = false
- local nosend = false
-
- local sendtraffic, readtraffic = 0, 0
-
- local maxsendlen = _maxsendlen
- local maxreadlen = _maxreadlen
-
- --// public methods of the object //--
-
- local handler = bufferqueue -- saves a table ^_^
-
- handler.dispatch = function( )
- return dispatch
- end
- handler.disconnect = function( )
- return disconnect
- end
- handler.setlistener = function( listeners )
- dispatch = listeners.incoming
- disconnect = listeners.disconnect
- end
- handler.getstats = function( )
- return readtraffic, sendtraffic
- end
- handler.ssl = function( )
- return ssl
- end
- handler.send = function( _, data, i, j )
- return send( socket, data, i, j )
- end
- handler.receive = function( pattern, prefix )
- return receive( socket, pattern, prefix )
- end
- handler.shutdown = function( pattern )
- return shutdown( socket, pattern )
- end
- handler.close = function( forced )
- if not handler then return true; end
- _readlistlen = removesocket( _readlist, socket, _readlistlen )
- _readtimes[ handler ] = nil
- if bufferqueuelen ~= 0 then
- if not ( forced or fatalerror ) then
- handler.sendbuffer( )
- if bufferqueuelen ~= 0 then -- try again...
- if handler then
- handler.write = nil -- ... but no further writing allowed
- end
- toclose = true
- return false
- end
- else
- send( socket, table_concat( bufferqueue, "", 1, bufferqueuelen ), 1, bufferlen ) -- forced send
- end
- end
- _ = shutdown and shutdown( socket )
- socket:close( )
- _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
- _socketlist[ socket ] = nil
- if handler then
- _writetimes[ handler ] = nil
- _closelist[ handler ] = nil
- handler = nil
- end
- socket = nil
- mem_free( )
- if server then
- server.remove( )
- end
- out_put "server.lua: closed client handler and removed socket from list"
- return true
- end
- handler.ip = function( )
- return ip
- end
- handler.serverport = function( )
- return serverport
- end
- handler.clientport = function( )
- return clientport
- end
- local write = function( data )
- bufferlen = bufferlen + string_len( data )
- if bufferlen > maxsendlen then
- _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle
- handler.write = idfalse -- dont write anymore
- return false
- elseif socket and not _sendlist[ socket ] then
- _sendlistlen = _sendlistlen + 1
- _sendlist[ _sendlistlen ] = socket
- _sendlist[ socket ] = _sendlistlen
- end
- bufferqueuelen = bufferqueuelen + 1
- bufferqueue[ bufferqueuelen ] = data
- if handler then
- _writetimes[ handler ] = _writetimes[ handler ] or _currenttime
- end
- return true
- end
- handler.write = write
- handler.bufferqueue = function( )
- return bufferqueue
- end
- handler.socket = function( )
- return socket
- end
- handler.pattern = function( new )
- pattern = new or pattern
- return pattern
- end
- handler.setsend = function ( newsend )
- send = newsend or send
- return send
- end
- handler.bufferlen = function( readlen, sendlen )
- maxsendlen = sendlen or maxsendlen
- maxreadlen = readlen or maxreadlen
- return maxreadlen, maxsendlen
- end
- handler.lock = function( switch )
- if switch == true then
- handler.write = idfalse
- local tmp = _sendlistlen
- _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
- _writetimes[ handler ] = nil
- if _sendlistlen ~= tmp then
- nosend = true
- end
- tmp = _readlistlen
- _readlistlen = removesocket( _readlist, socket, _readlistlen )
- _readtimes[ handler ] = nil
- if _readlistlen ~= tmp then
- noread = true
- end
- elseif switch == false then
- handler.write = write
- if noread then
- noread = false
- _readlistlen = _readlistlen + 1
- _readlist[ socket ] = _readlistlen
- _readlist[ _readlistlen ] = socket
- _readtimes[ handler ] = _currenttime
- end
- if nosend then
- nosend = false
- write( "" )
- end
- end
- return noread, nosend
- end
- local _readbuffer = function( ) -- this function reads data
- local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern"
- if not err or ( err == "timeout" or err == "wantread" ) then -- received something
- local buffer = buffer or part or ""
- local len = string_len( buffer )
- if len > maxreadlen then
- disconnect( handler, "receive buffer exceeded" )
- handler.close( true )
- return false
- end
- local count = len * STAT_UNIT
- readtraffic = readtraffic + count
- _readtraffic = _readtraffic + count
- _readtimes[ handler ] = _currenttime
- --out_put( "server.lua: read data '", buffer, "', error: ", err )
- return dispatch( handler, buffer, err )
- else -- connections was closed or fatal error
- out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " error: ", tostring(err) )
- fatalerror = true
- disconnect( handler, err )
- _ = handler and handler.close( )
- return false
- end
- end
- local _sendbuffer = function( ) -- this function sends data
- local buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )
- local succ, err, byte = send( socket, buffer, 1, bufferlen )
- local count = ( succ or byte or 0 ) * STAT_UNIT
- sendtraffic = sendtraffic + count
- _sendtraffic = _sendtraffic + count
- _ = _cleanqueue and clean( bufferqueue )
- --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) )
- if succ then -- sending succesful
- bufferqueuelen = 0
- bufferlen = 0
- _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist
- _ = needtls and handler.starttls(true)
- _writetimes[ handler ] = nil
- _ = toclose and handler.close( )
- return true
- elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write
- buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer
- bufferqueue[ 1 ] = buffer -- insert new buffer in queue
- bufferqueuelen = 1
- bufferlen = bufferlen - byte
- _writetimes[ handler ] = _currenttime
- return true
- else -- connection was closed during sending or fatal error
- out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " error: ", tostring(err) )
- fatalerror = true
- disconnect( handler, err )
- _ = handler and handler.close( )
- return false
- end
- end
-
- if sslctx then -- ssl?
- ssl = true
- local wrote
- local read
- local handshake = coroutine_wrap( function( client ) -- create handshake coroutine
- local err
- for i = 1, 10 do -- 10 handshake attemps
- _sendlistlen = ( wrote and removesocket( _sendlist, socket, _sendlistlen ) ) or _sendlistlen
- _readlistlen = ( read and removesocket( _readlist, socket, _readlistlen ) ) or _readlistlen
- read, wrote = nil, nil
- _, err = client:dohandshake( )
- if not err then
- out_put( "server.lua: ssl handshake done" )
- handler.readbuffer = _readbuffer -- when handshake is done, replace the handshake function with regular functions
- handler.sendbuffer = _sendbuffer
- -- return dispatch( handler )
- return true
- else
- out_put( "server.lua: error during ssl handshake: ", tostring(err) )
- if err == "wantwrite" and not wrote then
- _sendlistlen = _sendlistlen + 1
- _sendlist[ _sendlistlen ] = client
- wrote = true
- elseif err == "wantread" and not read then
- _readlistlen = _readlistlen + 1
- _readlist [ _readlistlen ] = client
- read = true
- else
- break;
- end
- --coroutine_yield( handler, nil, err ) -- handshake not finished
- coroutine_yield( )
- end
- end
- disconnect( handler, "ssl handshake failed" )
- _ = handler and handler.close( true ) -- forced disconnect
- return false -- handshake failed
- end
- )
- if startssl then -- ssl now?
- --out_put("server.lua: ", "starting ssl handshake")
- local err
- socket, err = ssl_wrap( socket, sslctx ) -- wrap socket
- if err then
- out_put( "server.lua: ssl error: ", tostring(err) )
- mem_free( )
- return nil, nil, err -- fatal error
- end
- socket:settimeout( 0 )
- handler.readbuffer = handshake
- handler.sendbuffer = handshake
- handshake( socket ) -- do handshake
- if not socket then
- return nil, nil, "ssl handshake failed";
- end
- else
- -- We're not automatically doing SSL, so we're not secure (yet)
- ssl = false
- handler.starttls = function( now )
- if not now then
- --out_put "server.lua: we need to do tls, but delaying until later"
- needtls = true
- return
- end
- --out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
- local oldsocket, err = socket
- socket, err = ssl_wrap( socket, sslctx ) -- wrap socket
- --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) )
- if err then
- out_put( "server.lua: error while starting tls on client: ", tostring(err) )
- return nil, err -- fatal error
- end
-
- socket:settimeout( 0 )
-
- -- add the new socket to our system
-
- send = socket.send
- receive = socket.receive
- shutdown = id
-
- _socketlist[ socket ] = handler
- _readlistlen = _readlistlen + 1
- _readlist[ _readlistlen ] = socket
- _readlist[ socket ] = _readlistlen
-
- -- remove traces of the old socket
-
- _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
- _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
- _socketlist[ oldsocket ] = nil
-
- handler.starttls = nil
- needtls = nil
-
- -- Secure now
- ssl = true
-
- handler.readbuffer = handshake
- handler.sendbuffer = handshake
- handshake( socket ) -- do handshake
- end
- handler.readbuffer = _readbuffer
- handler.sendbuffer = _sendbuffer
- end
- else -- normal connection
- ssl = false
- handler.readbuffer = _readbuffer
- handler.sendbuffer = _sendbuffer
- end
-
- send = socket.send
- receive = socket.receive
- shutdown = ( ssl and id ) or socket.shutdown
-
- _socketlist[ socket ] = handler
- _readlistlen = _readlistlen + 1
- _readlist[ _readlistlen ] = socket
- _readlist[ socket ] = _readlistlen
-
- return handler, socket
-end
-
-id = function( )
-end
-
-idfalse = function( )
- return false
-end
-
-removesocket = function( list, socket, len ) -- this function removes sockets from a list ( copied from copas )
- local pos = list[ socket ]
- if pos then
- list[ socket ] = nil
- local last = list[ len ]
- list[ len ] = nil
- if last ~= socket then
- list[ last ] = pos
- list[ pos ] = last
- end
- return len - 1
- end
- return len
-end
-
-closesocket = function( socket )
- _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
- _readlistlen = removesocket( _readlist, socket, _readlistlen )
- _socketlist[ socket ] = nil
- socket:close( )
- mem_free( )
-end
-
-----------------------------------// PUBLIC //--
-
-addserver = function( listeners, port, addr, pattern, sslctx, maxconnections, startssl ) -- this function provides a way for other scripts to reg a server
- local err
- --out_put("server.lua: autossl on ", port, " is ", startssl)
- if type( listeners ) ~= "table" then
- err = "invalid listener table"
- end
- if not type( port ) == "number" or not ( port >= 0 and port <= 65535 ) then
- err = "invalid port"
- elseif _server[ port ] then
- err = "listeners on port '" .. port .. "' already exist"
- elseif sslctx and not luasec then
- err = "luasec not found"
- end
- if err then
- out_error( "server.lua, port ", port, ": ", err )
- return nil, err
- end
- addr = addr or "*"
- local server, err = socket_bind( addr, port )
- if err then
- out_error( "server.lua, port ", port, ": ", err )
- return nil, err
- end
- local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, maxconnections, startssl ) -- wrap new server socket
- if not handler then
- server:close( )
- return nil, err
- end
- server:settimeout( 0 )
- _readlistlen = _readlistlen + 1
- _readlist[ _readlistlen ] = server
- _server[ port ] = handler
- _socketlist[ server ] = handler
- out_put( "server.lua: new server listener on '", addr, ":", port, "'" )
- return handler
-end
-
-getserver = function ( port )
- return _server[ port ];
-end
-
-removeserver = function( port )
- local handler = _server[ port ]
- if not handler then
- return nil, "no server found on port '" .. tostring( port ) "'"
- end
- handler.close( )
- _server[ port ] = nil
- return true
-end
-
-closeall = function( )
- for _, handler in pairs( _socketlist ) do
- handler.close( )
- _socketlist[ _ ] = nil
- end
- _readlistlen = 0
- _sendlistlen = 0
- _timerlistlen = 0
- _server = { }
- _readlist = { }
- _sendlist = { }
- _timerlist = { }
- _socketlist = { }
- mem_free( )
-end
-
-getsettings = function( )
- return _selecttimeout, _sleeptime, _maxsendlen, _maxreadlen, _checkinterval, _sendtimeout, _readtimeout, _cleanqueue, _maxclientsperserver
-end
-
-changesettings = function( new )
- if type( new ) ~= "table" then
- return nil, "invalid settings table"
- end
- _selecttimeout = tonumber( new.timeout ) or _selecttimeout
- _sleeptime = tonumber( new.sleeptime ) or _sleeptime
- _maxsendlen = tonumber( new.maxsendlen ) or _maxsendlen
- _maxreadlen = tonumber( new.maxreadlen ) or _maxreadlen
- _checkinterval = tonumber( new.checkinterval ) or _checkinterval
- _sendtimeout = tonumber( new.sendtimeout ) or _sendtimeout
- _readtimeout = tonumber( new.readtimeout ) or _readtimeout
- _cleanqueue = new.cleanqueue
- _maxclientsperserver = new._maxclientsperserver or _maxclientsperserver
- return true
-end
-
-addtimer = function( listener )
- if type( listener ) ~= "function" then
- return nil, "invalid listener function"
- end
- _timerlistlen = _timerlistlen + 1
- _timerlist[ _timerlistlen ] = listener
- return true
-end
-
-stats = function( )
- return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen
-end
-
-local dontstop = true; -- thinking about tomorrow, ...
-
-setquitting = function (quit)
- dontstop = not quit;
- return;
-end
-
-loop = function( ) -- this is the main loop of the program
- while dontstop do
- local read, write, err = socket_select( _readlist, _sendlist, _selecttimeout )
- for i, socket in ipairs( write ) do -- send data waiting in writequeues
- local handler = _socketlist[ socket ]
- if handler then
- handler.sendbuffer( )
- else
- closesocket( socket )
- out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen
- end
- end
- for i, socket in ipairs( read ) do -- receive data
- local handler = _socketlist[ socket ]
- if handler then
- handler.readbuffer( )
- else
- closesocket( socket )
- out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen
- end
- end
- for handler, err in pairs( _closelist ) do
- handler.disconnect( )( handler, err )
- handler.close( true ) -- forced disconnect
- end
- clean( _closelist )
- _currenttime = os_time( )
- if os_difftime( _currenttime - _timer ) >= 1 then
- for i = 1, _timerlistlen do
- _timerlist[ i ]( ) -- fire timers
- end
- _timer = _currenttime
- end
- socket_sleep( _sleeptime ) -- wait some time
- --collectgarbage( )
- end
- return "quitting"
-end
-
---// EXPERIMENTAL //--
-
-local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx, startssl )
- local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, startssl )
- _socketlist[ socket ] = handler
- _sendlistlen = _sendlistlen + 1
- _sendlist[ _sendlistlen ] = socket
- _sendlist[ socket ] = _sendlistlen
- return handler, socket
-end
-
-local addclient = function( address, port, listeners, pattern, sslctx, startssl )
- local client, err = luasocket.tcp( )
- if err then
- return nil, err
- end
- client:settimeout( 0 )
- _, err = client:connect( address, port )
- if err then -- try again
- local handler = wrapclient( client, address, port, listeners )
- else
- wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx, startssl )
- end
-end
-
---// EXPERIMENTAL //--
-
-----------------------------------// BEGIN //--
-
-use "setmetatable" ( _socketlist, { __mode = "k" } )
-use "setmetatable" ( _readtimes, { __mode = "k" } )
-use "setmetatable" ( _writetimes, { __mode = "k" } )
-
-_timer = os_time( )
-_starttime = os_time( )
-
-addtimer( function( )
- local difftime = os_difftime( _currenttime - _starttime )
- if difftime > _checkinterval then
- _starttime = _currenttime
- for handler, timestamp in pairs( _writetimes ) do
- if os_difftime( _currenttime - timestamp ) > _sendtimeout then
- --_writetimes[ handler ] = nil
- handler.disconnect( )( handler, "send timeout" )
- handler.close( true ) -- forced disconnect
- end
- end
- for handler, timestamp in pairs( _readtimes ) do
- if os_difftime( _currenttime - timestamp ) > _readtimeout then
- --_readtimes[ handler ] = nil
- handler.disconnect( )( handler, "read timeout" )
- handler.close( ) -- forced disconnect?
- end
- end
- end
- end
-)
-
-----------------------------------// PUBLIC INTERFACE //--
-
-return {
-
- addclient = addclient,
- wrapclient = wrapclient,
-
- loop = loop,
- stats = stats,
- closeall = closeall,
- addtimer = addtimer,
- addserver = addserver,
- getserver = getserver,
- getsettings = getsettings,
- setquitting = setquitting,
- removeserver = removeserver,
- changesettings = changesettings,
-}
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local use_luaevent = prosody and require "core.configmanager".get("*", "use_libevent");
+
+if use_luaevent then
+ use_luaevent = pcall(require, "luaevent.core");
+ if not use_luaevent then
+ log("error", "libevent not found, falling back to select()");
+ end
+end
+
+local server;
+
+if use_luaevent then
+ server = require "net.server_event";
+
+ -- Overwrite signal.signal() because we need to ask libevent to
+ -- handle them instead
+ local ok, signal = pcall(require, "util.signal");
+ if ok and signal then
+ local _signal_signal = signal.signal;
+ function signal.signal(signal_id, handler)
+ if type(signal_id) == "string" then
+ signal_id = signal[signal_id:upper()];
+ end
+ if type(signal_id) ~= "number" then
+ return false, "invalid-signal";
+ end
+ return server.hook_signal(signal_id, handler);
+ end
+ end
+else
+ use_luaevent = false;
+ server = require "net.server_select";
+end
+
+if prosody then
+ local config_get = require "core.configmanager".get;
+ local defaults = {};
+ for k,v in pairs(server.cfg or server.getsettings()) do
+ defaults[k] = v;
+ end
+ local function load_config()
+ local settings = config_get("*", "network_settings") or {};
+ if use_luaevent then
+ local event_settings = {
+ ACCEPT_DELAY = settings.event_accept_retry_interval;
+ ACCEPT_QUEUE = settings.tcp_backlog;
+ CLEAR_DELAY = settings.event_clear_interval;
+ CONNECT_TIMEOUT = settings.connect_timeout;
+ DEBUG = settings.debug;
+ HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout;
+ MAX_CONNECTIONS = settings.max_connections;
+ MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips;
+ MAX_READ_LENGTH = settings.max_receive_buffer_size;
+ MAX_SEND_LENGTH = settings.max_send_buffer_size;
+ READ_TIMEOUT = settings.read_timeout;
+ WRITE_TIMEOUT = settings.send_timeout;
+ };
+
+ for k,default in pairs(defaults) do
+ server.cfg[k] = event_settings[k] or default;
+ end
+ else
+ local select_settings = {};
+ for k,default in pairs(defaults) do
+ select_settings[k] = settings[k] or default;
+ end
+ server.changesettings(select_settings);
+ end
+ end
+ load_config();
+ prosody.events.add_handler("config-reloaded", load_config);
+end
+
+-- require "net.server" shall now forever return this,
+-- ie. server_select or server_event as chosen above.
+return server;
diff --git a/net/server_event.lua b/net/server_event.lua
new file mode 100644
index 00000000..5eae95a9
--- /dev/null
+++ b/net/server_event.lua
@@ -0,0 +1,872 @@
+--[[
+
+
+ server.lua based on lua/libevent by blastbeat
+
+ notes:
+ -- when using luaevent, never register 2 or more EV_READ at one socket, same for EV_WRITE
+ -- you cant even register a new EV_READ/EV_WRITE callback inside another one
+ -- to do some of the above, use timeout events or something what will called from outside
+ -- dont let garbagecollect eventcallbacks, as long they are running
+ -- when using luasec, there are 4 cases of timeout errors: wantread or wantwrite during reading or writing
+
+--]]
+
+local SCRIPT_NAME = "server_event.lua"
+local SCRIPT_VERSION = "0.05"
+local SCRIPT_AUTHOR = "blastbeat"
+local LAST_MODIFIED = "2009/11/20"
+
+local cfg = {
+ MAX_CONNECTIONS = 100000, -- max per server connections (use "ulimit -n" on *nix)
+ MAX_HANDSHAKE_ATTEMPTS= 1000, -- attempts to finish ssl handshake
+ HANDSHAKE_TIMEOUT = 60, -- timeout in seconds per handshake attempt
+ MAX_READ_LENGTH = 1024 * 1024 * 1024 * 1024, -- max bytes allowed to read from sockets
+ MAX_SEND_LENGTH = 1024 * 1024 * 1024 * 1024, -- max bytes size of write buffer (for writing on sockets)
+ ACCEPT_QUEUE = 128, -- might influence the length of the pending sockets queue
+ ACCEPT_DELAY = 10, -- seconds to wait until the next attempt of a full server to accept
+ READ_TIMEOUT = 60 * 60 * 6, -- timeout in seconds for read data from socket
+ WRITE_TIMEOUT = 180, -- timeout in seconds for write data on socket
+ CONNECT_TIMEOUT = 20, -- timeout in seconds for connection attempts
+ CLEAR_DELAY = 5, -- seconds to wait for clearing interface list (and calling ondisconnect listeners)
+ DEBUG = true, -- show debug messages
+}
+
+local function use(x) return rawget(_G, x); end
+local ipairs = use "ipairs"
+local string = use "string"
+local select = use "select"
+local require = use "require"
+local tostring = use "tostring"
+local coroutine = use "coroutine"
+local setmetatable = use "setmetatable"
+
+local t_insert = table.insert
+local t_concat = table.concat
+
+local ssl = use "ssl"
+local socket = use "socket" or require "socket"
+
+local log = require ("util.logger").init("socket")
+
+local function debug(...)
+ return log("debug", ("%s "):rep(select('#', ...)), ...)
+end
+local vdebug = debug;
+
+local bitor = ( function( ) -- thx Rici Lake
+ local hasbit = function( x, p )
+ return x % ( p + p ) >= p
+ end
+ return function( x, y )
+ local p = 1
+ local z = 0
+ local limit = x > y and x or y
+ while p <= limit do
+ if hasbit( x, p ) or hasbit( y, p ) then
+ z = z + p
+ end
+ p = p + p
+ end
+ return z
+ end
+end )( )
+
+local event = require "luaevent.core"
+local base = event.new( )
+local EV_READ = event.EV_READ
+local EV_WRITE = event.EV_WRITE
+local EV_TIMEOUT = event.EV_TIMEOUT
+local EV_SIGNAL = event.EV_SIGNAL
+
+local EV_READWRITE = bitor( EV_READ, EV_WRITE )
+
+local interfacelist = ( function( ) -- holds the interfaces for sockets
+ local array = { }
+ local len = 0
+ return function( method, arg )
+ if "add" == method then
+ len = len + 1
+ array[ len ] = arg
+ arg:_position( len )
+ return len
+ elseif "delete" == method then
+ if len <= 0 then
+ return nil, "array is already empty"
+ end
+ local position = arg:_position() -- get position in array
+ if position ~= len then
+ local interface = array[ len ] -- get last interface
+ array[ position ] = interface -- copy it into free position
+ array[ len ] = nil -- free last position
+ interface:_position( position ) -- set new position in array
+ else -- free last position
+ array[ len ] = nil
+ end
+ len = len - 1
+ return len
+ else
+ return array
+ end
+ end
+end )( )
+
+-- Client interface methods
+local interface_mt
+do
+ interface_mt = {}; interface_mt.__index = interface_mt;
+
+ local addevent = base.addevent
+ local coroutine_wrap, coroutine_yield = coroutine.wrap,coroutine.yield
+
+ -- Private methods
+ function interface_mt:_position(new_position)
+ self.position = new_position or self.position
+ return self.position;
+ end
+ function interface_mt:_close()
+ return self:_destroy();
+ end
+
+ function interface_mt:_start_connection(plainssl) -- should be called from addclient
+ local callback = function( event )
+ if EV_TIMEOUT == event then -- timeout during connection
+ self.fatalerror = "connection timeout"
+ self:ontimeout() -- call timeout listener
+ self:_close()
+ debug( "new connection failed. id:", self.id, "error:", self.fatalerror )
+ else
+ if plainssl and ssl then -- start ssl session
+ self:starttls(self._sslctx, true)
+ else -- normal connection
+ self:_start_session(true)
+ end
+ debug( "new connection established. id:", self.id )
+ end
+ self.eventconnect = nil
+ return -1
+ end
+ self.eventconnect = addevent( base, self.conn, EV_WRITE, callback, cfg.CONNECT_TIMEOUT )
+ return true
+ end
+ function interface_mt:_start_session(call_onconnect) -- new session, for example after startssl
+ if self.type == "client" then
+ local callback = function( )
+ self:_lock( false, false, false )
+ --vdebug( "start listening on client socket with id:", self.id )
+ self.eventread = addevent( base, self.conn, EV_READ, self.readcallback, cfg.READ_TIMEOUT ); -- register callback
+ if call_onconnect then
+ self:onconnect()
+ end
+ self.eventsession = nil
+ return -1
+ end
+ self.eventsession = addevent( base, nil, EV_TIMEOUT, callback, 0 )
+ else
+ self:_lock( false )
+ --vdebug( "start listening on server socket with id:", self.id )
+ self.eventread = addevent( base, self.conn, EV_READ, self.readcallback ) -- register callback
+ end
+ return true
+ end
+ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed, therefore we have to close read/write events first
+ --vdebug( "starting ssl session with client id:", self.id )
+ local _
+ _ = self.eventread and self.eventread:close( ) -- close events; this must be called outside of the event callbacks!
+ _ = self.eventwrite and self.eventwrite:close( )
+ self.eventread, self.eventwrite = nil, nil
+ local err
+ self.conn, err = ssl.wrap( self.conn, self._sslctx )
+ if err then
+ self.fatalerror = err
+ self.conn = nil -- cannot be used anymore
+ if call_onconnect then
+ self.ondisconnect = nil -- dont call this when client isnt really connected
+ end
+ self:_close()
+ debug( "fatal error while ssl wrapping:", err )
+ return false
+ end
+ self.conn:settimeout( 0 ) -- set non blocking
+ local handshakecallback = coroutine_wrap(
+ function( event )
+ local _, err
+ local attempt = 0
+ local maxattempt = cfg.MAX_HANDSHAKE_ATTEMPTS
+ while attempt < maxattempt do -- no endless loop
+ attempt = attempt + 1
+ debug( "ssl handshake of client with id:"..tostring(self)..", attempt:"..attempt )
+ if attempt > maxattempt then
+ self.fatalerror = "max handshake attempts exceeded"
+ elseif EV_TIMEOUT == event then
+ self.fatalerror = "timeout during handshake"
+ else
+ _, err = self.conn:dohandshake( )
+ if not err then
+ self:_lock( false, false, false ) -- unlock the interface; sending, closing etc allowed
+ self.send = self.conn.send -- caching table lookups with new client object
+ self.receive = self.conn.receive
+ if not call_onconnect then -- trigger listener
+ self:onstatus("ssl-handshake-complete");
+ end
+ self:_start_session( call_onconnect )
+ debug( "ssl handshake done" )
+ self.eventhandshake = nil
+ return -1
+ end
+ if err == "wantwrite" then
+ event = EV_WRITE
+ elseif err == "wantread" then
+ event = EV_READ
+ else
+ debug( "ssl handshake error:", err )
+ self.fatalerror = err
+ end
+ end
+ if self.fatalerror then
+ if call_onconnect then
+ self.ondisconnect = nil -- dont call this when client isnt really connected
+ end
+ self:_close()
+ debug( "handshake failed because:", self.fatalerror )
+ self.eventhandshake = nil
+ return -1
+ end
+ event = coroutine_yield( event, cfg.HANDSHAKE_TIMEOUT ) -- yield this monster...
+ end
+ end
+ )
+ debug "starting handshake..."
+ self:_lock( false, true, true ) -- unlock read/write events, but keep interface locked
+ self.eventhandshake = addevent( base, self.conn, EV_READWRITE, handshakecallback, cfg.HANDSHAKE_TIMEOUT )
+ return true
+ end
+ function interface_mt:_destroy() -- close this interface + events and call last listener
+ debug( "closing client with id:", self.id, self.fatalerror )
+ self:_lock( true, true, true ) -- first of all, lock the interface to avoid further actions
+ local _
+ _ = self.eventread and self.eventread:close( )
+ if self.type == "client" then
+ _ = self.eventwrite and self.eventwrite:close( )
+ _ = self.eventhandshake and self.eventhandshake:close( )
+ _ = self.eventstarthandshake and self.eventstarthandshake:close( )
+ _ = self.eventconnect and self.eventconnect:close( )
+ _ = self.eventsession and self.eventsession:close( )
+ _ = self.eventwritetimeout and self.eventwritetimeout:close( )
+ _ = self.eventreadtimeout and self.eventreadtimeout:close( )
+ _ = self.ondisconnect and self:ondisconnect( self.fatalerror ~= "client to close" and self.fatalerror) -- call ondisconnect listener (wont be the case if handshake failed on connect)
+ _ = self.conn and self.conn:close( ) -- close connection
+ _ = self._server and self._server:counter(-1);
+ self.eventread, self.eventwrite = nil, nil
+ self.eventstarthandshake, self.eventhandshake, self.eventclose = nil, nil, nil
+ self.readcallback, self.writecallback = nil, nil
+ else
+ self.conn:close( )
+ self.eventread, self.eventclose = nil, nil
+ self.interface, self.readcallback = nil, nil
+ end
+ interfacelist( "delete", self )
+ return true
+ end
+
+ function interface_mt:_lock(nointerface, noreading, nowriting) -- lock or unlock this interface or events
+ self.nointerface, self.noreading, self.nowriting = nointerface, noreading, nowriting
+ return nointerface, noreading, nowriting
+ end
+
+ --TODO: Deprecate
+ function interface_mt:lock_read(switch)
+ if switch then
+ return self:pause();
+ else
+ return self:resume();
+ end
+ end
+
+ function interface_mt:pause()
+ return self:_lock(self.nointerface, true, self.nowriting);
+ end
+
+ function interface_mt:resume()
+ self:_lock(self.nointerface, false, self.nowriting);
+ if not self.eventread then
+ self.eventread = addevent( base, self.conn, EV_READ, self.readcallback, cfg.READ_TIMEOUT ); -- register callback
+ end
+ end
+
+ function interface_mt:counter(c)
+ if c then
+ self._connections = self._connections + c
+ end
+ return self._connections
+ end
+
+ -- Public methods
+ function interface_mt:write(data)
+ if self.nowriting then return nil, "locked" end
+ --vdebug( "try to send data to client, id/data:", self.id, data )
+ data = tostring( data )
+ local len = #data
+ local total = len + self.writebufferlen
+ if total > cfg.MAX_SEND_LENGTH then -- check buffer length
+ local err = "send buffer exceeded"
+ debug( "error:", err ) -- to much, check your app
+ return nil, err
+ end
+ t_insert(self.writebuffer, data) -- new buffer
+ self.writebufferlen = total
+ if not self.eventwrite then -- register new write event
+ --vdebug( "register new write event" )
+ self.eventwrite = addevent( base, self.conn, EV_WRITE, self.writecallback, cfg.WRITE_TIMEOUT )
+ end
+ return true
+ end
+ function interface_mt:close()
+ if self.nointerface then return nil, "locked"; end
+ debug( "try to close client connection with id:", self.id )
+ if self.type == "client" then
+ self.fatalerror = "client to close"
+ if self.eventwrite then -- wait for incomplete write request
+ self:_lock( true, true, false )
+ debug "closing delayed until writebuffer is empty"
+ return nil, "writebuffer not empty, waiting"
+ else -- close now
+ self:_lock( true, true, true )
+ self:_close()
+ return true
+ end
+ else
+ debug( "try to close server with id:", tostring(self.id))
+ self.fatalerror = "server to close"
+ self:_lock( true )
+ self:_close( 0 )
+ return true
+ end
+ end
+
+ function interface_mt:socket()
+ return self.conn
+ end
+
+ function interface_mt:server()
+ return self._server or self;
+ end
+
+ function interface_mt:port()
+ return self._port
+ end
+
+ function interface_mt:serverport()
+ return self._serverport
+ end
+
+ function interface_mt:ip()
+ return self._ip
+ end
+
+ function interface_mt:ssl()
+ return self._usingssl
+ end
+
+ function interface_mt:type()
+ return self._type or "client"
+ end
+
+ function interface_mt:connections()
+ return self._connections
+ end
+
+ function interface_mt:address()
+ return self.addr
+ end
+
+ function interface_mt:set_sslctx(sslctx)
+ self._sslctx = sslctx;
+ if sslctx then
+ self.starttls = nil; -- use starttls() of interface_mt
+ else
+ self.starttls = false; -- prevent starttls()
+ end
+ end
+
+ function interface_mt:set_mode(pattern)
+ if pattern then
+ self._pattern = pattern;
+ end
+ return self._pattern;
+ end
+
+ function interface_mt:set_send(new_send)
+ -- No-op, we always use the underlying connection's send
+ end
+
+ function interface_mt:starttls(sslctx, call_onconnect)
+ debug( "try to start ssl at client id:", self.id )
+ local err
+ self._sslctx = sslctx;
+ if self._usingssl then -- startssl was already called
+ err = "ssl already active"
+ end
+ if err then
+ debug( "error:", err )
+ return nil, err
+ end
+ self._usingssl = true
+ self.startsslcallback = function( ) -- we have to start the handshake outside of a read/write event
+ self.startsslcallback = nil
+ self:_start_ssl(call_onconnect);
+ self.eventstarthandshake = nil
+ return -1
+ end
+ if not self.eventwrite then
+ self:_lock( true, true, true ) -- lock the interface, to not disturb the handshake
+ self.eventstarthandshake = addevent( base, nil, EV_TIMEOUT, self.startsslcallback, 0 ) -- add event to start handshake
+ else -- wait until writebuffer is empty
+ self:_lock( true, true, false )
+ debug "ssl session delayed until writebuffer is empty..."
+ end
+ self.starttls = false;
+ return true
+ end
+
+ function interface_mt:setoption(option, value)
+ if self.conn.setoption then
+ return self.conn:setoption(option, value);
+ end
+ return false, "setoption not implemented";
+ end
+
+ function interface_mt:setlistener(listener)
+ self.onconnect, self.ondisconnect, self.onincoming, self.ontimeout, self.onstatus
+ = listener.onconnect, listener.ondisconnect, listener.onincoming, listener.ontimeout, listener.onstatus;
+ end
+
+ -- Stub handlers
+ function interface_mt:onconnect()
+ end
+ function interface_mt:onincoming()
+ end
+ function interface_mt:ondisconnect()
+ end
+ function interface_mt:ontimeout()
+ end
+ function interface_mt:ondrain()
+ end
+ function interface_mt:onstatus()
+ end
+end
+
+-- End of client interface methods
+
+local handleclient;
+do
+ local string_sub = string.sub -- caching table lookups
+ local addevent = base.addevent
+ local socket_gettime = socket.gettime
+ function handleclient( client, ip, port, server, pattern, listener, sslctx ) -- creates an client interface
+ --vdebug("creating client interfacce...")
+ local interface = {
+ type = "client";
+ conn = client;
+ currenttime = socket_gettime( ); -- safe the origin
+ writebuffer = {}; -- writebuffer
+ writebufferlen = 0; -- length of writebuffer
+ send = client.send; -- caching table lookups
+ receive = client.receive;
+ onconnect = listener.onconnect; -- will be called when client disconnects
+ ondisconnect = listener.ondisconnect; -- will be called when client disconnects
+ onincoming = listener.onincoming; -- will be called when client sends data
+ ontimeout = listener.ontimeout; -- called when fatal socket timeout occurs
+ onstatus = listener.onstatus; -- called for status changes (e.g. of SSL/TLS)
+ eventread = false, eventwrite = false, eventclose = false,
+ eventhandshake = false, eventstarthandshake = false; -- event handler
+ eventconnect = false, eventsession = false; -- more event handler...
+ eventwritetimeout = false; -- even more event handler...
+ eventreadtimeout = false;
+ fatalerror = false; -- error message
+ writecallback = false; -- will be called on write events
+ readcallback = false; -- will be called on read events
+ nointerface = true; -- lock/unlock parameter of this interface
+ noreading = false, nowriting = false; -- locks of the read/writecallback
+ startsslcallback = false; -- starting handshake callback
+ position = false; -- position of client in interfacelist
+
+ -- Properties
+ _ip = ip, _port = port, _server = server, _pattern = pattern,
+ _serverport = (server and server:port() or nil),
+ _sslctx = sslctx; -- parameters
+ _usingssl = false; -- client is using ssl;
+ }
+ if not ssl then interface.starttls = false; end
+ interface.id = tostring(interface):match("%x+$");
+ interface.writecallback = function( event ) -- called on write events
+ --vdebug( "new client write event, id/ip/port:", interface, ip, port )
+ if interface.nowriting or ( interface.fatalerror and ( "client to close" ~= interface.fatalerror ) ) then -- leave this event
+ --vdebug( "leaving this event because:", interface.nowriting or interface.fatalerror )
+ interface.eventwrite = false
+ return -1
+ end
+ if EV_TIMEOUT == event then -- took too long to write some data to socket -> disconnect
+ interface.fatalerror = "timeout during writing"
+ debug( "writing failed:", interface.fatalerror )
+ interface:_close()
+ interface.eventwrite = false
+ return -1
+ else -- can write :)
+ if interface._usingssl then -- handle luasec
+ if interface.eventreadtimeout then -- we have to read first
+ local ret = interface.readcallback( ) -- call readcallback
+ --vdebug( "tried to read in writecallback, result:", ret )
+ end
+ if interface.eventwritetimeout then -- luasec only
+ interface.eventwritetimeout:close( ) -- first we have to close timeout event which where regged after a wantread error
+ interface.eventwritetimeout = false
+ end
+ end
+ interface.writebuffer = { t_concat(interface.writebuffer) }
+ local succ, err, byte = interface.conn:send( interface.writebuffer[1], 1, interface.writebufferlen )
+ --vdebug( "write data:", interface.writebuffer, "error:", err, "part:", byte )
+ if succ then -- writing succesful
+ interface.writebuffer[1] = nil
+ interface.writebufferlen = 0
+ interface:ondrain();
+ if interface.fatalerror then
+ debug "closing client after writing"
+ interface:_close() -- close interface if needed
+ elseif interface.startsslcallback then -- start ssl connection if needed
+ debug "starting ssl handshake after writing"
+ interface.eventstarthandshake = addevent( base, nil, EV_TIMEOUT, interface.startsslcallback, 0 )
+ elseif interface.eventreadtimeout then
+ return EV_WRITE, EV_TIMEOUT
+ end
+ interface.eventwrite = nil
+ return -1
+ elseif byte and (err == "timeout" or err == "wantwrite") then -- want write again
+ --vdebug( "writebuffer is not empty:", err )
+ interface.writebuffer[1] = string_sub( interface.writebuffer[1], byte + 1, interface.writebufferlen ) -- new buffer
+ interface.writebufferlen = interface.writebufferlen - byte
+ if "wantread" == err then -- happens only with luasec
+ local callback = function( )
+ interface:_close()
+ interface.eventwritetimeout = nil
+ return -1;
+ end
+ interface.eventwritetimeout = addevent( base, nil, EV_TIMEOUT, callback, cfg.WRITE_TIMEOUT ) -- reg a new timeout event
+ debug( "wantread during write attempt, reg it in readcallback but dont know what really happens next..." )
+ -- hopefully this works with luasec; its simply not possible to use 2 different write events on a socket in luaevent
+ return -1
+ end
+ return EV_WRITE, cfg.WRITE_TIMEOUT
+ else -- connection was closed during writing or fatal error
+ interface.fatalerror = err or "fatal error"
+ debug( "connection failed in write event:", interface.fatalerror )
+ interface:_close()
+ interface.eventwrite = nil
+ return -1
+ end
+ end
+ end
+
+ interface.readcallback = function( event ) -- called on read events
+ --vdebug( "new client read event, id/ip/port:", tostring(interface.id), tostring(ip), tostring(port) )
+ if interface.noreading or interface.fatalerror then -- leave this event
+ --vdebug( "leaving this event because:", tostring(interface.noreading or interface.fatalerror) )
+ interface.eventread = nil
+ return -1
+ end
+ if EV_TIMEOUT == event then -- took too long to get some data from client -> disconnect
+ interface.fatalerror = "timeout during receiving"
+ debug( "connection failed:", interface.fatalerror )
+ interface:_close()
+ interface.eventread = nil
+ return -1
+ else -- can read
+ if interface._usingssl then -- handle luasec
+ if interface.eventwritetimeout then -- ok, in the past writecallback was regged
+ local ret = interface.writecallback( ) -- call it
+ --vdebug( "tried to write in readcallback, result:", tostring(ret) )
+ end
+ if interface.eventreadtimeout then
+ interface.eventreadtimeout:close( )
+ interface.eventreadtimeout = nil
+ end
+ end
+ local buffer, err, part = interface.conn:receive( interface._pattern ) -- receive buffer with "pattern"
+ --vdebug( "read data:", tostring(buffer), "error:", tostring(err), "part:", tostring(part) )
+ buffer = buffer or part
+ if buffer and #buffer > cfg.MAX_READ_LENGTH then -- check buffer length
+ interface.fatalerror = "receive buffer exceeded"
+ debug( "fatal error:", interface.fatalerror )
+ interface:_close()
+ interface.eventread = nil
+ return -1
+ end
+ if err and ( err ~= "timeout" and err ~= "wantread" ) then
+ if "wantwrite" == err then -- need to read on write event
+ if not interface.eventwrite then -- register new write event if needed
+ interface.eventwrite = addevent( base, interface.conn, EV_WRITE, interface.writecallback, cfg.WRITE_TIMEOUT )
+ end
+ interface.eventreadtimeout = addevent( base, nil, EV_TIMEOUT,
+ function( )
+ interface:_close()
+ end, cfg.READ_TIMEOUT
+ )
+ debug( "wantwrite during read attempt, reg it in writecallback but dont know what really happens next..." )
+ -- to be honest i dont know what happens next, if it is allowed to first read, the write etc...
+ else -- connection was closed or fatal error
+ interface.fatalerror = err
+ debug( "connection failed in read event:", interface.fatalerror )
+ interface:_close()
+ interface.eventread = nil
+ return -1
+ end
+ else
+ interface.onincoming( interface, buffer, err ) -- send new data to listener
+ end
+ if interface.noreading then
+ interface.eventread = nil;
+ return -1;
+ end
+ return EV_READ, cfg.READ_TIMEOUT
+ end
+ end
+
+ client:settimeout( 0 ) -- set non blocking
+ setmetatable(interface, interface_mt)
+ interfacelist( "add", interface ) -- add to interfacelist
+ return interface
+ end
+end
+
+local handleserver
+do
+ function handleserver( server, addr, port, pattern, listener, sslctx ) -- creates an server interface
+ debug "creating server interface..."
+ local interface = {
+ _connections = 0;
+
+ conn = server;
+ onconnect = listener.onconnect; -- will be called when new client connected
+ eventread = false; -- read event handler
+ eventclose = false; -- close event handler
+ readcallback = false; -- read event callback
+ fatalerror = false; -- error message
+ nointerface = true; -- lock/unlock parameter
+
+ _ip = addr, _port = port, _pattern = pattern,
+ _sslctx = sslctx;
+ }
+ interface.id = tostring(interface):match("%x+$");
+ interface.readcallback = function( event ) -- server handler, called on incoming connections
+ --vdebug( "server can accept, id/addr/port:", interface, addr, port )
+ if interface.fatalerror then
+ --vdebug( "leaving this event because:", self.fatalerror )
+ interface.eventread = nil
+ return -1
+ end
+ local delay = cfg.ACCEPT_DELAY
+ if EV_TIMEOUT == event then
+ if interface._connections >= cfg.MAX_CONNECTIONS then -- check connection count
+ debug( "to many connections, seconds to wait for next accept:", delay )
+ return EV_TIMEOUT, delay -- timeout...
+ else
+ return EV_READ -- accept again
+ end
+ end
+ --vdebug("max connection check ok, accepting...")
+ local client, err = server:accept() -- try to accept; TODO: check err
+ while client do
+ if interface._connections >= cfg.MAX_CONNECTIONS then
+ client:close( ) -- refuse connection
+ debug( "maximal connections reached, refuse client connection; accept delay:", delay )
+ return EV_TIMEOUT, delay -- delay for next accept attempt
+ end
+ local client_ip, client_port = client:getpeername( )
+ interface._connections = interface._connections + 1 -- increase connection count
+ local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx )
+ --vdebug( "client id:", clientinterface, "startssl:", startssl )
+ if ssl and sslctx then
+ clientinterface:starttls(sslctx, true)
+ else
+ clientinterface:_start_session( true )
+ end
+ debug( "accepted incoming client connection from:", client_ip or "<unknown IP>", client_port or "<unknown port>", "to", port or "<unknown port>");
+
+ client, err = server:accept() -- try to accept again
+ end
+ return EV_READ
+ end
+
+ server:settimeout( 0 )
+ setmetatable(interface, interface_mt)
+ interfacelist( "add", interface )
+ interface:_start_session()
+ return interface
+ end
+end
+
+local addserver = ( function( )
+ return function( addr, port, listener, pattern, sslcfg, startssl ) -- TODO: check arguments
+ --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslcfg or "nil", startssl or "nil")
+ local server, err = socket.bind( addr, port, cfg.ACCEPT_QUEUE ) -- create server socket
+ if not server then
+ debug( "creating server socket on "..addr.." port "..port.." failed:", err )
+ return nil, err
+ end
+ local sslctx
+ if sslcfg then
+ if not ssl then
+ debug "fatal error: luasec not found"
+ return nil, "luasec not found"
+ end
+ sslctx, err = sslcfg
+ if err then
+ debug( "error while creating new ssl context for server socket:", err )
+ return nil, err
+ end
+ end
+ local interface = handleserver( server, addr, port, pattern, listener, sslctx, startssl ) -- new server handler
+ debug( "new server created with id:", tostring(interface))
+ return interface
+ end
+end )( )
+
+local addclient, wrapclient
+do
+ function wrapclient( client, ip, port, listeners, pattern, sslctx )
+ local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx )
+ interface:_start_connection(sslctx)
+ return interface, client
+ --function handleclient( client, ip, port, server, pattern, listener, _, sslctx ) -- creates an client interface
+ end
+
+ function addclient( addr, serverport, listener, pattern, localaddr, localport, sslcfg, startssl )
+ local client, err = socket.tcp() -- creating new socket
+ if not client then
+ debug( "cannot create socket:", err )
+ return nil, err
+ end
+ client:settimeout( 0 ) -- set nonblocking
+ if localaddr then
+ local res, err = client:bind( localaddr, localport, -1 )
+ if not res then
+ debug( "cannot bind client:", err )
+ return nil, err
+ end
+ end
+ local sslctx
+ if sslcfg then -- handle ssl/new context
+ if not ssl then
+ debug "need luasec, but not available"
+ return nil, "luasec not found"
+ end
+ sslctx, err = sslcfg
+ if err then
+ debug( "cannot create new ssl context:", err )
+ return nil, err
+ end
+ end
+ local res, err = client:connect( addr, serverport ) -- connect
+ if res or ( err == "timeout" ) then
+ local ip, port = client:getsockname( )
+ local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx, startssl )
+ interface:_start_connection( startssl )
+ debug( "new connection id:", interface.id )
+ return interface, err
+ else
+ debug( "new connection failed:", err )
+ return nil, err
+ end
+ end
+end
+
+
+local loop = function( ) -- starts the event loop
+ base:loop( )
+ return "quitting";
+end
+
+local newevent = ( function( )
+ local add = base.addevent
+ return function( ... )
+ return add( base, ... )
+ end
+end )( )
+
+local closeallservers = function( arg )
+ for _, item in ipairs( interfacelist( ) ) do
+ if item.type == "server" then
+ item:close( arg )
+ end
+ end
+end
+
+local function setquitting(yes)
+ if yes then
+ -- Quit now
+ closeallservers();
+ base:loopexit();
+ end
+end
+
+local function get_backend()
+ return base:method();
+end
+
+-- We need to hold onto the events to stop them
+-- being garbage-collected
+local signal_events = {}; -- [signal_num] -> event object
+local function hook_signal(signal_num, handler)
+ local function _handler(event)
+ local ret = handler();
+ if ret ~= false then -- Continue handling this signal?
+ return EV_SIGNAL; -- Yes
+ end
+ return -1; -- Close this event
+ end
+ signal_events[signal_num] = base:addevent(signal_num, EV_SIGNAL, _handler);
+ return signal_events[signal_num];
+end
+
+local function link(sender, receiver, buffersize)
+ local sender_locked;
+
+ function receiver:ondrain()
+ if sender_locked then
+ sender:resume();
+ sender_locked = nil;
+ end
+ end
+
+ function sender:onincoming(data)
+ receiver:write(data);
+ if receiver.writebufferlen >= buffersize then
+ sender_locked = true;
+ sender:pause();
+ end
+ end
+end
+
+return {
+
+ cfg = cfg,
+ base = base,
+ loop = loop,
+ link = link,
+ event = event,
+ event_base = base,
+ addevent = newevent,
+ addserver = addserver,
+ addclient = addclient,
+ wrapclient = wrapclient,
+ setquitting = setquitting,
+ closeall = closeallservers,
+ get_backend = get_backend,
+ hook_signal = hook_signal,
+
+ __NAME = SCRIPT_NAME,
+ __DATE = LAST_MODIFIED,
+ __AUTHOR = SCRIPT_AUTHOR,
+ __VERSION = SCRIPT_VERSION,
+
+}
diff --git a/net/server_select.lua b/net/server_select.lua
new file mode 100644
index 00000000..7eb330a8
--- /dev/null
+++ b/net/server_select.lua
@@ -0,0 +1,984 @@
+--
+-- server.lua by blastbeat of the luadch project
+-- Re-used here under the MIT/X Consortium License
+--
+-- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain
+--
+
+-- // wrapping luadch stuff // --
+
+local use = function( what )
+ return _G[ what ]
+end
+
+local log, table_concat = require ("util.logger").init("socket"), table.concat;
+local out_put = function (...) return log("debug", table_concat{...}); end
+local out_error = function (...) return log("warn", table_concat{...}); end
+
+----------------------------------// DECLARATION //--
+
+--// constants //--
+
+local STAT_UNIT = 1 -- byte
+
+--// lua functions //--
+
+local type = use "type"
+local pairs = use "pairs"
+local ipairs = use "ipairs"
+local tonumber = use "tonumber"
+local tostring = use "tostring"
+
+--// lua libs //--
+
+local os = use "os"
+local table = use "table"
+local string = use "string"
+local coroutine = use "coroutine"
+
+--// lua lib methods //--
+
+local os_difftime = os.difftime
+local math_min = math.min
+local math_huge = math.huge
+local table_concat = table.concat
+local string_sub = string.sub
+local coroutine_wrap = coroutine.wrap
+local coroutine_yield = coroutine.yield
+
+--// extern libs //--
+
+local luasec = use "ssl"
+local luasocket = use "socket" or require "socket"
+local luasocket_gettime = luasocket.gettime
+
+--// extern lib methods //--
+
+local ssl_wrap = ( luasec and luasec.wrap )
+local socket_bind = luasocket.bind
+local socket_sleep = luasocket.sleep
+local socket_select = luasocket.select
+
+--// functions //--
+
+local id
+local loop
+local stats
+local idfalse
+local closeall
+local addsocket
+local addserver
+local addtimer
+local getserver
+local wrapserver
+local getsettings
+local closesocket
+local removesocket
+local removeserver
+local wrapconnection
+local changesettings
+
+--// tables //--
+
+local _server
+local _readlist
+local _timerlist
+local _sendlist
+local _socketlist
+local _closelist
+local _readtimes
+local _writetimes
+
+--// simple data types //--
+
+local _
+local _readlistlen
+local _sendlistlen
+local _timerlistlen
+
+local _sendtraffic
+local _readtraffic
+
+local _selecttimeout
+local _sleeptime
+local _tcpbacklog
+
+local _starttime
+local _currenttime
+
+local _maxsendlen
+local _maxreadlen
+
+local _checkinterval
+local _sendtimeout
+local _readtimeout
+
+local _timer
+
+local _maxselectlen
+local _maxfd
+
+local _maxsslhandshake
+
+----------------------------------// DEFINITION //--
+
+_server = { } -- key = port, value = table; list of listening servers
+_readlist = { } -- array with sockets to read from
+_sendlist = { } -- arrary with sockets to write to
+_timerlist = { } -- array of timer functions
+_socketlist = { } -- key = socket, value = wrapped socket (handlers)
+_readtimes = { } -- key = handler, value = timestamp of last data reading
+_writetimes = { } -- key = handler, value = timestamp of last data writing/sending
+_closelist = { } -- handlers to close
+
+_readlistlen = 0 -- length of readlist
+_sendlistlen = 0 -- length of sendlist
+_timerlistlen = 0 -- lenght of timerlist
+
+_sendtraffic = 0 -- some stats
+_readtraffic = 0
+
+_selecttimeout = 1 -- timeout of socket.select
+_sleeptime = 0 -- time to wait at the end of every loop
+_tcpbacklog = 128 -- some kind of hint to the OS
+
+_maxsendlen = 51000 * 1024 -- max len of send buffer
+_maxreadlen = 25000 * 1024 -- max len of read buffer
+
+_checkinterval = 1200000 -- interval in secs to check idle clients
+_sendtimeout = 60000 -- allowed send idle time in secs
+_readtimeout = 6 * 60 * 60 -- allowed read idle time in secs
+
+local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows
+_maxfd = luasocket._SETSIZE or (is_windows and math.huge) or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows
+_maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows
+
+_maxsslhandshake = 30 -- max handshake round-trips
+
+----------------------------------// PRIVATE //--
+
+wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd
+
+ if socket:getfd() >= _maxfd then
+ out_error("server.lua: Disallowed FD number: "..socket:getfd())
+ socket:close()
+ return nil, "fd-too-large"
+ end
+
+ local connections = 0
+
+ local dispatch, disconnect = listeners.onconnect, listeners.ondisconnect
+
+ local accept = socket.accept
+
+ --// public methods of the object //--
+
+ local handler = { }
+
+ handler.shutdown = function( ) end
+
+ handler.ssl = function( )
+ return sslctx ~= nil
+ end
+ handler.sslctx = function( )
+ return sslctx
+ end
+ handler.remove = function( )
+ connections = connections - 1
+ if handler then
+ handler.resume( )
+ end
+ end
+ handler.close = function()
+ socket:close( )
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
+ _readlistlen = removesocket( _readlist, socket, _readlistlen )
+ _server[ip..":"..serverport] = nil;
+ _socketlist[ socket ] = nil
+ handler = nil
+ socket = nil
+ --mem_free( )
+ out_put "server.lua: closed server handler and removed sockets from list"
+ end
+ handler.pause = function( hard )
+ if not handler.paused then
+ _readlistlen = removesocket( _readlist, socket, _readlistlen )
+ if hard then
+ _socketlist[ socket ] = nil
+ socket:close( )
+ socket = nil;
+ end
+ handler.paused = true;
+ end
+ end
+ handler.resume = function( )
+ if handler.paused then
+ if not socket then
+ socket = socket_bind( ip, serverport, _tcpbacklog );
+ socket:settimeout( 0 )
+ end
+ _readlistlen = addsocket(_readlist, socket, _readlistlen)
+ _socketlist[ socket ] = handler
+ handler.paused = false;
+ end
+ end
+ handler.ip = function( )
+ return ip
+ end
+ handler.serverport = function( )
+ return serverport
+ end
+ handler.socket = function( )
+ return socket
+ end
+ handler.readbuffer = function( )
+ if _readlistlen >= _maxselectlen or _sendlistlen >= _maxselectlen then
+ handler.pause( )
+ out_put( "server.lua: refused new client connection: server full" )
+ return false
+ end
+ local client, err = accept( socket ) -- try to accept
+ if client then
+ local ip, clientport = client:getpeername( )
+ local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket
+ if err then -- error while wrapping ssl socket
+ return false
+ end
+ connections = connections + 1
+ out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))
+ if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes
+ return dispatch( handler );
+ end
+ return;
+ elseif err then -- maybe timeout or something else
+ out_put( "server.lua: error with new client connection: ", tostring(err) )
+ return false
+ end
+ end
+ return handler
+end
+
+wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object
+
+ if socket:getfd() >= _maxfd then
+ out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent
+ socket:close( ) -- Should we send some kind of error here?
+ server.pause( )
+ return nil, nil, "fd-too-large"
+ end
+ socket:settimeout( 0 )
+
+ --// local import of socket methods //--
+
+ local send
+ local receive
+ local shutdown
+
+ --// private closures of the object //--
+
+ local ssl
+
+ local dispatch = listeners.onincoming
+ local status = listeners.onstatus
+ local disconnect = listeners.ondisconnect
+ local drain = listeners.ondrain
+
+ local bufferqueue = { } -- buffer array
+ local bufferqueuelen = 0 -- end of buffer array
+
+ local toclose
+ local fatalerror
+ local needtls
+
+ local bufferlen = 0
+
+ local noread = false
+ local nosend = false
+
+ local sendtraffic, readtraffic = 0, 0
+
+ local maxsendlen = _maxsendlen
+ local maxreadlen = _maxreadlen
+
+ --// public methods of the object //--
+
+ local handler = bufferqueue -- saves a table ^_^
+
+ handler.dispatch = function( )
+ return dispatch
+ end
+ handler.disconnect = function( )
+ return disconnect
+ end
+ handler.setlistener = function( self, listeners )
+ dispatch = listeners.onincoming
+ disconnect = listeners.ondisconnect
+ status = listeners.onstatus
+ drain = listeners.ondrain
+ end
+ handler.getstats = function( )
+ return readtraffic, sendtraffic
+ end
+ handler.ssl = function( )
+ return ssl
+ end
+ handler.sslctx = function ( )
+ return sslctx
+ end
+ handler.send = function( _, data, i, j )
+ return send( socket, data, i, j )
+ end
+ handler.receive = function( pattern, prefix )
+ return receive( socket, pattern, prefix )
+ end
+ handler.shutdown = function( pattern )
+ return shutdown( socket, pattern )
+ end
+ handler.setoption = function (self, option, value)
+ if socket.setoption then
+ return socket:setoption(option, value);
+ end
+ return false, "setoption not implemented";
+ end
+ handler.force_close = function ( self, err )
+ if bufferqueuelen ~= 0 then
+ out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport))
+ bufferqueuelen = 0;
+ end
+ return self:close(err);
+ end
+ handler.close = function( self, err )
+ if not handler then return true; end
+ _readlistlen = removesocket( _readlist, socket, _readlistlen )
+ _readtimes[ handler ] = nil
+ if bufferqueuelen ~= 0 then
+ handler.sendbuffer() -- Try now to send any outstanding data
+ if bufferqueuelen ~= 0 then -- Still not empty, so we'll try again later
+ if handler then
+ handler.write = nil -- ... but no further writing allowed
+ end
+ toclose = true
+ return false
+ end
+ end
+ if socket then
+ _ = shutdown and shutdown( socket )
+ socket:close( )
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
+ _socketlist[ socket ] = nil
+ socket = nil
+ else
+ out_put "server.lua: socket already closed"
+ end
+ if handler then
+ _writetimes[ handler ] = nil
+ _closelist[ handler ] = nil
+ local _handler = handler;
+ handler = nil
+ if disconnect then
+ disconnect(_handler, err or false);
+ disconnect = nil
+ end
+ end
+ if server then
+ server.remove( )
+ end
+ out_put "server.lua: closed client handler and removed socket from list"
+ return true
+ end
+ handler.ip = function( )
+ return ip
+ end
+ handler.serverport = function( )
+ return serverport
+ end
+ handler.clientport = function( )
+ return clientport
+ end
+ local write = function( self, data )
+ bufferlen = bufferlen + #data
+ if bufferlen > maxsendlen then
+ _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle
+ handler.write = idfalse -- dont write anymore
+ return false
+ elseif socket and not _sendlist[ socket ] then
+ _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
+ end
+ bufferqueuelen = bufferqueuelen + 1
+ bufferqueue[ bufferqueuelen ] = data
+ if handler then
+ _writetimes[ handler ] = _writetimes[ handler ] or _currenttime
+ end
+ return true
+ end
+ handler.write = write
+ handler.bufferqueue = function( self )
+ return bufferqueue
+ end
+ handler.socket = function( self )
+ return socket
+ end
+ handler.set_mode = function( self, new )
+ pattern = new or pattern
+ return pattern
+ end
+ handler.set_send = function ( self, newsend )
+ send = newsend or send
+ return send
+ end
+ handler.bufferlen = function( self, readlen, sendlen )
+ maxsendlen = sendlen or maxsendlen
+ maxreadlen = readlen or maxreadlen
+ return bufferlen, maxreadlen, maxsendlen
+ end
+ --TODO: Deprecate
+ handler.lock_read = function (self, switch)
+ if switch == true then
+ local tmp = _readlistlen
+ _readlistlen = removesocket( _readlist, socket, _readlistlen )
+ _readtimes[ handler ] = nil
+ if _readlistlen ~= tmp then
+ noread = true
+ end
+ elseif switch == false then
+ if noread then
+ noread = false
+ _readlistlen = addsocket(_readlist, socket, _readlistlen)
+ _readtimes[ handler ] = _currenttime
+ end
+ end
+ return noread
+ end
+ handler.pause = function (self)
+ return self:lock_read(true);
+ end
+ handler.resume = function (self)
+ return self:lock_read(false);
+ end
+ handler.lock = function( self, switch )
+ handler.lock_read (switch)
+ if switch == true then
+ handler.write = idfalse
+ local tmp = _sendlistlen
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
+ _writetimes[ handler ] = nil
+ if _sendlistlen ~= tmp then
+ nosend = true
+ end
+ elseif switch == false then
+ handler.write = write
+ if nosend then
+ nosend = false
+ write( "" )
+ end
+ end
+ return noread, nosend
+ end
+ local _readbuffer = function( ) -- this function reads data
+ local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern"
+ if not err or (err == "wantread" or err == "timeout") then -- received something
+ local buffer = buffer or part or ""
+ local len = #buffer
+ if len > maxreadlen then
+ handler:close( "receive buffer exceeded" )
+ return false
+ end
+ local count = len * STAT_UNIT
+ readtraffic = readtraffic + count
+ _readtraffic = _readtraffic + count
+ _readtimes[ handler ] = _currenttime
+ --out_put( "server.lua: read data '", buffer:gsub("[^%w%p ]", "."), "', error: ", err )
+ return dispatch( handler, buffer, err )
+ else -- connections was closed or fatal error
+ out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
+ fatalerror = true
+ _ = handler and handler:force_close( err )
+ return false
+ end
+ end
+ local _sendbuffer = function( ) -- this function sends data
+ local succ, err, byte, buffer, count;
+ if socket then
+ buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )
+ succ, err, byte = send( socket, buffer, 1, bufferlen )
+ count = ( succ or byte or 0 ) * STAT_UNIT
+ sendtraffic = sendtraffic + count
+ _sendtraffic = _sendtraffic + count
+ for i = bufferqueuelen,1,-1 do
+ bufferqueue[ i ] = nil
+ end
+ --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) )
+ else
+ succ, err, count = false, "unexpected close", 0;
+ end
+ if succ then -- sending succesful
+ bufferqueuelen = 0
+ bufferlen = 0
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist
+ _writetimes[ handler ] = nil
+ if drain then
+ drain(handler)
+ end
+ _ = needtls and handler:starttls(nil)
+ _ = toclose and handler:force_close( )
+ return true
+ elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write
+ buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer
+ bufferqueue[ 1 ] = buffer -- insert new buffer in queue
+ bufferqueuelen = 1
+ bufferlen = bufferlen - byte
+ _writetimes[ handler ] = _currenttime
+ return true
+ else -- connection was closed during sending or fatal error
+ out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) )
+ fatalerror = true
+ _ = handler and handler:force_close( err )
+ return false
+ end
+ end
+
+ -- Set the sslctx
+ local handshake;
+ function handler.set_sslctx(self, new_sslctx)
+ sslctx = new_sslctx;
+ local read, wrote
+ handshake = coroutine_wrap( function( client ) -- create handshake coroutine
+ local err
+ for i = 1, _maxsslhandshake do
+ _sendlistlen = ( wrote and removesocket( _sendlist, client, _sendlistlen ) ) or _sendlistlen
+ _readlistlen = ( read and removesocket( _readlist, client, _readlistlen ) ) or _readlistlen
+ read, wrote = nil, nil
+ _, err = client:dohandshake( )
+ if not err then
+ out_put( "server.lua: ssl handshake done" )
+ handler.readbuffer = _readbuffer -- when handshake is done, replace the handshake function with regular functions
+ handler.sendbuffer = _sendbuffer
+ _ = status and status( handler, "ssl-handshake-complete" )
+ if self.autostart_ssl and listeners.onconnect then
+ listeners.onconnect(self);
+ end
+ _readlistlen = addsocket(_readlist, client, _readlistlen)
+ return true
+ else
+ if err == "wantwrite" then
+ _sendlistlen = addsocket(_sendlist, client, _sendlistlen)
+ wrote = true
+ elseif err == "wantread" then
+ _readlistlen = addsocket(_readlist, client, _readlistlen)
+ read = true
+ else
+ break;
+ end
+ err = nil;
+ coroutine_yield( ) -- handshake not finished
+ end
+ end
+ out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") )
+ _ = handler and handler:force_close("ssl handshake failed")
+ return false, err -- handshake failed
+ end
+ )
+ end
+ if luasec then
+ handler.starttls = function( self, _sslctx)
+ if _sslctx then
+ handler:set_sslctx(_sslctx);
+ end
+ if bufferqueuelen > 0 then
+ out_put "server.lua: we need to do tls, but delaying until send buffer empty"
+ needtls = true
+ return
+ end
+ out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
+ local oldsocket, err = socket
+ socket, err = ssl_wrap( socket, sslctx ) -- wrap socket
+ if not socket then
+ out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
+ return nil, err -- fatal error
+ end
+
+ socket:settimeout( 0 )
+
+ -- add the new socket to our system
+ send = socket.send
+ receive = socket.receive
+ shutdown = id
+ _socketlist[ socket ] = handler
+ _readlistlen = addsocket(_readlist, socket, _readlistlen)
+
+ -- remove traces of the old socket
+ _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
+ _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
+ _socketlist[ oldsocket ] = nil
+
+ handler.starttls = nil
+ needtls = nil
+
+ -- Secure now (if handshake fails connection will close)
+ ssl = true
+
+ handler.readbuffer = handshake
+ handler.sendbuffer = handshake
+ return handshake( socket ) -- do handshake
+ end
+ end
+
+ handler.readbuffer = _readbuffer
+ handler.sendbuffer = _sendbuffer
+ send = socket.send
+ receive = socket.receive
+ shutdown = ( ssl and id ) or socket.shutdown
+
+ _socketlist[ socket ] = handler
+ _readlistlen = addsocket(_readlist, socket, _readlistlen)
+
+ if sslctx and luasec then
+ out_put "server.lua: auto-starting ssl negotiation..."
+ handler.autostart_ssl = true;
+ local ok, err = handler:starttls(sslctx);
+ if ok == false then
+ return nil, nil, err
+ end
+ end
+
+ return handler, socket
+end
+
+id = function( )
+end
+
+idfalse = function( )
+ return false
+end
+
+addsocket = function( list, socket, len )
+ if not list[ socket ] then
+ len = len + 1
+ list[ len ] = socket
+ list[ socket ] = len
+ end
+ return len;
+end
+
+removesocket = function( list, socket, len ) -- this function removes sockets from a list ( copied from copas )
+ local pos = list[ socket ]
+ if pos then
+ list[ socket ] = nil
+ local last = list[ len ]
+ list[ len ] = nil
+ if last ~= socket then
+ list[ last ] = pos
+ list[ pos ] = last
+ end
+ return len - 1
+ end
+ return len
+end
+
+closesocket = function( socket )
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
+ _readlistlen = removesocket( _readlist, socket, _readlistlen )
+ _socketlist[ socket ] = nil
+ socket:close( )
+ --mem_free( )
+end
+
+local function link(sender, receiver, buffersize)
+ local sender_locked;
+ local _sendbuffer = receiver.sendbuffer;
+ function receiver.sendbuffer()
+ _sendbuffer();
+ if sender_locked and receiver.bufferlen() < buffersize then
+ sender:lock_read(false); -- Unlock now
+ sender_locked = nil;
+ end
+ end
+
+ local _readbuffer = sender.readbuffer;
+ function sender.readbuffer()
+ _readbuffer();
+ if not sender_locked and receiver.bufferlen() >= buffersize then
+ sender_locked = true;
+ sender:lock_read(true);
+ end
+ end
+end
+
+----------------------------------// PUBLIC //--
+
+addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
+ local err
+ if type( listeners ) ~= "table" then
+ err = "invalid listener table"
+ end
+ if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
+ err = "invalid port"
+ elseif _server[ addr..":"..port ] then
+ err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist"
+ elseif sslctx and not luasec then
+ err = "luasec not found"
+ end
+ if err then
+ out_error( "server.lua, [", addr, "]:", port, ": ", err )
+ return nil, err
+ end
+ addr = addr or "*"
+ local server, err = socket_bind( addr, port, _tcpbacklog )
+ if err then
+ out_error( "server.lua, [", addr, "]:", port, ": ", err )
+ return nil, err
+ end
+ local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket
+ if not handler then
+ server:close( )
+ return nil, err
+ end
+ server:settimeout( 0 )
+ _readlistlen = addsocket(_readlist, server, _readlistlen)
+ _server[ addr..":"..port ] = handler
+ _socketlist[ server ] = handler
+ out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" )
+ return handler
+end
+
+getserver = function ( addr, port )
+ return _server[ addr..":"..port ];
+end
+
+removeserver = function( addr, port )
+ local handler = _server[ addr..":"..port ]
+ if not handler then
+ return nil, "no server found on '[" .. addr .. "]:" .. tostring( port ) .. "'"
+ end
+ handler:close( )
+ _server[ addr..":"..port ] = nil
+ return true
+end
+
+closeall = function( )
+ for _, handler in pairs( _socketlist ) do
+ handler:close( )
+ _socketlist[ _ ] = nil
+ end
+ _readlistlen = 0
+ _sendlistlen = 0
+ _timerlistlen = 0
+ _server = { }
+ _readlist = { }
+ _sendlist = { }
+ _timerlist = { }
+ _socketlist = { }
+ --mem_free( )
+end
+
+getsettings = function( )
+ return {
+ select_timeout = _selecttimeout;
+ select_sleep_time = _sleeptime;
+ tcp_backlog = _tcpbacklog;
+ max_send_buffer_size = _maxsendlen;
+ max_receive_buffer_size = _maxreadlen;
+ select_idle_check_interval = _checkinterval;
+ send_timeout = _sendtimeout;
+ read_timeout = _readtimeout;
+ max_connections = _maxselectlen;
+ max_ssl_handshake_roundtrips = _maxsslhandshake;
+ highest_allowed_fd = _maxfd;
+ }
+end
+
+changesettings = function( new )
+ if type( new ) ~= "table" then
+ return nil, "invalid settings table"
+ end
+ _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout
+ _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime
+ _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen
+ _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen
+ _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval
+ _tcpbacklog = tonumber( new.tcp_backlog ) or _tcpbacklog
+ _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout
+ _readtimeout = tonumber( new.read_timeout ) or _readtimeout
+ _maxselectlen = new.max_connections or _maxselectlen
+ _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake
+ _maxfd = new.highest_allowed_fd or _maxfd
+ return true
+end
+
+addtimer = function( listener )
+ if type( listener ) ~= "function" then
+ return nil, "invalid listener function"
+ end
+ _timerlistlen = _timerlistlen + 1
+ _timerlist[ _timerlistlen ] = listener
+ return true
+end
+
+stats = function( )
+ return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen
+end
+
+local quitting;
+
+local function setquitting(quit)
+ quitting = not not quit;
+end
+
+loop = function(once) -- this is the main loop of the program
+ if quitting then return "quitting"; end
+ if once then quitting = "once"; end
+ local next_timer_time = math_huge;
+ repeat
+ local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) )
+ for i, socket in ipairs( write ) do -- send data waiting in writequeues
+ local handler = _socketlist[ socket ]
+ if handler then
+ handler.sendbuffer( )
+ else
+ closesocket( socket )
+ out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen
+ end
+ end
+ for i, socket in ipairs( read ) do -- receive data
+ local handler = _socketlist[ socket ]
+ if handler then
+ handler.readbuffer( )
+ else
+ closesocket( socket )
+ out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen
+ end
+ end
+ for handler, err in pairs( _closelist ) do
+ handler.disconnect( )( handler, err )
+ handler:force_close() -- forced disconnect
+ _closelist[ handler ] = nil;
+ end
+ _currenttime = luasocket_gettime( )
+
+ -- Check for socket timeouts
+ local difftime = os_difftime( _currenttime - _starttime )
+ if difftime > _checkinterval then
+ _starttime = _currenttime
+ for handler, timestamp in pairs( _writetimes ) do
+ if os_difftime( _currenttime - timestamp ) > _sendtimeout then
+ --_writetimes[ handler ] = nil
+ handler.disconnect( )( handler, "send timeout" )
+ handler:force_close() -- forced disconnect
+ end
+ end
+ for handler, timestamp in pairs( _readtimes ) do
+ if os_difftime( _currenttime - timestamp ) > _readtimeout then
+ --_readtimes[ handler ] = nil
+ handler.disconnect( )( handler, "read timeout" )
+ handler:close( ) -- forced disconnect?
+ end
+ end
+ end
+
+ -- Fire timers
+ if _currenttime - _timer >= math_min(next_timer_time, 1) then
+ next_timer_time = math_huge;
+ for i = 1, _timerlistlen do
+ local t = _timerlist[ i ]( _currenttime ) -- fire timers
+ if t then next_timer_time = math_min(next_timer_time, t); end
+ end
+ _timer = _currenttime
+ else
+ next_timer_time = next_timer_time - (_currenttime - _timer);
+ end
+
+ -- wait some time (0 by default)
+ socket_sleep( _sleeptime )
+ until quitting;
+ if once and quitting == "once" then quitting = nil; return; end
+ return "quitting"
+end
+
+local function step()
+ return loop(true);
+end
+
+local function get_backend()
+ return "select";
+end
+
+--// EXPERIMENTAL //--
+
+local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx )
+ local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
+ if not handler then return nil, err end
+ _socketlist[ socket ] = handler
+ if not sslctx then
+ _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
+ if listeners.onconnect then
+ -- When socket is writeable, call onconnect
+ local _sendbuffer = handler.sendbuffer;
+ handler.sendbuffer = function ()
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen );
+ handler.sendbuffer = _sendbuffer;
+ listeners.onconnect(handler);
+ -- If there was data with the incoming packet, handle it now.
+ if #handler:bufferqueue() > 0 then
+ return _sendbuffer();
+ end
+ end
+ end
+ end
+ return handler, socket
+end
+
+local addclient = function( address, port, listeners, pattern, sslctx )
+ local client, err = luasocket.tcp( )
+ if err then
+ return nil, err
+ end
+ client:settimeout( 0 )
+ _, err = client:connect( address, port )
+ if err then -- try again
+ local handler = wrapclient( client, address, port, listeners )
+ else
+ wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx )
+ end
+end
+
+--// EXPERIMENTAL //--
+
+----------------------------------// BEGIN //--
+
+use "setmetatable" ( _socketlist, { __mode = "k" } )
+use "setmetatable" ( _readtimes, { __mode = "k" } )
+use "setmetatable" ( _writetimes, { __mode = "k" } )
+
+_timer = luasocket_gettime( )
+_starttime = luasocket_gettime( )
+
+local function setlogger(new_logger)
+ local old_logger = log;
+ if new_logger then
+ log = new_logger;
+ end
+ return old_logger;
+end
+
+----------------------------------// PUBLIC INTERFACE //--
+
+return {
+ _addtimer = addtimer,
+
+ addclient = addclient,
+ wrapclient = wrapclient,
+
+ loop = loop,
+ link = link,
+ step = step,
+ stats = stats,
+ closeall = closeall,
+ addserver = addserver,
+ getserver = getserver,
+ setlogger = setlogger,
+ getsettings = getsettings,
+ setquitting = setquitting,
+ removeserver = removeserver,
+ get_backend = get_backend,
+ changesettings = changesettings,
+}
diff --git a/net/xmppclient_listener.lua b/net/xmppclient_listener.lua
deleted file mode 100644
index dcc561f3..00000000
--- a/net/xmppclient_listener.lua
+++ /dev/null
@@ -1,152 +0,0 @@
--- Prosody IM
--- Copyright (C) 2008-2009 Matthew Wild
--- Copyright (C) 2008-2009 Waqas Hussain
---
--- This project is MIT/X11 licensed. Please see the
--- COPYING file in the source package for more information.
---
-
-
-
-local logger = require "logger";
-local log = logger.init("xmppclient_listener");
-local lxp = require "lxp"
-local init_xmlhandlers = require "core.xmlhandlers"
-local sm_new_session = require "core.sessionmanager".new_session;
-
-local connlisteners_register = require "net.connlisteners".register;
-
-local t_insert = table.insert;
-local t_concat = table.concat;
-local t_concatall = function (t, sep) local tt = {}; for _, s in ipairs(t) do t_insert(tt, tostring(s)); end return t_concat(tt, sep); end
-local m_random = math.random;
-local format = string.format;
-local sessionmanager = require "core.sessionmanager";
-local sm_new_session, sm_destroy_session = sessionmanager.new_session, sessionmanager.destroy_session;
-local sm_streamopened = sessionmanager.streamopened;
-local sm_streamclosed = sessionmanager.streamclosed;
-local st = require "util.stanza";
-
-local stream_callbacks = { stream_tag = "http://etherx.jabber.org/streams|stream",
- default_ns = "jabber:client",
- streamopened = sm_streamopened, streamclosed = sm_streamclosed, handlestanza = core_process_stanza };
-
-function stream_callbacks.error(session, error, data)
- if error == "no-stream" then
- session.log("debug", "Invalid opening stream header");
- session:close("invalid-namespace");
- elseif session.close then
- (session.log or log)("debug", "Client XML parse error: %s", tostring(error));
- session:close("xml-not-well-formed");
- end
-end
-
-local function handleerr(err) log("error", "Traceback[c2s]: %s: %s", tostring(err), debug.traceback()); end
-function stream_callbacks.handlestanza(a, b)
- xpcall(function () core_process_stanza(a, b) end, handleerr);
-end
-
-local sessions = {};
-local xmppclient = { default_port = 5222, default_mode = "*a" };
-
--- These are session methods --
-
-local function session_reset_stream(session)
- -- Reset stream
- local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "|");
- session.parser = parser;
-
- session.notopen = true;
-
- function session.data(conn, data)
- local ok, err = parser:parse(data);
- if ok then return; end
- log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " "));
- session:close("xml-not-well-formed");
- end
-
- return true;
-end
-
-
-local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'};
-local default_stream_attr = { ["xmlns:stream"] = stream_callbacks.stream_tag:gsub("%|[^|]+$", ""), xmlns = stream_callbacks.default_ns, version = "1.0", id = "" };
-local function session_close(session, reason)
- local log = session.log or log;
- if session.conn then
- if session.notopen then
- session.send("<?xml version='1.0'?>");
- session.send(st.stanza("stream:stream", default_stream_attr):top_tag());
- end
- if reason then
- if type(reason) == "string" then -- assume stream error
- log("info", "Disconnecting client, <stream:error> is: %s", reason);
- session.send(st.stanza("stream:error"):tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' }));
- elseif type(reason) == "table" then
- if reason.condition then
- local stanza = st.stanza("stream:error"):tag(reason.condition, stream_xmlns_attr):up();
- if reason.text then
- stanza:tag("text", stream_xmlns_attr):text(reason.text):up();
- end
- if reason.extra then
- stanza:add_child(reason.extra);
- end
- log("info", "Disconnecting client, <stream:error> is: %s", tostring(stanza));
- session.send(stanza);
- elseif reason.name then -- a stanza
- log("info", "Disconnecting client, <stream:error> is: %s", tostring(reason));
- session.send(reason);
- end
- end
- end
- session.send("</stream:stream>");
- session.conn.close();
- xmppclient.disconnect(session.conn, (reason and (reason.text or reason.condition)) or reason or "session closed");
- end
-end
-
-
--- End of session methods --
-
-function xmppclient.listener(conn, data)
- local session = sessions[conn];
- if not session then
- session = sm_new_session(conn);
- sessions[conn] = session;
-
- -- Logging functions --
-
- local conn_name = "c2s"..tostring(conn):match("[a-f0-9]+$");
- session.log = logger.init(conn_name);
-
- session.log("info", "Client connected");
-
- -- Client is using legacy SSL (otherwise mod_tls sets this flag)
- if conn.ssl() then
- session.secure = true;
- end
-
- session.reset_stream = session_reset_stream;
- session.close = session_close;
-
- session_reset_stream(session); -- Initialise, ready for use
-
- session.dispatch_stanza = stream_callbacks.handlestanza;
- end
- if data then
- session.data(conn, data);
- end
-end
-
-function xmppclient.disconnect(conn, err)
- local session = sessions[conn];
- if session then
- (session.log or log)("info", "Client disconnected: %s", err);
- sm_destroy_session(session, err);
- sessions[conn] = nil;
- session = nil;
- collectgarbage("collect");
- end
-end
-
-connlisteners_register("xmppclient", xmppclient);
diff --git a/net/xmppcomponent_listener.lua b/net/xmppcomponent_listener.lua
deleted file mode 100644
index bee05967..00000000
--- a/net/xmppcomponent_listener.lua
+++ /dev/null
@@ -1,176 +0,0 @@
--- Prosody IM
--- Copyright (C) 2008-2009 Matthew Wild
--- Copyright (C) 2008-2009 Waqas Hussain
---
--- This project is MIT/X11 licensed. Please see the
--- COPYING file in the source package for more information.
---
-
-
-local hosts = _G.hosts;
-
-local t_concat = table.concat;
-
-local lxp = require "lxp";
-local logger = require "util.logger";
-local config = require "core.configmanager";
-local connlisteners = require "net.connlisteners";
-local cm_register_component = require "core.componentmanager".register_component;
-local cm_deregister_component = require "core.componentmanager".deregister_component;
-local uuid_gen = require "util.uuid".generate;
-local sha1 = require "util.hashes".sha1;
-local st = require "util.stanza";
-local init_xmlhandlers = require "core.xmlhandlers";
-
-local sessions = {};
-
-local log = logger.init("componentlistener");
-
-local component_listener = { default_port = 5347; default_mode = "*a"; default_interface = config.get("*", "core", "component_interface") or "127.0.0.1" };
-
-local xmlns_component = 'jabber:component:accept';
-
---- Callbacks/data for xmlhandlers to handle streams for us ---
-
-local stream_callbacks = { stream_tag = "http://etherx.jabber.org/streams|stream", default_ns = xmlns_component };
-
-function stream_callbacks.error(session, error, data, data2)
- log("warn", "Error processing component stream: "..tostring(error));
- if error == "no-stream" then
- session:close("invalid-namespace");
- elseif error == "xml-parse-error" and data == "unexpected-element-close" then
- session.log("warn", "Unexpected close of '%s' tag", data2);
- session:close("xml-not-well-formed");
- else
- session.log("warn", "External component %s XML parse error: %s", tostring(session.host), tostring(error));
- session:close("xml-not-well-formed");
- end
-end
-
-function stream_callbacks.streamopened(session, attr)
- if config.get(attr.to, "core", "component_module") ~= "component" then
- -- Trying to act as a component domain which
- -- hasn't been configured
- session:close{ condition = "host-unknown", text = tostring(attr.to).." does not match any configured external components" };
- return;
- end
-
- -- Store the original host (this is used for config, etc.)
- session.user = attr.to;
- -- Set the host for future reference
- session.host = config.get(attr.to, "core", "component_address") or attr.to;
- -- Note that we don't create the internal component
- -- until after the external component auths successfully
-
- session.streamid = uuid_gen();
- session.notopen = nil;
-
- session.send(st.stanza("stream:stream", { xmlns=xmlns_component,
- ["xmlns:stream"]='http://etherx.jabber.org/streams', id=session.streamid, from=session.host }):top_tag());
-
-end
-
-function stream_callbacks.streamclosed(session)
- session.send("</stream:stream>");
- session.notopen = true;
-end
-
-local core_process_stanza = core_process_stanza;
-
-function stream_callbacks.handlestanza(session, stanza)
- -- Namespaces are icky.
- if not stanza.attr.xmlns and stanza.name == "handshake" then
- stanza.attr.xmlns = xmlns_component;
- end
- return core_process_stanza(session, stanza);
-end
-
---- Closing a component connection
-local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'};
-local default_stream_attr = { ["xmlns:stream"] = stream_callbacks.stream_tag:gsub("%|[^|]+$", ""), xmlns = stream_callbacks.default_ns, version = "1.0", id = "" };
-local function session_close(session, reason)
- local log = session.log or log;
- if session.conn then
- if session.notopen then
- session.send("<?xml version='1.0'?>");
- session.send(st.stanza("stream:stream", default_stream_attr):top_tag());
- end
- if reason then
- if type(reason) == "string" then -- assume stream error
- log("info", "Disconnecting component, <stream:error> is: %s", reason);
- session.send(st.stanza("stream:error"):tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' }));
- elseif type(reason) == "table" then
- if reason.condition then
- local stanza = st.stanza("stream:error"):tag(reason.condition, stream_xmlns_attr):up();
- if reason.text then
- stanza:tag("text", stream_xmlns_attr):text(reason.text):up();
- end
- if reason.extra then
- stanza:add_child(reason.extra);
- end
- log("info", "Disconnecting component, <stream:error> is: %s", tostring(stanza));
- session.send(stanza);
- elseif reason.name then -- a stanza
- log("info", "Disconnecting component, <stream:error> is: %s", tostring(reason));
- session.send(reason);
- end
- end
- end
- session.send("</stream:stream>");
- session.conn.close();
- component_listener.disconnect(session.conn, "stream error");
- end
-end
-
---- Component connlistener
-function component_listener.listener(conn, data)
- local session = sessions[conn];
- if not session then
- local _send = conn.write;
- session = { type = "component", conn = conn, send = function (data) return _send(tostring(data)); end };
- sessions[conn] = session;
-
- -- Logging functions --
-
- local conn_name = "jcp"..tostring(conn):match("[a-f0-9]+$");
- session.log = logger.init(conn_name);
- session.close = session_close;
-
- session.log("info", "Incoming Jabber component connection");
-
- local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "|");
- session.parser = parser;
-
- session.notopen = true;
-
- function session.data(conn, data)
- local ok, err = parser:parse(data);
- if ok then return; end
- session:close("xml-not-well-formed");
- end
-
- session.dispatch_stanza = stream_callbacks.handlestanza;
-
- end
- if data then
- session.data(conn, data);
- end
-end
-
-function component_listener.disconnect(conn, err)
- local session = sessions[conn];
- if session then
- (session.log or log)("info", "component disconnected: %s (%s)", tostring(session.host), tostring(err));
- if session.host then
- log("debug", "Deregistering component");
- cm_deregister_component(session.host);
- hosts[session.host].connected = nil;
- end
- sessions[conn] = nil;
- for k in pairs(session) do session[k] = nil; end
- session = nil;
- collectgarbage("collect");
- end
-end
-
-connlisteners.register('xmppcomponent', component_listener);
diff --git a/net/xmppserver_listener.lua b/net/xmppserver_listener.lua
deleted file mode 100644
index 1f27d841..00000000
--- a/net/xmppserver_listener.lua
+++ /dev/null
@@ -1,174 +0,0 @@
--- Prosody IM
--- Copyright (C) 2008-2009 Matthew Wild
--- Copyright (C) 2008-2009 Waqas Hussain
---
--- This project is MIT/X11 licensed. Please see the
--- COPYING file in the source package for more information.
---
-
-
-
-local logger = require "logger";
-local log = logger.init("xmppserver_listener");
-local lxp = require "lxp"
-local init_xmlhandlers = require "core.xmlhandlers"
-local s2s_new_incoming = require "core.s2smanager".new_incoming;
-local s2s_streamopened = require "core.s2smanager".streamopened;
-local s2s_streamclosed = require "core.s2smanager".streamclosed;
-local s2s_destroy_session = require "core.s2smanager".destroy_session;
-local s2s_attempt_connect = require "core.s2smanager".attempt_connection;
-local stream_callbacks = { stream_tag = "http://etherx.jabber.org/streams|stream",
- default_ns = "jabber:server",
- streamopened = s2s_streamopened, streamclosed = s2s_streamclosed, handlestanza = core_process_stanza };
-
-function stream_callbacks.error(session, error, data)
- if error == "no-stream" then
- session:close("invalid-namespace");
- else
- session.log("debug", "Server-to-server XML parse error: %s", tostring(error));
- session:close("xml-not-well-formed");
- end
-end
-
-local function handleerr(err) log("error", "Traceback[s2s]: %s: %s", tostring(err), debug.traceback()); end
-function stream_callbacks.handlestanza(a, b)
- xpcall(function () core_process_stanza(a, b) end, handleerr);
-end
-
-local connlisteners_register = require "net.connlisteners".register;
-
-local t_insert = table.insert;
-local t_concat = table.concat;
-local t_concatall = function (t, sep) local tt = {}; for _, s in ipairs(t) do t_insert(tt, tostring(s)); end return t_concat(tt, sep); end
-local m_random = math.random;
-local format = string.format;
-local sessionmanager = require "core.sessionmanager";
-local sm_new_session, sm_destroy_session = sessionmanager.new_session, sessionmanager.destroy_session;
-local st = require "util.stanza";
-
-local sessions = {};
-local xmppserver = { default_port = 5269, default_mode = "*a" };
-
--- These are session methods --
-
-local function session_reset_stream(session)
- -- Reset stream
- local parser = lxp.new(init_xmlhandlers(session, stream_callbacks), "|");
- session.parser = parser;
-
- session.notopen = true;
-
- function session.data(conn, data)
- local ok, err = parser:parse(data);
- if ok then return; end
- log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " "));
- session:close("xml-not-well-formed");
- end
-
- return true;
-end
-
-
-local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'};
-local default_stream_attr = { ["xmlns:stream"] = stream_callbacks.stream_tag:gsub("%|[^|]+$", ""), xmlns = stream_callbacks.default_ns, version = "1.0", id = "" };
-local function session_close(session, reason)
- local log = session.log or log;
- if session.conn then
- if session.notopen then
- session.sends2s("<?xml version='1.0'?>");
- session.sends2s(st.stanza("stream:stream", default_stream_attr):top_tag());
- end
- if reason then
- if type(reason) == "string" then -- assume stream error
- log("info", "Disconnecting %s[%s], <stream:error> is: %s", session.host or "(unknown host)", session.type, reason);
- session.sends2s(st.stanza("stream:error"):tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' }));
- elseif type(reason) == "table" then
- if reason.condition then
- local stanza = st.stanza("stream:error"):tag(reason.condition, stream_xmlns_attr):up();
- if reason.text then
- stanza:tag("text", stream_xmlns_attr):text(reason.text):up();
- end
- if reason.extra then
- stanza:add_child(reason.extra);
- end
- log("info", "Disconnecting %s[%s], <stream:error> is: %s", session.host or "(unknown host)", session.type, tostring(stanza));
- session.sends2s(stanza);
- elseif reason.name then -- a stanza
- log("info", "Disconnecting %s->%s[%s], <stream:error> is: %s", session.from_host or "(unknown host)", session.to_host or "(unknown host)", session.type, tostring(reason));
- session.sends2s(reason);
- end
- end
- end
- session.sends2s("</stream:stream>");
- session.conn.close();
- xmppserver.disconnect(session.conn, "stream error");
- end
-end
-
-
--- End of session methods --
-
-function xmppserver.listener(conn, data)
- local session = sessions[conn];
- if not session then
- session = s2s_new_incoming(conn);
- sessions[conn] = session;
-
- -- Logging functions --
-
-
- local conn_name = "s2sin"..tostring(conn):match("[a-f0-9]+$");
- session.log = logger.init(conn_name);
-
- session.log("info", "Incoming s2s connection");
-
- session.reset_stream = session_reset_stream;
- session.close = session_close;
-
- session_reset_stream(session); -- Initialise, ready for use
-
- session.dispatch_stanza = stream_callbacks.handlestanza;
- end
- if data then
- session.data(conn, data);
- end
-end
-
-function xmppserver.disconnect(conn, err)
- local session = sessions[conn];
- if session then
- if err and err ~= "closed" and session.srv_hosts then
- (session.log or log)("debug", "s2s connection closed unexpectedly");
- if s2s_attempt_connect(session, err) then
- (session.log or log)("debug", "...so we're going to try again");
- return; -- Session lives for now
- end
- end
- (session.log or log)("info", "s2s disconnected: %s->%s (%s)", tostring(session.from_host), tostring(session.to_host), tostring(err));
- s2s_destroy_session(session);
- sessions[conn] = nil;
- session = nil;
- collectgarbage("collect");
- end
-end
-
-function xmppserver.register_outgoing(conn, session)
- session.direction = "outgoing";
- sessions[conn] = session;
-
- session.reset_stream = session_reset_stream;
- session.close = session_close;
- session_reset_stream(session); -- Initialise, ready for use
-
- --local function handleerr(err) print("Traceback:", err, debug.traceback()); end
- --session.stanza_dispatch = function (stanza) return select(2, xpcall(function () return core_process_stanza(session, stanza); end, handleerr)); end
-end
-
-connlisteners_register("xmppserver", xmppserver);
-
-
--- We need to perform some initialisation when a connection is created
--- We also need to perform that same initialisation at other points (SASL, TLS, ...)
-
--- ...and we need to handle data
--- ...and record all sessions associated with connections