aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/adns.lua1
-rw-r--r--net/connect.lua89
-rw-r--r--net/connlisteners.lua4
-rw-r--r--net/cqueues.lua74
-rw-r--r--net/dns.lua112
-rw-r--r--net/http.lua156
-rw-r--r--net/http/codes.lua106
-rw-r--r--net/http/server.lua36
-rw-r--r--net/httpserver.lua5
-rw-r--r--net/resolvers/basic.lua71
-rw-r--r--net/resolvers/manual.lua25
-rw-r--r--net/resolvers/service.lua70
-rw-r--r--net/server.lua122
-rw-r--r--net/server_epoll.lua786
-rw-r--r--net/server_event.lua104
-rw-r--r--net/server_select.lua174
-rw-r--r--net/websocket.lua43
-rw-r--r--net/websocket/frames.lua8
18 files changed, 1707 insertions, 279 deletions
diff --git a/net/adns.lua b/net/adns.lua
index a19cbd59..560e4b53 100644
--- a/net/adns.lua
+++ b/net/adns.lua
@@ -17,6 +17,7 @@ local setmetatable = setmetatable;
local function dummy_send(sock, data, i, j) return (j-i)+1; end
local _ENV = nil;
+-- luacheck: std none
local async_resolver_methods = {};
local async_resolver_mt = { __index = async_resolver_methods };
diff --git a/net/connect.lua b/net/connect.lua
new file mode 100644
index 00000000..b812ffcd
--- /dev/null
+++ b/net/connect.lua
@@ -0,0 +1,89 @@
+local server = require "net.server";
+local log = require "util.logger".init("net.connect");
+local new_id = require "util.id".short;
+
+local pending_connection_methods = {};
+local pending_connection_mt = {
+ __name = "pending_connection";
+ __index = pending_connection_methods;
+ __tostring = function (p)
+ return "<pending connection "..p.id.." to "..tostring(p.target_resolver.hostname)..">";
+ end;
+};
+
+function pending_connection_methods:log(level, message, ...)
+ log(level, "[pending connection %s] "..message, self.id, ...);
+end
+
+-- pending_connections_map[conn] = pending_connection
+local pending_connections_map = {};
+
+local pending_connection_listeners = {};
+
+local function attempt_connection(p)
+ p:log("debug", "Checking for targets...");
+ if p.conn then
+ pending_connections_map[p.conn] = nil;
+ p.conn = nil;
+ end
+ p.target_resolver:next(function (conn_type, ip, port, extra)
+ if not conn_type then
+ -- No more targets to try
+ p:log("debug", "No more connection targets to try");
+ if p.listeners.onfail then
+ p.listeners.onfail(p.data, p.last_error or "unable to resolve service");
+ end
+ return;
+ end
+ p:log("debug", "Next target to try is %s:%d", ip, port);
+ local conn, err = server.addclient(ip, port, pending_connection_listeners, p.options.pattern or "*a", p.options.sslctx, conn_type, extra);
+ if not conn then
+ log("debug", "Connection attempt failed immediately: %s", tostring(err));
+ p.last_error = err or "unknown reason";
+ return attempt_connection(p);
+ end
+ p.conn = conn;
+ pending_connections_map[conn] = p;
+ end);
+end
+
+function pending_connection_listeners.onconnect(conn)
+ local p = pending_connections_map[conn];
+ if not p then
+ log("warn", "Successful connection, but unexpected! Closing.");
+ conn:close();
+ return;
+ end
+ pending_connections_map[conn] = nil;
+ p:log("debug", "Successfully connected");
+ conn:setlistener(p.listeners, p.data);
+ return p.listeners.onconnect(conn);
+end
+
+function pending_connection_listeners.ondisconnect(conn, reason)
+ local p = pending_connections_map[conn];
+ if not p then
+ log("warn", "Failed connection, but unexpected!");
+ return;
+ end
+ p.last_error = reason or "unknown reason";
+ p:log("debug", "Connection attempt failed: %s", p.last_error);
+ attempt_connection(p);
+end
+
+local function connect(target_resolver, listeners, options, data)
+ local p = setmetatable({
+ id = new_id();
+ target_resolver = target_resolver;
+ listeners = assert(listeners);
+ options = options or {};
+ data = data;
+ }, pending_connection_mt);
+
+ p:log("debug", "Starting connection process");
+ attempt_connection(p);
+end
+
+return {
+ connect = connect;
+};
diff --git a/net/connlisteners.lua b/net/connlisteners.lua
index 000bfa63..9b8f88c3 100644
--- a/net/connlisteners.lua
+++ b/net/connlisteners.lua
@@ -3,15 +3,15 @@ local log = require "util.logger".init("net.connlisteners");
local traceback = debug.traceback;
local _ENV = nil;
+-- luacheck: std none
local function fail()
- log("error", "Attempt to use legacy connlisteners API. For more info see http://prosody.im/doc/developers/network");
+ log("error", "Attempt to use legacy connlisteners API. For more info see https://prosody.im/doc/developers/network");
log("error", "Legacy connlisteners API usage, %s", traceback("", 2));
end
return {
register = fail;
- register = fail;
get = fail;
start = fail;
-- epic fail
diff --git a/net/cqueues.lua b/net/cqueues.lua
new file mode 100644
index 00000000..8c4c756f
--- /dev/null
+++ b/net/cqueues.lua
@@ -0,0 +1,74 @@
+-- Prosody IM
+-- Copyright (C) 2014 Daurnimator
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+-- This module allows you to use cqueues with a net.server mainloop
+--
+
+local server = require "net.server";
+local cqueues = require "cqueues";
+assert(cqueues.VERSION >= 20150113, "cqueues newer than 20150113 required")
+
+-- Create a single top level cqueue
+local cq;
+
+if server.cq then -- server provides cqueues object
+ cq = server.cq;
+elseif server.get_backend() == "select" and server._addtimer then -- server_select
+ cq = cqueues.new();
+ local function step()
+ assert(cq:loop(0));
+ end
+
+ -- Use wrapclient (as wrapconnection isn't exported) to get server_select to watch cq fd
+ local handler = server.wrapclient({
+ getfd = function() return cq:pollfd(); end;
+ settimeout = function() end; -- Method just needs to exist
+ close = function() end; -- Need close method for 'closeall'
+ }, nil, nil, {});
+
+ -- Only need to listen for readable; cqueues handles everything under the hood
+ -- readbuffer is called when `select` notes an fd as readable
+ handler.readbuffer = step;
+
+ -- Use server_select low lever timer facility,
+ -- this callback gets called *every* time there is a timeout in the main loop
+ server._addtimer(function(current_time)
+ -- This may end up in extra step()'s, but cqueues handles it for us.
+ step();
+ return cq:timeout();
+ end);
+elseif server.event and server.base then -- server_event
+ cq = cqueues.new();
+ -- Only need to listen for readable; cqueues handles everything under the hood
+ local EV_READ = server.event.EV_READ;
+ -- Convert a cqueues timeout to an acceptable timeout for luaevent
+ local function luaevent_safe_timeout(cq)
+ local t = cq:timeout();
+ -- if you give luaevent 0 or nil, it re-uses the previous timeout.
+ if t == 0 then
+ t = 0.000001; -- 1 microsecond is the smallest that works (goes into a `struct timeval`)
+ elseif t == nil then -- pick something big if we don't have one
+ t = 0x7FFFFFFF; -- largest 32bit int
+ end
+ return t
+ end
+ local event_handle;
+ event_handle = server.base:addevent(cq:pollfd(), EV_READ, function(e)
+ -- Need to reference event_handle or this callback will get collected
+ -- This creates a circular reference that can only be broken if event_handle is manually :close()'d
+ local _ = event_handle;
+ -- Run as many cqueues things as possible (with a timeout of 0)
+ -- If an error is thrown, it will break the libevent loop; but prosody resumes after logging a top level error
+ assert(cq:loop(0));
+ return EV_READ, luaevent_safe_timeout(cq);
+ end, luaevent_safe_timeout(cq));
+else
+ error "NYI"
+end
+
+return {
+ cq = cq;
+}
diff --git a/net/dns.lua b/net/dns.lua
index 563c81a6..af5f1216 100644
--- a/net/dns.lua
+++ b/net/dns.lua
@@ -15,6 +15,7 @@
local socket = require "socket";
local timer = require "util.timer";
local new_ip = require "util.ip".new_ip;
+local have_util_net, util_net = pcall(require, "util.net");
local _, windows = pcall(require, "util.windows");
local is_windows = (_ and windows) or os.getenv("WINDIR");
@@ -72,6 +73,7 @@ local default_timeout = 15;
-------------------------------------------------- module dns
local _ENV = nil;
+-- luacheck: std none
local dns = {};
@@ -119,11 +121,99 @@ end
dns.types = {
- 'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS',
- 'PTR', 'HINFO', 'MINFO', 'MX', 'TXT',
- [ 28] = 'AAAA', [ 29] = 'LOC', [ 33] = 'SRV',
- [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' };
-
+ [1] = "A", -- a host address,[RFC1035],,
+ [2] = "NS", -- an authoritative name server,[RFC1035],,
+ [3] = "MD", -- a mail destination (OBSOLETE - use MX),[RFC1035],,
+ [4] = "MF", -- a mail forwarder (OBSOLETE - use MX),[RFC1035],,
+ [5] = "CNAME", -- the canonical name for an alias,[RFC1035],,
+ [6] = "SOA", -- marks the start of a zone of authority,[RFC1035],,
+ [7] = "MB", -- a mailbox domain name (EXPERIMENTAL),[RFC1035],,
+ [8] = "MG", -- a mail group member (EXPERIMENTAL),[RFC1035],,
+ [9] = "MR", -- a mail rename domain name (EXPERIMENTAL),[RFC1035],,
+ [10] = "NULL", -- a null RR (EXPERIMENTAL),[RFC1035],,
+ [11] = "WKS", -- a well known service description,[RFC1035],,
+ [12] = "PTR", -- a domain name pointer,[RFC1035],,
+ [13] = "HINFO", -- host information,[RFC1035],,
+ [14] = "MINFO", -- mailbox or mail list information,[RFC1035],,
+ [15] = "MX", -- mail exchange,[RFC1035],,
+ [16] = "TXT", -- text strings,[RFC1035],,
+ [17] = "RP", -- for Responsible Person,[RFC1183],,
+ [18] = "AFSDB", -- for AFS Data Base location,[RFC1183][RFC5864],,
+ [19] = "X25", -- for X.25 PSDN address,[RFC1183],,
+ [20] = "ISDN", -- for ISDN address,[RFC1183],,
+ [21] = "RT", -- for Route Through,[RFC1183],,
+ [22] = "NSAP", -- "for NSAP address, NSAP style A record",[RFC1706],,
+ [23] = "NSAP-PTR", -- "for domain name pointer, NSAP style",[RFC1348][RFC1637][RFC1706],,
+ [24] = "SIG", -- for security signature,[RFC4034][RFC3755][RFC2535][RFC2536][RFC2537][RFC2931][RFC3110][RFC3008],,
+ [25] = "KEY", -- for security key,[RFC4034][RFC3755][RFC2535][RFC2536][RFC2537][RFC2539][RFC3008][RFC3110],,
+ [26] = "PX", -- X.400 mail mapping information,[RFC2163],,
+ [27] = "GPOS", -- Geographical Position,[RFC1712],,
+ [28] = "AAAA", -- IP6 Address,[RFC3596],,
+ [29] = "LOC", -- Location Information,[RFC1876],,
+ [30] = "NXT", -- Next Domain (OBSOLETE),[RFC3755][RFC2535],,
+ [31] = "EID", -- Endpoint Identifier,[Michael_Patton][http://ana-3.lcs.mit.edu/~jnc/nimrod/dns.txt],,1995-06
+ [32] = "NIMLOC", -- Nimrod Locator,[1][Michael_Patton][http://ana-3.lcs.mit.edu/~jnc/nimrod/dns.txt],,1995-06
+ [33] = "SRV", -- Server Selection,[1][RFC2782],,
+ [34] = "ATMA", -- ATM Address,"[ ATM Forum Technical Committee, ""ATM Name System, V2.0"", Doc ID: AF-DANS-0152.000, July 2000. Available from and held in escrow by IANA.]",,
+ [35] = "NAPTR", -- Naming Authority Pointer,[RFC2915][RFC2168][RFC3403],,
+ [36] = "KX", -- Key Exchanger,[RFC2230],,
+ [37] = "CERT", -- CERT,[RFC4398],,
+ [38] = "A6", -- A6 (OBSOLETE - use AAAA),[RFC3226][RFC2874][RFC6563],,
+ [39] = "DNAME", -- DNAME,[RFC6672],,
+ [40] = "SINK", -- SINK,[Donald_E_Eastlake][http://tools.ietf.org/html/draft-eastlake-kitchen-sink],,1997-11
+ [41] = "OPT", -- OPT,[RFC6891][RFC3225],,
+ [42] = "APL", -- APL,[RFC3123],,
+ [43] = "DS", -- Delegation Signer,[RFC4034][RFC3658],,
+ [44] = "SSHFP", -- SSH Key Fingerprint,[RFC4255],,
+ [45] = "IPSECKEY", -- IPSECKEY,[RFC4025],,
+ [46] = "RRSIG", -- RRSIG,[RFC4034][RFC3755],,
+ [47] = "NSEC", -- NSEC,[RFC4034][RFC3755],,
+ [48] = "DNSKEY", -- DNSKEY,[RFC4034][RFC3755],,
+ [49] = "DHCID", -- DHCID,[RFC4701],,
+ [50] = "NSEC3", -- NSEC3,[RFC5155],,
+ [51] = "NSEC3PARAM", -- NSEC3PARAM,[RFC5155],,
+ [52] = "TLSA", -- TLSA,[RFC6698],,
+ [53] = "SMIMEA", -- S/MIME cert association,[RFC8162],SMIMEA/smimea-completed-template,2015-12-01
+ -- [54] = "Unassigned", -- ,,,
+ [55] = "HIP", -- Host Identity Protocol,[RFC8005],,
+ [56] = "NINFO", -- NINFO,[Jim_Reid],NINFO/ninfo-completed-template,2008-01-21
+ [57] = "RKEY", -- RKEY,[Jim_Reid],RKEY/rkey-completed-template,2008-01-21
+ [58] = "TALINK", -- Trust Anchor LINK,[Wouter_Wijngaards],TALINK/talink-completed-template,2010-02-17
+ [59] = "CDS", -- Child DS,[RFC7344],CDS/cds-completed-template,2011-06-06
+ [60] = "CDNSKEY", -- DNSKEY(s) the Child wants reflected in DS,[RFC7344],,2014-06-16
+ [61] = "OPENPGPKEY", -- OpenPGP Key,[RFC7929],OPENPGPKEY/openpgpkey-completed-template,2014-08-12
+ [62] = "CSYNC", -- Child-To-Parent Synchronization,[RFC7477],,2015-01-27
+ -- [63 .. 98] = "Unassigned", -- ,,,
+ [99] = "SPF", -- ,[RFC7208],,
+ [100] = "UINFO", -- ,[IANA-Reserved],,
+ [101] = "UID", -- ,[IANA-Reserved],,
+ [102] = "GID", -- ,[IANA-Reserved],,
+ [103] = "UNSPEC", -- ,[IANA-Reserved],,
+ [104] = "NID", -- ,[RFC6742],ILNP/nid-completed-template,
+ [105] = "L32", -- ,[RFC6742],ILNP/l32-completed-template,
+ [106] = "L64", -- ,[RFC6742],ILNP/l64-completed-template,
+ [107] = "LP", -- ,[RFC6742],ILNP/lp-completed-template,
+ [108] = "EUI48", -- an EUI-48 address,[RFC7043],EUI48/eui48-completed-template,2013-03-27
+ [109] = "EUI64", -- an EUI-64 address,[RFC7043],EUI64/eui64-completed-template,2013-03-27
+ -- [110 .. 248] = "Unassigned", -- ,,,
+ [249] = "TKEY", -- Transaction Key,[RFC2930],,
+ [250] = "TSIG", -- Transaction Signature,[RFC2845],,
+ [251] = "IXFR", -- incremental transfer,[RFC1995],,
+ [252] = "AXFR", -- transfer of an entire zone,[RFC1035][RFC5936],,
+ [253] = "MAILB", -- "mailbox-related RRs (MB, MG or MR)",[RFC1035],,
+ [254] = "MAILA", -- mail agent RRs (OBSOLETE - see MX),[RFC1035],,
+ [255] = "*", -- A request for all records the server/cache has available,[RFC1035][RFC6895],,
+ [256] = "URI", -- URI,[RFC7553],URI/uri-completed-template,2011-02-22
+ [257] = "CAA", -- Certification Authority Restriction,[RFC6844],CAA/caa-completed-template,2011-04-07
+ [258] = "AVC", -- Application Visibility and Control,[Wolfgang_Riedel],AVC/avc-completed-template,2016-02-26
+ [259] = "DOA", -- Digital Object Architecture,[draft-durand-doa-over-dns],DOA/doa-completed-template,2017-08-30
+ -- [260 .. 32767] = "Unassigned", -- ,,,
+ [32768] = "TA", -- DNSSEC Trust Authorities,"[Sam_Weiler][http://cameo.library.cmu.edu/][ Deploying DNSSEC Without a Signed Root. Technical Report 1999-19, Information Networking Institute, Carnegie Mellon University, April 2004.]",,2005-12-13
+ [32769] = "DLV", -- DNSSEC Lookaside Validation,[RFC4431],,
+ -- [32770 .. 65279] = "Unassigned", -- ,,,
+ -- [65280 .. 65534] = "Private use", -- ,,,
+ -- [65535] = "Reserved", -- ,,,
+}
dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' };
@@ -391,6 +481,12 @@ function resolver:A(rr) -- - - - - - - - - - - - - - - - - - - - - - - - A
rr.a = string.format('%i.%i.%i.%i', b1, b2, b3, b4);
end
+if have_util_net and util_net.ntop then
+ function resolver:A(rr)
+ rr.a = util_net.ntop(self:sub(4));
+ end
+end
+
function resolver:AAAA(rr)
local addr = {};
for _ = 1, rr.rdlength, 2 do
@@ -411,6 +507,12 @@ function resolver:AAAA(rr)
rr.aaaa = addr:gsub(zeros[1], "::", 1):gsub("^0::", "::"):gsub("::0$", "::");
end
+if have_util_net and util_net.ntop then
+ function resolver:AAAA(rr)
+ rr.aaaa = util_net.ntop(self:sub(16));
+ end
+end
+
function resolver:CNAME(rr) -- - - - - - - - - - - - - - - - - - - - CNAME
rr.cname = self:name();
end
diff --git a/net/http.lua b/net/http.lua
index effb0ef5..6e5ad67c 100644
--- a/net/http.lua
+++ b/net/http.lua
@@ -13,9 +13,10 @@ local util_http = require "util.http";
local events = require "util.events";
local verify_identity = require"util.x509".verify_identity;
-local ssl_available = pcall(require, "ssl");
+local basic_resolver = require "net.resolvers.basic";
+local connect = require "net.connect".connect;
-local server = require "net.server"
+local ssl_available = pcall(require, "ssl");
local t_insert, t_concat = table.insert, table.concat;
local pairs = pairs;
@@ -27,6 +28,7 @@ local setmetatable = setmetatable;
local log = require "util.logger".init("http");
local _ENV = nil;
+-- luacheck: std none
local requests = {}; -- Open requests
@@ -34,9 +36,78 @@ local function make_id(req) return (tostring(req):match("%x+$")); end
local listener = { default_port = 80, default_mode = "*a" };
+-- Request-related helper functions
+local function handleerr(err) log("error", "Traceback[http]: %s", traceback(tostring(err), 2)); return err; end
+local function log_if_failed(req, ret, ...)
+ if not ret then
+ log("error", "Request '%s': error in callback: %s", req.id, tostring((...)));
+ if not req.suppress_errors then
+ error(...);
+ end
+ end
+ return ...;
+end
+
+local function destroy_request(request)
+ local conn = request.conn;
+ if conn then
+ request.conn = nil;
+ conn:close()
+ end
+end
+
+local function request_reader(request, data, err)
+ if not request.parser then
+ local function error_cb(reason)
+ if request.callback then
+ request.callback(reason or "connection-closed", 0, request);
+ request.callback = nil;
+ end
+ destroy_request(request);
+ end
+
+ if not data then
+ error_cb(err);
+ return;
+ end
+
+ local function success_cb(r)
+ if request.callback then
+ request.callback(r.body, r.code, r, request);
+ request.callback = nil;
+ end
+ destroy_request(request);
+ end
+ local function options_cb()
+ return request;
+ end
+ request.parser = httpstream_new(success_cb, error_cb, "client", options_cb);
+ end
+ request.parser:feed(data);
+end
+
+-- Connection listener callbacks
function listener.onconnect(conn)
local req = requests[conn];
+ -- Initialize request object
+ req.write = function (...) return req.conn:write(...); end
+ local callback = req.callback;
+ req.callback = function (content, code, response, request)
+ do
+ local event = { http = req.http, url = req.url, request = req, response = response, content = content, code = code, callback = req.callback };
+ req.http.events.fire_event("response", event);
+ content, code, response = event.content, event.code, event.response;
+ end
+
+ log("debug", "Request '%s': Calling callback, status %s", req.id, code or "---");
+ return log_if_failed(req.id, xpcall(function () return callback(content, code, response, request) end, handleerr));
+ end
+ req.reader = request_reader;
+ req.state = "status";
+
+ requests[req.conn] = req;
+
-- Validate certificate
if not req.insecure and conn:ssl() then
local sock = conn:socket();
@@ -96,58 +167,24 @@ function listener.ondisconnect(conn, err)
requests[conn] = nil;
end
-function listener.ondetach(conn)
- requests[conn] = nil;
-end
-
-local function destroy_request(request)
- if request.conn then
- request.conn = nil;
- request.handler:close()
- end
+function listener.onattach(conn, req)
+ requests[conn] = req;
+ req.conn = conn;
end
-local function request_reader(request, data, err)
- if not request.parser then
- local function error_cb(reason)
- if request.callback then
- request.callback(reason or "connection-closed", 0, request);
- request.callback = nil;
- end
- destroy_request(request);
- end
-
- if not data then
- error_cb(err);
- return;
- end
-
- local function success_cb(r)
- if request.callback then
- request.callback(r.body, r.code, r, request);
- request.callback = nil;
- end
- destroy_request(request);
- end
- local function options_cb()
- return request;
- end
- request.parser = httpstream_new(success_cb, error_cb, "client", options_cb);
- end
- request.parser:feed(data);
+function listener.ondetach(conn)
+ requests[conn] = nil;
end
-local function handleerr(err) log("error", "Traceback[http]: %s", traceback(tostring(err), 2)); end
-local function log_if_failed(id, ret, ...)
- if not ret then
- log("error", "Request '%s': error in callback: %s", id, tostring((...)));
- end
- return ...;
+function listener.onfail(req, reason)
+ req.http.events.fire_event("request-connection-error", { http = req.http, request = req, url = req.url, err = reason });
+ req.callback(reason or "connection failed", 0, req);
end
local function request(self, u, ex, callback)
local req = url.parse(u);
req.url = u;
+ req.http = self;
if not (req and req.host) then
callback("invalid-url", 0, req);
@@ -166,7 +203,7 @@ local function request(self, u, ex, callback)
if ret then
return ret;
end
- req, u, ex, callback = event.request, event.url, event.options, event.callback;
+ req, u, ex, req.callback = event.request, event.url, event.options, event.callback;
end
local method, headers, body;
@@ -204,6 +241,7 @@ local function request(self, u, ex, callback)
end
end
req.insecure = ex.insecure;
+ req.suppress_errors = ex.suppress_errors;
end
log("debug", "Making %s %s request '%s' to %s", req.scheme:upper(), method or "GET", req.id, (ex and ex.suppress_url and host_header) or u);
@@ -222,29 +260,8 @@ local function request(self, u, ex, callback)
sslctx = ex and ex.sslctx or self.options and self.options.sslctx;
end
- local handler, conn = server.addclient(host, port_number, listener, "*a", sslctx)
- if not handler then
- self.events.fire_event("request-connection-error", { http = self, request = req, url = u, err = conn });
- callback(conn, 0, req);
- return nil, conn;
- end
- req.handler, req.conn = handler, conn
- req.write = function (...) return req.handler:write(...); end
-
- req.callback = function (content, code, response, request)
- do
- local event = { http = self, url = u, request = req, response = response, content = content, code = code, callback = callback };
- self.events.fire_event("response", event);
- content, code, response = event.content, event.code, event.response;
- end
-
- log("debug", "Request '%s': Calling callback, status %s", req.id, code or "---");
- return log_if_failed(req.id, xpcall(function () return callback(content, code, response, request) end, handleerr));
- end
- req.reader = request_reader;
- req.state = "status";
-
- requests[req.handler] = req;
+ local http_service = basic_resolver.new(host, port_number);
+ connect(http_service, listener, { sslctx = sslctx }, req);
self.events.fire_event("request", { http = self, request = req, url = u });
return req;
@@ -264,6 +281,7 @@ end
local default_http = new({
sslctx = { mode = "client", protocol = "sslv23", options = { "no_sslv2", "no_sslv3" } };
+ suppress_errors = true;
});
return {
diff --git a/net/http/codes.lua b/net/http/codes.lua
index 1090e545..8098b5c3 100644
--- a/net/http/codes.lua
+++ b/net/http/codes.lua
@@ -1,73 +1,85 @@
local response_codes = {
-- Source: http://www.iana.org/assignments/http-status-codes
- -- s/^\(\d*\)\s*\(.*\S\)\s*\[RFC.*\]\s*$/^I["\1"] = "\2";
- [100] = "Continue";
- [101] = "Switching Protocols";
+
+ [100] = "Continue"; -- RFC7231, Section 6.2.1
+ [101] = "Switching Protocols"; -- RFC7231, Section 6.2.2
[102] = "Processing";
+ [103] = "Early Hints";
+ -- [104-199] = "Unassigned";
- [200] = "OK";
- [201] = "Created";
- [202] = "Accepted";
- [203] = "Non-Authoritative Information";
- [204] = "No Content";
- [205] = "Reset Content";
- [206] = "Partial Content";
+ [200] = "OK"; -- RFC7231, Section 6.3.1
+ [201] = "Created"; -- RFC7231, Section 6.3.2
+ [202] = "Accepted"; -- RFC7231, Section 6.3.3
+ [203] = "Non-Authoritative Information"; -- RFC7231, Section 6.3.4
+ [204] = "No Content"; -- RFC7231, Section 6.3.5
+ [205] = "Reset Content"; -- RFC7231, Section 6.3.6
+ [206] = "Partial Content"; -- RFC7233, Section 4.1
[207] = "Multi-Status";
[208] = "Already Reported";
+ -- [209-225] = "Unassigned";
[226] = "IM Used";
+ -- [227-299] = "Unassigned";
- [300] = "Multiple Choices";
- [301] = "Moved Permanently";
- [302] = "Found";
- [303] = "See Other";
- [304] = "Not Modified";
- [305] = "Use Proxy";
- -- The 306 status code was used in a previous version of [RFC2616], is no longer used, and the code is reserved.
- [307] = "Temporary Redirect";
+ [300] = "Multiple Choices"; -- RFC7231, Section 6.4.1
+ [301] = "Moved Permanently"; -- RFC7231, Section 6.4.2
+ [302] = "Found"; -- RFC7231, Section 6.4.3
+ [303] = "See Other"; -- RFC7231, Section 6.4.4
+ [304] = "Not Modified"; -- RFC7232, Section 4.1
+ [305] = "Use Proxy"; -- RFC7231, Section 6.4.5
+ -- [306] = "(Unused)"; -- RFC7231, Section 6.4.6
+ [307] = "Temporary Redirect"; -- RFC7231, Section 6.4.7
[308] = "Permanent Redirect";
+ -- [309-399] = "Unassigned";
- [400] = "Bad Request";
- [401] = "Unauthorized";
- [402] = "Payment Required";
- [403] = "Forbidden";
- [404] = "Not Found";
- [405] = "Method Not Allowed";
- [406] = "Not Acceptable";
- [407] = "Proxy Authentication Required";
- [408] = "Request Timeout";
- [409] = "Conflict";
- [410] = "Gone";
- [411] = "Length Required";
- [412] = "Precondition Failed";
- [413] = "Payload Too Large";
- [414] = "URI Too Long";
- [415] = "Unsupported Media Type";
- [416] = "Range Not Satisfiable";
- [417] = "Expectation Failed";
- [418] = "I'm a teapot";
- [421] = "Misdirected Request";
+ [400] = "Bad Request"; -- RFC7231, Section 6.5.1
+ [401] = "Unauthorized"; -- RFC7235, Section 3.1
+ [402] = "Payment Required"; -- RFC7231, Section 6.5.2
+ [403] = "Forbidden"; -- RFC7231, Section 6.5.3
+ [404] = "Not Found"; -- RFC7231, Section 6.5.4
+ [405] = "Method Not Allowed"; -- RFC7231, Section 6.5.5
+ [406] = "Not Acceptable"; -- RFC7231, Section 6.5.6
+ [407] = "Proxy Authentication Required"; -- RFC7235, Section 3.2
+ [408] = "Request Timeout"; -- RFC7231, Section 6.5.7
+ [409] = "Conflict"; -- RFC7231, Section 6.5.8
+ [410] = "Gone"; -- RFC7231, Section 6.5.9
+ [411] = "Length Required"; -- RFC7231, Section 6.5.10
+ [412] = "Precondition Failed"; -- RFC7232, Section 4.2
+ [413] = "Payload Too Large"; -- RFC7231, Section 6.5.11
+ [414] = "URI Too Long"; -- RFC7231, Section 6.5.12
+ [415] = "Unsupported Media Type"; -- RFC7231, Section 6.5.13
+ [416] = "Range Not Satisfiable"; -- RFC7233, Section 4.4
+ [417] = "Expectation Failed"; -- RFC7231, Section 6.5.14
+ [418] = "I'm a teapot"; -- RFC2324, Section 2.3.2
+ -- [419-420] = "Unassigned";
+ [421] = "Misdirected Request"; -- RFC7540, Section 9.1.2
[422] = "Unprocessable Entity";
[423] = "Locked";
[424] = "Failed Dependency";
- -- The 425 status code is reserved for the WebDAV advanced collections expired proposal [RFC2817]
- [426] = "Upgrade Required";
+ [425] = "Too Early";
+ [426] = "Upgrade Required"; -- RFC7231, Section 6.5.15
+ -- [427] = "Unassigned";
[428] = "Precondition Required";
[429] = "Too Many Requests";
+ -- [430] = "Unassigned";
[431] = "Request Header Fields Too Large";
+ -- [432-450] = "Unassigned";
[451] = "Unavailable For Legal Reasons";
+ -- [452-499] = "Unassigned";
- [500] = "Internal Server Error";
- [501] = "Not Implemented";
- [502] = "Bad Gateway";
- [503] = "Service Unavailable";
- [504] = "Gateway Timeout";
- [505] = "HTTP Version Not Supported";
- [506] = "Variant Also Negotiates"; -- Experimental
+ [500] = "Internal Server Error"; -- RFC7231, Section 6.6.1
+ [501] = "Not Implemented"; -- RFC7231, Section 6.6.2
+ [502] = "Bad Gateway"; -- RFC7231, Section 6.6.3
+ [503] = "Service Unavailable"; -- RFC7231, Section 6.6.4
+ [504] = "Gateway Timeout"; -- RFC7231, Section 6.6.5
+ [505] = "HTTP Version Not Supported"; -- RFC7231, Section 6.6.6
+ [506] = "Variant Also Negotiates";
[507] = "Insufficient Storage";
[508] = "Loop Detected";
+ -- [509] = "Unassigned";
[510] = "Not Extended";
[511] = "Network Authentication Required";
+ -- [512-599] = "Unassigned";
};
for k,v in pairs(response_codes) do response_codes[k] = k.." "..v; end
diff --git a/net/http/server.lua b/net/http/server.lua
index 877c7f17..b73f3385 100644
--- a/net/http/server.lua
+++ b/net/http/server.lua
@@ -217,14 +217,6 @@ function handle_request(conn, request, finish_cb)
local err_code, err;
if not request.path then
err_code, err = 400, "Invalid path";
- elseif not hosts[host] then
- if hosts[default_host] then
- host = default_host;
- elseif host then
- err_code, err = 404, "Unknown host: "..host;
- else
- err_code, err = 400, "Missing or invalid 'Host' header";
- end
end
if err then
@@ -233,10 +225,32 @@ function handle_request(conn, request, finish_cb)
return;
end
- local event = request.method.." "..host..request.path:match("[^?]*");
+ local global_event = request.method.." "..request.path:match("[^?]*");
+
local payload = { request = request, response = response };
- log("debug", "Firing event: %s", event);
- local result = events.fire_event(event, payload);
+ log("debug", "Firing event: %s", global_event);
+ local result = events.fire_event(global_event, payload);
+ if result == nil then
+ if not hosts[host] then
+ if hosts[default_host] then
+ host = default_host;
+ elseif host then
+ err_code, err = 404, "Unknown host: "..host;
+ else
+ err_code, err = 400, "Missing or invalid 'Host' header";
+ end
+ end
+ local host_event = request.method.." "..host..request.path:match("[^?]*");
+
+ if err then
+ response.status_code = err_code;
+ response:send(events.fire_event("http-error", { code = err_code, message = err, response = response }));
+ return;
+ end
+
+ log("debug", "Firing event: %s", host_event);
+ result = events.fire_event(host_event, payload);
+ end
if result ~= nil then
if result ~= true then
local body;
diff --git a/net/httpserver.lua b/net/httpserver.lua
index 6e2e31b9..6b14313b 100644
--- a/net/httpserver.lua
+++ b/net/httpserver.lua
@@ -3,9 +3,10 @@ local log = require "util.logger".init("net.httpserver");
local traceback = debug.traceback;
local _ENV = nil;
+-- luacheck: std none
-function fail()
- log("error", "Attempt to use legacy HTTP API. For more info see http://prosody.im/doc/developers/legacy_http");
+local function fail()
+ log("error", "Attempt to use legacy HTTP API. For more info see https://prosody.im/doc/developers/legacy_http");
log("error", "Legacy HTTP API usage, %s", traceback("", 2));
end
diff --git a/net/resolvers/basic.lua b/net/resolvers/basic.lua
new file mode 100644
index 00000000..9a3c9952
--- /dev/null
+++ b/net/resolvers/basic.lua
@@ -0,0 +1,71 @@
+local adns = require "net.adns";
+local inet_pton = require "util.net".pton;
+
+local methods = {};
+local resolver_mt = { __index = methods };
+
+-- Find the next target to connect to, and
+-- pass it to cb()
+function methods:next(cb)
+ if self.targets then
+ if #self.targets == 0 then
+ cb(nil);
+ return;
+ end
+ local next_target = table.remove(self.targets, 1);
+ cb(unpack(next_target, 1, 4));
+ return;
+ end
+
+ local targets = {};
+ local n = 2;
+ local function ready()
+ n = n - 1;
+ if n > 0 then return; end
+ self.targets = targets;
+ self:next(cb);
+ end
+
+ local is_ip = inet_pton(self.hostname);
+ if is_ip then
+ if #is_ip == 16 then
+ cb(self.conn_type.."6", self.hostname, self.port, self.extra);
+ elseif #is_ip == 4 then
+ cb(self.conn_type.."4", self.hostname, self.port, self.extra);
+ end
+ return;
+ end
+
+ -- Resolve DNS to target list
+ local dns_resolver = adns.resolver();
+ dns_resolver:lookup(function (answer)
+ if answer then
+ for _, record in ipairs(answer) do
+ table.insert(targets, { self.conn_type.."4", record.a, self.port, self.extra });
+ end
+ end
+ ready();
+ end, self.hostname, "A", "IN");
+
+ dns_resolver:lookup(function (answer)
+ if answer then
+ for _, record in ipairs(answer) do
+ table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra });
+ end
+ end
+ ready();
+ end, self.hostname, "AAAA", "IN");
+end
+
+local function new(hostname, port, conn_type, extra)
+ return setmetatable({
+ hostname = hostname;
+ port = port;
+ conn_type = conn_type or "tcp";
+ extra = extra;
+ }, resolver_mt);
+end
+
+return {
+ new = new;
+};
diff --git a/net/resolvers/manual.lua b/net/resolvers/manual.lua
new file mode 100644
index 00000000..c0d4e5d5
--- /dev/null
+++ b/net/resolvers/manual.lua
@@ -0,0 +1,25 @@
+local methods = {};
+local resolver_mt = { __index = methods };
+
+-- Find the next target to connect to, and
+-- pass it to cb()
+function methods:next(cb)
+ if #self.targets == 0 then
+ cb(nil);
+ return;
+ end
+ local next_target = table.remove(self.targets, 1);
+ cb(unpack(next_target, 1, 4));
+end
+
+local function new(targets, conn_type, extra)
+ return setmetatable({
+ conn_type = conn_type;
+ extra = extra;
+ targets = targets or {};
+ }, resolver_mt);
+end
+
+return {
+ new = new;
+};
diff --git a/net/resolvers/service.lua b/net/resolvers/service.lua
new file mode 100644
index 00000000..b5a2d821
--- /dev/null
+++ b/net/resolvers/service.lua
@@ -0,0 +1,70 @@
+local adns = require "net.adns";
+local basic = require "net.resolvers.basic";
+
+local methods = {};
+local resolver_mt = { __index = methods };
+
+-- Find the next target to connect to, and
+-- pass it to cb()
+function methods:next(cb)
+ if self.targets then
+ if #self.targets == 0 then
+ cb(nil);
+ return;
+ end
+ local next_target = table.remove(self.targets, 1);
+ self.resolver = basic.new(unpack(next_target, 1, 4));
+ self.resolver:next(function (...)
+ if ... == nil then
+ self:next(cb);
+ else
+ cb(...);
+ end
+ end);
+ return;
+ end
+
+ local targets = {};
+ local function ready()
+ self.targets = targets;
+ self:next(cb);
+ end
+
+ -- Resolve DNS to target list
+ local dns_resolver = adns.resolver();
+ dns_resolver:lookup(function (answer)
+ if answer then
+ if #answer == 0 then
+ if self.extra and self.extra.default_port then
+ table.insert(targets, { self.hostname, self.extra.default_port, self.conn_type, self.extra });
+ end
+ ready();
+ return;
+ end
+
+ if #answer == 1 and answer[1].srv.target == "." then -- No service here
+ ready();
+ return;
+ end
+
+ table.sort(answer, function (a, b) return a.srv.priority < b.srv.priority end);
+ for _, record in ipairs(answer) do
+ table.insert(targets, { record.srv.target, record.srv.port, self.conn_type, self.extra });
+ end
+ end
+ ready();
+ end, "_" .. self.service .. "._" .. self.conn_type .. "." .. self.hostname, "SRV", "IN");
+end
+
+local function new(hostname, service, conn_type, extra)
+ return setmetatable({
+ hostname = hostname;
+ service = service;
+ conn_type = conn_type or "tcp";
+ extra = extra;
+ }, resolver_mt);
+end
+
+return {
+ new = new;
+};
diff --git a/net/server.lua b/net/server.lua
index 41e180fa..abbb421d 100644
--- a/net/server.lua
+++ b/net/server.lua
@@ -6,25 +6,83 @@
-- COPYING file in the source package for more information.
--
-local use_luaevent = prosody and require "core.configmanager".get("*", "use_libevent");
+if not (prosody and prosody.config_loaded) then
+ -- This module only supports loading inside Prosody, outside Prosody
+ -- you should directly require net.server_select or server_event, etc.
+ error(debug.traceback("Loading outside Prosody or Prosody not yet initialized"), 0);
+end
+
+local log = require "util.logger".init("net.server");
+local server_type = require "core.configmanager".get("*", "network_backend") or "select";
-if use_luaevent then
- use_luaevent = pcall(require, "luaevent.core");
- if not use_luaevent then
+if require "core.configmanager".get("*", "use_libevent") then
+ server_type = "event";
+end
+
+if server_type == "event" then
+ if not pcall(require, "luaevent.core") then
log("error", "libevent not found, falling back to select()");
+ server_type = "select"
end
end
local server;
-
-if use_luaevent then
+local set_config;
+if server_type == "event" then
server = require "net.server_event";
- -- Overwrite signal.signal() because we need to ask libevent to
- -- handle them instead
- local ok, signal = pcall(require, "util.signal");
- if ok and signal then
- local _signal_signal = signal.signal;
+ local defaults = {};
+ for k,v in pairs(server.cfg) do
+ defaults[k] = v;
+ end
+ function set_config(settings)
+ local event_settings = {
+ ACCEPT_DELAY = settings.accept_retry_interval;
+ ACCEPT_QUEUE = settings.tcp_backlog;
+ CLEAR_DELAY = settings.event_clear_interval;
+ CONNECT_TIMEOUT = settings.connect_timeout;
+ DEBUG = settings.debug;
+ HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout;
+ MAX_CONNECTIONS = settings.max_connections;
+ MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips;
+ MAX_READ_LENGTH = settings.max_receive_buffer_size;
+ MAX_SEND_LENGTH = settings.max_send_buffer_size;
+ READ_TIMEOUT = settings.read_timeout;
+ WRITE_TIMEOUT = settings.send_timeout;
+ };
+
+ for k,default in pairs(defaults) do
+ server.cfg[k] = event_settings[k] or default;
+ end
+ end
+elseif server_type == "select" then
+ server = require "net.server_select";
+
+ local defaults = {};
+ for k,v in pairs(server.getsettings()) do
+ defaults[k] = v;
+ end
+ function set_config(settings)
+ local select_settings = {};
+ for k,default in pairs(defaults) do
+ select_settings[k] = settings[k] or default;
+ end
+ server.changesettings(select_settings);
+ end
+else
+ server = require("net.server_"..server_type);
+ set_config = server.set_config;
+ if not server.get_backend then
+ function server.get_backend()
+ return server_type;
+ end
+ end
+end
+
+-- If server.hook_signal exists, replace signal.signal()
+local has_signal, signal = pcall(require, "util.signal");
+if has_signal then
+ if server.hook_signal then
function signal.signal(signal_id, handler)
if type(signal_id) == "string" then
signal_id = signal[signal_id:upper()];
@@ -34,46 +92,22 @@ if use_luaevent then
end
return server.hook_signal(signal_id, handler);
end
+ else
+ server.hook_signal = signal.signal;
end
else
- use_luaevent = false;
- server = require "net.server_select";
+ if not server.hook_signal then
+ server.hook_signal = function()
+ return false, "signal hooking not supported"
+ end
+ end
end
-if prosody then
+if prosody and set_config then
local config_get = require "core.configmanager".get;
- local defaults = {};
- for k,v in pairs(server.cfg or server.getsettings()) do
- defaults[k] = v;
- end
local function load_config()
local settings = config_get("*", "network_settings") or {};
- if use_luaevent then
- local event_settings = {
- ACCEPT_DELAY = settings.accept_retry_interval;
- ACCEPT_QUEUE = settings.tcp_backlog;
- CLEAR_DELAY = settings.event_clear_interval;
- CONNECT_TIMEOUT = settings.connect_timeout;
- DEBUG = settings.debug;
- HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout;
- MAX_CONNECTIONS = settings.max_connections;
- MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips;
- MAX_READ_LENGTH = settings.max_receive_buffer_size;
- MAX_SEND_LENGTH = settings.max_send_buffer_size;
- READ_TIMEOUT = settings.read_timeout;
- WRITE_TIMEOUT = settings.send_timeout;
- };
-
- for k,default in pairs(defaults) do
- server.cfg[k] = event_settings[k] or default;
- end
- else
- local select_settings = {};
- for k,default in pairs(defaults) do
- select_settings[k] = settings[k] or default;
- end
- server.changesettings(select_settings);
- end
+ return set_config(settings);
end
load_config();
prosody.events.add_handler("config-reloaded", load_config);
diff --git a/net/server_epoll.lua b/net/server_epoll.lua
new file mode 100644
index 00000000..e0189179
--- /dev/null
+++ b/net/server_epoll.lua
@@ -0,0 +1,786 @@
+-- Prosody IM
+-- Copyright (C) 2016-2018 Kim Alvefur
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+
+local t_sort = table.sort;
+local t_insert = table.insert;
+local t_remove = table.remove;
+local t_concat = table.concat;
+local setmetatable = setmetatable;
+local tostring = tostring;
+local pcall = pcall;
+local type = type;
+local next = next;
+local pairs = pairs;
+local log = require "util.logger".init("server_epoll");
+local socket = require "socket";
+local luasec = require "ssl";
+local gettime = require "util.time".now;
+local createtable = require "util.table".create;
+local inet = require "util.net";
+local inet_pton = inet.pton;
+local _SOCKETINVALID = socket._SOCKETINVALID or -1;
+
+local poll = assert(require "util.poll".new());
+
+local _ENV = nil;
+-- luacheck: std none
+
+local default_config = { __index = {
+ read_timeout = 14 * 60;
+ write_timeout = 7;
+ tcp_backlog = 128;
+ accept_retry_interval = 10;
+ read_retry_delay = 1e-06;
+ read_size = 8192;
+ connect_timeout = 20;
+ handshake_timeout = 60;
+ max_wait = 86400;
+ min_wait = 1e-06;
+}};
+local cfg = default_config.__index;
+
+local fds = createtable(10, 0); -- FD -> conn
+
+-- Timer and scheduling --
+
+local timers = {};
+
+local function noop() end
+local function closetimer(t)
+ t[1] = 0;
+ t[2] = noop;
+end
+
+-- Set to true when timers have changed
+local resort_timers = false;
+
+-- Add absolute timer
+local function at(time, f)
+ local timer = { time, f, close = closetimer };
+ t_insert(timers, timer);
+ resort_timers = true;
+ return timer;
+end
+
+-- Add relative timer
+local function addtimer(timeout, f)
+ return at(gettime() + timeout, f);
+end
+
+-- Run callbacks of expired timers
+-- Return time until next timeout
+local function runtimers(next_delay, min_wait)
+ -- Any timers at all?
+ if not timers[1] then
+ return next_delay;
+ end
+
+ if resort_timers then
+ -- Sort earliest timers to the end
+ t_sort(timers, function (a, b) return a[1] > b[1]; end);
+ resort_timers = false;
+ end
+
+ -- Iterate from the end and remove completed timers
+ for i = #timers, 1, -1 do
+ local timer = timers[i];
+ local t, f = timer[1], timer[2];
+ -- Get time for every iteration to increase accuracy
+ local now = gettime();
+ if t > now then
+ -- This timer should not fire yet
+ local diff = t - now;
+ if diff < next_delay then
+ next_delay = diff;
+ end
+ break;
+ end
+ local new_timeout = f(now);
+ if new_timeout then
+ -- Schedule for 'delay' from the time actually scheduled,
+ -- not from now, in order to prevent timer drift.
+ timer[1] = t + new_timeout;
+ resort_timers = true;
+ else
+ t_remove(timers, i);
+ end
+ end
+
+ if resort_timers or next_delay < min_wait then
+ -- Timers may be added from within a timer callback.
+ -- Those would not be considered for next_delay,
+ -- and we might sleep for too long, so instead
+ -- we return a shorter timeout so we can
+ -- properly sort all new timers.
+ next_delay = min_wait;
+ end
+
+ return next_delay;
+end
+
+-- Socket handler interface
+
+local interface = {};
+local interface_mt = { __index = interface };
+
+function interface_mt:__tostring()
+ if self.sockname and self.peername then
+ return ("FD %d (%s, %d, %s, %d)"):format(self:getfd(), self.peername, self.peerport, self.sockname, self.sockport);
+ elseif self.sockname or self.peername then
+ return ("FD %d (%s, %d)"):format(self:getfd(), self.sockname or self.peername, self.sockport or self.peerport);
+ end
+ return ("FD %d"):format(self:getfd());
+end
+
+-- Replace the listener and tell the old one
+function interface:setlistener(listeners, data)
+ self:on("detach");
+ self.listeners = listeners;
+ self:on("attach", data);
+end
+
+-- Call a listener callback
+function interface:on(what, ...)
+ if not self.listeners then
+ log("error", "%s has no listeners", self);
+ return;
+ end
+ local listener = self.listeners["on"..what];
+ if not listener then
+ -- log("debug", "Missing listener 'on%s'", what); -- uncomment for development and debugging
+ return;
+ end
+ local ok, err = pcall(listener, self, ...);
+ if not ok then
+ log("error", "Error calling on%s: %s", what, err);
+ end
+ return err;
+end
+
+-- Return the file descriptor number
+function interface:getfd()
+ if self.conn then
+ return self.conn:getfd();
+ end
+ return _SOCKETINVALID;
+end
+
+function interface:server()
+ return self._server or self;
+end
+
+-- Get IP address
+function interface:ip()
+ return self.peername or self.sockname;
+end
+
+-- Get a port number, doesn't matter which
+function interface:port()
+ return self.sockport or self.peerport;
+end
+
+-- Get local port number
+function interface:clientport()
+ return self.sockport;
+end
+
+-- Get remote port
+function interface:serverport()
+ if self.sockport then
+ return self.sockport;
+ elseif self._server then
+ self._server:port();
+ end
+end
+
+-- Return underlying socket
+function interface:socket()
+ return self.conn;
+end
+
+function interface:set_mode(new_mode)
+ self.read_size = new_mode;
+end
+
+function interface:setoption(k, v)
+ -- LuaSec doesn't expose setoption :(
+ if self.conn.setoption then
+ self.conn:setoption(k, v);
+ end
+end
+
+-- Timeout for detecting dead or idle sockets
+function interface:setreadtimeout(t)
+ if t == false then
+ if self._readtimeout then
+ self._readtimeout:close();
+ self._readtimeout = nil;
+ end
+ return
+ end
+ t = t or cfg.read_timeout;
+ if self._readtimeout then
+ self._readtimeout[1] = gettime() + t;
+ resort_timers = true;
+ else
+ self._readtimeout = addtimer(t, function ()
+ if self:on("readtimeout") then
+ return cfg.read_timeout;
+ else
+ self:on("disconnect", "read timeout");
+ self:destroy();
+ end
+ end);
+ end
+end
+
+-- Timeout for detecting dead sockets
+function interface:setwritetimeout(t)
+ if t == false then
+ if self._writetimeout then
+ self._writetimeout:close();
+ self._writetimeout = nil;
+ end
+ return
+ end
+ t = t or cfg.write_timeout;
+ if self._writetimeout then
+ self._writetimeout[1] = gettime() + t;
+ resort_timers = true;
+ else
+ self._writetimeout = addtimer(t, function ()
+ self:on("disconnect", "write timeout");
+ self:destroy();
+ end);
+ end
+end
+
+function interface:add(r, w)
+ local fd = self:getfd();
+ if fd < 0 then
+ return nil, "invalid fd";
+ end
+ if r == nil then r = self._wantread; end
+ if w == nil then w = self._wantwrite; end
+ local ok, err, errno = poll:add(fd, r, w);
+ if not ok then
+ log("error", "Could not register %s: %s(%d)", self, err, errno);
+ return ok, err;
+ end
+ self._wantread, self._wantwrite = r, w;
+ fds[fd] = self;
+ log("debug", "Watching %s", self);
+ return true;
+end
+
+function interface:set(r, w)
+ local fd = self:getfd();
+ if fd < 0 then
+ return nil, "invalid fd";
+ end
+ if r == nil then r = self._wantread; end
+ if w == nil then w = self._wantwrite; end
+ local ok, err, errno = poll:set(fd, r, w);
+ if not ok then
+ log("error", "Could not update poller state %s: %s(%d)", self, err, errno);
+ return ok, err;
+ end
+ self._wantread, self._wantwrite = r, w;
+ return true;
+end
+
+function interface:del()
+ local fd = self:getfd();
+ if fd < 0 then
+ return nil, "invalid fd";
+ end
+ if fds[fd] ~= self then
+ return nil, "unregistered fd";
+ end
+ local ok, err, errno = poll:del(fd);
+ if not ok then
+ log("error", "Could not unregister %s: %s(%d)", self, err, errno);
+ return ok, err;
+ end
+ self._wantread, self._wantwrite = nil, nil;
+ fds[fd] = nil;
+ log("debug", "Unwatched %s", self);
+ return true;
+end
+
+function interface:setflags(r, w)
+ if not(self._wantread or self._wantwrite) then
+ if not(r or w) then
+ return true; -- no change
+ end
+ return self:add(r, w);
+ end
+ if not(r or w) then
+ return self:del();
+ end
+ return self:set(r, w);
+end
+
+-- Called when socket is readable
+function interface:onreadable()
+ local data, err, partial = self.conn:receive(self.read_size or cfg.read_size);
+ if data then
+ self:onconnect();
+ self:on("incoming", data);
+ else
+ if partial and partial ~= "" then
+ self:onconnect();
+ self:on("incoming", partial, err);
+ end
+ if err == "wantread" then
+ self:set(true, nil);
+ elseif err == "wantwrite" then
+ self:set(nil, true);
+ elseif err ~= "timeout" then
+ self:on("disconnect", err);
+ self:destroy()
+ return;
+ end
+ end
+ if not self.conn then return; end
+ if self.conn:dirty() then
+ self:setreadtimeout(false);
+ self:pausefor(cfg.read_retry_delay);
+ else
+ self:setreadtimeout();
+ end
+end
+
+-- Called when socket is writable
+function interface:onwritable()
+ self:onconnect();
+ if not self.conn then return; end -- could have been closed in onconnect
+ local buffer = self.writebuffer;
+ local data = t_concat(buffer);
+ local ok, err, partial = self.conn:send(data);
+ if ok then
+ self:set(nil, false);
+ for i = #buffer, 1, -1 do
+ buffer[i] = nil;
+ end
+ self:setwritetimeout(false);
+ self:ondrain(); -- Be aware of writes in ondrain
+ return;
+ elseif partial then
+ buffer[1] = data:sub(partial+1);
+ for i = #buffer, 2, -1 do
+ buffer[i] = nil;
+ end
+ self:setwritetimeout();
+ end
+ if err == "wantwrite" or err == "timeout" then
+ self:set(nil, true);
+ elseif err == "wantread" then
+ self:set(true, nil);
+ elseif err ~= "timeout" then
+ self:on("disconnect", err);
+ self:destroy();
+ end
+end
+
+-- The write buffer has been successfully emptied
+function interface:ondrain()
+ return self:on("drain");
+end
+
+-- Add data to write buffer and set flag for wanting to write
+function interface:write(data)
+ local buffer = self.writebuffer;
+ if buffer then
+ t_insert(buffer, data);
+ else
+ self.writebuffer = { data };
+ end
+ self:setwritetimeout();
+ self:set(nil, true);
+ return #data;
+end
+interface.send = interface.write;
+
+-- Close, possibly after writing is done
+function interface:close()
+ if self.writebuffer and self.writebuffer[1] then
+ self:set(false, true); -- Flush final buffer contents
+ self.write, self.send = noop, noop; -- No more writing
+ log("debug", "Close %s after writing", self);
+ self.ondrain = interface.close;
+ else
+ log("debug", "Close %s now", self);
+ self.write, self.send = noop, noop;
+ self.close = noop;
+ self:on("disconnect");
+ self:destroy();
+ end
+end
+
+function interface:destroy()
+ self:del();
+ self:setwritetimeout(false);
+ self:setreadtimeout(false);
+ self.onreadable = noop;
+ self.onwritable = noop;
+ self.destroy = noop;
+ self.close = noop;
+ self.on = noop;
+ self.conn:close();
+ self.conn = nil;
+end
+
+function interface:ssl()
+ return self._tls;
+end
+
+function interface:starttls(tls_ctx)
+ if tls_ctx then self.tls_ctx = tls_ctx; end
+ self.starttls = false;
+ if self.writebuffer and self.writebuffer[1] then
+ log("debug", "Start TLS on %s after write", self);
+ self.ondrain = interface.starttls;
+ self:set(nil, true); -- make sure wantwrite is set
+ else
+ if self.ondrain == interface.starttls then
+ self.ondrain = nil;
+ end
+ self.onwritable = interface.tlshandskake;
+ self.onreadable = interface.tlshandskake;
+ self:set(true, true);
+ log("debug", "Prepare to start TLS on %s", self);
+ end
+end
+
+function interface:tlshandskake()
+ self:setwritetimeout(false);
+ self:setreadtimeout(false);
+ if not self._tls then
+ self._tls = true;
+ log("debug", "Start TLS on %s now", self);
+ self:del();
+ local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx);
+ if not ok then
+ log("error", "Failed to initialize TLS: %s", conn);
+ conn, err = ok, conn;
+ end
+ if not conn then
+ self:on("disconnect", err);
+ self:destroy();
+ return conn, err;
+ end
+ conn:settimeout(0);
+ self.conn = conn;
+ self:on("starttls");
+ self.ondrain = nil;
+ self.onwritable = interface.tlshandskake;
+ self.onreadable = interface.tlshandskake;
+ return self:init();
+ end
+ local ok, err = self.conn:dohandshake();
+ if ok then
+ log("debug", "TLS handshake on %s complete", self);
+ self.onwritable = nil;
+ self.onreadable = nil;
+ self:on("status", "ssl-handshake-complete");
+ self:setwritetimeout();
+ self:set(true, true);
+ elseif err == "wantread" then
+ log("debug", "TLS handshake on %s to wait until readable", self);
+ self:set(true, false);
+ self:setreadtimeout(cfg.handshake_timeout);
+ elseif err == "wantwrite" then
+ log("debug", "TLS handshake on %s to wait until writable", self);
+ self:set(false, true);
+ self:setwritetimeout(cfg.handshake_timeout);
+ else
+ log("debug", "TLS handshake error on %s: %s", self, err);
+ self:on("disconnect", err);
+ self:destroy();
+ end
+end
+
+local function wrapsocket(client, server, read_size, listeners, tls_ctx) -- luasocket object -> interface object
+ client:settimeout(0);
+ local conn = setmetatable({
+ conn = client;
+ _server = server;
+ created = gettime();
+ listeners = listeners;
+ read_size = read_size or (server and server.read_size);
+ writebuffer = {};
+ tls_ctx = tls_ctx or (server and server.tls_ctx);
+ tls_direct = server and server.tls_direct;
+ }, interface_mt);
+
+ conn:updatenames();
+ return conn;
+end
+
+function interface:updatenames()
+ local conn = self.conn;
+ local ok, peername, peerport = pcall(conn.getpeername, conn);
+ if ok then
+ self.peername, self.peerport = peername, peerport;
+ end
+ local ok, sockname, sockport = pcall(conn.getsockname, conn);
+ if ok then
+ self.sockname, self.sockport = sockname, sockport;
+ end
+end
+
+-- A server interface has new incoming connections waiting
+-- This replaces the onreadable callback
+function interface:onacceptable()
+ local conn, err = self.conn:accept();
+ if not conn then
+ log("debug", "Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval);
+ self:pausefor(cfg.accept_retry_interval);
+ return;
+ end
+ local client = wrapsocket(conn, self, nil, self.listeners);
+ log("debug", "New connection %s", tostring(client));
+ client:init();
+ if self.tls_direct then
+ client:starttls(self.tls_ctx);
+ end
+end
+
+-- Initialization
+function interface:init()
+ self:setwritetimeout();
+ return self:add(true, true);
+end
+
+function interface:pause()
+ return self:set(false);
+end
+
+function interface:resume()
+ return self:set(true);
+end
+
+-- Pause connection for some time
+function interface:pausefor(t)
+ if self._pausefor then
+ self._pausefor:close();
+ end
+ if t == false then return; end
+ self:set(false);
+ self._pausefor = addtimer(t, function ()
+ self._pausefor = nil;
+ if self.conn and self.conn:dirty() then
+ self:onreadable();
+ end
+ self:set(true);
+ end);
+end
+
+-- Connected!
+function interface:onconnect()
+ if self.conn and not self.peername and self.conn.getpeername then
+ self.peername, self.peerport = self.conn:getpeername();
+ end
+ self.onconnect = noop;
+ self:on("connect");
+end
+
+local function addserver(addr, port, listeners, read_size, tls_ctx)
+ local conn, err = socket.bind(addr, port, cfg.tcp_backlog);
+ if not conn then return conn, err; end
+ conn:settimeout(0);
+ local server = setmetatable({
+ conn = conn;
+ created = gettime();
+ listeners = listeners;
+ read_size = read_size;
+ onreadable = interface.onacceptable;
+ tls_ctx = tls_ctx;
+ tls_direct = tls_ctx and true or false;
+ sockname = addr;
+ sockport = port;
+ }, interface_mt);
+ server:add(true, false);
+ return server;
+end
+
+-- COMPAT
+local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx)
+ local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx);
+ if not client.peername then
+ client.peername, client.peerport = addr, port;
+ end
+ local ok, err = client:init();
+ if not ok then return ok, err; end
+ if tls_ctx then
+ client:starttls(tls_ctx);
+ end
+ return client;
+end
+
+-- New outgoing TCP connection
+local function addclient(addr, port, listeners, read_size, tls_ctx, typ)
+ local create;
+ if not typ then
+ local n = inet_pton(addr);
+ if not n then return nil, "invalid-ip"; end
+ if #n == 16 then
+ typ = "tcp6";
+ else
+ typ = "tcp4";
+ end
+ end
+ if typ then
+ create = socket[typ];
+ end
+ if type(create) ~= "function" then
+ return nil, "invalid socket type";
+ end
+ local conn, err = create();
+ local ok, err = conn:settimeout(0);
+ if not ok then return ok, err; end
+ local ok, err = conn:setpeername(addr, port);
+ if not ok and err ~= "timeout" then return ok, err; end
+ local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx)
+ local ok, err = client:init();
+ if not ok then return ok, err; end
+ if tls_ctx then
+ client:starttls(tls_ctx);
+ end
+ return client, conn;
+end
+
+local function watchfd(fd, onreadable, onwritable)
+ local conn = setmetatable({
+ conn = fd;
+ onreadable = onreadable;
+ onwritable = onwritable;
+ close = function (self)
+ self:del();
+ end
+ }, interface_mt);
+ if type(fd) == "number" then
+ conn.getfd = function ()
+ return fd;
+ end;
+ -- Otherwise it'll need to be something LuaSocket-compatible
+ end
+ conn:add(onreadable, onwritable);
+ return conn;
+end;
+
+-- Dump all data from one connection into another
+local function link(from, to)
+ from.listeners = setmetatable({
+ onincoming = function (_, data)
+ from:pause();
+ to:write(data);
+ end,
+ }, {__index=from.listeners});
+ to.listeners = setmetatable({
+ ondrain = function ()
+ from:resume();
+ end,
+ }, {__index=to.listeners});
+ from:set(true, nil);
+ to:set(nil, true);
+end
+
+-- COMPAT
+-- net.adns calls this but then replaces :send so this can be a noop
+function interface:set_send(new_send) -- luacheck: ignore 212
+end
+
+-- Close all connections and servers
+local function closeall()
+ for fd, conn in pairs(fds) do -- luacheck: ignore 213/fd
+ conn:close();
+ end
+end
+
+local quitting = nil;
+
+-- Signal main loop about shutdown via above upvalue
+local function setquitting(quit)
+ if quit then
+ quitting = "quitting";
+ closeall();
+ else
+ quitting = nil;
+ end
+end
+
+-- Main loop
+local function loop(once)
+ repeat
+ local t = runtimers(cfg.max_wait, cfg.min_wait);
+ local fd, r, w = poll:wait(t);
+ if fd then
+ local conn = fds[fd];
+ if conn then
+ if r then
+ conn:onreadable();
+ end
+ if w then
+ conn:onwritable();
+ end
+ else
+ log("debug", "Removing unknown fd %d", fd);
+ poll:del(fd);
+ end
+ elseif r ~= "timeout" then
+ log("debug", "epoll_wait error: %s[%d]", r, w);
+ end
+ until once or (quitting and next(fds) == nil);
+ return quitting;
+end
+
+return {
+ get_backend = function () return "epoll"; end;
+ addserver = addserver;
+ addclient = addclient;
+ add_task = addtimer;
+ at = at;
+ loop = loop;
+ closeall = closeall;
+ setquitting = setquitting;
+ wrapclient = wrapclient;
+ watchfd = watchfd;
+ link = link;
+ set_config = function (newconfig)
+ cfg = setmetatable(newconfig, default_config);
+ end;
+
+ -- libevent emulation
+ event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 };
+ addevent = function (fd, mode, callback)
+ local function onevent(self)
+ local ret = self:callback();
+ if ret == -1 then
+ self:set(false, false);
+ elseif ret then
+ self:set(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
+ end
+ end
+
+ local conn = setmetatable({
+ getfd = function () return fd; end;
+ callback = callback;
+ onreadable = onevent;
+ onwritable = onevent;
+ close = function (self)
+ self:del();
+ fds[fd] = nil;
+ end;
+ }, interface_mt);
+ local ok, err = conn:add(mode == "r" or mode == "rw", mode == "w" or mode == "rw");
+ if not ok then return ok, err; end
+ return conn;
+ end;
+};
diff --git a/net/server_event.lua b/net/server_event.lua
index 3a907349..11bd6a29 100644
--- a/net/server_event.lua
+++ b/net/server_event.lua
@@ -5,9 +5,9 @@
notes:
-- when using luaevent, never register 2 or more EV_READ at one socket, same for EV_WRITE
- -- you cant even register a new EV_READ/EV_WRITE callback inside another one
+ -- you can't even register a new EV_READ/EV_WRITE callback inside another one
-- to do some of the above, use timeout events or something what will called from outside
- -- dont let garbagecollect eventcallbacks, as long they are running
+ -- don't let garbagecollect eventcallbacks, as long they are running
-- when using luasec, there are 4 cases of timeout errors: wantread or wantwrite during reading or writing
--]]
@@ -26,7 +26,7 @@ local cfg = {
MAX_SEND_LENGTH = 1024 * 1024 * 1024 * 1024, -- max bytes size of write buffer (for writing on sockets)
ACCEPT_QUEUE = 128, -- might influence the length of the pending sockets queue
ACCEPT_DELAY = 10, -- seconds to wait until the next attempt of a full server to accept
- READ_TIMEOUT = 60 * 60 * 6, -- timeout in seconds for read data from socket
+ READ_TIMEOUT = 14 * 60, -- timeout in seconds for read data from socket
WRITE_TIMEOUT = 180, -- timeout in seconds for write data on socket
CONNECT_TIMEOUT = 20, -- timeout in seconds for connection attempts
CLEAR_DELAY = 5, -- seconds to wait for clearing interface list (and calling ondisconnect listeners)
@@ -50,9 +50,10 @@ local coroutine_yield = coroutine.yield
local has_luasec, ssl = pcall ( require , "ssl" )
local socket = require "socket"
local levent = require "luaevent.core"
+local inet = require "util.net";
+local inet_pton = inet.pton;
local socket_gettime = socket.gettime
-local getaddrinfo = socket.dns.getaddrinfo
local log = require ("util.logger").init("socket")
@@ -106,6 +107,12 @@ function interface_mt:_start_connection(plainssl) -- called from wrapclient
self:_close()
debug( "new connection failed. id:", self.id, "error:", self.fatalerror )
else
+ if EV_READWRITE == event then
+ if self.readcallback(event) == -1 then
+ -- Fatal error occurred
+ return -1;
+ end
+ end
if plainssl and has_luasec then -- start ssl session
self:starttls(self._sslctx, true)
else -- normal connection
@@ -116,7 +123,7 @@ function interface_mt:_start_connection(plainssl) -- called from wrapclient
self.eventconnect = nil
return -1
end
- self.eventconnect = addevent( base, self.conn, EV_WRITE, callback, cfg.CONNECT_TIMEOUT )
+ self.eventconnect = addevent( base, self.conn, EV_READWRITE, callback, cfg.CONNECT_TIMEOUT )
return true
end
function interface_mt:_start_session(call_onconnect) -- new session, for example after startssl
@@ -151,7 +158,7 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed
self.fatalerror = err
self.conn = nil -- cannot be used anymore
if call_onconnect then
- self.ondisconnect = nil -- dont call this when client isnt really connected
+ self.ondisconnect = nil -- don't call this when client isn't really connected
end
self:_close()
debug( "fatal error while ssl wrapping:", err )
@@ -194,7 +201,7 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed
end
if self.fatalerror then
if call_onconnect then
- self.ondisconnect = nil -- dont call this when client isnt really connected
+ self.ondisconnect = nil -- don't call this when client isn't really connected
end
self:_close()
debug( "handshake failed because:", self.fatalerror )
@@ -223,7 +230,8 @@ function interface_mt:_destroy() -- close this interface + events and call last
_ = self.eventsession and self.eventsession:close( )
_ = self.eventwritetimeout and self.eventwritetimeout:close( )
_ = self.eventreadtimeout and self.eventreadtimeout:close( )
- _ = self.ondisconnect and self:ondisconnect( self.fatalerror ~= "client to close" and self.fatalerror) -- call ondisconnect listener (wont be the case if handshake failed on connect)
+ -- call ondisconnect listener (won't be the case if handshake failed on connect)
+ _ = self.ondisconnect and self:ondisconnect( self.fatalerror ~= "client to close" and self.fatalerror)
_ = self.conn and self.conn:close( ) -- close connection
_ = self._server and self._server:counter(-1);
self.eventread, self.eventwrite = nil, nil
@@ -408,7 +416,7 @@ function interface_mt:setoption(option, value)
return false, "setoption not implemented";
end
-function interface_mt:setlistener(listener)
+function interface_mt:setlistener(listener, data)
self:ondetach(); -- Notify listener that it is no longer responsible for this connection
self.onconnect = listener.onconnect;
self.ondisconnect = listener.ondisconnect;
@@ -417,7 +425,9 @@ function interface_mt:setlistener(listener)
self.onreadtimeout = listener.onreadtimeout;
self.onstatus = listener.onstatus;
self.ondetach = listener.ondetach;
+ self.onattach = listener.onattach;
self.ondrain = listener.ondrain;
+ self:onattach(data);
end
-- Stub handlers
@@ -439,6 +449,8 @@ function interface_mt:ondrain()
end
function interface_mt:ondetach()
end
+function interface_mt:onattach()
+end
function interface_mt:onstatus()
end
@@ -510,7 +522,7 @@ local function handleclient( client, ip, port, server, pattern, listener, sslctx
interface.writebuffer = { t_concat(interface.writebuffer) }
local succ, err, byte = interface.conn:send( interface.writebuffer[1], 1, interface.writebufferlen )
--vdebug( "write data:", interface.writebuffer, "error:", err, "part:", byte )
- if succ then -- writing succesful
+ if succ then -- writing successful
interface.writebuffer[1] = nil
interface.writebufferlen = 0
interface:ondrain();
@@ -539,7 +551,7 @@ local function handleclient( client, ip, port, server, pattern, listener, sslctx
return -1;
end
interface.eventwritetimeout = addevent( base, nil, EV_TIMEOUT, callback, cfg.WRITE_TIMEOUT ) -- reg a new timeout event
- debug( "wantread during write attempt, reg it in readcallback but dont know what really happens next..." )
+ debug( "wantread during write attempt, reg it in readcallback but don't know what really happens next..." )
-- hopefully this works with luasec; its simply not possible to use 2 different write events on a socket in luaevent
return -1
end
@@ -595,8 +607,8 @@ local function handleclient( client, ip, port, server, pattern, listener, sslctx
end
interface.eventreadtimeout = addevent( base, nil, EV_TIMEOUT,
function( ) interface:_close() end, cfg.READ_TIMEOUT)
- debug( "wantwrite during read attempt, reg it in writecallback but dont know what really happens next..." )
- -- to be honest i dont know what happens next, if it is allowed to first read, the write etc...
+ debug( "wantwrite during read attempt, reg it in writecallback but don't know what really happens next..." )
+ -- to be honest i don't know what happens next, if it is allowed to first read, the write etc...
else -- connection was closed or fatal error
interface.fatalerror = err
debug( "connection failed in read event:", interface.fatalerror )
@@ -717,15 +729,15 @@ local function addclient( addr, serverport, listener, pattern, sslctx, typ )
return nil, "luasec not found"
end
if not typ then
- local addrinfo, err = getaddrinfo(addr)
- if not addrinfo then return nil, err end
- if addrinfo[1] and addrinfo[1].family == "inet6" then
- typ = "tcp6"
- else
- typ = "tcp"
+ local n = inet_pton(addr);
+ if not n then return nil, "invalid-ip"; end
+ if #n == 16 then
+ typ = "tcp6";
+ elseif #n == 4 then
+ typ = "tcp4";
end
end
- local create = socket[typ]
+ local create = socket[typ];
if type( create ) ~= "function" then
return nil, "invalid socket type"
end
@@ -735,7 +747,7 @@ local function addclient( addr, serverport, listener, pattern, sslctx, typ )
return nil, err
end
client:settimeout( 0 ) -- set nonblocking
- local res, err = client:connect( addr, serverport ) -- connect
+ local res, err = client:setpeername( addr, serverport ) -- connect
if res or ( err == "timeout" ) then
local ip, port = client:getsockname( )
local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx )
@@ -767,13 +779,15 @@ end
local function setquitting(yes)
if yes then
-- Quit now
- closeallservers();
+ if yes ~= "once" then
+ closeallservers();
+ end
base:loopexit();
end
end
local function get_backend()
- return base:method();
+ return "libevent " .. base:method();
end
-- We need to hold onto the events to stop them
@@ -811,6 +825,48 @@ local function link(sender, receiver, buffersize)
sender:set_mode("*a");
end
+local function add_task(delay, callback)
+ local event_handle;
+ event_handle = base:addevent(nil, 0, function ()
+ local ret = callback(socket_gettime());
+ if ret then
+ return 0, ret;
+ elseif event_handle then
+ return -1;
+ end
+ end
+ , delay);
+ return event_handle;
+end
+
+local function watchfd(fd, onreadable, onwriteable)
+ local handle = {};
+ function handle:setflags(r,w)
+ if r ~= nil then
+ if r and not self.wantread then
+ self.wantread = base:addevent(fd, EV_READ, function ()
+ onreadable(self);
+ end);
+ elseif not r and self.wantread then
+ self.wantread:close();
+ self.wantread = nil;
+ end
+ end
+ if w ~= nil then
+ if w and not self.wantwrite then
+ self.wantwrite = base:addevent(fd, EV_WRITE, function ()
+ onwriteable(self);
+ end);
+ elseif not r and self.wantread then
+ self.wantwrite:close();
+ self.wantwrite = nil;
+ end
+ end
+ end
+ handle:setflags(onreadable, onwriteable);
+ return handle;
+end
+
return {
cfg = cfg,
base = base,
@@ -826,6 +882,8 @@ return {
closeall = closeallservers,
get_backend = get_backend,
hook_signal = hook_signal,
+ add_task = add_task,
+ watchfd = watchfd,
__NAME = SCRIPT_NAME,
__DATE = LAST_MODIFIED,
diff --git a/net/server_select.lua b/net/server_select.lua
index 12aef9d8..bc86742c 100644
--- a/net/server_select.lua
+++ b/net/server_select.lua
@@ -40,6 +40,7 @@ local coroutine = use "coroutine"
local math_min = math.min
local math_huge = math.huge
local table_concat = table.concat
+local table_insert = table.insert
local string_sub = string.sub
local coroutine_wrap = coroutine.wrap
local coroutine_yield = coroutine.yield
@@ -49,13 +50,13 @@ local coroutine_yield = coroutine.yield
local has_luasec, luasec = pcall ( require , "ssl" )
local luasocket = use "socket" or require "socket"
local luasocket_gettime = luasocket.gettime
-local getaddrinfo = luasocket.dns.getaddrinfo
+local inet = require "util.net";
+local inet_pton = inet.pton;
--// extern lib methods //--
local ssl_wrap = ( has_luasec and luasec.wrap )
local socket_bind = luasocket.bind
-local socket_sleep = luasocket.sleep
local socket_select = luasocket.select
--// functions //--
@@ -100,7 +101,6 @@ local _sendtraffic
local _readtraffic
local _selecttimeout
-local _sleeptime
local _tcpbacklog
local _accepretry
@@ -114,8 +114,6 @@ local _checkinterval
local _sendtimeout
local _readtimeout
-local _timer
-
local _maxselectlen
local _maxfd
@@ -135,13 +133,12 @@ _fullservers = { } -- servers in a paused state while there are too many clients
_readlistlen = 0 -- length of readlist
_sendlistlen = 0 -- length of sendlist
-_timerlistlen = 0 -- lenght of timerlist
+_timerlistlen = 0 -- length of timerlist
_sendtraffic = 0 -- some stats
_readtraffic = 0
_selecttimeout = 1 -- timeout of socket.select
-_sleeptime = 0 -- time to wait at the end of every loop
_tcpbacklog = 128 -- some kind of hint to the OS
_accepretry = 10 -- seconds to wait until the next attempt of a full server to accept
@@ -150,7 +147,7 @@ _maxreadlen = 25000 * 1024 -- max len of read buffer
_checkinterval = 30 -- interval in secs to check idle clients
_sendtimeout = 60000 -- allowed send idle time in secs
-_readtimeout = 6 * 60 * 60 -- allowed read idle time in secs
+_readtimeout = 14 * 60 -- allowed read idle time in secs
local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows
_maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows
@@ -301,7 +298,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
local bufferqueuelen = 0 -- end of buffer array
local toclose
- local fatalerror
local needtls
local bufferlen = 0
@@ -326,7 +322,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
end
handler.onreadtimeout = onreadtimeout;
- handler.setlistener = function( self, listeners )
+ handler.setlistener = function( self, listeners, data )
if detach then
detach(self) -- Notify listener that it is no longer responsible for this connection
end
@@ -336,6 +332,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
drain = listeners.ondrain
handler.onreadtimeout = listeners.onreadtimeout
detach = listeners.ondetach
+ if listeners.onattach then
+ listeners.onattach(self, data)
+ end
end
handler.getstats = function( )
return readtraffic, sendtraffic
@@ -425,7 +424,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
bufferlen = bufferlen + #data
if bufferlen > maxsendlen then
_closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle
- handler.write = idfalse -- dont write anymore
+ handler.write = idfalse -- don't write anymore
return false
elseif socket and not _sendlist[ socket ] then
_sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
@@ -517,7 +516,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
return dispatch( handler, buffer, err )
else -- connections was closed or fatal error
out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
- fatalerror = true
_ = handler and handler:force_close( err )
return false
end
@@ -537,7 +535,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
else
succ, err, count = false, "unexpected close", 0;
end
- if succ then -- sending succesful
+ if succ then -- sending successful
bufferqueuelen = 0
bufferlen = 0
_sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist
@@ -557,7 +555,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
return true
else -- connection was closed during sending or fatal error
out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) )
- fatalerror = true
_ = handler and handler:force_close( err )
return false
end
@@ -806,7 +803,6 @@ end
getsettings = function( )
return {
select_timeout = _selecttimeout;
- select_sleep_time = _sleeptime;
tcp_backlog = _tcpbacklog;
max_send_buffer_size = _maxsendlen;
max_receive_buffer_size = _maxreadlen;
@@ -825,7 +821,6 @@ changesettings = function( new )
return nil, "invalid settings table"
end
_selecttimeout = tonumber( new.select_timeout ) or _selecttimeout
- _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime
_maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen
_maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen
_checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval
@@ -848,6 +843,49 @@ addtimer = function( listener )
return true
end
+local add_task do
+ local data = {};
+ local new_data = {};
+
+ function add_task(delay, callback)
+ local current_time = luasocket_gettime();
+ delay = delay + current_time;
+ if delay >= current_time then
+ table_insert(new_data, {delay, callback});
+ else
+ local r = callback(current_time);
+ if r and type(r) == "number" then
+ return add_task(r, callback);
+ end
+ end
+ end
+
+ addtimer(function(current_time)
+ if #new_data > 0 then
+ for _, d in pairs(new_data) do
+ table_insert(data, d);
+ end
+ new_data = {};
+ end
+
+ local next_time = math_huge;
+ for i, d in pairs(data) do
+ local t, callback = d[1], d[2];
+ if t <= current_time then
+ data[i] = nil;
+ local r = callback(current_time);
+ if type(r) == "number" then
+ add_task(r, callback);
+ next_time = math_min(next_time, r);
+ end
+ else
+ next_time = math_min(next_time, t - current_time);
+ end
+ end
+ return next_time;
+ end);
+end
+
stats = function( )
return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen
end
@@ -855,31 +893,38 @@ end
local quitting;
local function setquitting(quit)
- quitting = not not quit;
+ quitting = quit;
end
loop = function(once) -- this is the main loop of the program
if quitting then return "quitting"; end
if once then quitting = "once"; end
- local next_timer_time = math_huge;
+ _currenttime = luasocket_gettime( )
repeat
+ -- Fire timers
+ local next_timer_time = math_huge;
+ for i = 1, _timerlistlen do
+ local t = _timerlist[ i ]( _currenttime ) -- fire timers
+ if t then next_timer_time = math_min(next_timer_time, t); end
+ end
+
local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) )
- for _, socket in ipairs( write ) do -- send data waiting in writequeues
+ for _, socket in ipairs( read ) do -- receive data
local handler = _socketlist[ socket ]
if handler then
- handler.sendbuffer( )
+ handler.readbuffer( )
else
closesocket( socket )
- out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen
+ out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen
end
end
- for _, socket in ipairs( read ) do -- receive data
+ for _, socket in ipairs( write ) do -- send data waiting in writequeues
local handler = _socketlist[ socket ]
if handler then
- handler.readbuffer( )
+ handler.sendbuffer( )
else
closesocket( socket )
- out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen
+ out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen
end
end
for handler, err in pairs( _closelist ) do
@@ -910,29 +955,14 @@ loop = function(once) -- this is the main loop of the program
end
end
- -- Fire timers
- if _currenttime - _timer >= math_min(next_timer_time, 1) then
- next_timer_time = math_huge;
- for i = 1, _timerlistlen do
- local t = _timerlist[ i ]( _currenttime ) -- fire timers
- if t then next_timer_time = math_min(next_timer_time, t); end
- end
- _timer = _currenttime
- else
- next_timer_time = next_timer_time - (_currenttime - _timer);
- end
-
for server, paused_time in pairs( _fullservers ) do
if _currenttime - paused_time > _accepretry then
_fullservers[ server ] = nil;
server.resume();
end
end
-
- -- wait some time (0 by default)
- socket_sleep( _sleeptime )
until quitting;
- if once and quitting == "once" then quitting = nil; return; end
+ if quitting == "once" then quitting = nil; return; end
closeall();
return "quitting"
end
@@ -952,6 +982,7 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx
if not handler then return nil, err end
_socketlist[ socket ] = handler
if not sslctx then
+ _readlistlen = addsocket(_readlist, socket, _readlistlen)
_sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
if listeners.onconnect then
-- When socket is writeable, call onconnect
@@ -978,15 +1009,15 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ )
err = "luasec not found"
end
if not typ then
- local addrinfo, err = getaddrinfo(address)
- if not addrinfo then return nil, err end
- if addrinfo[1] and addrinfo[1].family == "inet6" then
- typ = "tcp6"
- else
- typ = "tcp"
+ local n = inet_pton(addr);
+ if not n then return nil, "invalid-ip"; end
+ if #n == 16 then
+ typ = "tcp6";
+ elseif #n == 4 then
+ typ = "tcp4";
end
end
- local create = luasocket[typ]
+ local create = luasocket[typ];
if type( create ) ~= "function" then
err = "invalid socket type"
end
@@ -1001,15 +1032,55 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ )
return nil, err
end
client:settimeout( 0 )
- local ok, err = client:connect( address, port )
- if ok or err == "timeout" then
+ local ok, err = client:setpeername( address, port )
+ if ok or err == "timeout" or err == "Operation already in progress" then
return wrapclient( client, address, port, listeners, pattern, sslctx )
else
return nil, err
end
end
---// EXPERIMENTAL //--
+local closewatcher = function (handler)
+ local socket = handler.conn;
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
+ _readlistlen = removesocket( _readlist, socket, _readlistlen )
+ _socketlist[ socket ] = nil
+end;
+
+local addremove = function (handler, read, send)
+ local socket = handler.conn
+ _socketlist[ socket ] = handler
+ if read ~= nil then
+ if read then
+ _readlistlen = addsocket( _readlist, socket, _readlistlen )
+ else
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
+ end
+ end
+ if send ~= nil then
+ if send then
+ _sendlistlen = addsocket( _sendlist, socket, _sendlistlen )
+ else
+ _readlistlen = removesocket( _readlist, socket, _readlistlen )
+ end
+ end
+end
+
+local watchfd = function ( fd, onreadable, onwriteable )
+ local socket = fd
+ if type(fd) == "number" then
+ socket = { getfd = function () return fd; end }
+ end
+ local handler = {
+ conn = socket;
+ readbuffer = onreadable or id;
+ sendbuffer = onwriteable or id;
+ close = closewatcher;
+ setflags = addremove;
+ };
+ addremove( handler, onreadable, onwriteable )
+ return handler
+end
----------------------------------// BEGIN //--
@@ -1017,7 +1088,6 @@ use "setmetatable" ( _socketlist, { __mode = "k" } )
use "setmetatable" ( _readtimes, { __mode = "k" } )
use "setmetatable" ( _writetimes, { __mode = "k" } )
-_timer = luasocket_gettime( )
_starttime = luasocket_gettime( )
local function setlogger(new_logger)
@@ -1032,9 +1102,11 @@ end
return {
_addtimer = addtimer,
+ add_task = add_task;
addclient = addclient,
wrapclient = wrapclient,
+ watchfd = watchfd,
loop = loop,
link = link,
diff --git a/net/websocket.lua b/net/websocket.lua
index 777b894c..469c6a58 100644
--- a/net/websocket.lua
+++ b/net/websocket.lua
@@ -21,9 +21,9 @@ local close_timeout = 3; -- Seconds to wait after sending close frame until clos
local websockets = {};
local websocket_listeners = {};
-function websocket_listeners.ondisconnect(handler, err)
- local s = websockets[handler];
- websockets[handler] = nil;
+function websocket_listeners.ondisconnect(conn, err)
+ local s = websockets[conn];
+ websockets[conn] = nil;
if s.close_timer then
timer.stop(s.close_timer);
s.close_timer = nil;
@@ -33,19 +33,19 @@ function websocket_listeners.ondisconnect(handler, err)
if s.onclose then s:onclose(s.close_code, s.close_message or err); end
end
-function websocket_listeners.ondetach(handler)
- websockets[handler] = nil;
+function websocket_listeners.ondetach(conn)
+ websockets[conn] = nil;
end
local function fail(s, code, reason)
log("warn", "WebSocket connection failed, closing. %d %s", code, reason);
s:close(code, reason);
- s.handler:close();
+ s.conn:close();
return false
end
-function websocket_listeners.onincoming(handler, buffer, err) -- luacheck: ignore 212/err
- local s = websockets[handler];
+function websocket_listeners.onincoming(conn, buffer, err) -- luacheck: ignore 212/err
+ local s = websockets[conn];
s.readbuffer = s.readbuffer..buffer;
while true do
local frame, len = frames.parse(s.readbuffer);
@@ -111,7 +111,7 @@ function websocket_listeners.onincoming(handler, buffer, err) -- luacheck: ignor
elseif frame.opcode == 0x9 then -- Ping frame
frame.opcode = 0xA;
frame.MASK = true; -- RFC 6455 6.1.5: If the data is being sent by the client, the frame(s) MUST be masked
- handler:write(frames.build(frame));
+ conn:write(frames.build(frame));
elseif frame.opcode == 0xA then -- Pong frame
log("debug", "Received unexpected pong frame: " .. tostring(frame.data));
else
@@ -126,15 +126,15 @@ local websocket_methods = {};
local function close_timeout_cb(now, timerid, s) -- luacheck: ignore 212/now 212/timerid
s.close_timer = nil;
log("warn", "Close timeout waiting for server to close, closing manually.");
- s.handler:close();
+ s.conn:close();
end
function websocket_methods:close(code, reason)
if self.readyState < 2 then
code = code or 1000;
log("debug", "closing WebSocket with code %i: %s" , code , tostring(reason));
self.readyState = 2;
- local handler = self.handler;
- handler:write(frames.build_close(code, reason, true));
+ local conn = self.conn;
+ conn:write(frames.build_close(code, reason, true));
-- Do not close socket straight away, wait for acknowledgement from server.
self.close_timer = timer.add_task(close_timeout, close_timeout_cb, self);
elseif self.readyState == 2 then
@@ -144,8 +144,8 @@ function websocket_methods:close(code, reason)
timer.stop(self.close_timer);
self.close_timer = nil;
end
- local handler = self.handler;
- handler:close();
+ local conn = self.conn;
+ conn:close();
else
log("debug", "tried to close a closed WebSocket, ignoring.");
end
@@ -168,7 +168,7 @@ function websocket_methods:send(data, opcode)
data = tostring(data);
};
log("debug", "WebSocket sending frame: opcode=%0x, %i bytes", frame.opcode, #frame.data);
- return self.handler:write(frames.build(frame));
+ return self.conn:write(frames.build(frame));
end
local websocket_metatable = {
@@ -216,7 +216,7 @@ local function connect(url, ex, listeners)
local s = setmetatable({
readbuffer = "";
databuffer = nil;
- handler = nil;
+ conn = nil;
close_code = nil;
close_message = nil;
close_timer = nil;
@@ -236,6 +236,7 @@ local function connect(url, ex, listeners)
method = "GET";
headers = headers;
sslctx = ex.sslctx;
+ insecure = ex.insecure;
}, function(b, c, r, http_req)
if c ~= 101
or r.headers["connection"]:lower() ~= "upgrade"
@@ -252,16 +253,16 @@ local function connect(url, ex, listeners)
s.protocol = r.headers["sec-websocket-protocol"];
-- Take possession of socket from http
+ local conn = http_req.conn;
http_req.conn = nil;
- local handler = http_req.handler;
- s.handler = handler;
- websockets[handler] = s;
- handler:setlistener(websocket_listeners);
+ s.conn = conn;
+ websockets[conn] = s;
+ conn:setlistener(websocket_listeners);
log("debug", "WebSocket connected successfully to %s", url);
s.readyState = 1;
if s.onopen then s:onopen(); end
- websocket_listeners.onincoming(handler, b);
+ websocket_listeners.onincoming(conn, b);
end);
return s;
diff --git a/net/websocket/frames.lua b/net/websocket/frames.lua
index 5fe96d45..ba25d261 100644
--- a/net/websocket/frames.lua
+++ b/net/websocket/frames.lua
@@ -21,8 +21,8 @@ local t_concat = table.concat;
local s_byte = string.byte;
local s_char= string.char;
local s_sub = string.sub;
-local s_pack = string.pack;
-local s_unpack = string.unpack;
+local s_pack = string.pack; -- luacheck: ignore 143
+local s_unpack = string.unpack; -- luacheck: ignore 143
if not s_pack and softreq"struct" then
s_pack = softreq"struct".pack;
@@ -112,9 +112,9 @@ end
-- TODO: optimize
local function apply_mask(str, key, from, to)
from = from or 1
- if from < 0 then from = #str + from + 1 end -- negative indicies
+ if from < 0 then from = #str + from + 1 end -- negative indices
to = to or #str
- if to < 0 then to = #str + to + 1 end -- negative indicies
+ if to < 0 then to = #str + to + 1 end -- negative indices
local key_len = #key
local counter = 0;
local data = {};