diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/adns.lua | 1 | ||||
-rw-r--r-- | net/connect.lua | 92 | ||||
-rw-r--r-- | net/connlisteners.lua | 4 | ||||
-rw-r--r-- | net/cqueues.lua | 74 | ||||
-rw-r--r-- | net/dns.lua | 14 | ||||
-rw-r--r-- | net/http.lua | 156 | ||||
-rw-r--r-- | net/httpserver.lua | 5 | ||||
-rw-r--r-- | net/resolvers/basic.lua | 71 | ||||
-rw-r--r-- | net/resolvers/manual.lua | 25 | ||||
-rw-r--r-- | net/server.lua | 115 | ||||
-rw-r--r-- | net/server_epoll.lua | 718 | ||||
-rw-r--r-- | net/server_event.lua | 77 | ||||
-rw-r--r-- | net/server_select.lua | 154 | ||||
-rw-r--r-- | net/websocket.lua | 43 | ||||
-rw-r--r-- | net/websocket/frames.lua | 8 |
15 files changed, 1359 insertions, 198 deletions
diff --git a/net/adns.lua b/net/adns.lua index a19cbd59..560e4b53 100644 --- a/net/adns.lua +++ b/net/adns.lua @@ -17,6 +17,7 @@ local setmetatable = setmetatable; local function dummy_send(sock, data, i, j) return (j-i)+1; end local _ENV = nil; +-- luacheck: std none local async_resolver_methods = {}; local async_resolver_mt = { __index = async_resolver_methods }; diff --git a/net/connect.lua b/net/connect.lua new file mode 100644 index 00000000..02ab4cc0 --- /dev/null +++ b/net/connect.lua @@ -0,0 +1,92 @@ +local server = require "net.server"; +local log = require "util.logger".init("net.connect"); +local new_id = require "util.id".short; + +local pending_connection_methods = {}; +local pending_connection_mt = { + __name = "pending_connection"; + __index = pending_connection_methods; + __tostring = function (p) + return "<pending connection "..p.id.." to "..tostring(p.target_resolver.hostname)..">"; + end; +}; + +function pending_connection_methods:log(level, message, ...) + log(level, "[pending connection %s] "..message, self.id, ...); +end + +-- pending_connections_map[conn] = pending_connection +local pending_connections_map = {}; + +local pending_connection_listeners = {}; + +local function attempt_connection(p) + p:log("debug", "Checking for targets..."); + if p.conn then + pending_connections_map[p.conn] = nil; + p.conn = nil; + end + p.target_resolver:next(function (conn_type, ip, port, extra) + if not conn_type then + -- No more targets to try + p:log("debug", "No more connection targets to try"); + if p.listeners.onfail then + p.listeners.onfail(p.data, p.last_error or "unable to resolve service"); + end + return; + end + p:log("debug", "Next target to try is %s:%d", ip, port); + local conn, err = server.addclient(ip, port, pending_connection_listeners, p.options.pattern or "*a", p.options.sslctx, conn_type, extra); + if not conn then + log("debug", "Connection attempt failed immediately: %s", tostring(err)); + p.last_error = err or "unknown reason"; + return attempt_connection(p); + end + p.conn = conn; + pending_connections_map[conn] = p; + end); +end + +function pending_connection_listeners.onconnect(conn) + local p = pending_connections_map[conn]; + if not p then + log("warn", "Successful connection, but unexpected! Closing."); + conn:close(); + return; + end + pending_connections_map[conn] = nil; + p:log("debug", "Successfully connected"); + if p.listeners.onattach then + p.listeners.onattach(conn, p.data); + end + conn:setlistener(p.listeners); + return p.listeners.onconnect(conn); +end + +function pending_connection_listeners.ondisconnect(conn, reason) + local p = pending_connections_map[conn]; + if not p then + log("warn", "Failed connection, but unexpected!"); + return; + end + p.last_error = reason or "unknown reason"; + p:log("debug", "Connection attempt failed: %s", p.last_error); + attempt_connection(p); +end + +local function connect(target_resolver, listeners, options, data) + local p = setmetatable({ + id = new_id(); + target_resolver = target_resolver; + listeners = assert(listeners); + options = options or {}; + data = data; + }, pending_connection_mt); + + p:log("debug", "Starting connection process"); + attempt_connection(p); +end + +return { + connect = connect; +}; diff --git a/net/connlisteners.lua b/net/connlisteners.lua index 000bfa63..9b8f88c3 100644 --- a/net/connlisteners.lua +++ b/net/connlisteners.lua @@ -3,15 +3,15 @@ local log = require "util.logger".init("net.connlisteners"); local traceback = debug.traceback; local _ENV = nil; +-- luacheck: std none local function fail() - log("error", "Attempt to use legacy connlisteners API. For more info see http://prosody.im/doc/developers/network"); + log("error", "Attempt to use legacy connlisteners API. For more info see https://prosody.im/doc/developers/network"); log("error", "Legacy connlisteners API usage, %s", traceback("", 2)); end return { register = fail; - register = fail; get = fail; start = fail; -- epic fail diff --git a/net/cqueues.lua b/net/cqueues.lua new file mode 100644 index 00000000..8c4c756f --- /dev/null +++ b/net/cqueues.lua @@ -0,0 +1,74 @@ +-- Prosody IM +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- +-- This module allows you to use cqueues with a net.server mainloop +-- + +local server = require "net.server"; +local cqueues = require "cqueues"; +assert(cqueues.VERSION >= 20150113, "cqueues newer than 20150113 required") + +-- Create a single top level cqueue +local cq; + +if server.cq then -- server provides cqueues object + cq = server.cq; +elseif server.get_backend() == "select" and server._addtimer then -- server_select + cq = cqueues.new(); + local function step() + assert(cq:loop(0)); + end + + -- Use wrapclient (as wrapconnection isn't exported) to get server_select to watch cq fd + local handler = server.wrapclient({ + getfd = function() return cq:pollfd(); end; + settimeout = function() end; -- Method just needs to exist + close = function() end; -- Need close method for 'closeall' + }, nil, nil, {}); + + -- Only need to listen for readable; cqueues handles everything under the hood + -- readbuffer is called when `select` notes an fd as readable + handler.readbuffer = step; + + -- Use server_select low lever timer facility, + -- this callback gets called *every* time there is a timeout in the main loop + server._addtimer(function(current_time) + -- This may end up in extra step()'s, but cqueues handles it for us. + step(); + return cq:timeout(); + end); +elseif server.event and server.base then -- server_event + cq = cqueues.new(); + -- Only need to listen for readable; cqueues handles everything under the hood + local EV_READ = server.event.EV_READ; + -- Convert a cqueues timeout to an acceptable timeout for luaevent + local function luaevent_safe_timeout(cq) + local t = cq:timeout(); + -- if you give luaevent 0 or nil, it re-uses the previous timeout. + if t == 0 then + t = 0.000001; -- 1 microsecond is the smallest that works (goes into a `struct timeval`) + elseif t == nil then -- pick something big if we don't have one + t = 0x7FFFFFFF; -- largest 32bit int + end + return t + end + local event_handle; + event_handle = server.base:addevent(cq:pollfd(), EV_READ, function(e) + -- Need to reference event_handle or this callback will get collected + -- This creates a circular reference that can only be broken if event_handle is manually :close()'d + local _ = event_handle; + -- Run as many cqueues things as possible (with a timeout of 0) + -- If an error is thrown, it will break the libevent loop; but prosody resumes after logging a top level error + assert(cq:loop(0)); + return EV_READ, luaevent_safe_timeout(cq); + end, luaevent_safe_timeout(cq)); +else + error "NYI" +end + +return { + cq = cq; +} diff --git a/net/dns.lua b/net/dns.lua index 0d2cce01..83d7ad1e 100644 --- a/net/dns.lua +++ b/net/dns.lua @@ -15,6 +15,7 @@ local socket = require "socket"; local timer = require "util.timer"; local new_ip = require "util.ip".new_ip; +local have_util_net, util_net = pcall(require, "util.net"); local _, windows = pcall(require, "util.windows"); local is_windows = (_ and windows) or os.getenv("WINDIR"); @@ -72,6 +73,7 @@ local default_timeout = 15; -------------------------------------------------- module dns local _ENV = nil; +-- luacheck: std none local dns = {}; @@ -383,6 +385,12 @@ function resolver:A(rr) -- - - - - - - - - - - - - - - - - - - - - - - - A rr.a = string.format('%i.%i.%i.%i', b1, b2, b3, b4); end +if have_util_net and util_net.ntop then + function resolver:A(rr) + rr.a = util_net.ntop(self:sub(4)); + end +end + function resolver:AAAA(rr) local addr = {}; for _ = 1, rr.rdlength, 2 do @@ -403,6 +411,12 @@ function resolver:AAAA(rr) rr.aaaa = addr:gsub(zeros[1], "::", 1):gsub("^0::", "::"):gsub("::0$", "::"); end +if have_util_net and util_net.ntop then + function resolver:AAAA(rr) + rr.aaaa = util_net.ntop(self:sub(16)); + end +end + function resolver:CNAME(rr) -- - - - - - - - - - - - - - - - - - - - CNAME rr.cname = self:name(); end diff --git a/net/http.lua b/net/http.lua index effb0ef5..6e5ad67c 100644 --- a/net/http.lua +++ b/net/http.lua @@ -13,9 +13,10 @@ local util_http = require "util.http"; local events = require "util.events"; local verify_identity = require"util.x509".verify_identity; -local ssl_available = pcall(require, "ssl"); +local basic_resolver = require "net.resolvers.basic"; +local connect = require "net.connect".connect; -local server = require "net.server" +local ssl_available = pcall(require, "ssl"); local t_insert, t_concat = table.insert, table.concat; local pairs = pairs; @@ -27,6 +28,7 @@ local setmetatable = setmetatable; local log = require "util.logger".init("http"); local _ENV = nil; +-- luacheck: std none local requests = {}; -- Open requests @@ -34,9 +36,78 @@ local function make_id(req) return (tostring(req):match("%x+$")); end local listener = { default_port = 80, default_mode = "*a" }; +-- Request-related helper functions +local function handleerr(err) log("error", "Traceback[http]: %s", traceback(tostring(err), 2)); return err; end +local function log_if_failed(req, ret, ...) + if not ret then + log("error", "Request '%s': error in callback: %s", req.id, tostring((...))); + if not req.suppress_errors then + error(...); + end + end + return ...; +end + +local function destroy_request(request) + local conn = request.conn; + if conn then + request.conn = nil; + conn:close() + end +end + +local function request_reader(request, data, err) + if not request.parser then + local function error_cb(reason) + if request.callback then + request.callback(reason or "connection-closed", 0, request); + request.callback = nil; + end + destroy_request(request); + end + + if not data then + error_cb(err); + return; + end + + local function success_cb(r) + if request.callback then + request.callback(r.body, r.code, r, request); + request.callback = nil; + end + destroy_request(request); + end + local function options_cb() + return request; + end + request.parser = httpstream_new(success_cb, error_cb, "client", options_cb); + end + request.parser:feed(data); +end + +-- Connection listener callbacks function listener.onconnect(conn) local req = requests[conn]; + -- Initialize request object + req.write = function (...) return req.conn:write(...); end + local callback = req.callback; + req.callback = function (content, code, response, request) + do + local event = { http = req.http, url = req.url, request = req, response = response, content = content, code = code, callback = req.callback }; + req.http.events.fire_event("response", event); + content, code, response = event.content, event.code, event.response; + end + + log("debug", "Request '%s': Calling callback, status %s", req.id, code or "---"); + return log_if_failed(req.id, xpcall(function () return callback(content, code, response, request) end, handleerr)); + end + req.reader = request_reader; + req.state = "status"; + + requests[req.conn] = req; + -- Validate certificate if not req.insecure and conn:ssl() then local sock = conn:socket(); @@ -96,58 +167,24 @@ function listener.ondisconnect(conn, err) requests[conn] = nil; end -function listener.ondetach(conn) - requests[conn] = nil; -end - -local function destroy_request(request) - if request.conn then - request.conn = nil; - request.handler:close() - end +function listener.onattach(conn, req) + requests[conn] = req; + req.conn = conn; end -local function request_reader(request, data, err) - if not request.parser then - local function error_cb(reason) - if request.callback then - request.callback(reason or "connection-closed", 0, request); - request.callback = nil; - end - destroy_request(request); - end - - if not data then - error_cb(err); - return; - end - - local function success_cb(r) - if request.callback then - request.callback(r.body, r.code, r, request); - request.callback = nil; - end - destroy_request(request); - end - local function options_cb() - return request; - end - request.parser = httpstream_new(success_cb, error_cb, "client", options_cb); - end - request.parser:feed(data); +function listener.ondetach(conn) + requests[conn] = nil; end -local function handleerr(err) log("error", "Traceback[http]: %s", traceback(tostring(err), 2)); end -local function log_if_failed(id, ret, ...) - if not ret then - log("error", "Request '%s': error in callback: %s", id, tostring((...))); - end - return ...; +function listener.onfail(req, reason) + req.http.events.fire_event("request-connection-error", { http = req.http, request = req, url = req.url, err = reason }); + req.callback(reason or "connection failed", 0, req); end local function request(self, u, ex, callback) local req = url.parse(u); req.url = u; + req.http = self; if not (req and req.host) then callback("invalid-url", 0, req); @@ -166,7 +203,7 @@ local function request(self, u, ex, callback) if ret then return ret; end - req, u, ex, callback = event.request, event.url, event.options, event.callback; + req, u, ex, req.callback = event.request, event.url, event.options, event.callback; end local method, headers, body; @@ -204,6 +241,7 @@ local function request(self, u, ex, callback) end end req.insecure = ex.insecure; + req.suppress_errors = ex.suppress_errors; end log("debug", "Making %s %s request '%s' to %s", req.scheme:upper(), method or "GET", req.id, (ex and ex.suppress_url and host_header) or u); @@ -222,29 +260,8 @@ local function request(self, u, ex, callback) sslctx = ex and ex.sslctx or self.options and self.options.sslctx; end - local handler, conn = server.addclient(host, port_number, listener, "*a", sslctx) - if not handler then - self.events.fire_event("request-connection-error", { http = self, request = req, url = u, err = conn }); - callback(conn, 0, req); - return nil, conn; - end - req.handler, req.conn = handler, conn - req.write = function (...) return req.handler:write(...); end - - req.callback = function (content, code, response, request) - do - local event = { http = self, url = u, request = req, response = response, content = content, code = code, callback = callback }; - self.events.fire_event("response", event); - content, code, response = event.content, event.code, event.response; - end - - log("debug", "Request '%s': Calling callback, status %s", req.id, code or "---"); - return log_if_failed(req.id, xpcall(function () return callback(content, code, response, request) end, handleerr)); - end - req.reader = request_reader; - req.state = "status"; - - requests[req.handler] = req; + local http_service = basic_resolver.new(host, port_number); + connect(http_service, listener, { sslctx = sslctx }, req); self.events.fire_event("request", { http = self, request = req, url = u }); return req; @@ -264,6 +281,7 @@ end local default_http = new({ sslctx = { mode = "client", protocol = "sslv23", options = { "no_sslv2", "no_sslv3" } }; + suppress_errors = true; }); return { diff --git a/net/httpserver.lua b/net/httpserver.lua index 6e2e31b9..6b14313b 100644 --- a/net/httpserver.lua +++ b/net/httpserver.lua @@ -3,9 +3,10 @@ local log = require "util.logger".init("net.httpserver"); local traceback = debug.traceback; local _ENV = nil; +-- luacheck: std none -function fail() - log("error", "Attempt to use legacy HTTP API. For more info see http://prosody.im/doc/developers/legacy_http"); +local function fail() + log("error", "Attempt to use legacy HTTP API. For more info see https://prosody.im/doc/developers/legacy_http"); log("error", "Legacy HTTP API usage, %s", traceback("", 2)); end diff --git a/net/resolvers/basic.lua b/net/resolvers/basic.lua new file mode 100644 index 00000000..c2fd9260 --- /dev/null +++ b/net/resolvers/basic.lua @@ -0,0 +1,71 @@ +local adns = require "net.adns"; +local inet_pton = require "util.net".pton; + +local methods = {}; +local resolver_mt = { __index = methods }; + +-- Find the next target to connect to, and +-- pass it to cb() +function methods:next(cb) + if self.targets then + if #self.targets == 0 then + cb(nil); + return; + end + local next_target = table.remove(self.targets, 1); + cb(unpack(next_target, 1, 4)); + return; + end + + local targets = {}; + local n = 2; + local function ready() + n = n - 1; + if n > 0 then return; end + self.targets = targets; + self:next(cb); + end + + local is_ip = inet_pton(self.hostname); + if is_ip then + if #is_ip == 16 then + cb(self.conn_type.."6", self.hostname, self.port, self.extra); + elseif #is_ip == 4 then + cb(self.conn_type, self.hostname, self.port, self.extra); + end + return; + end + + -- Resolve DNS to target list + local dns_resolver = adns.resolver(); + dns_resolver:lookup(function (answer) + if answer then + for _, record in ipairs(answer) do + table.insert(targets, { self.conn_type, record.a, self.port, self.extra }); + end + end + ready(); + end, self.hostname, "A", "IN"); + + dns_resolver:lookup(function (answer) + if answer then + for _, record in ipairs(answer) do + table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra }); + end + end + ready(); + end, self.hostname, "AAAA", "IN"); +end + +local function new(hostname, port, conn_type, extra) + return setmetatable({ + hostname = hostname; + port = port; + conn_type = conn_type or "tcp"; + extra = extra; + }, resolver_mt); +end + +return { + new = new; +}; diff --git a/net/resolvers/manual.lua b/net/resolvers/manual.lua new file mode 100644 index 00000000..c0d4e5d5 --- /dev/null +++ b/net/resolvers/manual.lua @@ -0,0 +1,25 @@ +local methods = {}; +local resolver_mt = { __index = methods }; + +-- Find the next target to connect to, and +-- pass it to cb() +function methods:next(cb) + if #self.targets == 0 then + cb(nil); + return; + end + local next_target = table.remove(self.targets, 1); + cb(unpack(next_target, 1, 4)); +end + +local function new(targets, conn_type, extra) + return setmetatable({ + conn_type = conn_type; + extra = extra; + targets = targets or {}; + }, resolver_mt); +end + +return { + new = new; +}; diff --git a/net/server.lua b/net/server.lua index 41e180fa..d8f24847 100644 --- a/net/server.lua +++ b/net/server.lua @@ -6,25 +6,76 @@ -- COPYING file in the source package for more information. -- -local use_luaevent = prosody and require "core.configmanager".get("*", "use_libevent"); +local log = require "util.logger".init("net.server"); +local server_type = prosody and require "core.configmanager".get("*", "network_backend") or "select"; +if prosody and require "core.configmanager".get("*", "use_libevent") then + server_type = "event"; +end -if use_luaevent then - use_luaevent = pcall(require, "luaevent.core"); - if not use_luaevent then +if server_type == "event" then + if not pcall(require, "luaevent.core") then log("error", "libevent not found, falling back to select()"); + server_type = "select" end end local server; - -if use_luaevent then +local set_config; +if server_type == "event" then server = require "net.server_event"; - -- Overwrite signal.signal() because we need to ask libevent to - -- handle them instead - local ok, signal = pcall(require, "util.signal"); - if ok and signal then - local _signal_signal = signal.signal; + local defaults = {}; + for k,v in pairs(server.cfg) do + defaults[k] = v; + end + function set_config(settings) + local event_settings = { + ACCEPT_DELAY = settings.accept_retry_interval; + ACCEPT_QUEUE = settings.tcp_backlog; + CLEAR_DELAY = settings.event_clear_interval; + CONNECT_TIMEOUT = settings.connect_timeout; + DEBUG = settings.debug; + HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout; + MAX_CONNECTIONS = settings.max_connections; + MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips; + MAX_READ_LENGTH = settings.max_receive_buffer_size; + MAX_SEND_LENGTH = settings.max_send_buffer_size; + READ_TIMEOUT = settings.read_timeout; + WRITE_TIMEOUT = settings.send_timeout; + }; + + for k,default in pairs(defaults) do + server.cfg[k] = event_settings[k] or default; + end + end +elseif server_type == "select" then + server = require "net.server_select"; + + local defaults = {}; + for k,v in pairs(server.getsettings()) do + defaults[k] = v; + end + function set_config(settings) + local select_settings = {}; + for k,default in pairs(defaults) do + select_settings[k] = settings[k] or default; + end + server.changesettings(select_settings); + end +else + server = require("net.server_"..server_type); + set_config = server.set_config; + if not server.get_backend then + function server.get_backend() + return server_type; + end + end +end + +-- If server.hook_signal exists, replace signal.signal() +local has_signal, signal = pcall(require, "util.signal"); +if has_signal then + if server.hook_signal then function signal.signal(signal_id, handler) if type(signal_id) == "string" then signal_id = signal[signal_id:upper()]; @@ -34,46 +85,22 @@ if use_luaevent then end return server.hook_signal(signal_id, handler); end + else + server.hook_signal = signal.signal; end else - use_luaevent = false; - server = require "net.server_select"; + if not server.hook_signal then + server.hook_signal = function() + return false, "signal hooking not supported" + end + end end -if prosody then +if prosody and set_config then local config_get = require "core.configmanager".get; - local defaults = {}; - for k,v in pairs(server.cfg or server.getsettings()) do - defaults[k] = v; - end local function load_config() local settings = config_get("*", "network_settings") or {}; - if use_luaevent then - local event_settings = { - ACCEPT_DELAY = settings.accept_retry_interval; - ACCEPT_QUEUE = settings.tcp_backlog; - CLEAR_DELAY = settings.event_clear_interval; - CONNECT_TIMEOUT = settings.connect_timeout; - DEBUG = settings.debug; - HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout; - MAX_CONNECTIONS = settings.max_connections; - MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips; - MAX_READ_LENGTH = settings.max_receive_buffer_size; - MAX_SEND_LENGTH = settings.max_send_buffer_size; - READ_TIMEOUT = settings.read_timeout; - WRITE_TIMEOUT = settings.send_timeout; - }; - - for k,default in pairs(defaults) do - server.cfg[k] = event_settings[k] or default; - end - else - local select_settings = {}; - for k,default in pairs(defaults) do - select_settings[k] = settings[k] or default; - end - server.changesettings(select_settings); - end + return set_config(settings); end load_config(); prosody.events.add_handler("config-reloaded", load_config); diff --git a/net/server_epoll.lua b/net/server_epoll.lua new file mode 100644 index 00000000..0881f797 --- /dev/null +++ b/net/server_epoll.lua @@ -0,0 +1,718 @@ +-- Prosody IM +-- Copyright (C) 2016 Kim Alvefur +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +-- server_epoll +-- Server backend based on https://luarocks.org/modules/zash/lua-epoll + +local t_sort = table.sort; +local t_insert = table.insert; +local t_remove = table.remove; +local t_concat = table.concat; +local setmetatable = setmetatable; +local tostring = tostring; +local pcall = pcall; +local type = type; +local next = next; +local pairs = pairs; +local log = require "util.logger".init("server_epoll"); +local epoll = require "epoll"; +local socket = require "socket"; +local luasec = require "ssl"; +local gettime = require "util.time".now; +local createtable = require "util.table".create; +local _SOCKETINVALID = socket._SOCKETINVALID or -1; + +assert(socket.tcp6 and socket.tcp4, "Incompatible LuaSocket version"); + +local _ENV = nil; +-- luacheck: std none + +local default_config = { __index = { + read_timeout = 900; + write_timeout = 7; + tcp_backlog = 128; + accept_retry_interval = 10; + read_retry_delay = 1e-06; + connect_timeout = 20; + handshake_timeout = 60; + max_wait = 86400; + min_wait = 1e-06; +}}; +local cfg = default_config.__index; + +local fds = createtable(10, 0); -- FD -> conn + +-- Timer and scheduling -- + +local timers = {}; + +local function noop() end +local function closetimer(t) + t[1] = 0; + t[2] = noop; +end + +-- Set to true when timers have changed +local resort_timers = false; + +-- Add absolute timer +local function at(time, f) + local timer = { time, f, close = closetimer }; + t_insert(timers, timer); + resort_timers = true; + return timer; +end + +-- Add relative timer +local function addtimer(timeout, f) + return at(gettime() + timeout, f); +end + +-- Run callbacks of expired timers +-- Return time until next timeout +local function runtimers(next_delay, min_wait) + -- Any timers at all? + if not timers[1] then + return next_delay; + end + + if resort_timers then + -- Sort earliest timers to the end + t_sort(timers, function (a, b) return a[1] > b[1]; end); + resort_timers = false; + end + + -- Iterate from the end and remove completed timers + for i = #timers, 1, -1 do + local timer = timers[i]; + local t, f = timer[1], timer[2]; + -- Get time for every iteration to increase accuracy + local now = gettime(); + if t > now then + -- This timer should not fire yet + local diff = t - now; + if diff < next_delay then + next_delay = diff; + end + break; + end + local new_timeout = f(now); + if new_timeout then + -- Schedule for 'delay' from the time actually scheduled, + -- not from now, in order to prevent timer drift. + timer[1] = t + new_timeout; + resort_timers = true; + else + t_remove(timers, i); + end + end + + if resort_timers or next_delay < min_wait then + -- Timers may be added from within a timer callback. + -- Those would not be considered for next_delay, + -- and we might sleep for too long, so instead + -- we return a shorter timeout so we can + -- properly sort all new timers. + next_delay = min_wait; + end + + return next_delay; +end + +-- Socket handler interface + +local interface = {}; +local interface_mt = { __index = interface }; + +function interface_mt:__tostring() + if self.sockname and self.peername then + return ("FD %d (%s, %d, %s, %d)"):format(self:getfd(), self.peername, self.peerport, self.sockname, self.sockport); + elseif self.sockname or self.peername then + return ("FD %d (%s, %d)"):format(self:getfd(), self.sockname or self.peername, self.sockport or self.peerport); + end + return ("%s FD %d"):format(tostring(self.conn), self:getfd()); +end + +-- Replace the listener and tell the old one +function interface:setlistener(listeners) + self:on("detach"); + self.listeners = listeners; +end + +-- Call a listener callback +function interface:on(what, ...) + if not self.listeners then + log("error", "%s has no listeners", self); + return; + end + local listener = self.listeners["on"..what]; + if not listener then + -- log("debug", "Missing listener 'on%s'", what); -- uncomment for development and debugging + return; + end + local ok, err = pcall(listener, self, ...); + if not ok then + log("error", "Error calling on%s: %s", what, err); + end + return err; +end + +-- Return the file descriptor number +function interface:getfd() + if self.conn then + return self.conn:getfd(); + end + return _SOCKETINVALID; +end + +function interface:server() + return self._server or self; +end + +-- Get IP address +function interface:ip() + return self.peername or self.sockname; +end + +-- Get a port number, doesn't matter which +function interface:port() + return self.sockport or self.peerport; +end + +-- Get local port number +function interface:clientport() + return self.sockport; +end + +-- Get remote port +function interface:serverport() + if self.sockport then + return self.sockport; + elseif self._server then + self._server:port(); + end +end + +-- Return underlying socket +function interface:socket() + return self.conn; +end + +function interface:set_mode(new_mode) + self._pattern = new_mode; +end + +function interface:setoption(k, v) + -- LuaSec doesn't expose setoption :( + if self.conn.setoption then + self.conn:setoption(k, v); + end +end + +-- Timeout for detecting dead or idle sockets +function interface:setreadtimeout(t) + if t == false then + if self._readtimeout then + self._readtimeout:close(); + self._readtimeout = nil; + end + return + end + t = t or cfg.read_timeout; + if self._readtimeout then + self._readtimeout[1] = gettime() + t; + resort_timers = true; + else + self._readtimeout = addtimer(t, function () + if self:on("readtimeout") then + return cfg.read_timeout; + else + self:on("disconnect", "read timeout"); + self:destroy(); + end + end); + end +end + +-- Timeout for detecting dead sockets +function interface:setwritetimeout(t) + if t == false then + if self._writetimeout then + self._writetimeout:close(); + self._writetimeout = nil; + end + return + end + t = t or cfg.write_timeout; + if self._writetimeout then + self._writetimeout[1] = gettime() + t; + resort_timers = true; + else + self._writetimeout = addtimer(t, function () + self:on("disconnect", "write timeout"); + self:destroy(); + end); + end +end + +-- lua-epoll flag for currently requested poll state +function interface:flags() + if self._wantread then + if self._wantwrite then + return "rw"; + end + return "r"; + elseif self._wantwrite then + return "w"; + end +end + +-- Add or remove sockets or modify epoll flags +function interface:setflags(r, w) + if r ~= nil then self._wantread = r; end + if w ~= nil then self._wantwrite = w; end + local flags = self:flags(); + local currentflags = self._flags; + if flags == currentflags then + return true; + end + local fd = self:getfd(); + if fd < 0 then + self._wantread, self._wantwrite = nil, nil; + return nil, "invalid fd"; + end + local op = "mod"; + if not flags then + op = "del"; + elseif not currentflags then + op = "add"; + end + local ok, err = epoll.ctl(op, fd, flags); +-- log("debug", "epoll_ctl(%q, %d, %q) -> %s" .. (err and ", %q" or ""), +-- op, fd, flags or "", tostring(ok), err); + if not ok then return ok, err end + if op == "add" then + fds[fd] = self; + elseif op == "del" then + fds[fd] = nil; + end + self._flags = flags; + return true; +end + +-- Called when socket is readable +function interface:onreadable() + local data, err, partial = self.conn:receive(self._pattern); + if data then + self:onconnect(); + self:on("incoming", data); + else + if partial and partial ~= "" then + self:onconnect(); + self:on("incoming", partial, err); + end + if err == "wantread" then + self:setflags(true, nil); + elseif err == "wantwrite" then + self:setflags(nil, true); + elseif err ~= "timeout" then + self:on("disconnect", err); + self:destroy() + return; + end + end + if not self.conn then return; end + if self.conn:dirty() then + self:setreadtimeout(false); + self:pausefor(cfg.read_retry_delay); + else + self:setreadtimeout(); + end +end + +-- Called when socket is writable +function interface:onwritable() + self:onconnect(); + if not self.conn then return; end -- could have been closed in onconnect + local buffer = self.writebuffer; + local data = t_concat(buffer); + local ok, err, partial = self.conn:send(data); + if ok then + self:setflags(nil, false); + for i = #buffer, 1, -1 do + buffer[i] = nil; + end + self:setwritetimeout(false); + self:ondrain(); -- Be aware of writes in ondrain + return; + elseif partial then + buffer[1] = data:sub(partial+1); + for i = #buffer, 2, -1 do + buffer[i] = nil; + end + self:setwritetimeout(); + end + if err == "wantwrite" or err == "timeout" then + self:setflags(nil, true); + elseif err == "wantread" then + self:setflags(true, nil); + elseif err ~= "timeout" then + self:on("disconnect", err); + self:destroy(); + end +end + +-- The write buffer has been successfully emptied +function interface:ondrain() + return self:on("drain"); +end + +-- Add data to write buffer and set flag for wanting to write +function interface:write(data) + local buffer = self.writebuffer; + if buffer then + t_insert(buffer, data); + else + self.writebuffer = { data }; + end + self:setwritetimeout(); + self:setflags(nil, true); + return #data; +end +interface.send = interface.write; + +-- Close, possibly after writing is done +function interface:close() + if self.writebuffer and self.writebuffer[1] then + self:setflags(false, true); -- Flush final buffer contents + self.write, self.send = noop, noop; -- No more writing + log("debug", "Close %s after writing", tostring(self)); + self.ondrain = interface.close; + else + log("debug", "Close %s now", tostring(self)); + self.write, self.send = noop, noop; + self.close = noop; + self:on("disconnect"); + self:destroy(); + end +end + +function interface:destroy() + self:setflags(false, false); + self:setwritetimeout(false); + self:setreadtimeout(false); + self.onreadable = noop; + self.onwritable = noop; + self.destroy = noop; + self.close = noop; + self.on = noop; + self.conn:close(); + self.conn = nil; +end + +function interface:ssl() + return self._tls; +end + +function interface:starttls(ctx) + if ctx then self.tls = ctx; end + if self.writebuffer and self.writebuffer[1] then + log("debug", "Start TLS on %s after write", tostring(self)); + self.ondrain = interface.starttls; + self.starttls = false; + self:setflags(nil, true); -- make sure wantwrite is set + else + log("debug", "Start TLS on %s now", tostring(self)); + self:setflags(false, false); + local conn, err = luasec.wrap(self.conn, ctx or self.tls); + if not conn then + self:on("disconnect", err); + self:destroy(); + return conn, err; + end + conn:settimeout(0); + self.conn = conn; + self.ondrain = nil; + self.onwritable = interface.tlshandskake; + self.onreadable = interface.tlshandskake; + self:setflags(true, true); + self:setwritetimeout(cfg.handshake_timeout); + end +end + +function interface:tlshandskake() + self:setwritetimeout(false); + self:setreadtimeout(false); + local ok, err = self.conn:dohandshake(); + if ok then + log("debug", "TLS handshake on %s complete", tostring(self)); + self.onwritable = nil; + self.onreadable = nil; + self._tls = true; + self:on("status", "ssl-handshake-complete"); + self:init(); + elseif err == "wantread" then + log("debug", "TLS handshake on %s to wait until readable", tostring(self)); + self:setflags(true, false); + self:setreadtimeout(cfg.handshake_timeout); + elseif err == "wantwrite" then + log("debug", "TLS handshake on %s to wait until writable", tostring(self)); + self:setflags(false, true); + self:setwritetimeout(cfg.handshake_timeout); + else + log("debug", "TLS handshake error on %s: %s", tostring(self), err); + self:on("disconnect", err); + self:destroy(); + end +end + +local function wrapsocket(client, server, pattern, listeners, tls) -- luasocket object -> interface object + client:settimeout(0); + local conn = setmetatable({ + conn = client; + _server = server; + created = gettime(); + listeners = listeners; + _pattern = pattern or (server and server._pattern); + writebuffer = {}; + tls = tls; + }, interface_mt); + + if client.getpeername then + conn.peername, conn.peerport = client:getpeername(); + end + if client.getsockname then + conn.sockname, conn.sockport = client:getsockname(); + end + return conn; +end + +-- A server interface has new incoming connections waiting +-- This replaces the onreadable callback +function interface:onacceptable() + local conn, err = self.conn:accept(); + if not conn then + log("debug", "Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval); + self:pausefor(cfg.accept_retry_interval); + return; + end + local client = wrapsocket(conn, self, nil, self.listeners, self.tls); + log("debug", "New connection %s", tostring(client)); + client:init(); +end + +-- Initialization +function interface:init() + if self.tls and not self._tls then + return self:starttls(); + else + self:setwritetimeout(); + return self:setflags(true, true); + end +end + +function interface:pause() + return self:setflags(false); +end + +function interface:resume() + return self:setflags(true); +end + +-- Pause connection for some time +function interface:pausefor(t) + if self._pausefor then + self._pausefor:close(); + end + if t == false then return; end + self:setflags(false); + self._pausefor = addtimer(t, function () + self._pausefor = nil; + if self.conn and self.conn:dirty() then + self:onreadable(); + end + self:setflags(true); + end); +end + +-- Connected! +function interface:onconnect() + if self.conn and not self.peername and self.conn.getpeername then + self.peername, self.peerport = self.conn:getpeername(); + end + self.onconnect = noop; + self:on("connect"); +end + +local function addserver(addr, port, listeners, pattern, tls) + local conn, err = socket.bind(addr, port, cfg.tcp_backlog); + if not conn then return conn, err; end + conn:settimeout(0); + local server = setmetatable({ + conn = conn; + created = gettime(); + listeners = listeners; + _pattern = pattern; + onreadable = interface.onacceptable; + tls = tls; + sockname = addr; + sockport = port; + }, interface_mt); + server:setflags(true, false); + return server; +end + +-- COMPAT +local function wrapclient(conn, addr, port, listeners, pattern, tls) + local client = wrapsocket(conn, nil, pattern, listeners, tls); + if not client.peername then + client.peername, client.peerport = addr, port; + end + client:init(); + return client; +end + +-- New outgoing TCP connection +local function addclient(addr, port, listeners, pattern, tls) + local conn, err = socket.tcp(); + if not conn then return conn, err; end + conn:settimeout(0); + conn:connect(addr, port); + local client = wrapsocket(conn, nil, pattern, listeners, tls) + client:init(); + return client, conn; +end + +local function watchfd(fd, onreadable, onwriteable) + local conn = setmetatable({ + conn = fd; + onreadable = onreadable; + onwriteable = onwriteable; + close = function (self) + self:setflags(false, false); + end + }, interface_mt); + if type(fd) == "number" then + conn.getfd = function () + return fd; + end; + -- Otherwise it'll need to be something LuaSocket-compatible + end + conn:setflags(onreadable, onwriteable); + return conn; +end; + +-- Dump all data from one connection into another +local function link(from, to) + from.listeners = setmetatable({ + onincoming = function (_, data) + from:pause(); + to:write(data); + end, + }, {__index=from.listeners}); + to.listeners = setmetatable({ + ondrain = function () + from:resume(); + end, + }, {__index=to.listeners}); + from:setflags(true, nil); + to:setflags(nil, true); +end + +-- XXX What uses this? +-- net.adns +function interface:set_send(new_send) + self.send = new_send; +end + +-- Close all connections and servers +local function closeall() + for fd, conn in pairs(fds) do -- luacheck: ignore 213/fd + conn:close(); + end +end + +local quitting = nil; + +-- Signal main loop about shutdown via above upvalue +local function setquitting(quit) + if quit then + quitting = "quitting"; + closeall(); + else + quitting = nil; + end +end + +-- Main loop +local function loop(once) + repeat + local t = runtimers(cfg.max_wait, cfg.min_wait); + local fd, r, w = epoll.wait(t); + if fd then + local conn = fds[fd]; + if conn then + if r then + conn:onreadable(); + end + if w then + conn:onwritable(); + end + else + log("debug", "Removing unknown fd %d", fd); + epoll.ctl("del", fd); + end + elseif r ~= "timeout" then + log("debug", "epoll_wait error: %s", tostring(r)); + end + until once or (quitting and next(fds) == nil); + return quitting; +end + +return { + get_backend = function () return "epoll"; end; + addserver = addserver; + addclient = addclient; + add_task = addtimer; + at = at; + loop = loop; + closeall = closeall; + setquitting = setquitting; + wrapclient = wrapclient; + watchfd = watchfd; + link = link; + set_config = function (newconfig) + cfg = setmetatable(newconfig, default_config); + end; + + -- libevent emulation + event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 }; + addevent = function (fd, mode, callback) + local function onevent(self) + local ret = self:callback(); + if ret == -1 then + self:setflags(false, false); + elseif ret then + self:setflags(mode == "r" or mode == "rw", mode == "w" or mode == "rw"); + end + end + + local conn = setmetatable({ + getfd = function () return fd; end; + callback = callback; + onreadable = onevent; + onwritable = onevent; + close = function (self) + self:setflags(false, false); + fds[fd] = nil; + end; + }, interface_mt); + local ok, err = conn:setflags(mode == "r" or mode == "rw", mode == "w" or mode == "rw"); + if not ok then return ok, err; end + return conn; + end; +}; diff --git a/net/server_event.lua b/net/server_event.lua index 3a907349..3e949092 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -5,9 +5,9 @@ notes: -- when using luaevent, never register 2 or more EV_READ at one socket, same for EV_WRITE - -- you cant even register a new EV_READ/EV_WRITE callback inside another one + -- you can't even register a new EV_READ/EV_WRITE callback inside another one -- to do some of the above, use timeout events or something what will called from outside - -- dont let garbagecollect eventcallbacks, as long they are running + -- don't let garbagecollect eventcallbacks, as long they are running -- when using luasec, there are 4 cases of timeout errors: wantread or wantwrite during reading or writing --]] @@ -106,6 +106,12 @@ function interface_mt:_start_connection(plainssl) -- called from wrapclient self:_close() debug( "new connection failed. id:", self.id, "error:", self.fatalerror ) else + if EV_READWRITE == event then + if self.readcallback(event) == -1 then + -- Fatal error occurred + return -1; + end + end if plainssl and has_luasec then -- start ssl session self:starttls(self._sslctx, true) else -- normal connection @@ -116,7 +122,7 @@ function interface_mt:_start_connection(plainssl) -- called from wrapclient self.eventconnect = nil return -1 end - self.eventconnect = addevent( base, self.conn, EV_WRITE, callback, cfg.CONNECT_TIMEOUT ) + self.eventconnect = addevent( base, self.conn, EV_READWRITE, callback, cfg.CONNECT_TIMEOUT ) return true end function interface_mt:_start_session(call_onconnect) -- new session, for example after startssl @@ -151,7 +157,7 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed self.fatalerror = err self.conn = nil -- cannot be used anymore if call_onconnect then - self.ondisconnect = nil -- dont call this when client isnt really connected + self.ondisconnect = nil -- don't call this when client isn't really connected end self:_close() debug( "fatal error while ssl wrapping:", err ) @@ -194,7 +200,7 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed end if self.fatalerror then if call_onconnect then - self.ondisconnect = nil -- dont call this when client isnt really connected + self.ondisconnect = nil -- don't call this when client isn't really connected end self:_close() debug( "handshake failed because:", self.fatalerror ) @@ -223,7 +229,8 @@ function interface_mt:_destroy() -- close this interface + events and call last _ = self.eventsession and self.eventsession:close( ) _ = self.eventwritetimeout and self.eventwritetimeout:close( ) _ = self.eventreadtimeout and self.eventreadtimeout:close( ) - _ = self.ondisconnect and self:ondisconnect( self.fatalerror ~= "client to close" and self.fatalerror) -- call ondisconnect listener (wont be the case if handshake failed on connect) + -- call ondisconnect listener (won't be the case if handshake failed on connect) + _ = self.ondisconnect and self:ondisconnect( self.fatalerror ~= "client to close" and self.fatalerror) _ = self.conn and self.conn:close( ) -- close connection _ = self._server and self._server:counter(-1); self.eventread, self.eventwrite = nil, nil @@ -510,7 +517,7 @@ local function handleclient( client, ip, port, server, pattern, listener, sslctx interface.writebuffer = { t_concat(interface.writebuffer) } local succ, err, byte = interface.conn:send( interface.writebuffer[1], 1, interface.writebufferlen ) --vdebug( "write data:", interface.writebuffer, "error:", err, "part:", byte ) - if succ then -- writing succesful + if succ then -- writing successful interface.writebuffer[1] = nil interface.writebufferlen = 0 interface:ondrain(); @@ -539,7 +546,7 @@ local function handleclient( client, ip, port, server, pattern, listener, sslctx return -1; end interface.eventwritetimeout = addevent( base, nil, EV_TIMEOUT, callback, cfg.WRITE_TIMEOUT ) -- reg a new timeout event - debug( "wantread during write attempt, reg it in readcallback but dont know what really happens next..." ) + debug( "wantread during write attempt, reg it in readcallback but don't know what really happens next..." ) -- hopefully this works with luasec; its simply not possible to use 2 different write events on a socket in luaevent return -1 end @@ -595,8 +602,8 @@ local function handleclient( client, ip, port, server, pattern, listener, sslctx end interface.eventreadtimeout = addevent( base, nil, EV_TIMEOUT, function( ) interface:_close() end, cfg.READ_TIMEOUT) - debug( "wantwrite during read attempt, reg it in writecallback but dont know what really happens next..." ) - -- to be honest i dont know what happens next, if it is allowed to first read, the write etc... + debug( "wantwrite during read attempt, reg it in writecallback but don't know what really happens next..." ) + -- to be honest i don't know what happens next, if it is allowed to first read, the write etc... else -- connection was closed or fatal error interface.fatalerror = err debug( "connection failed in read event:", interface.fatalerror ) @@ -767,13 +774,15 @@ end local function setquitting(yes) if yes then -- Quit now - closeallservers(); + if yes ~= "once" then + closeallservers(); + end base:loopexit(); end end local function get_backend() - return base:method(); + return "libevent " .. base:method(); end -- We need to hold onto the events to stop them @@ -811,6 +820,48 @@ local function link(sender, receiver, buffersize) sender:set_mode("*a"); end +local function add_task(delay, callback) + local event_handle; + event_handle = base:addevent(nil, 0, function () + local ret = callback(socket_gettime()); + if ret then + return 0, ret; + elseif event_handle then + return -1; + end + end + , delay); + return event_handle; +end + +local function watchfd(fd, onreadable, onwriteable) + local handle = {}; + function handle:setflags(r,w) + if r ~= nil then + if r and not self.wantread then + self.wantread = base:addevent(fd, EV_READ, function () + onreadable(self); + end); + elseif not r and self.wantread then + self.wantread:close(); + self.wantread = nil; + end + end + if w ~= nil then + if w and not self.wantwrite then + self.wantwrite = base:addevent(fd, EV_WRITE, function () + onwriteable(self); + end); + elseif not r and self.wantread then + self.wantwrite:close(); + self.wantwrite = nil; + end + end + end + handle:setflags(onreadable, onwriteable); + return handle; +end + return { cfg = cfg, base = base, @@ -826,6 +877,8 @@ return { closeall = closeallservers, get_backend = get_backend, hook_signal = hook_signal, + add_task = add_task, + watchfd = watchfd, __NAME = SCRIPT_NAME, __DATE = LAST_MODIFIED, diff --git a/net/server_select.lua b/net/server_select.lua index 12aef9d8..3b83bb6d 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -40,6 +40,7 @@ local coroutine = use "coroutine" local math_min = math.min local math_huge = math.huge local table_concat = table.concat +local table_insert = table.insert local string_sub = string.sub local coroutine_wrap = coroutine.wrap local coroutine_yield = coroutine.yield @@ -55,7 +56,6 @@ local getaddrinfo = luasocket.dns.getaddrinfo local ssl_wrap = ( has_luasec and luasec.wrap ) local socket_bind = luasocket.bind -local socket_sleep = luasocket.sleep local socket_select = luasocket.select --// functions //-- @@ -100,7 +100,6 @@ local _sendtraffic local _readtraffic local _selecttimeout -local _sleeptime local _tcpbacklog local _accepretry @@ -114,8 +113,6 @@ local _checkinterval local _sendtimeout local _readtimeout -local _timer - local _maxselectlen local _maxfd @@ -135,13 +132,12 @@ _fullservers = { } -- servers in a paused state while there are too many clients _readlistlen = 0 -- length of readlist _sendlistlen = 0 -- length of sendlist -_timerlistlen = 0 -- lenght of timerlist +_timerlistlen = 0 -- length of timerlist _sendtraffic = 0 -- some stats _readtraffic = 0 _selecttimeout = 1 -- timeout of socket.select -_sleeptime = 0 -- time to wait at the end of every loop _tcpbacklog = 128 -- some kind of hint to the OS _accepretry = 10 -- seconds to wait until the next attempt of a full server to accept @@ -301,7 +297,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport local bufferqueuelen = 0 -- end of buffer array local toclose - local fatalerror local needtls local bufferlen = 0 @@ -425,7 +420,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport bufferlen = bufferlen + #data if bufferlen > maxsendlen then _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle - handler.write = idfalse -- dont write anymore + handler.write = idfalse -- don't write anymore return false elseif socket and not _sendlist[ socket ] then _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) @@ -517,7 +512,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport return dispatch( handler, buffer, err ) else -- connections was closed or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) ) - fatalerror = true _ = handler and handler:force_close( err ) return false end @@ -537,7 +531,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport else succ, err, count = false, "unexpected close", 0; end - if succ then -- sending succesful + if succ then -- sending successful bufferqueuelen = 0 bufferlen = 0 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist @@ -557,7 +551,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport return true else -- connection was closed during sending or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) ) - fatalerror = true _ = handler and handler:force_close( err ) return false end @@ -806,7 +799,6 @@ end getsettings = function( ) return { select_timeout = _selecttimeout; - select_sleep_time = _sleeptime; tcp_backlog = _tcpbacklog; max_send_buffer_size = _maxsendlen; max_receive_buffer_size = _maxreadlen; @@ -825,7 +817,6 @@ changesettings = function( new ) return nil, "invalid settings table" end _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout - _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval @@ -848,6 +839,49 @@ addtimer = function( listener ) return true end +local add_task do + local data = {}; + local new_data = {}; + + function add_task(delay, callback) + local current_time = luasocket_gettime(); + delay = delay + current_time; + if delay >= current_time then + table_insert(new_data, {delay, callback}); + else + local r = callback(current_time); + if r and type(r) == "number" then + return add_task(r, callback); + end + end + end + + addtimer(function(current_time) + if #new_data > 0 then + for _, d in pairs(new_data) do + table_insert(data, d); + end + new_data = {}; + end + + local next_time = math_huge; + for i, d in pairs(data) do + local t, callback = d[1], d[2]; + if t <= current_time then + data[i] = nil; + local r = callback(current_time); + if type(r) == "number" then + add_task(r, callback); + next_time = math_min(next_time, r); + end + else + next_time = math_min(next_time, t - current_time); + end + end + return next_time; + end); +end + stats = function( ) return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen end @@ -855,31 +889,38 @@ end local quitting; local function setquitting(quit) - quitting = not not quit; + quitting = quit; end loop = function(once) -- this is the main loop of the program if quitting then return "quitting"; end if once then quitting = "once"; end - local next_timer_time = math_huge; + _currenttime = luasocket_gettime( ) repeat + -- Fire timers + local next_timer_time = math_huge; + for i = 1, _timerlistlen do + local t = _timerlist[ i ]( _currenttime ) -- fire timers + if t then next_timer_time = math_min(next_timer_time, t); end + end + local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) ) - for _, socket in ipairs( write ) do -- send data waiting in writequeues + for _, socket in ipairs( read ) do -- receive data local handler = _socketlist[ socket ] if handler then - handler.sendbuffer( ) + handler.readbuffer( ) else closesocket( socket ) - out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen + out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen end end - for _, socket in ipairs( read ) do -- receive data + for _, socket in ipairs( write ) do -- send data waiting in writequeues local handler = _socketlist[ socket ] if handler then - handler.readbuffer( ) + handler.sendbuffer( ) else closesocket( socket ) - out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen + out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen end end for handler, err in pairs( _closelist ) do @@ -910,29 +951,14 @@ loop = function(once) -- this is the main loop of the program end end - -- Fire timers - if _currenttime - _timer >= math_min(next_timer_time, 1) then - next_timer_time = math_huge; - for i = 1, _timerlistlen do - local t = _timerlist[ i ]( _currenttime ) -- fire timers - if t then next_timer_time = math_min(next_timer_time, t); end - end - _timer = _currenttime - else - next_timer_time = next_timer_time - (_currenttime - _timer); - end - for server, paused_time in pairs( _fullservers ) do if _currenttime - paused_time > _accepretry then _fullservers[ server ] = nil; server.resume(); end end - - -- wait some time (0 by default) - socket_sleep( _sleeptime ) until quitting; - if once and quitting == "once" then quitting = nil; return; end + if quitting == "once" then quitting = nil; return; end closeall(); return "quitting" end @@ -952,6 +978,7 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx if not handler then return nil, err end _socketlist[ socket ] = handler if not sslctx then + _readlistlen = addsocket(_readlist, socket, _readlistlen) _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) if listeners.onconnect then -- When socket is writeable, call onconnect @@ -977,16 +1004,14 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ ) elseif sslctx and not has_luasec then err = "luasec not found" end - if not typ then + if getaddrinfo and not typ then local addrinfo, err = getaddrinfo(address) if not addrinfo then return nil, err end if addrinfo[1] and addrinfo[1].family == "inet6" then typ = "tcp6" - else - typ = "tcp" end end - local create = luasocket[typ] + local create = luasocket[typ or "tcp"] if type( create ) ~= "function" then err = "invalid socket type" end @@ -1002,14 +1027,54 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ ) end client:settimeout( 0 ) local ok, err = client:connect( address, port ) - if ok or err == "timeout" then + if ok or err == "timeout" or err == "Operation already in progress" then return wrapclient( client, address, port, listeners, pattern, sslctx ) else return nil, err end end ---// EXPERIMENTAL //-- +local closewatcher = function (handler) + local socket = handler.conn; + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _socketlist[ socket ] = nil +end; + +local addremove = function (handler, read, send) + local socket = handler.conn + _socketlist[ socket ] = handler + if read ~= nil then + if read then + _readlistlen = addsocket( _readlist, socket, _readlistlen ) + else + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + end + end + if send ~= nil then + if send then + _sendlistlen = addsocket( _sendlist, socket, _sendlistlen ) + else + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + end + end +end + +local watchfd = function ( fd, onreadable, onwriteable ) + local socket = fd + if type(fd) == "number" then + socket = { getfd = function () return fd; end } + end + local handler = { + conn = socket; + readbuffer = onreadable or id; + sendbuffer = onwriteable or id; + close = closewatcher; + setflags = addremove; + }; + addremove( handler, onreadable, onwriteable ) + return handler +end ----------------------------------// BEGIN //-- @@ -1017,7 +1082,6 @@ use "setmetatable" ( _socketlist, { __mode = "k" } ) use "setmetatable" ( _readtimes, { __mode = "k" } ) use "setmetatable" ( _writetimes, { __mode = "k" } ) -_timer = luasocket_gettime( ) _starttime = luasocket_gettime( ) local function setlogger(new_logger) @@ -1032,9 +1096,11 @@ end return { _addtimer = addtimer, + add_task = add_task; addclient = addclient, wrapclient = wrapclient, + watchfd = watchfd, loop = loop, link = link, diff --git a/net/websocket.lua b/net/websocket.lua index 777b894c..469c6a58 100644 --- a/net/websocket.lua +++ b/net/websocket.lua @@ -21,9 +21,9 @@ local close_timeout = 3; -- Seconds to wait after sending close frame until clos local websockets = {}; local websocket_listeners = {}; -function websocket_listeners.ondisconnect(handler, err) - local s = websockets[handler]; - websockets[handler] = nil; +function websocket_listeners.ondisconnect(conn, err) + local s = websockets[conn]; + websockets[conn] = nil; if s.close_timer then timer.stop(s.close_timer); s.close_timer = nil; @@ -33,19 +33,19 @@ function websocket_listeners.ondisconnect(handler, err) if s.onclose then s:onclose(s.close_code, s.close_message or err); end end -function websocket_listeners.ondetach(handler) - websockets[handler] = nil; +function websocket_listeners.ondetach(conn) + websockets[conn] = nil; end local function fail(s, code, reason) log("warn", "WebSocket connection failed, closing. %d %s", code, reason); s:close(code, reason); - s.handler:close(); + s.conn:close(); return false end -function websocket_listeners.onincoming(handler, buffer, err) -- luacheck: ignore 212/err - local s = websockets[handler]; +function websocket_listeners.onincoming(conn, buffer, err) -- luacheck: ignore 212/err + local s = websockets[conn]; s.readbuffer = s.readbuffer..buffer; while true do local frame, len = frames.parse(s.readbuffer); @@ -111,7 +111,7 @@ function websocket_listeners.onincoming(handler, buffer, err) -- luacheck: ignor elseif frame.opcode == 0x9 then -- Ping frame frame.opcode = 0xA; frame.MASK = true; -- RFC 6455 6.1.5: If the data is being sent by the client, the frame(s) MUST be masked - handler:write(frames.build(frame)); + conn:write(frames.build(frame)); elseif frame.opcode == 0xA then -- Pong frame log("debug", "Received unexpected pong frame: " .. tostring(frame.data)); else @@ -126,15 +126,15 @@ local websocket_methods = {}; local function close_timeout_cb(now, timerid, s) -- luacheck: ignore 212/now 212/timerid s.close_timer = nil; log("warn", "Close timeout waiting for server to close, closing manually."); - s.handler:close(); + s.conn:close(); end function websocket_methods:close(code, reason) if self.readyState < 2 then code = code or 1000; log("debug", "closing WebSocket with code %i: %s" , code , tostring(reason)); self.readyState = 2; - local handler = self.handler; - handler:write(frames.build_close(code, reason, true)); + local conn = self.conn; + conn:write(frames.build_close(code, reason, true)); -- Do not close socket straight away, wait for acknowledgement from server. self.close_timer = timer.add_task(close_timeout, close_timeout_cb, self); elseif self.readyState == 2 then @@ -144,8 +144,8 @@ function websocket_methods:close(code, reason) timer.stop(self.close_timer); self.close_timer = nil; end - local handler = self.handler; - handler:close(); + local conn = self.conn; + conn:close(); else log("debug", "tried to close a closed WebSocket, ignoring."); end @@ -168,7 +168,7 @@ function websocket_methods:send(data, opcode) data = tostring(data); }; log("debug", "WebSocket sending frame: opcode=%0x, %i bytes", frame.opcode, #frame.data); - return self.handler:write(frames.build(frame)); + return self.conn:write(frames.build(frame)); end local websocket_metatable = { @@ -216,7 +216,7 @@ local function connect(url, ex, listeners) local s = setmetatable({ readbuffer = ""; databuffer = nil; - handler = nil; + conn = nil; close_code = nil; close_message = nil; close_timer = nil; @@ -236,6 +236,7 @@ local function connect(url, ex, listeners) method = "GET"; headers = headers; sslctx = ex.sslctx; + insecure = ex.insecure; }, function(b, c, r, http_req) if c ~= 101 or r.headers["connection"]:lower() ~= "upgrade" @@ -252,16 +253,16 @@ local function connect(url, ex, listeners) s.protocol = r.headers["sec-websocket-protocol"]; -- Take possession of socket from http + local conn = http_req.conn; http_req.conn = nil; - local handler = http_req.handler; - s.handler = handler; - websockets[handler] = s; - handler:setlistener(websocket_listeners); + s.conn = conn; + websockets[conn] = s; + conn:setlistener(websocket_listeners); log("debug", "WebSocket connected successfully to %s", url); s.readyState = 1; if s.onopen then s:onopen(); end - websocket_listeners.onincoming(handler, b); + websocket_listeners.onincoming(conn, b); end); return s; diff --git a/net/websocket/frames.lua b/net/websocket/frames.lua index 5fe96d45..ba25d261 100644 --- a/net/websocket/frames.lua +++ b/net/websocket/frames.lua @@ -21,8 +21,8 @@ local t_concat = table.concat; local s_byte = string.byte; local s_char= string.char; local s_sub = string.sub; -local s_pack = string.pack; -local s_unpack = string.unpack; +local s_pack = string.pack; -- luacheck: ignore 143 +local s_unpack = string.unpack; -- luacheck: ignore 143 if not s_pack and softreq"struct" then s_pack = softreq"struct".pack; @@ -112,9 +112,9 @@ end -- TODO: optimize local function apply_mask(str, key, from, to) from = from or 1 - if from < 0 then from = #str + from + 1 end -- negative indicies + if from < 0 then from = #str + from + 1 end -- negative indices to = to or #str - if to < 0 then to = #str + to + 1 end -- negative indicies + if to < 0 then to = #str + to + 1 end -- negative indices local key_len = #key local counter = 0; local data = {}; |