aboutsummaryrefslogtreecommitdiffstats
path: root/net/http
diff options
context:
space:
mode:
Diffstat (limited to 'net/http')
-rw-r--r--net/http/codes.lua2
-rw-r--r--net/http/errors.lua119
-rw-r--r--net/http/files.lua149
-rw-r--r--net/http/parser.lua147
-rw-r--r--net/http/server.lua179
5 files changed, 470 insertions, 126 deletions
diff --git a/net/http/codes.lua b/net/http/codes.lua
index 8098b5c3..4327f151 100644
--- a/net/http/codes.lua
+++ b/net/http/codes.lua
@@ -82,5 +82,5 @@ local response_codes = {
-- [512-599] = "Unassigned";
};
-for k,v in pairs(response_codes) do response_codes[k] = k.." "..v; end
+for k,v in pairs(response_codes) do response_codes[k] = ("%03d %s"):format(k, v); end
return setmetatable(response_codes, { __index = function(_, k) return k.." Unassigned"; end })
diff --git a/net/http/errors.lua b/net/http/errors.lua
new file mode 100644
index 00000000..1691e426
--- /dev/null
+++ b/net/http/errors.lua
@@ -0,0 +1,119 @@
+-- This module returns a table that is suitable for use as a util.error registry,
+-- and a function to return a util.error object given callback 'code' and 'body'
+-- parameters.
+
+local codes = require "net.http.codes";
+local util_error = require "util.error";
+
+local error_templates = {
+ -- This code is used by us to report a client-side or connection error.
+ -- Instead of using the code, use the supplied body text to get one of
+ -- the more detailed errors below.
+ [0] = {
+ code = 0, type = "cancel", condition = "internal-server-error";
+ text = "Connection or internal error";
+ };
+
+ -- These are net.http built-in errors, they are returned in
+ -- the body parameter when code == 0
+ ["cancelled"] = {
+ code = 0, type = "cancel", condition = "remote-server-timeout";
+ text = "Request cancelled";
+ };
+ ["connection-closed"] = {
+ code = 0, type = "wait", condition = "remote-server-timeout";
+ text = "Connection closed";
+ };
+ ["certificate-chain-invalid"] = {
+ code = 0, type = "cancel", condition = "remote-server-timeout";
+ text = "Server certificate not trusted";
+ };
+ ["certificate-verify-failed"] = {
+ code = 0, type = "cancel", condition = "remote-server-timeout";
+ text = "Server certificate invalid";
+ };
+ ["connection failed"] = {
+ code = 0, type = "cancel", condition = "remote-server-not-found";
+ text = "Connection failed";
+ };
+ ["invalid-url"] = {
+ code = 0, type = "modify", condition = "bad-request";
+ text = "Invalid URL";
+ };
+ ["unable to resolve service"] = {
+ code = 0, type = "cancel", condition = "remote-server-not-found";
+ text = "DNS resolution failed";
+ };
+
+ -- This doesn't attempt to map every single HTTP code (not all have sane mappings),
+ -- but all the common ones should be covered. XEP-0086 was used as reference for
+ -- most of these.
+ [400] = { type = "modify", condition = "bad-request" };
+ [401] = { type = "auth", condition = "not-authorized" };
+ [402] = { type = "auth", condition = "payment-required" };
+ [403] = { type = "auth", condition = "forbidden" };
+ [404] = { type = "cancel", condition = "item-not-found" };
+ [405] = { type = "cancel", condition = "not-allowed" };
+ [406] = { type = "modify", condition = "not-acceptable" };
+ [407] = { type = "auth", condition = "registration-required" };
+ [408] = { type = "wait", condition = "remote-server-timeout" };
+ [409] = { type = "cancel", condition = "conflict" };
+ [410] = { type = "cancel", condition = "gone" };
+ [411] = { type = "modify", condition = "bad-request" };
+ [412] = { type = "cancel", condition = "conflict" };
+ [413] = { type = "modify", condition = "resource-constraint" };
+ [414] = { type = "modify", condition = "resource-constraint" };
+ [415] = { type = "cancel", condition = "feature-not-implemented" };
+ [416] = { type = "modify", condition = "bad-request" };
+
+ [422] = { type = "modify", condition = "bad-request" };
+ [423] = { type = "wait", condition = "resource-constraint" };
+
+ [429] = { type = "wait", condition = "resource-constraint" };
+ [431] = { type = "modify", condition = "resource-constraint" };
+ [451] = { type = "auth", condition = "forbidden" };
+
+ [500] = { type = "wait", condition = "internal-server-error" };
+ [501] = { type = "cancel", condition = "feature-not-implemented" };
+ [502] = { type = "wait", condition = "remote-server-timeout" };
+ [503] = { type = "cancel", condition = "service-unavailable" };
+ [504] = { type = "wait", condition = "remote-server-timeout" };
+ [507] = { type = "wait", condition = "resource-constraint" };
+ [511] = { type = "auth", condition = "not-authorized" };
+};
+
+for k, v in pairs(codes) do
+ if error_templates[k] then
+ error_templates[k].code = k;
+ error_templates[k].text = v;
+ else
+ error_templates[k] = { type = "cancel", condition = "undefined-condition", text = v, code = k };
+ end
+end
+
+setmetatable(error_templates, {
+ __index = function(_, k)
+ if type(k) ~= "number" then
+ return nil;
+ end
+ return {
+ type = "cancel";
+ condition = "undefined-condition";
+ text = codes[k] or (k.." Unassigned");
+ code = k;
+ };
+ end
+});
+
+local function new(code, body, context)
+ if code == 0 then
+ return util_error.new(body, context, error_templates);
+ else
+ return util_error.new(code, context, error_templates);
+ end
+end
+
+return {
+ registry = error_templates;
+ new = new;
+};
diff --git a/net/http/files.lua b/net/http/files.lua
new file mode 100644
index 00000000..583f7514
--- /dev/null
+++ b/net/http/files.lua
@@ -0,0 +1,149 @@
+-- Prosody IM
+-- Copyright (C) 2008-2010 Matthew Wild
+-- Copyright (C) 2008-2010 Waqas Hussain
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+local server = require"net.http.server";
+local lfs = require "lfs";
+local new_cache = require "util.cache".new;
+local log = require "util.logger".init("net.http.files");
+
+local os_date = os.date;
+local open = io.open;
+local stat = lfs.attributes;
+local build_path = require"socket.url".build_path;
+local path_sep = package.config:sub(1,1);
+
+
+local forbidden_chars_pattern = "[/%z]";
+if package.config:sub(1,1) == "\\" then
+ forbidden_chars_pattern = "[/%z\001-\031\127\"*:<>?|]"
+end
+
+local urldecode = require "util.http".urldecode;
+local function sanitize_path(path) --> util.paths or util.http?
+ if not path then return end
+ local out = {};
+
+ local c = 0;
+ for component in path:gmatch("([^/]+)") do
+ component = urldecode(component);
+ if component:find(forbidden_chars_pattern) then
+ return nil;
+ elseif component == ".." then
+ if c <= 0 then
+ return nil;
+ end
+ out[c] = nil;
+ c = c - 1;
+ elseif component ~= "." then
+ c = c + 1;
+ out[c] = component;
+ end
+ end
+ if path:sub(-1,-1) == "/" then
+ out[c+1] = "";
+ end
+ return "/"..table.concat(out, "/");
+end
+
+local function serve(opts)
+ if type(opts) ~= "table" then -- assume path string
+ opts = { path = opts };
+ end
+ local mime_map = opts.mime_map or { html = "text/html" };
+ local cache = new_cache(opts.cache_size or 256);
+ local cache_max_file_size = tonumber(opts.cache_max_file_size) or 1024
+ -- luacheck: ignore 431
+ local base_path = opts.path;
+ local dir_indices = opts.index_files or { "index.html", "index.htm" };
+ local directory_index = opts.directory_index;
+ local function serve_file(event, path)
+ local request, response = event.request, event.response;
+ local sanitized_path = sanitize_path(path);
+ if path and not sanitized_path then
+ return 400;
+ end
+ path = sanitized_path;
+ local orig_path = sanitize_path(request.path);
+ local full_path = base_path .. (path or ""):gsub("/", path_sep);
+ local attr = stat(full_path:match("^.*[^\\/]")); -- Strip trailing path separator because Windows
+ if not attr then
+ return 404;
+ end
+
+ local request_headers, response_headers = request.headers, response.headers;
+
+ local last_modified = os_date('!%a, %d %b %Y %H:%M:%S GMT', attr.modification);
+ response_headers.last_modified = last_modified;
+
+ local etag = ('"%x-%x-%x"'):format(attr.change or 0, attr.size or 0, attr.modification or 0);
+ response_headers.etag = etag;
+
+ local if_none_match = request_headers.if_none_match
+ local if_modified_since = request_headers.if_modified_since;
+ if etag == if_none_match
+ or (not if_none_match and last_modified == if_modified_since) then
+ return 304;
+ end
+
+ local data = cache:get(orig_path);
+ if data and data.etag == etag then
+ response_headers.content_type = data.content_type;
+ data = data.data;
+ cache:set(orig_path, data);
+ elseif attr.mode == "directory" and path then
+ if full_path:sub(-1) ~= "/" then
+ local dir_path = { is_absolute = true, is_directory = true };
+ for dir in orig_path:gmatch("[^/]+") do dir_path[#dir_path+1]=dir; end
+ response_headers.location = build_path(dir_path);
+ return 301;
+ end
+ for i=1,#dir_indices do
+ if stat(full_path..dir_indices[i], "mode") == "file" then
+ return serve_file(event, path..dir_indices[i]);
+ end
+ end
+
+ if directory_index then
+ data = server._events.fire_event("directory-index", { path = request.path, full_path = full_path });
+ end
+ if not data then
+ return 403;
+ end
+ cache:set(orig_path, { data = data, content_type = mime_map.html; etag = etag; });
+ response_headers.content_type = mime_map.html;
+
+ else
+ local f, err = open(full_path, "rb");
+ if not f then
+ log("debug", "Could not open %s. Error was %s", full_path, err);
+ return 403;
+ end
+ local ext = full_path:match("%.([^./]+)$");
+ local content_type = ext and mime_map[ext];
+ response_headers.content_type = content_type;
+ if attr.size > cache_max_file_size then
+ response_headers.content_length = ("%d"):format(attr.size);
+ log("debug", "%d > cache_max_file_size", attr.size);
+ return response:send_file(f);
+ else
+ data = f:read("*a");
+ f:close();
+ end
+ cache:set(orig_path, { data = data; content_type = content_type; etag = etag });
+ end
+
+ return response:send(data);
+ end
+
+ return serve_file;
+end
+
+return {
+ serve = serve;
+}
+
diff --git a/net/http/parser.lua b/net/http/parser.lua
index 4e4ae9fb..96f17fdb 100644
--- a/net/http/parser.lua
+++ b/net/http/parser.lua
@@ -1,8 +1,8 @@
local tonumber = tonumber;
local assert = assert;
-local t_insert, t_concat = table.insert, table.concat;
local url_parse = require "socket.url".parse;
local urldecode = require "util.http".urldecode;
+local dbuffer = require "util.dbuffer";
local function preprocess_path(path)
path = urldecode((path:gsub("//+", "/")));
@@ -28,10 +28,13 @@ local httpstream = {};
function httpstream.new(success_cb, error_cb, parser_type, options_cb)
local client = true;
if not parser_type or parser_type == "server" then client = false; else assert(parser_type == "client", "Invalid parser type"); end
- local buf, buflen, buftable = {}, 0, true;
local bodylimit = tonumber(options_cb and options_cb().body_size_limit) or 10*1024*1024;
+ -- https://stackoverflow.com/a/686243
+ -- Indiviual headers can be up to 16k? What madness?
+ local headlimit = tonumber(options_cb and options_cb().head_size_limit) or 10*1024;
local buflimit = tonumber(options_cb and options_cb().buffer_size_limit) or bodylimit * 2;
- local chunked, chunk_size, chunk_start;
+ local buffer = dbuffer.new(buflimit);
+ local chunked;
local state = nil;
local packet;
local len;
@@ -41,32 +44,27 @@ function httpstream.new(success_cb, error_cb, parser_type, options_cb)
feed = function(_, data)
if error then return nil, "parse has failed"; end
if not data then -- EOF
- if buftable then buf, buftable = t_concat(buf), false; end
if state and client and not len then -- reading client body until EOF
- packet.body = buf;
+ buffer:collapse();
+ packet.body = buffer:read_chunk() or "";
+ packet.partial = nil;
success_cb(packet);
- elseif buf ~= "" then -- unexpected EOF
+ state = nil;
+ elseif buffer:length() ~= 0 then -- unexpected EOF
error = true; return error_cb("unexpected-eof");
end
return;
end
- if buftable then
- t_insert(buf, data);
- else
- buf = { buf, data };
- buftable = true;
- end
- buflen = buflen + #data;
- if buflen > buflimit then error = true; return error_cb("max-buffer-size-exceeded"); end
- while buflen > 0 do
+ if not buffer:write(data) then error = true; return error_cb("max-buffer-size-exceeded"); end
+ while buffer:length() > 0 do
if state == nil then -- read request
- if buftable then buf, buftable = t_concat(buf), false; end
- local index = buf:find("\r\n\r\n", nil, true);
+ local index = buffer:sub(1, headlimit):find("\r\n\r\n", nil, true);
if not index then return; end -- not enough data
- local method, path, httpversion, status_code, reason_phrase;
+ -- FIXME was reason_phrase meant to be passed on somewhere?
+ local method, path, httpversion, status_code, reason_phrase; -- luacheck: ignore reason_phrase
local first_line;
local headers = {};
- for line in buf:sub(1,index+1):gmatch("([^\r\n]+)\r\n") do -- parse request
+ for line in buffer:read(index+3):gmatch("([^\r\n]+)\r\n") do -- parse request
if first_line then
local key, val = line:match("^([^%s:]+): *(.*)$");
if not key then error = true; return error_cb("invalid-header-line"); end -- TODO handle multi-line and invalid headers
@@ -91,7 +89,6 @@ function httpstream.new(success_cb, error_cb, parser_type, options_cb)
if not first_line then error = true; return error_cb("invalid-status-line"); end
chunked = have_body and headers["transfer-encoding"] == "chunked";
len = tonumber(headers["content-length"]); -- TODO check for invalid len
- if len and len > bodylimit then error = true; return error_cb("content-length-limit-exceeded"); end
if client then
-- FIXME handle '100 Continue' response (by skipping it)
if not have_body then len = 0; end
@@ -99,7 +96,10 @@ function httpstream.new(success_cb, error_cb, parser_type, options_cb)
code = status_code;
httpversion = httpversion;
headers = headers;
- body = have_body and "" or nil;
+ body = false;
+ body_length = len;
+ chunked = chunked;
+ partial = true;
-- COMPAT the properties below are deprecated
responseversion = httpversion;
responseheaders = headers;
@@ -124,60 +124,81 @@ function httpstream.new(success_cb, error_cb, parser_type, options_cb)
path = path;
httpversion = httpversion;
headers = headers;
- body = nil;
+ body = false;
+ body_sink = nil;
+ chunked = chunked;
+ partial = true;
};
end
- buf = buf:sub(index + 4);
- buflen = #buf;
+ if len and len > bodylimit then
+ -- Early notification, for redirection
+ success_cb(packet);
+ if not packet.body_sink then error = true; return error_cb("content-length-limit-exceeded"); end
+ end
+ if chunked and not packet.body_sink then
+ success_cb(packet);
+ if not packet.body_sink then
+ packet.body_buffer = dbuffer.new(buflimit);
+ end
+ end
state = true;
end
if state then -- read body
- if client then
- if chunked then
- if chunk_start and buflen - chunk_start - 2 < chunk_size then
- return;
- end -- not enough data
- if buftable then buf, buftable = t_concat(buf), false; end
- if not buf:find("\r\n", nil, true) then
- return;
- end -- not enough data
- if not chunk_size then
- chunk_size, chunk_start = buf:match("^(%x+)[^\r\n]*\r\n()");
- chunk_size = chunk_size and tonumber(chunk_size, 16);
- if not chunk_size then error = true; return error_cb("invalid-chunk-size"); end
- end
- if chunk_size == 0 and buf:find("\r\n\r\n", chunk_start-2, true) then
- state, chunk_size = nil, nil;
- buf = buf:gsub("^.-\r\n\r\n", ""); -- This ensure extensions and trailers are stripped
- success_cb(packet);
- elseif buflen - chunk_start - 2 >= chunk_size then -- we have a chunk
- packet.body = packet.body..buf:sub(chunk_start, chunk_start + (chunk_size-1));
- buf = buf:sub(chunk_start + chunk_size + 2);
- buflen = buflen - (chunk_start + chunk_size + 2 - 1);
- chunk_size, chunk_start = nil, nil;
- else -- Partial chunk remaining
- break;
+ if chunked then
+ local chunk_header = buffer:sub(1, 512); -- XXX How large do chunk headers grow?
+ local chunk_size, chunk_start = chunk_header:match("^(%x+)[^\r\n]*\r\n()");
+ if not chunk_size then return; end
+ chunk_size = chunk_size and tonumber(chunk_size, 16);
+ if not chunk_size then error = true; return error_cb("invalid-chunk-size"); end
+ if chunk_size == 0 and chunk_header:find("\r\n\r\n", chunk_start-2, true) then
+ local body_buffer = packet.body_buffer;
+ if body_buffer then
+ packet.body_buffer = nil;
+ body_buffer:collapse();
+ packet.body = body_buffer:read_chunk() or "";
end
- elseif len and buflen >= len then
- if buftable then buf, buftable = t_concat(buf), false; end
- if packet.code == 101 then
- packet.body, buf, buflen, buftable = buf, {}, 0, true;
+
+ buffer:collapse();
+ local buf = buffer:read_chunk();
+ buf = buf:gsub("^.-\r\n\r\n", ""); -- This ensure extensions and trailers are stripped
+ buffer:write(buf);
+ state, chunked = nil, nil;
+ packet.partial = nil;
+ success_cb(packet);
+ elseif buffer:length() - chunk_start - 2 >= chunk_size then -- we have a chunk
+ buffer:discard(chunk_start - 1); -- TODO verify that it's not off-by-one
+ (packet.body_sink or packet.body_buffer):write(buffer:read(chunk_size));
+ buffer:discard(2); -- CRLF
+ else -- Partial chunk remaining
+ break;
+ end
+ elseif packet.body_sink then
+ local chunk = buffer:read_chunk(len);
+ while chunk and len > 0 do
+ if packet.body_sink:write(chunk) then
+ len = len - #chunk;
+ chunk = buffer:read_chunk(len);
else
- packet.body, buf = buf:sub(1, len), buf:sub(len + 1);
- buflen = #buf;
+ error = true;
+ return error_cb("body-sink-write-failure");
end
- state = nil; success_cb(packet);
- else
- break;
end
- elseif buflen >= len then
- if buftable then buf, buftable = t_concat(buf), false; end
- packet.body, buf = buf:sub(1, len), buf:sub(len + 1);
- buflen = #buf;
- state = nil; success_cb(packet);
+ if len == 0 then
+ state = nil;
+ packet.partial = nil;
+ success_cb(packet);
+ end
+ elseif buffer:length() >= len then
+ assert(not chunked)
+ packet.body = buffer:read(len) or "";
+ state = nil;
+ packet.partial = nil;
+ success_cb(packet);
else
break;
end
+ else
+ break;
end
end
end;
diff --git a/net/http/server.lua b/net/http/server.lua
index 3873bbe0..97e15e42 100644
--- a/net/http/server.lua
+++ b/net/http/server.lua
@@ -1,5 +1,5 @@
-local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat;
+local t_insert, t_concat = table.insert, table.concat;
local parser_new = require "net.http.parser".new;
local events = require "util.events".new();
local addserver = require "net.server".addserver;
@@ -8,12 +8,12 @@ local os_date = os.date;
local pairs = pairs;
local s_upper = string.upper;
local setmetatable = setmetatable;
-local xpcall = require "util.xpcall".xpcall;
-local traceback = debug.traceback;
-local tostring = tostring;
local cache = require "util.cache";
local codes = require "net.http.codes";
+local promise = require "util.promise";
+local errors = require "util.error";
local blocksize = 2^16;
+local async = require "util.async";
local _M = {};
@@ -89,51 +89,60 @@ setmetatable(events._handlers, {
local handle_request;
-local last_err;
-local function _traceback_handler(err) last_err = err; log("error", "Traceback[httpserver]: %s", traceback(tostring(err), 2)); end
events.add_handler("http-error", function (error)
return "Error processing request: "..codes[error.code]..". Check your error log for more information.";
end, -1);
+local runner_callbacks = {};
+
+function runner_callbacks:ready()
+ self.data.conn:resume();
+end
+
+function runner_callbacks:waiting()
+ self.data.conn:pause();
+end
+
+function runner_callbacks:error(err)
+ log("error", "Traceback[httpserver]: %s", err);
+ self.data.conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = err }));
+ self.data.conn:close();
+end
+
+local function noop() end
function listener.onconnect(conn)
+ local session = { conn = conn };
local secure = conn:ssl() and true or nil;
- local pending = {};
- local waiting = false;
- local function process_next()
- if waiting then return; end -- log("debug", "can't process_next, waiting");
- waiting = true;
- while sessions[conn] and #pending > 0 do
- local request = t_remove(pending);
- --log("debug", "process_next: %s", request.path);
- if not xpcall(handle_request, _traceback_handler, conn, request, process_next) then
- conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = last_err }));
- conn:close();
- end
+ local ip = conn:ip();
+ session.thread = async.runner(function (request)
+ local wait, done;
+ if request.partial == true then
+ -- Have the header for a request, we want to receive the rest
+ -- when we've decided where the data should go.
+ wait, done = noop, noop;
+ else -- Got the entire request
+ -- Hold off on receiving more incoming requests until this one has been handled.
+ wait, done = async.waiter();
end
- --log("debug", "ready for more");
- waiting = false;
- end
+ handle_request(conn, request, done); wait();
+ end, runner_callbacks, session);
local function success_cb(request)
--log("debug", "success_cb: %s", request.path);
- if waiting then
- log("error", "http connection handler is not reentrant: %s", request.path);
- assert(false, "http connection handler is not reentrant");
- end
+ request.ip = ip;
request.secure = secure;
- t_insert(pending, request);
- process_next();
+ session.thread:run(request);
end
local function error_cb(err)
log("debug", "error_cb: %s", err or "<nil>");
-- FIXME don't close immediately, wait until we process current stuff
-- FIXME if err, send off a bad-request response
- sessions[conn] = nil;
conn:close();
end
local function options_cb()
return options;
end
- sessions[conn] = parser_new(success_cb, error_cb, "server", options_cb);
+ session.parser = parser_new(success_cb, error_cb, "server", options_cb);
+ sessions[conn] = session;
end
function listener.ondisconnect(conn)
@@ -152,7 +161,7 @@ function listener.ondetach(conn)
end
function listener.onincoming(conn, data)
- sessions[conn]:feed(data);
+ sessions[conn].parser:feed(data);
end
function listener.ondrain(conn)
@@ -170,6 +179,49 @@ local headerfix = setmetatable({}, {
end
});
+local function handle_result(request, response, result)
+ if result == nil then
+ result = 404;
+ end
+
+ if result == true then
+ return;
+ end
+
+ local body;
+ local result_type = type(result);
+ if result_type == "number" then
+ response.status_code = result;
+ if result >= 400 then
+ body = events.fire_event("http-error", { request = request, response = response, code = result });
+ end
+ elseif result_type == "string" then
+ body = result;
+ elseif errors.is_err(result) then
+ response.status_code = result.code or 500;
+ body = events.fire_event("http-error", { request = request, response = response, code = result.code or 500, error = result });
+ elseif promise.is_promise(result) then
+ result:next(function (ret)
+ handle_result(request, response, ret);
+ end, function (err)
+ response.status_code = 500;
+ handle_result(request, response, err or 500);
+ end);
+ return true;
+ elseif result_type == "table" then
+ for k, v in pairs(result) do
+ if k ~= "headers" then
+ response[k] = v;
+ else
+ for header_name, header_value in pairs(v) do
+ response.headers[header_name] = header_value;
+ end
+ end
+ end
+ end
+ return response:send(body);
+end
+
function _M.hijack_response(response, listener) -- luacheck: ignore
error("TODO");
end
@@ -194,13 +246,17 @@ function handle_request(conn, request, finish_cb)
response_conn_header = httpversion == "1.1" and "close" or nil
end
+ local is_head_request = request.method == "HEAD";
+
local response = {
request = request;
+ is_head_request = is_head_request;
status_code = 200;
headers = { date = date_header, connection = response_conn_header };
persistent = persistent;
conn = conn;
send = _M.send_response;
+ write_headers = _M.write_headers;
send_file = _M.send_file;
done = _M.finish_response;
finish_cb = finish_cb;
@@ -227,6 +283,11 @@ function handle_request(conn, request, finish_cb)
local payload = { request = request, response = response };
log("debug", "Firing event: %s", global_event);
local result = events.fire_event(global_event, payload);
+ if result == nil and is_head_request then
+ local global_head_event = "GET "..request.path:match("[^?]*");
+ log("debug", "Firing event: %s", global_head_event);
+ result = events.fire_event(global_head_event, payload);
+ end
if result == nil then
if not hosts[host] then
if hosts[default_host] then
@@ -247,40 +308,17 @@ function handle_request(conn, request, finish_cb)
local host_event = request.method.." "..host..request.path:match("[^?]*");
log("debug", "Firing event: %s", host_event);
result = events.fire_event(host_event, payload);
- end
- if result ~= nil then
- if result ~= true then
- local body;
- local result_type = type(result);
- if result_type == "number" then
- response.status_code = result;
- if result >= 400 then
- payload.code = result;
- body = events.fire_event("http-error", payload);
- end
- elseif result_type == "string" then
- body = result;
- elseif result_type == "table" then
- for k, v in pairs(result) do
- if k ~= "headers" then
- response[k] = v;
- else
- for header_name, header_value in pairs(v) do
- response.headers[header_name] = header_value;
- end
- end
- end
- end
- response:send(body);
+
+ if result == nil and is_head_request then
+ local host_head_event = "GET "..host..request.path:match("[^?]*");
+ log("debug", "Firing event: %s", host_head_event);
+ result = events.fire_event(host_head_event, payload);
end
- return;
end
- -- if handler not called, return 404
- response.status_code = 404;
- payload.code = 404;
- response:send(events.fire_event("http-error", payload));
+ return handle_result(request, response, result);
end
+
local function prepare_header(response)
local status_line = "HTTP/"..response.request.httpversion.." "..(response.status or codes[response.status_code]);
local headers = response.headers;
@@ -292,12 +330,25 @@ local function prepare_header(response)
return output;
end
_M.prepare_header = prepare_header;
+function _M.write_headers(response)
+ if response.finished then return; end
+ local output = prepare_header(response);
+ response.conn:write(t_concat(output));
+end
+function _M.send_head_response(response)
+ if response.finished then return; end
+ _M.write_headers(response);
+ response:done();
+end
function _M.send_response(response, body)
if response.finished then return; end
body = body or response.body or "";
-- Per RFC 7230, informational (1xx) and 204 (no content) should have no c-l header
if response.status_code > 199 and response.status_code ~= 204 then
- response.headers.content_length = #body;
+ response.headers.content_length = ("%d"):format(#body);
+ end
+ if response.is_head_request then
+ return _M.send_head_response(response)
end
local output = prepare_header(response);
t_insert(output, body);
@@ -305,6 +356,10 @@ function _M.send_response(response, body)
response:done();
end
function _M.send_file(response, f)
+ if response.is_head_request then
+ if f.close then f:close(); end
+ return _M.send_head_response(response);
+ end
if response.finished then return; end
local chunked = not response.headers.content_length;
if chunked then response.headers.transfer_encoding = "chunked"; end
@@ -331,7 +386,7 @@ function _M.send_file(response, f)
return response:done();
end
end
- response.conn:write(t_concat(prepare_header(response)));
+ _M.write_headers(response);
return true;
end
function _M.finish_response(response)