diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/adns.lua | 16 | ||||
-rw-r--r-- | net/connlisteners.lua | 72 | ||||
-rw-r--r-- | net/dns.lua | 78 | ||||
-rw-r--r-- | net/http.lua | 205 | ||||
-rw-r--r-- | net/http/codes.lua | 67 | ||||
-rw-r--r-- | net/http/parser.lua | 160 | ||||
-rw-r--r-- | net/http/server.lua | 303 | ||||
-rw-r--r-- | net/httpclient_listener.lua | 44 | ||||
-rw-r--r-- | net/httpserver.lua | 227 | ||||
-rw-r--r-- | net/httpserver_listener.lua | 46 | ||||
-rw-r--r-- | net/multiplex_listener.lua | 50 | ||||
-rw-r--r-- | net/server.lua | 58 | ||||
-rw-r--r-- | net/server_event.lua | 251 | ||||
-rw-r--r-- | net/server_select.lua | 429 | ||||
-rw-r--r-- | net/xmppclient_listener.lua | 179 | ||||
-rw-r--r-- | net/xmppcomponent_listener.lua | 220 | ||||
-rw-r--r-- | net/xmppserver_listener.lua | 209 |
17 files changed, 1126 insertions, 1488 deletions
diff --git a/net/adns.lua b/net/adns.lua index 2f7b6804..08421f77 100644 --- a/net/adns.lua +++ b/net/adns.lua @@ -1,7 +1,7 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- @@ -64,21 +64,25 @@ function new_async_socket(sock, resolver) if resolver.socketset[conn] == resolver.best_server and resolver.best_server == #servers then log("error", "Exhausted all %d configured DNS servers, next lookup will try %s again", #servers, servers[1]); end - + resolver:servfail(conn); -- Let the magic commence end end - handler = server.wrapclient(sock, "dns", 53, listener); + handler, err = server.wrapclient(sock, "dns", 53, listener); if not handler then - log("warn", "handler is nil"); + return nil, err; end - + handler.settimeout = function () end handler.setsockname = function (_, ...) return sock:setsockname(...); end handler.setpeername = function (_, ...) peername = (...); local ret = sock:setpeername(...); _:set_send(dummy_send); return ret; end handler.connect = function (_, ...) return sock:connect(...) end --handler.send = function (_, data) _:write(data); return _.sendbuffer and _.sendbuffer(); end - handler.send = function (_, data) return sock:send(data); end + handler.send = function (_, data) + local getpeername = sock.getpeername; + log("debug", "Sending DNS query to %s", (getpeername and getpeername(sock)) or "<unconnected>"); + return sock:send(data); + end return handler; end diff --git a/net/connlisteners.lua b/net/connlisteners.lua index e13f85de..99ddc720 100644 --- a/net/connlisteners.lua +++ b/net/connlisteners.lua @@ -1,69 +1,15 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- +-- COMPAT w/pre-0.9 +local log = require "util.logger".init("net.connlisteners"); +local traceback = debug.traceback; +module "httpserver" - -local listeners_dir = (CFG_SOURCEDIR or ".").."/net/"; -local server = require "net.server"; -local log = require "util.logger".init("connlisteners"); -local tostring = tostring; - -local dofile, pcall, error = - dofile, pcall, error - -module "connlisteners" - -local listeners = {}; - -function register(name, listener) - if listeners[name] and listeners[name] ~= listener then - log("debug", "Listener %s is already registered, not registering any more", name); - return false; - end - listeners[name] = listener; - log("debug", "Registered connection listener %s", name); - return true; +function fail() + log("error", "Attempt to use legacy connlisteners API. For more info see http://prosody.im/doc/developers/network"); + log("error", "Legacy connlisteners API usage, %s", traceback("", 2)); end -function deregister(name) - listeners[name] = nil; -end - -function get(name) - local h = listeners[name]; - if not h then - local ok, ret = pcall(dofile, listeners_dir..name:gsub("[^%w%-]", "_").."_listener.lua"); - if not ok then - log("error", "Error while loading listener '%s': %s", tostring(name), tostring(ret)); - return nil, ret; - end - h = listeners[name]; - end - return h; -end - -function start(name, udata) - local h, err = get(name); - if not h then - error("No such connection module: "..name.. (err and (" ("..err..")") or ""), 0); - end - - local interface = (udata and udata.interface) or h.default_interface or "*"; - local port = (udata and udata.port) or h.default_port or error("Can't start listener "..name.." because no port was specified, and it has no default port", 0); - local mode = (udata and udata.mode) or h.default_mode or 1; - local ssl = (udata and udata.ssl) or nil; - local autossl = udata and udata.type == "ssl"; - - if autossl and not ssl then - return nil, "no ssl context"; - end - - return server.addserver(interface, port, h, mode, autossl and ssl or nil); -end +register, deregister = fail, fail; +get, start = fail, fail, epic_fail; return _M; diff --git a/net/dns.lua b/net/dns.lua index 61fb62e8..bd5c260e 100644 --- a/net/dns.lua +++ b/net/dns.lua @@ -14,6 +14,7 @@ local socket = require "socket"; local timer = require "util.timer"; +local new_ip = require "util.ip".new_ip; local _, windows = pcall(require, "util.windows"); local is_windows = (_ and windows) or os.getenv("WINDIR"); @@ -158,8 +159,6 @@ resolver.__index = resolver; resolver.timeout = default_timeout; -local SRV_tostring; - local function default_rr_tostring(rr) local rr_val = rr.type and rr[rr.type:lower()]; if type(rr_val) ~= "string" then @@ -170,8 +169,13 @@ end local special_tostrings = { LOC = resolver.LOC_tostring; - MX = function (rr) return string.format('%2i %s', rr.pref, rr.mx); end; - SRV = SRV_tostring; + MX = function (rr) + return string.format('%2i %s', rr.pref, rr.mx); + end; + SRV = function (rr) + local s = rr.srv; + return string.format('%5d %5d %5d %s', s.priority, s.weight, s.port, s.target); + end; }; local rr_metatable = {}; -- - - - - - - - - - - - - - - - - - - rr_metatable @@ -220,7 +224,7 @@ end function dns.random(...) -- - - - - - - - - - - - - - - - - - - dns.random - math.randomseed(math.floor(10000*socket.gettime())); + math.randomseed(math.floor(10000*socket.gettime()) % 0x100000000); dns.random = math.random; return dns.random(...); end @@ -355,6 +359,7 @@ function resolver:name() -- - - - - - - - - - - - - - - - - - - - - - name local remember, pointers = nil, 0; local len = self:byte(); local n = {}; + if len == 0 then return "." end -- Root label while len > 0 do if len >= 0xc0 then -- name is "compressed" pointers = pointers + 1; @@ -386,6 +391,25 @@ function resolver:A(rr) -- - - - - - - - - - - - - - - - - - - - - - - - A rr.a = string.format('%i.%i.%i.%i', b1, b2, b3, b4); end +function resolver:AAAA(rr) + local addr = {}; + for i = 1, rr.rdlength, 2 do + local b1, b2 = self:byte(2); + table.insert(addr, ("%02x%02x"):format(b1, b2)); + end + addr = table.concat(addr, ":"):gsub("%f[%x]0+(%x)","%1"); + local zeros = {}; + for item in addr:gmatch(":[0:]+:") do + table.insert(zeros, item) + end + if #zeros == 0 then + rr.aaaa = addr; + return + elseif #zeros > 1 then + table.sort(zeros, function(a, b) return #a > #b end); + end + rr.aaaa = addr:gsub(zeros[1], "::", 1):gsub("^0::", "::"):gsub("::0$", "::"); +end function resolver:CNAME(rr) -- - - - - - - - - - - - - - - - - - - - CNAME rr.cname = self:name(); @@ -475,14 +499,8 @@ function resolver:PTR(rr) rr.ptr = self:name(); end -function SRV_tostring(rr) -- - - - - - - - - - - - - - - - - - SRV_tostring - local s = rr.srv; - return string.format( '%5d %5d %5d %s', s.priority, s.weight, s.port, s.target ); -end - - function resolver:TXT(rr) -- - - - - - - - - - - - - - - - - - - - - - TXT - rr.txt = self:sub (rr.rdlength); + rr.txt = self:sub (self:byte()); end @@ -532,6 +550,7 @@ function resolver:decode(packet, force) -- - - - - - - - - - - - - - decode if not force then if not self.active[response.header.id] or not self.active[response.header.id][response.question.raw] then + self.active[response.header.id] = nil; return nil; end end @@ -579,11 +598,12 @@ function resolver:adddefaultnameservers() -- - - - - adddefaultnameservers if resolv_conf then for line in resolv_conf:lines() do line = line:gsub("#.*$", "") - :match('^%s*nameserver%s+(.*)%s*$'); + :match('^%s*nameserver%s+([%x:%.]*)%s*$'); if line then - line:gsub("%f[%d.](%d+%.%d+%.%d+%.%d+)%f[^%d.]", function (address) - self:addnameserver(address) - end); + local ip = new_ip(line); + if ip then + self:addnameserver(ip.addr); + end end end end @@ -603,15 +623,20 @@ function resolver:getsocket(servernum) -- - - - - - - - - - - - - getsocket if sock then return sock; end local err; - sock, err = socket.udp(); + local peer = self.server[servernum]; + if peer:find(":") then + sock, err = socket.udp6(); + else + sock, err = socket.udp(); + end + if sock and self.socket_wrapper then sock, err = self.socket_wrapper(sock, self); end if not sock then return nil, err; end - if self.socket_wrapper then sock = self.socket_wrapper(sock, self); end sock:settimeout(0); -- todo: attempt to use a random port, fallback to 0 sock:setsockname('*', 0); - sock:setpeername(self.server[servernum], 53); + sock:setpeername(peer, 53); self.socket[servernum] = sock; self.socketset[sock] = servernum; return sock; @@ -625,6 +650,7 @@ function resolver:voidsocket(sock) self.socket[self.socketset[sock]] = nil; self.socketset[sock] = nil; end + sock:close(); end function resolver:socket_wrapper_set(func) -- - - - - - - socket_wrapper_set @@ -689,7 +715,7 @@ function resolver:purge(soft) -- - - - - - - - - - - - - - - - - - - purge end end end - else self.cache = {}; end + else self.cache = setmetatable({}, cache_metatable); end end @@ -727,7 +753,7 @@ function resolver:query(qname, qtype, qclass) -- - - - - - - - - - -- query return nil, err; end conn:send (o.packet) - + if timer and self.timeout then local num_servers = #self.server; local i = 1; @@ -779,6 +805,9 @@ function resolver:servfail(sock) end end end + if next(queries) == nil then + self.active[id] = nil; + end end if num == self.best_server then @@ -820,7 +849,7 @@ function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive -- retire the query local queries = self.active[response.header.id]; queries[response.question.raw] = nil; - + if not next(queries) then self.active[response.header.id] = nil; end if not next(self.active) then self:closeall(); end @@ -835,6 +864,7 @@ function resolver:receive(rset) -- - - - - - - - - - - - - - - - - receive set(self.wanted, q.class, q.type, q.name, nil); end end + end end end @@ -1049,6 +1079,10 @@ function dns.settimeout(...) return _resolver:settimeout(...); end +function dns.cache() + return _resolver.cache; +end + function dns.socket_wrapper_set(...) -- - - - - - - - - socket_wrapper_set return _resolver:socket_wrapper_set(...); end diff --git a/net/http.lua b/net/http.lua index 6c8e0a68..5ec3163c 100644 --- a/net/http.lua +++ b/net/http.lua @@ -1,64 +1,94 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- - local socket = require "socket" -local mime = require "mime" +local b64 = require "util.encodings".base64.encode; local url = require "socket.url" -local httpstream_new = require "util.httpstream".new; +local httpstream_new = require "net.http.parser".new; +local util_http = require "util.http"; -local server = require "net.server" +local ssl_available = pcall(require, "ssl"); -local connlisteners_get = require "net.connlisteners".get; -local listener = connlisteners_get("httpclient") or error("No httpclient listener!"); +local server = require "net.server" local t_insert, t_concat = table.insert, table.concat; -local pairs, ipairs = pairs, ipairs; -local tonumber, tostring, xpcall, select, debug_traceback, char, format = - tonumber, tostring, xpcall, select, debug.traceback, string.char, string.format; +local pairs = pairs; +local tonumber, tostring, xpcall, select, traceback = + tonumber, tostring, xpcall, select, debug.traceback; local log = require "util.logger".init("http"); module "http" -function urlencode(s) return s and (s:gsub("%W", function (c) return format("%%%02x", c:byte()); end)); end -function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end)); end +local requests = {}; -- Open requests -local function _formencodepart(s) - return s and (s:gsub("%W", function (c) - if c ~= " " then - return format("%%%02x", c:byte()); - else - return "+"; - end - end)); +local listener = { default_port = 80, default_mode = "*a" }; + +function listener.onconnect(conn) + local req = requests[conn]; + -- Send the request + local request_line = { req.method or "GET", " ", req.path, " HTTP/1.1\r\n" }; + if req.query then + t_insert(request_line, 4, "?"..req.query); + end + + conn:write(t_concat(request_line)); + local t = { [2] = ": ", [4] = "\r\n" }; + for k, v in pairs(req.headers) do + t[1], t[3] = k, v; + conn:write(t_concat(t)); + end + conn:write("\r\n"); + + if req.body then + conn:write(req.body); + end +end + +function listener.onincoming(conn, data) + local request = requests[conn]; + + if not request then + log("warn", "Received response from connection %s with no request attached!", tostring(conn)); + return; + end + + if data and request.reader then + request:reader(data); + end end -function formencode(form) - local result = {}; - for _, field in ipairs(form) do - t_insert(result, _formencodepart(field.name).."=".._formencodepart(field.value)); + +function listener.ondisconnect(conn, err) + local request = requests[conn]; + if request and request.conn then + request:reader(nil, err); end - return t_concat(result, "&"); + requests[conn] = nil; end -local function request_reader(request, data, startpos) +local function request_reader(request, data, err) if not request.parser then - local function success_cb(r) + local function error_cb(reason) if request.callback then - for k,v in pairs(r) do request[k] = v; end - request.callback(r.body, r.code, request); + request.callback(reason or "connection-closed", 0, request); request.callback = nil; end destroy_request(request); end - local function error_cb(r) + + if not data then + error_cb(err); + return; + end + + local function success_cb(r) if request.callback then - request.callback(r or "connection-closed", 0, request); + request.callback(r.body, r.code, r, request); request.callback = nil; end destroy_request(request); @@ -71,82 +101,86 @@ local function request_reader(request, data, startpos) request.parser:feed(data); end -local function handleerr(err) log("error", "Traceback[http]: %s: %s", tostring(err), debug_traceback()); end +local function handleerr(err) log("error", "Traceback[http]: %s", traceback(tostring(err), 2)); end function request(u, ex, callback) local req = url.parse(u); - + if not (req and req.host) then callback(nil, 0, req); return nil, "invalid-url"; end - + if not req.path then req.path = "/"; end - - local custom_headers, body; - local default_headers = { ["Host"] = req.host, ["User-Agent"] = "Prosody XMPP Server" } - - + + local method, headers, body; + + local host, port = req.host, req.port; + local host_header = host; + if (port == "80" and req.scheme == "http") + or (port == "443" and req.scheme == "https") then + port = nil; + elseif port then + host_header = host_header..":"..port; + end + + headers = { + ["Host"] = host_header; + ["User-Agent"] = "Prosody XMPP Server"; + }; + if req.userinfo then - default_headers["Authorization"] = "Basic "..mime.b64(req.userinfo); + headers["Authorization"] = "Basic "..b64(req.userinfo); end - + if ex then - custom_headers = ex.headers; req.onlystatus = ex.onlystatus; body = ex.body; if body then - req.method = "POST "; - default_headers["Content-Length"] = tostring(#body); - default_headers["Content-Type"] = "application/x-www-form-urlencoded"; + method = "POST"; + headers["Content-Length"] = tostring(#body); + headers["Content-Type"] = "application/x-www-form-urlencoded"; + end + if ex.method then method = ex.method; end + if ex.headers then + for k, v in pairs(ex.headers) do + headers[k] = v; + end end - if ex.method then req.method = ex.method; end end - - req.handler, req.conn = server.wrapclient(socket.tcp(), req.host, req.port or 80, listener, "*a"); - req.write = function (...) return req.handler:write(...); end - req.conn:settimeout(0); - local ok, err = req.conn:connect(req.host, req.port or 80); + + -- Attach to request object + req.method, req.headers, req.body = method, headers, body; + + local using_https = req.scheme == "https"; + if using_https and not ssl_available then + error("SSL not available, unable to contact https URL"); + end + local port_number = port and tonumber(port) or (using_https and 443 or 80); + + -- Connect the socket, and wrap it with net.server + local conn = socket.tcp(); + conn:settimeout(10); + local ok, err = conn:connect(host, port_number); if not ok and err ~= "timeout" then callback(nil, 0, req); return nil, err; end - - local request_line = { req.method or "GET", " ", req.path, " HTTP/1.1\r\n" }; - - if req.query then - t_insert(request_line, 4, "?"); - t_insert(request_line, 5, req.query); - end - - req.write(t_concat(request_line)); - local t = { [2] = ": ", [4] = "\r\n" }; - if custom_headers then - for k, v in pairs(custom_headers) do - t[1], t[3] = k, v; - req.write(t_concat(t)); - default_headers[k] = nil; - end - end - - for k, v in pairs(default_headers) do - t[1], t[3] = k, v; - req.write(t_concat(t)); - default_headers[k] = nil; - end - req.write("\r\n"); - - if body then - req.write(body); + + local sslctx = false; + if using_https then + sslctx = ex and ex.sslctx or { mode = "client", protocol = "sslv23", options = { "no_sslv2" } }; end - - req.callback = function (content, code, request) log("debug", "Calling callback, status %s", code or "---"); return select(2, xpcall(function () return callback(content, code, request) end, handleerr)); end + + req.handler, req.conn = server.wrapclient(conn, host, port_number, listener, "*a", sslctx); + req.write = function (...) return req.handler:write(...); end + + req.callback = function (content, code, request, response) log("debug", "Calling callback, status %s", code or "---"); return select(2, xpcall(function () return callback(content, code, request, response) end, handleerr)); end req.reader = request_reader; req.state = "status"; - - listener.register_request(req.handler, req); + requests[req.handler] = req; return req; end @@ -154,10 +188,13 @@ function destroy_request(request) if request.conn then request.conn = nil; request.handler:close() - listener.ondisconnect(request.handler, "closed"); end end -_M.urlencode = urlencode; +local urlencode, urldecode = util_http.urlencode, util_http.urldecode; +local formencode, formdecode = util_http.formencode, util_http.formdecode; + +_M.urlencode, _M.urldecode = urlencode, urldecode; +_M.formencode, _M.formdecode = formencode, formdecode; return _M; diff --git a/net/http/codes.lua b/net/http/codes.lua new file mode 100644 index 00000000..0cadd079 --- /dev/null +++ b/net/http/codes.lua @@ -0,0 +1,67 @@ + +local response_codes = { + -- Source: http://www.iana.org/assignments/http-status-codes + -- s/^\(\d*\)\s*\(.*\S\)\s*\[RFC.*\]\s*$/^I["\1"] = "\2"; + [100] = "Continue"; + [101] = "Switching Protocols"; + [102] = "Processing"; + + [200] = "OK"; + [201] = "Created"; + [202] = "Accepted"; + [203] = "Non-Authoritative Information"; + [204] = "No Content"; + [205] = "Reset Content"; + [206] = "Partial Content"; + [207] = "Multi-Status"; + [208] = "Already Reported"; + [226] = "IM Used"; + + [300] = "Multiple Choices"; + [301] = "Moved Permanently"; + [302] = "Found"; + [303] = "See Other"; + [304] = "Not Modified"; + [305] = "Use Proxy"; + -- The 306 status code was used in a previous version of [RFC2616], is no longer used, and the code is reserved. + [307] = "Temporary Redirect"; + + [400] = "Bad Request"; + [401] = "Unauthorized"; + [402] = "Payment Required"; + [403] = "Forbidden"; + [404] = "Not Found"; + [405] = "Method Not Allowed"; + [406] = "Not Acceptable"; + [407] = "Proxy Authentication Required"; + [408] = "Request Timeout"; + [409] = "Conflict"; + [410] = "Gone"; + [411] = "Length Required"; + [412] = "Precondition Failed"; + [413] = "Request Entity Too Large"; + [414] = "Request-URI Too Long"; + [415] = "Unsupported Media Type"; + [416] = "Requested Range Not Satisfiable"; + [417] = "Expectation Failed"; + [418] = "I'm a teapot"; + [422] = "Unprocessable Entity"; + [423] = "Locked"; + [424] = "Failed Dependency"; + -- The 425 status code is reserved for the WebDAV advanced collections expired proposal [RFC2817] + [426] = "Upgrade Required"; + + [500] = "Internal Server Error"; + [501] = "Not Implemented"; + [502] = "Bad Gateway"; + [503] = "Service Unavailable"; + [504] = "Gateway Timeout"; + [505] = "HTTP Version Not Supported"; + [506] = "Variant Also Negotiates"; -- Experimental + [507] = "Insufficient Storage"; + [508] = "Loop Detected"; + [510] = "Not Extended"; +}; + +for k,v in pairs(response_codes) do response_codes[k] = k.." "..v; end +return setmetatable(response_codes, { __index = function(t, k) return k.." Unassigned"; end }) diff --git a/net/http/parser.lua b/net/http/parser.lua new file mode 100644 index 00000000..f9e6cea0 --- /dev/null +++ b/net/http/parser.lua @@ -0,0 +1,160 @@ +local tonumber = tonumber; +local assert = assert; +local url_parse = require "socket.url".parse; +local urldecode = require "util.http".urldecode; + +local function preprocess_path(path) + path = urldecode((path:gsub("//+", "/"))); + if path:sub(1,1) ~= "/" then + path = "/"..path; + end + local level = 0; + for component in path:gmatch("([^/]+)/") do + if component == ".." then + level = level - 1; + elseif component ~= "." then + level = level + 1; + end + if level < 0 then + return nil; + end + end + return path; +end + +local httpstream = {}; + +function httpstream.new(success_cb, error_cb, parser_type, options_cb) + local client = true; + if not parser_type or parser_type == "server" then client = false; else assert(parser_type == "client", "Invalid parser type"); end + local buf = ""; + local chunked, chunk_size, chunk_start; + local state = nil; + local packet; + local len; + local have_body; + local error; + return { + feed = function(self, data) + if error then return nil, "parse has failed"; end + if not data then -- EOF + if state and client and not len then -- reading client body until EOF + packet.body = buf; + success_cb(packet); + elseif buf ~= "" then -- unexpected EOF + error = true; return error_cb(); + end + return; + end + buf = buf..data; + while #buf > 0 do + if state == nil then -- read request + local index = buf:find("\r\n\r\n", nil, true); + if not index then return; end -- not enough data + local method, path, httpversion, status_code, reason_phrase; + local first_line; + local headers = {}; + for line in buf:sub(1,index+1):gmatch("([^\r\n]+)\r\n") do -- parse request + if first_line then + local key, val = line:match("^([^%s:]+): *(.*)$"); + if not key then error = true; return error_cb("invalid-header-line"); end -- TODO handle multi-line and invalid headers + key = key:lower(); + headers[key] = headers[key] and headers[key]..","..val or val; + else + first_line = line; + if client then + httpversion, status_code, reason_phrase = line:match("^HTTP/(1%.[01]) (%d%d%d) (.*)$"); + status_code = tonumber(status_code); + if not status_code then error = true; return error_cb("invalid-status-line"); end + have_body = not + ( (options_cb and options_cb().method == "HEAD") + or (status_code == 204 or status_code == 304 or status_code == 301) + or (status_code >= 100 and status_code < 200) ); + else + method, path, httpversion = line:match("^(%w+) (%S+) HTTP/(1%.[01])$"); + if not method then error = true; return error_cb("invalid-status-line"); end + end + end + end + if not first_line then error = true; return error_cb("invalid-status-line"); end + chunked = have_body and headers["transfer-encoding"] == "chunked"; + len = tonumber(headers["content-length"]); -- TODO check for invalid len + if client then + -- FIXME handle '100 Continue' response (by skipping it) + if not have_body then len = 0; end + packet = { + code = status_code; + httpversion = httpversion; + headers = headers; + body = have_body and "" or nil; + -- COMPAT the properties below are deprecated + responseversion = httpversion; + responseheaders = headers; + }; + else + local parsed_url; + if path:byte() == 47 then -- starts with / + local _path, _query = path:match("([^?]*).?(.*)"); + if _query == "" then _query = nil; end + parsed_url = { path = _path, query = _query }; + else + parsed_url = url_parse(path); + if not(parsed_url and parsed_url.path) then error = true; return error_cb("invalid-url"); end + end + path = preprocess_path(parsed_url.path); + headers.host = parsed_url.host or headers.host; + + len = len or 0; + packet = { + method = method; + url = parsed_url; + path = path; + httpversion = httpversion; + headers = headers; + body = nil; + }; + end + buf = buf:sub(index + 4); + state = true; + end + if state then -- read body + if client then + if chunked then + if not buf:find("\r\n", nil, true) then + return; + end -- not enough data + if not chunk_size then + chunk_size, chunk_start = buf:match("^(%x+)[^\r\n]*\r\n()"); + chunk_size = chunk_size and tonumber(chunk_size, 16); + if not chunk_size then error = true; return error_cb("invalid-chunk-size"); end + end + if chunk_size == 0 and buf:find("\r\n\r\n", chunk_start-2, true) then + state, chunk_size = nil, nil; + buf = buf:gsub("^.-\r\n\r\n", ""); -- This ensure extensions and trailers are stripped + success_cb(packet); + elseif #buf - chunk_start + 2 >= chunk_size then -- we have a chunk + packet.body = packet.body..buf:sub(chunk_start, chunk_start + (chunk_size-1)); + buf = buf:sub(chunk_start + chunk_size + 2); + chunk_size, chunk_start = nil, nil; + else -- Partial chunk remaining + break; + end + elseif len and #buf >= len then + packet.body, buf = buf:sub(1, len), buf:sub(len + 1); + state = nil; success_cb(packet); + else + break; + end + elseif #buf >= len then + packet.body, buf = buf:sub(1, len), buf:sub(len + 1); + state = nil; success_cb(packet); + else + break; + end + end + end + end; + }; +end + +return httpstream; diff --git a/net/http/server.lua b/net/http/server.lua new file mode 100644 index 00000000..5961169f --- /dev/null +++ b/net/http/server.lua @@ -0,0 +1,303 @@ + +local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat; +local parser_new = require "net.http.parser".new; +local events = require "util.events".new(); +local addserver = require "net.server".addserver; +local log = require "util.logger".init("http.server"); +local os_date = os.date; +local pairs = pairs; +local s_upper = string.upper; +local setmetatable = setmetatable; +local xpcall = xpcall; +local traceback = debug.traceback; +local tostring = tostring; +local codes = require "net.http.codes"; + +local _M = {}; + +local sessions = {}; +local listener = {}; +local hosts = {}; +local default_host; + +local function is_wildcard_event(event) + return event:sub(-2, -1) == "/*"; +end +local function is_wildcard_match(wildcard_event, event) + return wildcard_event:sub(1, -2) == event:sub(1, #wildcard_event-1); +end + +local recent_wildcard_events, max_cached_wildcard_events = {}, 10000; + +local event_map = events._event_map; +setmetatable(events._handlers, { + -- Called when firing an event that doesn't exist (but may match a wildcard handler) + __index = function (handlers, curr_event) + if is_wildcard_event(curr_event) then return; end -- Wildcard events cannot be fired + -- Find all handlers that could match this event, sort them + -- and then put the array into handlers[curr_event] (and return it) + local matching_handlers_set = {}; + local handlers_array = {}; + for event, handlers_set in pairs(event_map) do + if event == curr_event or + is_wildcard_event(event) and is_wildcard_match(event, curr_event) then + for handler, priority in pairs(handlers_set) do + matching_handlers_set[handler] = { (select(2, event:gsub("/", "%1"))), is_wildcard_event(event) and 0 or 1, priority }; + table.insert(handlers_array, handler); + end + end + end + if #handlers_array > 0 then + table.sort(handlers_array, function(b, a) + local a_score, b_score = matching_handlers_set[a], matching_handlers_set[b]; + for i = 1, #a_score do + if a_score[i] ~= b_score[i] then -- If equal, compare next score value + return a_score[i] < b_score[i]; + end + end + return false; + end); + else + handlers_array = false; + end + rawset(handlers, curr_event, handlers_array); + if not event_map[curr_event] then -- Only wildcard handlers match, if any + table.insert(recent_wildcard_events, curr_event); + if #recent_wildcard_events > max_cached_wildcard_events then + rawset(handlers, table.remove(recent_wildcard_events, 1), nil); + end + end + return handlers_array; + end; + __newindex = function (handlers, curr_event, handlers_array) + if handlers_array == nil + and is_wildcard_event(curr_event) then + -- Invalidate the indexes of all matching events + for event in pairs(handlers) do + if is_wildcard_match(curr_event, event) then + handlers[event] = nil; + end + end + end + rawset(handlers, curr_event, handlers_array); + end; +}); + +local handle_request; +local _1, _2, _3; +local function _handle_request() return handle_request(_1, _2, _3); end + +local last_err; +local function _traceback_handler(err) last_err = err; log("error", "Traceback[httpserver]: %s", traceback(tostring(err), 2)); end +events.add_handler("http-error", function (error) + return "Error processing request: "..codes[error.code]..". Check your error log for more information."; +end, -1); + +function listener.onconnect(conn) + local secure = conn:ssl() and true or nil; + local pending = {}; + local waiting = false; + local function process_next() + if waiting then log("debug", "can't process_next, waiting"); return; end + waiting = true; + while sessions[conn] and #pending > 0 do + local request = t_remove(pending); + --log("debug", "process_next: %s", request.path); + --handle_request(conn, request, process_next); + _1, _2, _3 = conn, request, process_next; + if not xpcall(_handle_request, _traceback_handler) then + conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = last_err })); + conn:close(); + end + end + --log("debug", "ready for more"); + waiting = false; + end + local function success_cb(request) + --log("debug", "success_cb: %s", request.path); + if waiting then + log("error", "http connection handler is not reentrant: %s", request.path); + assert(false, "http connection handler is not reentrant"); + end + request.secure = secure; + t_insert(pending, request); + process_next(); + end + local function error_cb(err) + log("debug", "error_cb: %s", err or "<nil>"); + -- FIXME don't close immediately, wait until we process current stuff + -- FIXME if err, send off a bad-request response + sessions[conn] = nil; + conn:close(); + end + sessions[conn] = parser_new(success_cb, error_cb); +end + +function listener.ondisconnect(conn) + local open_response = conn._http_open_response; + if open_response and open_response.on_destroy then + open_response.finished = true; + open_response:on_destroy(); + end + sessions[conn] = nil; +end + +function listener.onincoming(conn, data) + sessions[conn]:feed(data); +end + +local headerfix = setmetatable({}, { + __index = function(t, k) + local v = "\r\n"..k:gsub("_", "-"):gsub("%f[%w].", s_upper)..": "; + t[k] = v; + return v; + end +}); + +function _M.hijack_response(response, listener) + error("TODO"); +end +function handle_request(conn, request, finish_cb) + --log("debug", "handler: %s", request.path); + local headers = {}; + for k,v in pairs(request.headers) do headers[k:gsub("-", "_")] = v; end + request.headers = headers; + request.conn = conn; + + local date_header = os_date('!%a, %d %b %Y %H:%M:%S GMT'); -- FIXME use + local conn_header = request.headers.connection; + conn_header = conn_header and ","..conn_header:gsub("[ \t]", ""):lower().."," or "" + local httpversion = request.httpversion + local persistent = conn_header:find(",keep-alive,", 1, true) + or (httpversion == "1.1" and not conn_header:find(",close,", 1, true)); + + local response_conn_header; + if persistent then + response_conn_header = "Keep-Alive"; + else + response_conn_header = httpversion == "1.1" and "close" or nil + end + + local response = { + request = request; + status_code = 200; + headers = { date = date_header, connection = response_conn_header }; + persistent = persistent; + conn = conn; + send = _M.send_response; + finish_cb = finish_cb; + }; + conn._http_open_response = response; + + local host = (request.headers.host or ""):match("[^:]+"); + + -- Some sanity checking + local err_code, err; + if not request.path then + err_code, err = 400, "Invalid path"; + elseif not hosts[host] then + if hosts[default_host] then + host = default_host; + elseif host then + err_code, err = 404, "Unknown host: "..host; + else + err_code, err = 400, "Missing or invalid 'Host' header"; + end + end + + if err then + response.status_code = err_code; + response:send(events.fire_event("http-error", { code = err_code, message = err })); + return; + end + + local event = request.method.." "..host..request.path:match("[^?]*"); + local payload = { request = request, response = response }; + --log("debug", "Firing event: %s", event); + local result = events.fire_event(event, payload); + if result ~= nil then + if result ~= true then + local body; + local result_type = type(result); + if result_type == "number" then + response.status_code = result; + if result >= 400 then + body = events.fire_event("http-error", { code = result }); + end + elseif result_type == "string" then + body = result; + elseif result_type == "table" then + for k, v in pairs(result) do + if k ~= "headers" then + response[k] = v; + else + for header_name, header_value in pairs(v) do + response.headers[header_name] = header_value; + end + end + end + end + response:send(body); + end + return; + end + + -- if handler not called, return 404 + response.status_code = 404; + response:send(events.fire_event("http-error", { code = 404 })); +end +function _M.send_response(response, body) + if response.finished then return; end + response.finished = true; + response.conn._http_open_response = nil; + + local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]); + local headers = response.headers; + body = body or response.body or ""; + headers.content_length = #body; + + local output = { status_line }; + for k,v in pairs(headers) do + t_insert(output, headerfix[k]..v); + end + t_insert(output, "\r\n\r\n"); + t_insert(output, body); + + response.conn:write(t_concat(output)); + if response.on_destroy then + response:on_destroy(); + response.on_destroy = nil; + end + if response.persistent then + response:finish_cb(); + else + response.conn:close(); + end +end +function _M.add_handler(event, handler, priority) + events.add_handler(event, handler, priority); +end +function _M.remove_handler(event, handler) + events.remove_handler(event, handler); +end + +function _M.listen_on(port, interface, ssl) + addserver(interface or "*", port, listener, "*a", ssl); +end +function _M.add_host(host) + hosts[host] = true; +end +function _M.remove_host(host) + hosts[host] = nil; +end +function _M.set_default_host(host) + default_host = host; +end +function _M.fire_event(event, ...) + return events.fire_event(event, ...); +end + +_M.listener = listener; +_M.codes = codes; +_M._events = events; +return _M; diff --git a/net/httpclient_listener.lua b/net/httpclient_listener.lua deleted file mode 100644 index dfa25062..00000000 --- a/net/httpclient_listener.lua +++ /dev/null @@ -1,44 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - -local log = require "util.logger".init("httpclient_listener"); - -local connlisteners_register = require "net.connlisteners".register; - -local requests = {}; -- Open requests -local buffers = {}; -- Buffers of partial lines - -local httpclient = { default_port = 80, default_mode = "*a" }; - -function httpclient.onincoming(conn, data) - local request = requests[conn]; - - if not request then - log("warn", "Received response from connection %s with no request attached!", tostring(conn)); - return; - end - - if data and request.reader then - request:reader(data); - end -end - -function httpclient.ondisconnect(conn, err) - local request = requests[conn]; - if request and err ~= "closed" then - request:reader(nil); - end - requests[conn] = nil; -end - -function httpclient.register_request(conn, req) - log("debug", "Attaching request %s to connection %s", tostring(req.id or req), tostring(conn)); - requests[conn] = req; -end - -connlisteners_register("httpclient", httpclient); diff --git a/net/httpserver.lua b/net/httpserver.lua index 4c1200ac..7d574788 100644 --- a/net/httpserver.lua +++ b/net/httpserver.lua @@ -1,226 +1,15 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - -local socket = require "socket" -local server = require "net.server" -local url_parse = require "socket.url".parse; -local httpstream_new = require "util.httpstream".new; - -local connlisteners_start = require "net.connlisteners".start; -local connlisteners_get = require "net.connlisteners".get; -local listener; - -local t_insert, t_concat = table.insert, table.concat; -local s_match, s_gmatch = string.match, string.gmatch; -local tonumber, tostring, pairs, ipairs, type = tonumber, tostring, pairs, ipairs, type; - -local urlencode = function (s) return s and (s:gsub("%W", function (c) return string.format("%%%02x", c:byte()); end)); end - -local log = require "util.logger".init("httpserver"); - -local http_servers = {}; +-- COMPAT w/pre-0.9 +local log = require "util.logger".init("net.httpserver"); +local traceback = debug.traceback; module "httpserver" -local default_handler; - -local function expectbody(reqt) - return reqt.method == "POST"; -end - -local function send_response(request, response) - -- Write status line - local resp; - if response.body or response.headers then - local body = response.body and tostring(response.body); - log("debug", "Sending response to %s", request.id); - resp = { "HTTP/1.0 "..(response.status or "200 OK").."\r\n" }; - local h = response.headers; - if h then - for k, v in pairs(h) do - t_insert(resp, k..": "..v.."\r\n"); - end - end - if body and not (h and h["Content-Length"]) then - t_insert(resp, "Content-Length: "..#body.."\r\n"); - end - t_insert(resp, "\r\n"); - - if body and request.method ~= "HEAD" then - t_insert(resp, body); - end - request.write(t_concat(resp)); - else - -- Response we have is just a string (the body) - log("debug", "Sending 200 response to %s", request.id or "<none>"); - - local resp = "HTTP/1.0 200 OK\r\n" - .. "Connection: close\r\n" - .. "Content-Type: text/html\r\n" - .. "Content-Length: "..#response.."\r\n" - .. "\r\n" - .. response; - - request.write(resp); - end - if not request.stayopen then - request:destroy(); - end -end - -local function call_callback(request, err) - if request.handled then return; end - request.handled = true; - local callback = request.callback; - if not callback and request.path then - local path = request.url.path; - local base = path:match("^/([^/?]+)"); - if not base then - base = path:match("^http://[^/?]+/([^/?]+)"); - end - - callback = (request.server and request.server.handlers[base]) or default_handler; - end - if callback then - if err then - log("debug", "Request error: "..err); - if not callback(nil, err, request) then - destroy_request(request); - end - return; - end - - local response = callback(request.method, request.body and t_concat(request.body), request); - if response then - if response == true and not request.destroyed then - -- Keep connection open, we will reply later - log("debug", "Request %s left open, on_destroy is %s", request.id, tostring(request.on_destroy)); - elseif response ~= true then - -- Assume response - send_response(request, response); - destroy_request(request); - end - else - log("debug", "Request handler provided no response, destroying request..."); - -- No response, close connection - destroy_request(request); - end - end -end - -local function request_reader(request, data, startpos) - if not request.parser then - local function success_cb(r) - for k,v in pairs(r) do request[k] = v; end - request.url = url_parse(request.path); - request.body = { request.body }; - call_callback(request); - end - local function error_cb(r) - call_callback(request, r or "connection-closed"); - destroy_request(request); - end - request.parser = httpstream_new(success_cb, error_cb); - end - request.parser:feed(data); -end - --- The default handler for requests -default_handler = function (method, body, request) - log("debug", method.." request for "..tostring(request.path) .. " on port "..request.handler:serverport()); - return { status = "404 Not Found", - headers = { ["Content-Type"] = "text/html" }, - body = "<html><head><title>Page Not Found</title></head><body>Not here :(</body></html>" }; -end - - -function new_request(handler) - return { handler = handler, conn = handler, - write = function (...) return handler:write(...); end, state = "request", - server = http_servers[handler:serverport()], - send = send_response, - destroy = destroy_request, - id = tostring{}:match("%x+$") - }; -end - -function destroy_request(request) - log("debug", "Destroying request %s", request.id); - listener = listener or connlisteners_get("httpserver"); - if not request.destroyed then - request.destroyed = true; - if request.on_destroy then - log("debug", "Request has destroy callback"); - request.on_destroy(request); - else - log("debug", "Request has no destroy callback"); - end - request.handler:close() - if request.conn then - listener.ondisconnect(request.conn, "closed"); - end - end -end - -function new(params) - local http_server = http_servers[params.port]; - if not http_server then - http_server = { handlers = {} }; - http_servers[params.port] = http_server; - -- We weren't already listening on this port, so start now - connlisteners_start("httpserver", params); - end - if params.base then - http_server.handlers[params.base] = params.handler; - end -end - -function set_default_handler(handler) - default_handler = handler; -end - -function new_from_config(ports, handle_request, default_options) - if type(handle_request) == "string" then -- COMPAT with old plugins - log("warn", "Old syntax of httpserver.new_from_config being used to register %s", handle_request); - handle_request, default_options = default_options, { base = handle_request }; - end - ports = ports or {5280}; - for _, options in ipairs(ports) do - local port = default_options.port or 5280; - local base = default_options.base; - local ssl = default_options.ssl or false; - local interface = default_options.interface; - if type(options) == "number" then - port = options; - elseif type(options) == "table" then - port = options.port or port; - base = options.path or base; - ssl = options.ssl or ssl; - interface = options.interface or interface; - elseif type(options) == "string" then - base = options; - end - - if ssl then - ssl.mode = "server"; - ssl.protocol = "sslv23"; - ssl.options = "no_sslv2"; - end - - new{ port = port, interface = interface, - base = base, handler = handle_request, - ssl = ssl, type = (ssl and "ssl") or "tcp" }; - end +function fail() + log("error", "Attempt to use legacy HTTP API. For more info see http://prosody.im/doc/developers/legacy_http"); + log("error", "Legacy HTTP API usage, %s", traceback("", 2)); end -_M.request_reader = request_reader; -_M.send_response = send_response; -_M.urlencode = urlencode; +new, new_from_config = fail, fail; +set_default_handler = fail; return _M; diff --git a/net/httpserver_listener.lua b/net/httpserver_listener.lua deleted file mode 100644 index dd14b43c..00000000 --- a/net/httpserver_listener.lua +++ /dev/null @@ -1,46 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - - -local connlisteners_register = require "net.connlisteners".register; -local new_request = require "net.httpserver".new_request; -local request_reader = require "net.httpserver".request_reader; - -local requests = {}; -- Open requests - -local httpserver = { default_port = 80, default_mode = "*a" }; - -function httpserver.onincoming(conn, data) - local request = requests[conn]; - - if not request then - request = new_request(conn); - requests[conn] = request; - - -- If using HTTPS, request is secure - if conn:ssl() then - request.secure = true; - end - end - - if data and data ~= "" then - request_reader(request, data); - end -end - -function httpserver.ondisconnect(conn, err) - local request = requests[conn]; - if request and not request.destroyed then - request.conn = nil; - request_reader(request, nil); - end - requests[conn] = nil; -end - -connlisteners_register("httpserver", httpserver); diff --git a/net/multiplex_listener.lua b/net/multiplex_listener.lua deleted file mode 100644 index b515ccce..00000000 --- a/net/multiplex_listener.lua +++ /dev/null @@ -1,50 +0,0 @@ - -local connlisteners_register = require "net.connlisteners".register; -local connlisteners_get = require "net.connlisteners".get; - -local httpserver_listener = connlisteners_get("httpserver"); -local xmppserver_listener = connlisteners_get("xmppserver"); -local xmppclient_listener = connlisteners_get("xmppclient"); -local xmppcomponent_listener = connlisteners_get("xmppcomponent"); - -local server = { default_mode = "*a" }; - -local buffer = {}; - -function server.onincoming(conn, data) - if not data then return; end - local buf = buffer[conn]; - buffer[conn] = nil; - buf = buf and buf..data or data; - if buf:match("^[a-zA-Z]") then - local listener = httpserver_listener; - conn:setlistener(listener); - local onconnect = listener.onconnect; - if onconnect then onconnect(conn) end - listener.onincoming(conn, buf); - elseif buf:match(">") then - local listener; - local xmlns = buf:match("%sxmlns%s*=%s*['\"]([^'\"]*)"); - if xmlns == "jabber:server" then - listener = xmppserver_listener; - elseif xmlns == "jabber:component:accept" then - listener = xmppcomponent_listener; - else - listener = xmppclient_listener; - end - conn:setlistener(listener); - local onconnect = listener.onconnect; - if onconnect then onconnect(conn) end - listener.onincoming(conn, buf); - elseif #buf > 1024 then - conn:close(); - else - buffer[conn] = buf; - end -end - -function server.ondisconnect(conn, err) - buffer[conn] = nil; -- warn if no buffer? -end - -connlisteners_register("multiplex", server); diff --git a/net/server.lua b/net/server.lua index 1c1a63a4..2a0b89ae 100644 --- a/net/server.lua +++ b/net/server.lua @@ -1,12 +1,12 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain --- +-- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- -local use_luaevent = prosody and require "core.configmanager".get("*", "core", "use_libevent"); +local use_luaevent = prosody and require "core.configmanager".get("*", "use_libevent"); if use_luaevent then use_luaevent = pcall(require, "luaevent.core"); @@ -19,18 +19,7 @@ local server; if use_luaevent then server = require "net.server_event"; - -- util.timer requires "net.server", so instead of having - -- Lua look for, and load us again (causing a loop) - set this here - -- (usually it isn't set until we return, look down there...) - package.loaded["net.server"] = server; - - -- Backwards compatibility for timers, addtimer - -- called a function roughly every second - local add_task = require "util.timer".add_task; - function server.addtimer(f) - return add_task(1, function (...) f(...); return 1; end); - end - + -- Overwrite signal.signal() because we need to ask libevent to -- handle them instead local ok, signal = pcall(require, "util.signal"); @@ -47,8 +36,47 @@ if use_luaevent then end end else + use_luaevent = false; server = require "net.server_select"; - package.loaded["net.server"] = server; +end + +if prosody then + local config_get = require "core.configmanager".get; + local defaults = {}; + for k,v in pairs(server.cfg or server.getsettings()) do + defaults[k] = v; + end + local function load_config() + local settings = config_get("*", "network_settings") or {}; + if use_luaevent then + local event_settings = { + ACCEPT_DELAY = settings.event_accept_retry_interval; + ACCEPT_QUEUE = settings.tcp_backlog; + CLEAR_DELAY = settings.event_clear_interval; + CONNECT_TIMEOUT = settings.connect_timeout; + DEBUG = settings.debug; + HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout; + MAX_CONNECTIONS = settings.max_connections; + MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips; + MAX_READ_LENGTH = settings.max_receive_buffer_size; + MAX_SEND_LENGTH = settings.max_send_buffer_size; + READ_TIMEOUT = settings.read_timeout; + WRITE_TIMEOUT = settings.send_timeout; + }; + + for k,default in pairs(defaults) do + server.cfg[k] = event_settings[k] or default; + end + else + local select_settings = {}; + for k,default in pairs(defaults) do + select_settings[k] = settings[k] or default; + end + server.changesettings(select_settings); + end + end + load_config(); + prosody.events.add_handler("config-reloaded", load_config); end -- require "net.server" shall now forever return this, diff --git a/net/server_event.lua b/net/server_event.lua index 122d80fc..59217a0c 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -6,7 +6,6 @@ notes: -- when using luaevent, never register 2 or more EV_READ at one socket, same for EV_WRITE -- you cant even register a new EV_READ/EV_WRITE callback inside another one - -- never call eventcallback:close( ) from inside eventcallback -- to do some of the above, use timeout events or something what will called from outside -- dont let garbagecollect eventcallbacks, as long they are running -- when using luasec, there are 4 cases of timeout errors: wantread or wantwrite during reading or writing @@ -24,6 +23,7 @@ local cfg = { HANDSHAKE_TIMEOUT = 60, -- timeout in seconds per handshake attempt MAX_READ_LENGTH = 1024 * 1024 * 1024 * 1024, -- max bytes allowed to read from sockets MAX_SEND_LENGTH = 1024 * 1024 * 1024 * 1024, -- max bytes size of write buffer (for writing on sockets) + ACCEPT_QUEUE = 128, -- might influence the length of the pending sockets queue ACCEPT_DELAY = 10, -- seconds to wait until the next attempt of a full server to accept READ_TIMEOUT = 60 * 60 * 6, -- timeout in seconds for read data from socket WRITE_TIMEOUT = 180, -- timeout in seconds for write data on socket @@ -33,8 +33,6 @@ local cfg = { } local function use(x) return rawget(_G, x); end -local print = use "print" -local pcall = use "pcall" local ipairs = use "ipairs" local string = use "string" local select = use "select" @@ -43,6 +41,9 @@ local tostring = use "tostring" local coroutine = use "coroutine" local setmetatable = use "setmetatable" +local t_insert = table.insert +local t_concat = table.concat + local ssl = use "ssl" local socket = use "socket" or require "socket" @@ -114,26 +115,19 @@ end )( ) local interface_mt do interface_mt = {}; interface_mt.__index = interface_mt; - + local addevent = base.addevent local coroutine_wrap, coroutine_yield = coroutine.wrap,coroutine.yield - local string_len = string.len - + -- Private methods function interface_mt:_position(new_position) self.position = new_position or self.position return self.position; end - function interface_mt:_close() -- regs event to start self:_destroy() - local callback = function( ) - self:_destroy(); - self.eventclose = nil - return -1 - end - self.eventclose = addevent( base, nil, EV_TIMEOUT, callback, 0 ) - return true + function interface_mt:_close() + return self:_destroy(); end - + function interface_mt:_start_connection(plainssl) -- should be called from addclient local callback = function( event ) if EV_TIMEOUT == event then -- timeout during connection @@ -143,7 +137,7 @@ do debug( "new connection failed. id:", self.id, "error:", self.fatalerror ) else if plainssl and ssl then -- start ssl session - self:starttls(nil, true) + self:starttls(self._sslctx, true) else -- normal connection self:_start_session(true) end @@ -212,7 +206,6 @@ do self:_lock( false, false, false ) -- unlock the interface; sending, closing etc allowed self.send = self.conn.send -- caching table lookups with new client object self.receive = self.conn.receive - local onsomething if not call_onconnect then -- trigger listener self:onstatus("ssl-handshake-complete"); end @@ -221,12 +214,12 @@ do self.eventhandshake = nil return -1 end - debug( "error during ssl handshake:", err ) if err == "wantwrite" then event = EV_WRITE elseif err == "wantread" then event = EV_READ else + debug( "ssl handshake error:", err ) self.fatalerror = err end end @@ -249,10 +242,10 @@ do return true end function interface_mt:_destroy() -- close this interface + events and call last listener - debug( "closing client with id:", self.id ) + debug( "closing client with id:", self.id, self.fatalerror ) self:_lock( true, true, true ) -- first of all, lock the interface to avoid further actions local _ - _ = self.eventread and self.eventread:close( ) -- close events; this must be called outside of the event callbacks! + _ = self.eventread and self.eventread:close( ) if self.type == "client" then _ = self.eventwrite and self.eventwrite:close( ) _ = self.eventhandshake and self.eventhandshake:close( ) @@ -262,7 +255,7 @@ do _ = self.eventwritetimeout and self.eventwritetimeout:close( ) _ = self.eventreadtimeout and self.eventreadtimeout:close( ) _ = self.ondisconnect and self:ondisconnect( self.fatalerror ~= "client to close" and self.fatalerror) -- call ondisconnect listener (wont be the case if handshake failed on connect) - _ = self.conn and self.conn:close( ) -- close connection, must also be called outside of any socket registered events! + _ = self.conn and self.conn:close( ) -- close connection _ = self._server and self._server:counter(-1); self.eventread, self.eventwrite = nil, nil self.eventstarthandshake, self.eventhandshake, self.eventclose = nil, nil, nil @@ -275,12 +268,12 @@ do interfacelist( "delete", self ) return true end - + function interface_mt:_lock(nointerface, noreading, nowriting) -- lock or unlock this interface or events self.nointerface, self.noreading, self.nowriting = nointerface, noreading, nowriting return nointerface, noreading, nowriting end - + --TODO: Deprecate function interface_mt:lock_read(switch) if switch then @@ -295,7 +288,10 @@ do end function interface_mt:resume() - return self:_lock(self.nointerface, false, self.nowriting); + self:_lock(self.nointerface, false, self.nowriting); + if not self.eventread then + self.eventread = addevent( base, self.conn, EV_READ, self.readcallback, cfg.READ_TIMEOUT ); -- register callback + end end function interface_mt:counter(c) @@ -304,20 +300,20 @@ do end return self._connections end - + -- Public methods function interface_mt:write(data) if self.nowriting then return nil, "locked" end --vdebug( "try to send data to client, id/data:", self.id, data ) data = tostring( data ) - local len = string_len( data ) + local len = #data local total = len + self.writebufferlen if total > cfg.MAX_SEND_LENGTH then -- check buffer length local err = "send buffer exceeded" debug( "error:", err ) -- to much, check your app return nil, err end - self.writebuffer = self.writebuffer .. data -- new buffer + t_insert(self.writebuffer, data) -- new buffer self.writebufferlen = total if not self.eventwrite then -- register new write event --vdebug( "register new write event" ) @@ -325,62 +321,49 @@ do end return true end - function interface_mt:close(now) + function interface_mt:close() if self.nointerface then return nil, "locked"; end debug( "try to close client connection with id:", self.id ) if self.type == "client" then self.fatalerror = "client to close" - if ( not self.eventwrite ) or now then -- try to close immediately - self:_lock( true, true, true ) - self:_close() - return true - else -- wait for incomplete write request + if self.eventwrite then -- wait for incomplete write request self:_lock( true, true, false ) debug "closing delayed until writebuffer is empty" return nil, "writebuffer not empty, waiting" + else -- close now + self:_lock( true, true, true ) + self:_close() + return true end else - debug( "try to close server with id:", self.id, "args:", now ) + debug( "try to close server with id:", tostring(self.id)) self.fatalerror = "server to close" self:_lock( true ) - local count = 0 - for _, item in ipairs( interfacelist( ) ) do - if ( item.type ~= "server" ) and ( item._server == self ) then -- client/server match - if item:close( now ) then -- writebuffer was empty - count = count + 1 - end - end - end - local timeout = 0 -- dont wait for unfinished writebuffers of clients... - if not now then - timeout = cfg.WRITE_TIMEOUT -- ...or wait for it - end - self:_close( timeout ) -- add new event to remove the server interface - debug( "seconds remained until server is closed:", timeout ) - return count -- returns finished clients with empty writebuffer + self:_close( 0 ) + return true end end - + function interface_mt:socket() return self.conn end - + function interface_mt:server() return self._server or self; end - + function interface_mt:port() return self._port end - + function interface_mt:serverport() return self._serverport end - + function interface_mt:ip() return self._ip end - + function interface_mt:ssl() return self._usingssl end @@ -388,15 +371,15 @@ do function interface_mt:type() return self._type or "client" end - + function interface_mt:connections() return self._connections end - + function interface_mt:address() return self.addr end - + function interface_mt:set_sslctx(sslctx) self._sslctx = sslctx; if sslctx then @@ -412,11 +395,11 @@ do end return self._pattern; end - + function interface_mt:set_send(new_send) -- No-op, we always use the underlying connection's send end - + function interface_mt:starttls(sslctx, call_onconnect) debug( "try to start ssl at client id:", self.id ) local err @@ -445,22 +428,22 @@ do self.starttls = false; return true end - + function interface_mt:setoption(option, value) if self.conn.setoption then return self.conn:setoption(option, value); end return false, "setoption not implemented"; end - + function interface_mt:setlistener(listener) - self.onconnect, self.ondisconnect, self.onincoming, self.ontimeout, self.onstatus - = listener.onconnect, listener.ondisconnect, listener.onincoming, listener.ontimeout, listener.onstatus; + self.onconnect, self.ondisconnect, self.onincoming, self.ontimeout, self.onreadtimeout, self.onstatus + = listener.onconnect, listener.ondisconnect, listener.onincoming, + listener.ontimeout, listener.onreadtimeout, listener.onstatus; end - + -- Stub handlers function interface_mt:onconnect() - return self:onincoming(nil); end function interface_mt:onincoming() end @@ -468,6 +451,12 @@ do end function interface_mt:ontimeout() end + function interface_mt:onreadtimeout() + self.fatalerror = "timeout during receiving" + debug( "connection failed:", self.fatalerror ) + self:_close() + self.eventread = nil + end function interface_mt:ondrain() end function interface_mt:onstatus() @@ -479,18 +468,15 @@ end local handleclient; do local string_sub = string.sub -- caching table lookups - local string_len = string.len local addevent = base.addevent - local coroutine_wrap = coroutine.wrap local socket_gettime = socket.gettime - local coroutine_yield = coroutine.yield - function handleclient( client, ip, port, server, pattern, listener, _, sslctx ) -- creates an client interface + function handleclient( client, ip, port, server, pattern, listener, sslctx ) -- creates an client interface --vdebug("creating client interfacce...") local interface = { type = "client"; conn = client; currenttime = socket_gettime( ); -- safe the origin - writebuffer = ""; -- writebuffer + writebuffer = {}; -- writebuffer writebufferlen = 0; -- length of writebuffer send = client.send; -- caching table lookups receive = client.receive; @@ -498,6 +484,8 @@ do ondisconnect = listener.ondisconnect; -- will be called when client disconnects onincoming = listener.onincoming; -- will be called when client sends data ontimeout = listener.ontimeout; -- called when fatal socket timeout occurs + onreadtimeout = listener.onreadtimeout; -- called when socket inactivity timeout occurs + ondrain = listener.ondrain; -- called when writebuffer is empty onstatus = listener.onstatus; -- called for status changes (e.g. of SSL/TLS) eventread = false, eventwrite = false, eventclose = false, eventhandshake = false, eventstarthandshake = false; -- event handler @@ -511,7 +499,7 @@ do noreading = false, nowriting = false; -- locks of the read/writecallback startsslcallback = false; -- starting handshake callback position = false; -- position of client in interfacelist - + -- Properties _ip = ip, _port = port, _server = server, _pattern = pattern, _serverport = (server and server:port() or nil), @@ -544,10 +532,11 @@ do interface.eventwritetimeout = false end end - local succ, err, byte = interface.conn:send( interface.writebuffer, 1, interface.writebufferlen ) + interface.writebuffer = { t_concat(interface.writebuffer) } + local succ, err, byte = interface.conn:send( interface.writebuffer[1], 1, interface.writebufferlen ) --vdebug( "write data:", interface.writebuffer, "error:", err, "part:", byte ) if succ then -- writing succesful - interface.writebuffer = "" + interface.writebuffer[1] = nil interface.writebufferlen = 0 interface:ondrain(); if interface.fatalerror then @@ -563,7 +552,7 @@ do return -1 elseif byte and (err == "timeout" or err == "wantwrite") then -- want write again --vdebug( "writebuffer is not empty:", err ) - interface.writebuffer = string_sub( interface.writebuffer, byte + 1, interface.writebufferlen ) -- new buffer + interface.writebuffer[1] = string_sub( interface.writebuffer[1], byte + 1, interface.writebufferlen ) -- new buffer interface.writebufferlen = interface.writebufferlen - byte if "wantread" == err then -- happens only with luasec local callback = function( ) @@ -586,7 +575,7 @@ do end end end - + interface.readcallback = function( event ) -- called on read events --vdebug( "new client read event, id/ip/port:", tostring(interface.id), tostring(ip), tostring(port) ) if interface.noreading or interface.fatalerror then -- leave this event @@ -594,57 +583,56 @@ do interface.eventread = nil return -1 end - if EV_TIMEOUT == event then -- took too long to get some data from client -> disconnect - interface.fatalerror = "timeout during receiving" - debug( "connection failed:", interface.fatalerror ) + if EV_TIMEOUT == event and interface:onreadtimeout() ~= true then + return -1 -- took too long to get some data from client -> disconnect + end + if interface._usingssl then -- handle luasec + if interface.eventwritetimeout then -- ok, in the past writecallback was regged + local ret = interface.writecallback( ) -- call it + --vdebug( "tried to write in readcallback, result:", tostring(ret) ) + end + if interface.eventreadtimeout then + interface.eventreadtimeout:close( ) + interface.eventreadtimeout = nil + end + end + local buffer, err, part = interface.conn:receive( interface._pattern ) -- receive buffer with "pattern" + --vdebug( "read data:", tostring(buffer), "error:", tostring(err), "part:", tostring(part) ) + buffer = buffer or part + if buffer and #buffer > cfg.MAX_READ_LENGTH then -- check buffer length + interface.fatalerror = "receive buffer exceeded" + debug( "fatal error:", interface.fatalerror ) interface:_close() interface.eventread = nil return -1 - else -- can read - if interface._usingssl then -- handle luasec - if interface.eventwritetimeout then -- ok, in the past writecallback was regged - local ret = interface.writecallback( ) -- call it - --vdebug( "tried to write in readcallback, result:", tostring(ret) ) - end - if interface.eventreadtimeout then - interface.eventreadtimeout:close( ) - interface.eventreadtimeout = nil + end + if err and ( err ~= "timeout" and err ~= "wantread" ) then + if "wantwrite" == err then -- need to read on write event + if not interface.eventwrite then -- register new write event if needed + interface.eventwrite = addevent( base, interface.conn, EV_WRITE, interface.writecallback, cfg.WRITE_TIMEOUT ) end - end - local buffer, err, part = interface.conn:receive( interface._pattern ) -- receive buffer with "pattern" - --vdebug( "read data:", tostring(buffer), "error:", tostring(err), "part:", tostring(part) ) - buffer = buffer or part or "" - local len = string_len( buffer ) - if len > cfg.MAX_READ_LENGTH then -- check buffer length - interface.fatalerror = "receive buffer exceeded" - debug( "fatal error:", interface.fatalerror ) + interface.eventreadtimeout = addevent( base, nil, EV_TIMEOUT, + function( ) + interface:_close() + end, cfg.READ_TIMEOUT + ) + debug( "wantwrite during read attempt, reg it in writecallback but dont know what really happens next..." ) + -- to be honest i dont know what happens next, if it is allowed to first read, the write etc... + else -- connection was closed or fatal error + interface.fatalerror = err + debug( "connection failed in read event:", interface.fatalerror ) interface:_close() interface.eventread = nil return -1 end + else interface.onincoming( interface, buffer, err ) -- send new data to listener - if err and ( err ~= "timeout" and err ~= "wantread" ) then - if "wantwrite" == err then -- need to read on write event - if not interface.eventwrite then -- register new write event if needed - interface.eventwrite = addevent( base, interface.conn, EV_WRITE, interface.writecallback, cfg.WRITE_TIMEOUT ) - end - interface.eventreadtimeout = addevent( base, nil, EV_TIMEOUT, - function( ) - interface:_close() - end, cfg.READ_TIMEOUT - ) - debug( "wantwrite during read attempt, reg it in writecallback but dont know what really happens next..." ) - -- to be honest i dont know what happens next, if it is allowed to first read, the write etc... - else -- connection was closed or fatal error - interface.fatalerror = err - debug( "connection failed in read event:", interface.fatalerror ) - interface:_close() - interface.eventread = nil - return -1 - end - end - return EV_READ, cfg.READ_TIMEOUT end + if interface.noreading then + interface.eventread = nil; + return -1; + end + return EV_READ, cfg.READ_TIMEOUT end client:settimeout( 0 ) -- set non blocking @@ -660,7 +648,7 @@ do debug "creating server interface..." local interface = { _connections = 0; - + conn = server; onconnect = listener.onconnect; -- will be called when new client connected eventread = false; -- read event handler @@ -668,7 +656,7 @@ do readcallback = false; -- read event callback fatalerror = false; -- error message nointerface = true; -- lock/unlock parameter - + _ip = addr, _port = port, _pattern = pattern, _sslctx = sslctx; } @@ -699,7 +687,7 @@ do end local client_ip, client_port = client:getpeername( ) interface._connections = interface._connections + 1 -- increase connection count - local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, nil, sslctx ) + local clientinterface = handleclient( client, client_ip, client_port, interface, pattern, listener, sslctx ) --vdebug( "client id:", clientinterface, "startssl:", startssl ) if ssl and sslctx then clientinterface:starttls(sslctx, true) @@ -707,12 +695,12 @@ do clientinterface:_start_session( true ) end debug( "accepted incoming client connection from:", client_ip or "<unknown IP>", client_port or "<unknown port>", "to", port or "<unknown port>"); - + client, err = server:accept() -- try to accept again end return EV_READ end - + server:settimeout( 0 ) setmetatable(interface, interface_mt) interfacelist( "add", interface ) @@ -726,7 +714,7 @@ local addserver = ( function( ) --vdebug( "creating new tcp server with following parameters:", addr or "nil", port or "nil", sslcfg or "nil", startssl or "nil") local server, err = socket.bind( addr, port, cfg.ACCEPT_QUEUE ) -- create server socket if not server then - debug( "creating server socket failed because:", err ) + debug( "creating server socket on "..addr.." port "..port.." failed:", err ) return nil, err end local sslctx @@ -749,13 +737,13 @@ end )( ) local addclient, wrapclient do - function wrapclient( client, ip, port, listeners, pattern, sslctx, startssl ) + function wrapclient( client, ip, port, listeners, pattern, sslctx ) local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx ) - interface:_start_session() + interface:_start_connection(sslctx) return interface, client --function handleclient( client, ip, port, server, pattern, listener, _, sslctx ) -- creates an client interface end - + function addclient( addr, serverport, listener, pattern, localaddr, localport, sslcfg, startssl ) local client, err = socket.tcp() -- creating new socket if not client then @@ -785,9 +773,6 @@ do local res, err = client:connect( addr, serverport ) -- connect if res or ( err == "timeout" ) then local ip, port = client:getsockname( ) - local server = function( ) - return nil, "this is a dummy server interface" - end local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx, startssl ) interface:_start_connection( startssl ) debug( "new connection id:", interface.id ) @@ -828,14 +813,14 @@ local function setquitting(yes) end end -function get_backend() +local function get_backend() return base:method(); end -- We need to hold onto the events to stop them -- being garbage-collected local signal_events = {}; -- [signal_num] -> event object -function hook_signal(signal_num, handler) +local function hook_signal(signal_num, handler) local function _handler(event) local ret = handler(); if ret ~= false then -- Continue handling this signal? @@ -849,14 +834,14 @@ end local function link(sender, receiver, buffersize) local sender_locked; - + function receiver:ondrain() if sender_locked then sender:resume(); sender_locked = nil; end end - + function sender:onincoming(data) receiver:write(data); if receiver.writebufferlen >= buffersize then diff --git a/net/server_select.lua b/net/server_select.lua index cfd7f3cd..ca55d2d5 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -1,7 +1,7 @@ --- +-- -- server.lua by blastbeat of the luadch project -- Re-used here under the MIT/X Consortium License --- +-- -- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain -- @@ -10,16 +10,10 @@ local use = function( what ) return _G[ what ] end -local clean = function( tbl ) - for i, k in pairs( tbl ) do - tbl[ i ] = nil - end -end local log, table_concat = require ("util.logger").init("socket"), table.concat; local out_put = function (...) return log("debug", table_concat{...}); end local out_error = function (...) return log("warn", table_concat{...}); end -local mem_free = collectgarbage ----------------------------------// DECLARATION //-- @@ -34,7 +28,6 @@ local pairs = use "pairs" local ipairs = use "ipairs" local tonumber = use "tonumber" local tostring = use "tostring" -local collectgarbage = use "collectgarbage" --// lua libs //-- @@ -49,8 +42,6 @@ local os_difftime = os.difftime local math_min = math.min local math_huge = math.huge local table_concat = table.concat -local table_remove = table.remove -local string_len = string.len local string_sub = string.sub local coroutine_wrap = coroutine.wrap local coroutine_yield = coroutine.yield @@ -67,7 +58,6 @@ local ssl_wrap = ( luasec and luasec.wrap ) local socket_bind = luasocket.bind local socket_sleep = luasocket.sleep local socket_select = luasocket.select -local ssl_newcontext = ( luasec and luasec.newcontext ) --// functions //-- @@ -75,17 +65,16 @@ local id local loop local stats local idfalse -local addtimer local closeall local addsocket local addserver +local addtimer local getserver local wrapserver local getsettings local closesocket local removesocket local removeserver -local changetimeout local wrapconnection local changesettings @@ -112,6 +101,7 @@ local _readtraffic local _selecttimeout local _sleeptime +local _tcpbacklog local _starttime local _currenttime @@ -123,11 +113,10 @@ local _checkinterval local _sendtimeout local _readtimeout -local _cleanqueue - local _timer -local _maxclientsperserver +local _maxselectlen +local _maxfd local _maxsslhandshake @@ -151,29 +140,34 @@ _readtraffic = 0 _selecttimeout = 1 -- timeout of socket.select _sleeptime = 0 -- time to wait at the end of every loop +_tcpbacklog = 128 -- some kind of hint to the OS _maxsendlen = 51000 * 1024 -- max len of send buffer _maxreadlen = 25000 * 1024 -- max len of read buffer -_checkinterval = 1200000 -- interval in secs to check idle clients +_checkinterval = 30 -- interval in secs to check idle clients _sendtimeout = 60000 -- allowed send idle time in secs _readtimeout = 6 * 60 * 60 -- allowed read idle time in secs -_cleanqueue = false -- clean bufferqueue after using - -_maxclientsperserver = 1000 +local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows +_maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows +_maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows _maxsslhandshake = 30 -- max handshake round-trips ----------------------------------// PRIVATE //-- -wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxconnections ) -- this function wraps a server +wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd - maxconnections = maxconnections or _maxclientsperserver + if socket:getfd() >= _maxfd then + out_error("server.lua: Disallowed FD number: "..socket:getfd()) + socket:close() + return nil, "fd-too-large" + end local connections = 0 - local dispatch, disconnect = listeners.onconnect or listeners.onincoming, listeners.ondisconnect + local dispatch, disconnect = listeners.onconnect, listeners.ondisconnect local accept = socket.accept @@ -191,23 +185,43 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco end handler.remove = function( ) connections = connections - 1 - end - handler.close = function( ) - for _, handler in pairs( _socketlist ) do - if handler.serverport == serverport then - handler.disconnect( handler, "server closed" ) - handler:close( true ) - end + if handler then + handler.resume( ) end + end + handler.close = function() socket:close( ) _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _server[ip..":"..serverport] = nil; _socketlist[ socket ] = nil handler = nil socket = nil --mem_free( ) out_put "server.lua: closed server handler and removed sockets from list" end + handler.pause = function( hard ) + if not handler.paused then + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + if hard then + _socketlist[ socket ] = nil + socket:close( ) + socket = nil; + end + handler.paused = true; + end + end + handler.resume = function( ) + if handler.paused then + if not socket then + socket = socket_bind( ip, serverport, _tcpbacklog ); + socket:settimeout( 0 ) + end + _readlistlen = addsocket(_readlist, socket, _readlistlen) + _socketlist[ socket ] = handler + handler.paused = false; + end + end handler.ip = function( ) return ip end @@ -218,21 +232,24 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco return socket end handler.readbuffer = function( ) - if connections > maxconnections then + if _readlistlen >= _maxselectlen or _sendlistlen >= _maxselectlen then + handler.pause( ) out_put( "server.lua: refused new client connection: server full" ) return false end local client, err = accept( socket ) -- try to accept if client then local ip, clientport = client:getpeername( ) - client:settimeout( 0 ) local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket if err then -- error while wrapping ssl socket return false end connections = connections + 1 out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport)) - return dispatch( handler ) + if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes + return dispatch( handler ); + end + return; elseif err then -- maybe timeout or something else out_put( "server.lua: error with new client connection: ", tostring(err) ) return false @@ -243,6 +260,14 @@ end wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object + if socket:getfd() >= _maxfd then + out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent + socket:close( ) -- Should we send some kind of error here? + if server then + server.pause( ) + end + return nil, nil, "fd-too-large" + end socket:settimeout( 0 ) --// local import of socket methods //-- @@ -317,22 +342,25 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end return false, "setoption not implemented"; end - handler.close = function( self, forced ) + handler.force_close = function ( self, err ) + if bufferqueuelen ~= 0 then + out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport)) + bufferqueuelen = 0; + end + return self:close(err); + end + handler.close = function( self, err ) if not handler then return true; end _readlistlen = removesocket( _readlist, socket, _readlistlen ) _readtimes[ handler ] = nil if bufferqueuelen ~= 0 then - if not ( forced or fatalerror ) then - handler.sendbuffer( ) - if bufferqueuelen ~= 0 then -- try again... - if handler then - handler.write = nil -- ... but no further writing allowed - end - toclose = true - return false + handler.sendbuffer() -- Try now to send any outstanding data + if bufferqueuelen ~= 0 then -- Still not empty, so we'll try again later + if handler then + handler.write = nil -- ... but no further writing allowed end - else - send( socket, table_concat( bufferqueue, "", 1, bufferqueuelen ), 1, bufferlen ) -- forced send + toclose = true + return false end end if socket then @@ -347,7 +375,12 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport if handler then _writetimes[ handler ] = nil _closelist[ handler ] = nil + local _handler = handler; handler = nil + if disconnect then + disconnect(_handler, err or false); + disconnect = nil + end end if server then server.remove( ) @@ -365,7 +398,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport return clientport end local write = function( self, data ) - bufferlen = bufferlen + string_len( data ) + bufferlen = bufferlen + #data if bufferlen > maxsendlen then _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle handler.write = idfalse -- dont write anymore @@ -447,10 +480,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern" if not err or (err == "wantread" or err == "timeout") then -- received something local buffer = buffer or part or "" - local len = string_len( buffer ) + local len = #buffer if len > maxreadlen then - disconnect( handler, "receive buffer exceeded" ) - handler:close( true ) + handler:close( "receive buffer exceeded" ) return false end local count = len * STAT_UNIT @@ -462,24 +494,24 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport else -- connections was closed or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) ) fatalerror = true - disconnect( handler, err ) - _ = handler and handler:close( ) + _ = handler and handler:force_close( err ) return false end end local _sendbuffer = function( ) -- this function sends data local succ, err, byte, buffer, count; - local count; if socket then buffer = table_concat( bufferqueue, "", 1, bufferqueuelen ) succ, err, byte = send( socket, buffer, 1, bufferlen ) count = ( succ or byte or 0 ) * STAT_UNIT sendtraffic = sendtraffic + count _sendtraffic = _sendtraffic + count - _ = _cleanqueue and clean( bufferqueue ) + for i = bufferqueuelen,1,-1 do + bufferqueue[ i ] = nil + end --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) ) else - succ, err, count = false, "closed", 0; + succ, err, count = false, "unexpected close", 0; end if succ then -- sending succesful bufferqueuelen = 0 @@ -490,7 +522,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport drain(handler) end _ = needtls and handler:starttls(nil) - _ = toclose and handler:close( ) + _ = toclose and handler:force_close( ) return true elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer @@ -502,8 +534,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport else -- connection was closed during sending or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) ) fatalerror = true - disconnect( handler, err ) - _ = handler and handler:close( ) + _ = handler and handler:force_close( err ) return false end end @@ -511,10 +542,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport -- Set the sslctx local handshake; function handler.set_sslctx(self, new_sslctx) - ssl = true sslctx = new_sslctx; - local wrote - local read + local read, wrote handshake = coroutine_wrap( function( client ) -- create handshake coroutine local err for i = 1, _maxsslhandshake do @@ -527,108 +556,93 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.readbuffer = _readbuffer -- when handshake is done, replace the handshake function with regular functions handler.sendbuffer = _sendbuffer _ = status and status( handler, "ssl-handshake-complete" ) + if self.autostart_ssl and listeners.onconnect then + listeners.onconnect(self); + end _readlistlen = addsocket(_readlist, client, _readlistlen) return true else - out_put( "server.lua: error during ssl handshake: ", tostring(err) ) - if err == "wantwrite" and not wrote then + if err == "wantwrite" then _sendlistlen = addsocket(_sendlist, client, _sendlistlen) wrote = true - elseif err == "wantread" and not read then + elseif err == "wantread" then _readlistlen = addsocket(_readlist, client, _readlistlen) read = true else break; end - --coroutine_yield( handler, nil, err ) -- handshake not finished - coroutine_yield( ) + err = nil; + coroutine_yield( ) -- handshake not finished end end - disconnect( handler, "ssl handshake failed" ) - _ = handler and handler:close( true ) -- forced disconnect - return false -- handshake failed + out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") ) + _ = handler and handler:force_close("ssl handshake failed") + return false, err -- handshake failed end ) end if luasec then - if sslctx then -- ssl? - handler:set_sslctx(sslctx); - out_put("server.lua: ", "starting ssl handshake") - local err - socket, err = ssl_wrap( socket, sslctx ) -- wrap socket - if err then - out_put( "server.lua: ssl error: ", tostring(err) ) - --mem_free( ) - return nil, nil, err -- fatal error + handler.starttls = function( self, _sslctx) + if _sslctx then + handler:set_sslctx(_sslctx); end - socket:settimeout( 0 ) - handler.readbuffer = handshake - handler.sendbuffer = handshake - handshake( socket ) -- do handshake + if bufferqueuelen > 0 then + out_put "server.lua: we need to do tls, but delaying until send buffer empty" + needtls = true + return + end + out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) + local oldsocket, err = socket + socket, err = ssl_wrap( socket, sslctx ) -- wrap socket if not socket then - return nil, nil, "ssl handshake failed"; + out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) + return nil, err -- fatal error end - else - local sslctx; - handler.starttls = function( self, _sslctx) - if _sslctx then - sslctx = _sslctx; - handler:set_sslctx(sslctx); - end - if bufferqueuelen > 0 then - out_put "server.lua: we need to do tls, but delaying until send buffer empty" - needtls = true - return - end - out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) - local oldsocket, err = socket - socket, err = ssl_wrap( socket, sslctx ) -- wrap socket - --out_put( "server.lua: sslwrapped socket is " .. tostring( socket ) ) - if err then - out_put( "server.lua: error while starting tls on client: ", tostring(err) ) - return nil, err -- fatal error - end - socket:settimeout( 0 ) - - -- add the new socket to our system - - send = socket.send - receive = socket.receive - shutdown = id - - _socketlist[ socket ] = handler - _readlistlen = addsocket(_readlist, socket, _readlistlen) + socket:settimeout( 0 ) - -- remove traces of the old socket + -- add the new socket to our system + send = socket.send + receive = socket.receive + shutdown = id + _socketlist[ socket ] = handler + _readlistlen = addsocket(_readlist, socket, _readlistlen) - _readlistlen = removesocket( _readlist, oldsocket, _readlistlen ) - _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen ) - _socketlist[ oldsocket ] = nil + -- remove traces of the old socket + _readlistlen = removesocket( _readlist, oldsocket, _readlistlen ) + _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen ) + _socketlist[ oldsocket ] = nil - handler.starttls = nil - needtls = nil + handler.starttls = nil + needtls = nil - -- Secure now - ssl = true + -- Secure now (if handshake fails connection will close) + ssl = true - handler.readbuffer = handshake - handler.sendbuffer = handshake - handshake( socket ) -- do handshake - end - handler.readbuffer = _readbuffer - handler.sendbuffer = _sendbuffer + handler.readbuffer = handshake + handler.sendbuffer = handshake + return handshake( socket ) -- do handshake end - else - handler.readbuffer = _readbuffer - handler.sendbuffer = _sendbuffer end + + handler.readbuffer = _readbuffer + handler.sendbuffer = _sendbuffer send = socket.send receive = socket.receive shutdown = ( ssl and id ) or socket.shutdown _socketlist[ socket ] = handler _readlistlen = addsocket(_readlist, socket, _readlistlen) + + if sslctx and luasec then + out_put "server.lua: auto-starting ssl negotiation..." + handler.autostart_ssl = true; + local ok, err = handler:starttls(sslctx); + if ok == false then + return nil, nil, err + end + end + return handler, socket end @@ -681,7 +695,7 @@ local function link(sender, receiver, buffersize) sender_locked = nil; end end - + local _readbuffer = sender.readbuffer; function sender.readbuffer() _readbuffer(); @@ -701,45 +715,45 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function end if type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then err = "invalid port" - elseif _server[ port ] then - err = "listeners on port '" .. port .. "' already exist" + elseif _server[ addr..":"..port ] then + err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist" elseif sslctx and not luasec then err = "luasec not found" end if err then - out_error( "server.lua, port ", port, ": ", err ) + out_error( "server.lua, [", addr, "]:", port, ": ", err ) return nil, err end addr = addr or "*" - local server, err = socket_bind( addr, port ) + local server, err = socket_bind( addr, port, _tcpbacklog ) if err then - out_error( "server.lua, port ", port, ": ", err ) + out_error( "server.lua, [", addr, "]:", port, ": ", err ) return nil, err end - local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, _maxclientsperserver ) -- wrap new server socket + local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket if not handler then server:close( ) return nil, err end server:settimeout( 0 ) _readlistlen = addsocket(_readlist, server, _readlistlen) - _server[ port ] = handler + _server[ addr..":"..port ] = handler _socketlist[ server ] = handler - out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '", addr, ":", port, "'" ) + out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" ) return handler end -getserver = function ( port ) - return _server[ port ]; +getserver = function ( addr, port ) + return _server[ addr..":"..port ]; end -removeserver = function( port ) - local handler = _server[ port ] +removeserver = function( addr, port ) + local handler = _server[ addr..":"..port ] if not handler then - return nil, "no server found on port '" .. tostring( port ) .. "'" + return nil, "no server found on '[" .. addr .. "]:" .. tostring( port ) .. "'" end handler:close( ) - _server[ port ] = nil + _server[ addr..":"..port ] = nil return true end @@ -760,23 +774,36 @@ closeall = function( ) end getsettings = function( ) - return _selecttimeout, _sleeptime, _maxsendlen, _maxreadlen, _checkinterval, _sendtimeout, _readtimeout, _cleanqueue, _maxclientsperserver, _maxsslhandshake + return { + select_timeout = _selecttimeout; + select_sleep_time = _sleeptime; + tcp_backlog = _tcpbacklog; + max_send_buffer_size = _maxsendlen; + max_receive_buffer_size = _maxreadlen; + select_idle_check_interval = _checkinterval; + send_timeout = _sendtimeout; + read_timeout = _readtimeout; + max_connections = _maxselectlen; + max_ssl_handshake_roundtrips = _maxsslhandshake; + highest_allowed_fd = _maxfd; + } end changesettings = function( new ) if type( new ) ~= "table" then return nil, "invalid settings table" end - _selecttimeout = tonumber( new.timeout ) or _selecttimeout - _sleeptime = tonumber( new.sleeptime ) or _sleeptime - _maxsendlen = tonumber( new.maxsendlen ) or _maxsendlen - _maxreadlen = tonumber( new.maxreadlen ) or _maxreadlen - _checkinterval = tonumber( new.checkinterval ) or _checkinterval - _sendtimeout = tonumber( new.sendtimeout ) or _sendtimeout - _readtimeout = tonumber( new.readtimeout ) or _readtimeout - _cleanqueue = new.cleanqueue - _maxclientsperserver = new._maxclientsperserver or _maxclientsperserver - _maxsslhandshake = new._maxsslhandshake or _maxsslhandshake + _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout + _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime + _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen + _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen + _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval + _tcpbacklog = tonumber( new.tcp_backlog ) or _tcpbacklog + _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout + _readtimeout = tonumber( new.read_timeout ) or _readtimeout + _maxselectlen = new.max_connections or _maxselectlen + _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake + _maxfd = new.highest_allowed_fd or _maxfd return true end @@ -795,7 +822,7 @@ end local quitting; -setquitting = function (quit) +local function setquitting(quit) quitting = not not quit; end @@ -825,10 +852,32 @@ loop = function(once) -- this is the main loop of the program end for handler, err in pairs( _closelist ) do handler.disconnect( )( handler, err ) - handler:close( true ) -- forced disconnect + handler:force_close() -- forced disconnect + _closelist[ handler ] = nil; end - clean( _closelist ) _currenttime = luasocket_gettime( ) + + -- Check for socket timeouts + local difftime = os_difftime( _currenttime - _starttime ) + if difftime > _checkinterval then + _starttime = _currenttime + for handler, timestamp in pairs( _writetimes ) do + if os_difftime( _currenttime - timestamp ) > _sendtimeout then + handler.disconnect( )( handler, "send timeout" ) + handler:force_close() -- forced disconnect + end + end + for handler, timestamp in pairs( _readtimes ) do + if os_difftime( _currenttime - timestamp ) > _readtimeout then + if not(handler.onreadtimeout) or handler:onreadtimeout() ~= true then + handler.disconnect( )( handler, "read timeout" ) + handler:close( ) -- forced disconnect? + end + end + end + end + + -- Fire timers if _currenttime - _timer >= math_min(next_timer_time, 1) then next_timer_time = math_huge; for i = 1, _timerlistlen do @@ -839,14 +888,15 @@ loop = function(once) -- this is the main loop of the program else next_timer_time = next_timer_time - (_currenttime - _timer); end - socket_sleep( _sleeptime ) -- wait some time - --collectgarbage( ) + + -- wait some time (0 by default) + socket_sleep( _sleeptime ) until quitting; if once and quitting == "once" then quitting = nil; return; end return "quitting" end -step = function () +local function step() return loop(true); end @@ -857,18 +907,22 @@ end --// EXPERIMENTAL //-- local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx ) - local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) + local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) + if not handler then return nil, err end _socketlist[ socket ] = handler - _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) - if listeners.onconnect then - -- When socket is writeable, call onconnect - local _sendbuffer = handler.sendbuffer; - handler.sendbuffer = function () - handler.sendbuffer = _sendbuffer; - listeners.onconnect(handler); - -- If there was data with the incoming packet, handle it now. - if #handler:bufferqueue() > 0 then - return _sendbuffer(); + if not sslctx then + _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) + if listeners.onconnect then + -- When socket is writeable, call onconnect + local _sendbuffer = handler.sendbuffer; + handler.sendbuffer = function () + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ); + handler.sendbuffer = _sendbuffer; + listeners.onconnect(handler); + -- If there was data with the incoming packet, handle it now. + if #handler:bufferqueue() > 0 then + return _sendbuffer(); + end end end end @@ -883,9 +937,9 @@ local addclient = function( address, port, listeners, pattern, sslctx ) client:settimeout( 0 ) _, err = client:connect( address, port ) if err then -- try again - local handler = wrapclient( client, address, port, listeners ) + return wrapclient( client, address, port, listeners, pattern, sslctx ) else - wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx ) + return wrapconnection( nil, listeners, client, address, port, "clientport", pattern, sslctx ) end end @@ -900,28 +954,6 @@ use "setmetatable" ( _writetimes, { __mode = "k" } ) _timer = luasocket_gettime( ) _starttime = luasocket_gettime( ) -addtimer( function( ) - local difftime = os_difftime( _currenttime - _starttime ) - if difftime > _checkinterval then - _starttime = _currenttime - for handler, timestamp in pairs( _writetimes ) do - if os_difftime( _currenttime - timestamp ) > _sendtimeout then - --_writetimes[ handler ] = nil - handler.disconnect( )( handler, "send timeout" ) - handler:close( true ) -- forced disconnect - end - end - for handler, timestamp in pairs( _readtimes ) do - if os_difftime( _currenttime - timestamp ) > _readtimeout then - --_readtimes[ handler ] = nil - handler.disconnect( )( handler, "read timeout" ) - handler:close( ) -- forced disconnect? - end - end - end - end -) - local function setlogger(new_logger) local old_logger = log; if new_logger then @@ -933,15 +965,16 @@ end ----------------------------------// PUBLIC INTERFACE //-- return { + _addtimer = addtimer, addclient = addclient, wrapclient = wrapclient, - + loop = loop, link = link, + step = step, stats = stats, closeall = closeall, - addtimer = addtimer, addserver = addserver, getserver = getserver, setlogger = setlogger, diff --git a/net/xmppclient_listener.lua b/net/xmppclient_listener.lua deleted file mode 100644 index 4cc90cbf..00000000 --- a/net/xmppclient_listener.lua +++ /dev/null @@ -1,179 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - - -local logger = require "logger"; -local log = logger.init("xmppclient_listener"); -local new_xmpp_stream = require "util.xmppstream".new; - -local connlisteners_register = require "net.connlisteners".register; - -local sessionmanager = require "core.sessionmanager"; -local sm_new_session, sm_destroy_session = sessionmanager.new_session, sessionmanager.destroy_session; -local sm_streamopened = sessionmanager.streamopened; -local sm_streamclosed = sessionmanager.streamclosed; -local st = require "util.stanza"; -local xpcall = xpcall; -local tostring = tostring; -local type = type; -local traceback = debug.traceback; - -local config = require "core.configmanager"; -local opt_keepalives = config.get("*", "core", "tcp_keepalives"); - -local stream_callbacks = { default_ns = "jabber:client", - streamopened = sm_streamopened, streamclosed = sm_streamclosed, handlestanza = core_process_stanza }; - -local xmlns_xmpp_streams = "urn:ietf:params:xml:ns:xmpp-streams"; - -function stream_callbacks.error(session, error, data) - if error == "no-stream" then - session.log("debug", "Invalid opening stream header"); - session:close("invalid-namespace"); - elseif error == "parse-error" then - (session.log or log)("debug", "Client XML parse error: %s", tostring(data)); - session:close("not-well-formed"); - elseif error == "stream-error" then - local condition, text = "undefined-condition"; - for child in data:children() do - if child.attr.xmlns == xmlns_xmpp_streams then - if child.name ~= "text" then - condition = child.name; - else - text = child:get_text(); - end - if condition ~= "undefined-condition" and text then - break; - end - end - end - text = condition .. (text and (" ("..text..")") or ""); - session.log("info", "Session closed by remote with error: %s", text); - session:close(nil, text); - end -end - -local function handleerr(err) log("error", "Traceback[c2s]: %s: %s", tostring(err), traceback()); end -function stream_callbacks.handlestanza(session, stanza) - stanza = session.filter("stanzas/in", stanza); - if stanza then - return xpcall(function () return core_process_stanza(session, stanza) end, handleerr); - end -end - -local sessions = {}; -local xmppclient = { default_port = 5222, default_mode = "*a" }; - --- These are session methods -- - -local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; -local default_stream_attr = { ["xmlns:stream"] = "http://etherx.jabber.org/streams", xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; -local function session_close(session, reason) - local log = session.log or log; - if session.conn then - if session.notopen then - session.send("<?xml version='1.0'?>"); - session.send(st.stanza("stream:stream", default_stream_attr):top_tag()); - end - if reason then - if type(reason) == "string" then -- assume stream error - log("info", "Disconnecting client, <stream:error> is: %s", reason); - session.send(st.stanza("stream:error"):tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' })); - elseif type(reason) == "table" then - if reason.condition then - local stanza = st.stanza("stream:error"):tag(reason.condition, stream_xmlns_attr):up(); - if reason.text then - stanza:tag("text", stream_xmlns_attr):text(reason.text):up(); - end - if reason.extra then - stanza:add_child(reason.extra); - end - log("info", "Disconnecting client, <stream:error> is: %s", tostring(stanza)); - session.send(stanza); - elseif reason.name then -- a stanza - log("info", "Disconnecting client, <stream:error> is: %s", tostring(reason)); - session.send(reason); - end - end - end - session.send("</stream:stream>"); - session.conn:close(); - xmppclient.ondisconnect(session.conn, (reason and (reason.text or reason.condition)) or reason or "session closed"); - end -end - - --- End of session methods -- - -function xmppclient.onconnect(conn) - local session = sm_new_session(conn); - sessions[conn] = session; - - session.log("info", "Client connected"); - - -- Client is using legacy SSL (otherwise mod_tls sets this flag) - if conn:ssl() then - session.secure = true; - end - - if opt_keepalives ~= nil then - conn:setoption("keepalive", opt_keepalives); - end - - session.close = session_close; - - local stream = new_xmpp_stream(session, stream_callbacks); - session.stream = stream; - - session.notopen = true; - - function session.reset_stream() - session.notopen = true; - session.stream:reset(); - end - - local filter = session.filter; - function session.data(data) - data = filter("bytes/in", data); - if data then - local ok, err = stream:feed(data); - if ok then return; end - log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " "):gsub("[%z\1-\31]", "_")); - session:close("not-well-formed"); - end - end - - local handlestanza = stream_callbacks.handlestanza; - function session.dispatch_stanza(session, stanza) - return handlestanza(session, stanza); - end -end - -function xmppclient.onincoming(conn, data) - local session = sessions[conn]; - if session then - session.data(data); - end -end - -function xmppclient.ondisconnect(conn, err) - local session = sessions[conn]; - if session then - (session.log or log)("info", "Client disconnected: %s", err); - sm_destroy_session(session, err); - sessions[conn] = nil; - session = nil; - end -end - -function xmppclient.associate_session(conn, session) - sessions[conn] = session; -end - -connlisteners_register("xmppclient", xmppclient); diff --git a/net/xmppcomponent_listener.lua b/net/xmppcomponent_listener.lua deleted file mode 100644 index 90293559..00000000 --- a/net/xmppcomponent_listener.lua +++ /dev/null @@ -1,220 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - -local hosts = _G.hosts; - -local t_concat = table.concat; -local tostring = tostring; -local type = type; -local pairs = pairs; - -local lxp = require "lxp"; -local logger = require "util.logger"; -local config = require "core.configmanager"; -local connlisteners = require "net.connlisteners"; -local uuid_gen = require "util.uuid".generate; -local jid_split = require "util.jid".split; -local sha1 = require "util.hashes".sha1; -local st = require "util.stanza"; -local new_xmpp_stream = require "util.xmppstream".new; - -local sessions = {}; - -local log = logger.init("componentlistener"); - -local component_listener = { default_port = 5347; default_mode = "*a"; default_interface = config.get("*", "core", "component_interface") or "127.0.0.1" }; - -local xmlns_component = 'jabber:component:accept'; - ---- Callbacks/data for xmppstream to handle streams for us --- - -local stream_callbacks = { default_ns = xmlns_component }; - -local xmlns_xmpp_streams = "urn:ietf:params:xml:ns:xmpp-streams"; - -function stream_callbacks.error(session, error, data, data2) - if session.destroyed then return; end - log("warn", "Error processing component stream: "..tostring(error)); - if error == "no-stream" then - session:close("invalid-namespace"); - elseif error == "parse-error" then - session.log("warn", "External component %s XML parse error: %s", tostring(session.host), tostring(data)); - session:close("not-well-formed"); - elseif error == "stream-error" then - local condition, text = "undefined-condition"; - for child in data:children() do - if child.attr.xmlns == xmlns_xmpp_streams then - if child.name ~= "text" then - condition = child.name; - else - text = child:get_text(); - end - if condition ~= "undefined-condition" and text then - break; - end - end - end - text = condition .. (text and (" ("..text..")") or ""); - session.log("info", "Session closed by remote with error: %s", text); - session:close(nil, text); - end -end - -function stream_callbacks.streamopened(session, attr) - if config.get(attr.to, "core", "component_module") ~= "component" then - -- Trying to act as a component domain which - -- hasn't been configured - session:close{ condition = "host-unknown", text = tostring(attr.to).." does not match any configured external components" }; - return; - end - - -- Note that we don't create the internal component - -- until after the external component auths successfully - - session.host = attr.to; - session.streamid = uuid_gen(); - session.notopen = nil; - - session.send(st.stanza("stream:stream", { xmlns=xmlns_component, - ["xmlns:stream"]='http://etherx.jabber.org/streams', id=session.streamid, from=session.host }):top_tag()); - -end - -function stream_callbacks.streamclosed(session) - session.log("debug", "Received </stream:stream>"); - session:close(); -end - -local core_process_stanza = core_process_stanza; - -function stream_callbacks.handlestanza(session, stanza) - -- Namespaces are icky. - if not stanza.attr.xmlns and stanza.name == "handshake" then - stanza.attr.xmlns = xmlns_component; - end - if not stanza.attr.xmlns or stanza.attr.xmlns == "jabber:client" then - local from = stanza.attr.from; - if from then - if session.component_validate_from then - local _, domain = jid_split(stanza.attr.from); - if domain ~= session.host then - -- Return error - session.log("warn", "Component sent stanza with missing or invalid 'from' address"); - session:close{ - condition = "invalid-from"; - text = "Component tried to send from address <"..tostring(from) - .."> which is not in domain <"..tostring(session.host)..">"; - }; - return; - end - end - else - stanza.attr.from = session.host; - end - if not stanza.attr.to then - session.log("warn", "Rejecting stanza with no 'to' address"); - session.send(st.error_reply(stanza, "modify", "bad-request", "Components MUST specify a 'to' address on stanzas")); - return; - end - end - return core_process_stanza(session, stanza); -end - ---- Closing a component connection -local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; -local default_stream_attr = { ["xmlns:stream"] = "http://etherx.jabber.org/streams", xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; -local function session_close(session, reason) - if session.destroyed then return; end - local log = session.log or log; - if session.conn then - if session.notopen then - session.send("<?xml version='1.0'?>"); - session.send(st.stanza("stream:stream", default_stream_attr):top_tag()); - end - if reason then - if type(reason) == "string" then -- assume stream error - log("info", "Disconnecting component, <stream:error> is: %s", reason); - session.send(st.stanza("stream:error"):tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' })); - elseif type(reason) == "table" then - if reason.condition then - local stanza = st.stanza("stream:error"):tag(reason.condition, stream_xmlns_attr):up(); - if reason.text then - stanza:tag("text", stream_xmlns_attr):text(reason.text):up(); - end - if reason.extra then - stanza:add_child(reason.extra); - end - log("info", "Disconnecting component, <stream:error> is: %s", tostring(stanza)); - session.send(stanza); - elseif reason.name then -- a stanza - log("info", "Disconnecting component, <stream:error> is: %s", tostring(reason)); - session.send(reason); - end - end - end - session.send("</stream:stream>"); - session.conn:close(); - component_listener.ondisconnect(session.conn, "stream error"); - end -end - ---- Component connlistener -function component_listener.onconnect(conn) - local _send = conn.write; - local session = { type = "component", conn = conn, send = function (data) return _send(conn, tostring(data)); end }; - - -- Logging functions -- - local conn_name = "jcp"..tostring(conn):match("[a-f0-9]+$"); - session.log = logger.init(conn_name); - session.close = session_close; - - session.log("info", "Incoming Jabber component connection"); - - local stream = new_xmpp_stream(session, stream_callbacks); - session.stream = stream; - - session.notopen = true; - - function session.reset_stream() - session.notopen = true; - session.stream:reset(); - end - - function session.data(conn, data) - local ok, err = stream:feed(data); - if ok then return; end - log("debug", "Received invalid XML (%s) %d bytes: %s", tostring(err), #data, data:sub(1, 300):gsub("[\r\n]+", " "):gsub("[%z\1-\31]", "_")); - session:close("not-well-formed"); - end - - session.dispatch_stanza = stream_callbacks.handlestanza; - - sessions[conn] = session; -end -function component_listener.onincoming(conn, data) - local session = sessions[conn]; - session.data(conn, data); -end -function component_listener.ondisconnect(conn, err) - local session = sessions[conn]; - if session then - (session.log or log)("info", "component disconnected: %s (%s)", tostring(session.host), tostring(err)); - if session.on_destroy then session:on_destroy(err); end - sessions[conn] = nil; - for k in pairs(session) do - if k ~= "log" and k ~= "close" then - session[k] = nil; - end - end - session.destroyed = true; - session = nil; - end -end - -connlisteners.register('xmppcomponent', component_listener); diff --git a/net/xmppserver_listener.lua b/net/xmppserver_listener.lua deleted file mode 100644 index 3af0b962..00000000 --- a/net/xmppserver_listener.lua +++ /dev/null @@ -1,209 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - -local tostring = tostring; -local type = type; -local xpcall = xpcall; -local s_format = string.format; -local traceback = debug.traceback; - -local logger = require "logger"; -local log = logger.init("xmppserver_listener"); -local st = require "util.stanza"; -local connlisteners_register = require "net.connlisteners".register; -local new_xmpp_stream = require "util.xmppstream".new; -local s2s_new_incoming = require "core.s2smanager".new_incoming; -local s2s_streamopened = require "core.s2smanager".streamopened; -local s2s_streamclosed = require "core.s2smanager".streamclosed; -local s2s_destroy_session = require "core.s2smanager".destroy_session; -local s2s_attempt_connect = require "core.s2smanager".attempt_connection; -local stream_callbacks = { default_ns = "jabber:server", - streamopened = s2s_streamopened, streamclosed = s2s_streamclosed, handlestanza = core_process_stanza }; - -local xmlns_xmpp_streams = "urn:ietf:params:xml:ns:xmpp-streams"; - -function stream_callbacks.error(session, error, data) - if error == "no-stream" then - session:close("invalid-namespace"); - elseif error == "parse-error" then - session.log("debug", "Server-to-server XML parse error: %s", tostring(error)); - session:close("not-well-formed"); - elseif error == "stream-error" then - local condition, text = "undefined-condition"; - for child in data:children() do - if child.attr.xmlns == xmlns_xmpp_streams then - if child.name ~= "text" then - condition = child.name; - else - text = child:get_text(); - end - if condition ~= "undefined-condition" and text then - break; - end - end - end - text = condition .. (text and (" ("..text..")") or ""); - session.log("info", "Session closed by remote with error: %s", text); - session:close(nil, text); - end -end - -local function handleerr(err) log("error", "Traceback[s2s]: %s: %s", tostring(err), traceback()); end -function stream_callbacks.handlestanza(session, stanza) - if stanza.attr.xmlns == "jabber:client" then --COMPAT: Prosody pre-0.6.2 may send jabber:client - stanza.attr.xmlns = nil; - end - stanza = session.filter("stanzas/in", stanza); - if stanza then - return xpcall(function () return core_process_stanza(session, stanza) end, handleerr); - end -end - -local sessions = {}; -local xmppserver = { default_port = 5269, default_mode = "*a" }; - --- These are session methods -- - -local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; -local default_stream_attr = { ["xmlns:stream"] = "http://etherx.jabber.org/streams", xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; -local function session_close(session, reason, remote_reason) - local log = session.log or log; - if session.conn then - if session.notopen then - session.sends2s("<?xml version='1.0'?>"); - session.sends2s(st.stanza("stream:stream", default_stream_attr):top_tag()); - end - if reason then - if type(reason) == "string" then -- assume stream error - log("info", "Disconnecting %s[%s], <stream:error> is: %s", session.host or "(unknown host)", session.type, reason); - session.sends2s(st.stanza("stream:error"):tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' })); - elseif type(reason) == "table" then - if reason.condition then - local stanza = st.stanza("stream:error"):tag(reason.condition, stream_xmlns_attr):up(); - if reason.text then - stanza:tag("text", stream_xmlns_attr):text(reason.text):up(); - end - if reason.extra then - stanza:add_child(reason.extra); - end - log("info", "Disconnecting %s[%s], <stream:error> is: %s", session.host or "(unknown host)", session.type, tostring(stanza)); - session.sends2s(stanza); - elseif reason.name then -- a stanza - log("info", "Disconnecting %s->%s[%s], <stream:error> is: %s", session.from_host or "(unknown host)", session.to_host or "(unknown host)", session.type, tostring(reason)); - session.sends2s(reason); - end - end - end - session.sends2s("</stream:stream>"); - if session.notopen or not session.conn:close() then - session.conn:close(true); -- Force FIXME: timer? - end - session.conn:close(); - xmppserver.ondisconnect(session.conn, remote_reason or (reason and (reason.text or reason.condition)) or reason or "stream closed"); - end -end - - --- End of session methods -- - -local function initialize_session(session) - local stream = new_xmpp_stream(session, stream_callbacks); - session.stream = stream; - - session.notopen = true; - - function session.reset_stream() - session.notopen = true; - session.stream:reset(); - end - - local filter = session.filter; - function session.data(data) - data = filter("bytes/in", data); - if data then - local ok, err = stream:feed(data); - if ok then return; end - (session.log or log)("warn", "Received invalid XML: %s", data); - (session.log or log)("warn", "Problem was: %s", err); - session:close("not-well-formed"); - end - end - - session.close = session_close; - local handlestanza = stream_callbacks.handlestanza; - function session.dispatch_stanza(session, stanza) - return handlestanza(session, stanza); - end -end - -function xmppserver.onconnect(conn) - if not sessions[conn] then -- May be an existing outgoing session - local session = s2s_new_incoming(conn); - sessions[conn] = session; - - -- Logging functions -- - local conn_name = "s2sin"..tostring(conn):match("[a-f0-9]+$"); - session.log = logger.init(conn_name); - - session.log("info", "Incoming s2s connection"); - - initialize_session(session); - end -end - -function xmppserver.onincoming(conn, data) - local session = sessions[conn]; - if session then - session.data(data); - end -end - -function xmppserver.onstatus(conn, status) - if status == "ssl-handshake-complete" then - local session = sessions[conn]; - if session and session.direction == "outgoing" then - local to_host, from_host = session.to_host, session.from_host; - session.log("debug", "Sending stream header..."); - session.sends2s(s_format([[<stream:stream xmlns='jabber:server' xmlns:db='jabber:server:dialback' xmlns:stream='http://etherx.jabber.org/streams' from='%s' to='%s' version='1.0'>]], from_host, to_host)); - end - end -end - -function xmppserver.ondisconnect(conn, err) - local session = sessions[conn]; - if session then - if err and err ~= "closed" and session.srv_hosts then - (session.log or log)("debug", "s2s connection attempt failed: %s", err); - if s2s_attempt_connect(session, err) then - (session.log or log)("debug", "...so we're going to try another target"); - return; -- Session lives for now - end - end - (session.log or log)("info", "s2s disconnected: %s->%s (%s)", tostring(session.from_host), tostring(session.to_host), tostring(err or "closed")); - s2s_destroy_session(session, err); - sessions[conn] = nil; - session = nil; - end -end - -function xmppserver.register_outgoing(conn, session) - session.direction = "outgoing"; - sessions[conn] = session; - - initialize_session(session); -end - -connlisteners_register("xmppserver", xmppserver); - - --- We need to perform some initialisation when a connection is created --- We also need to perform that same initialisation at other points (SASL, TLS, ...) - --- ...and we need to handle data --- ...and record all sessions associated with connections |