From 4c9f43a09efbe6387c6efa1f7742c61cedc67945 Mon Sep 17 00:00:00 2001
From: Kim Alvefur <zash@zash.se>
Date: Mon, 17 Aug 2020 23:01:14 +0200
Subject: net.server: Backport client parts of SNI support from trunk (#409)

Partial backports of the following commits from trunk:

6c804b6b2ca2 net.http: Pass server name along for SNI (fixes #1408)
75d2874502c3 net.server_select: SNI support (#409)
9a905888b96c net.server_event: Add SNI support (#409)
adc0672b700e net.server_epoll: Add support for SNI (#409)
d4390c427a66 net.server: Handle server name (SNI) as extra argument
---
 net/http.lua          |  2 +-
 net/server_epoll.lua  | 20 +++++++++++++++-----
 net/server_event.lua  | 17 ++++++++++++-----
 net/server_select.lua | 19 ++++++++++++++-----
 4 files changed, 42 insertions(+), 16 deletions(-)

diff --git a/net/http.lua b/net/http.lua
index ae9d2974..0768cdab 100644
--- a/net/http.lua
+++ b/net/http.lua
@@ -272,7 +272,7 @@ local function request(self, u, ex, callback)
 		sslctx = ex and ex.sslctx or self.options and self.options.sslctx;
 	end
 
-	local http_service = basic_resolver.new(host, port_number);
+	local http_service = basic_resolver.new(host, port_number, "tcp", { servername = req.host });
 	connect(http_service, listener, { sslctx = sslctx }, req);
 
 	self.events.fire_event("request", { http = self, request = req, url = u });
diff --git a/net/server_epoll.lua b/net/server_epoll.lua
index 2182d56a..953bbb11 100644
--- a/net/server_epoll.lua
+++ b/net/server_epoll.lua
@@ -483,6 +483,9 @@ function interface:tlshandskake()
 		end
 		conn:settimeout(0);
 		self.conn = conn;
+		if conn.sni and self.servername then
+			conn:sni(self.servername);
+		end
 		self:on("starttls");
 		self.ondrain = nil;
 		self.onwritable = interface.tlshandskake;
@@ -512,7 +515,7 @@ function interface:tlshandskake()
 	end
 end
 
-local function wrapsocket(client, server, read_size, listeners, tls_ctx) -- luasocket object -> interface object
+local function wrapsocket(client, server, read_size, listeners, tls_ctx, extra) -- luasocket object -> interface object
 	client:settimeout(0);
 	local conn = setmetatable({
 		conn = client;
@@ -523,8 +526,15 @@ local function wrapsocket(client, server, read_size, listeners, tls_ctx) -- luas
 		writebuffer = {};
 		tls_ctx = tls_ctx or (server and server.tls_ctx);
 		tls_direct = server and server.tls_direct;
+		extra = extra;
 	}, interface_mt);
 
+	if extra then
+		if extra.servername then
+			conn.servername = extra.servername;
+		end
+	end
+
 	conn:updatenames();
 	return conn;
 end
@@ -617,8 +627,8 @@ local function addserver(addr, port, listeners, read_size, tls_ctx)
 end
 
 -- COMPAT
-local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx)
-	local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx);
+local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx, extra)
+	local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra);
 	if not client.peername then
 		client.peername, client.peerport = addr, port;
 	end
@@ -631,7 +641,7 @@ local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx)
 end
 
 -- New outgoing TCP connection
-local function addclient(addr, port, listeners, read_size, tls_ctx, typ)
+local function addclient(addr, port, listeners, read_size, tls_ctx, typ, extra)
 	local create;
 	if not typ then
 		local n = inet_pton(addr);
@@ -653,7 +663,7 @@ local function addclient(addr, port, listeners, read_size, tls_ctx, typ)
 	if not ok then return ok, err; end
 	local ok, err = conn:setpeername(addr, port);
 	if not ok and err ~= "timeout" then return ok, err; end
-	local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx)
+	local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra)
 	local ok, err = client:init();
 	if not ok then return ok, err; end
 	if tls_ctx then
diff --git a/net/server_event.lua b/net/server_event.lua
index 11bd6a29..746526ce 100644
--- a/net/server_event.lua
+++ b/net/server_event.lua
@@ -164,6 +164,11 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed
 		debug( "fatal error while ssl wrapping:", err )
 		return false
 	end
+
+	if self.conn.sni and self.servername then
+		self.conn:sni(self.servername);
+	end
+
 	self.conn:settimeout( 0 )  -- set non blocking
 	local handshakecallback = coroutine_wrap(function( event )
 		local _, err
@@ -456,7 +461,7 @@ end
 
 -- End of client interface methods
 
-local function handleclient( client, ip, port, server, pattern, listener, sslctx )  -- creates an client interface
+local function handleclient( client, ip, port, server, pattern, listener, sslctx, extra )  -- creates an client interface
 	--vdebug("creating client interfacce...")
 	local interface = {
 		type = "client";
@@ -492,6 +497,8 @@ local function handleclient( client, ip, port, server, pattern, listener, sslctx
 		_serverport = (server and server:port() or nil),
 		_sslctx = sslctx; -- parameters
 		_usingssl = false;  -- client is using ssl;
+		extra = extra;
+		servername = extra and extra.servername;
 	}
 	if not has_luasec then interface.starttls = false; end
 	interface.id = tostring(interface):match("%x+$");
@@ -716,14 +723,14 @@ local function addserver( addr, port, listener, pattern, sslctx, startssl )  --
 	return interface
 end
 
-local function wrapclient( client, ip, port, listeners, pattern, sslctx )
-	local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx )
+local function wrapclient( client, ip, port, listeners, pattern, sslctx, extra )
+	local interface = handleclient( client, ip, port, nil, pattern, listeners, sslctx, extra )
 	interface:_start_connection(sslctx)
 	return interface, client
 	--function handleclient( client, ip, port, server, pattern, listener, _, sslctx )  -- creates an client interface
 end
 
-local function addclient( addr, serverport, listener, pattern, sslctx, typ )
+local function addclient( addr, serverport, listener, pattern, sslctx, typ, extra )
 	if sslctx and not has_luasec then
 		debug "need luasec, but not available"
 		return nil, "luasec not found"
@@ -750,7 +757,7 @@ local function addclient( addr, serverport, listener, pattern, sslctx, typ )
 	local res, err = client:setpeername( addr, serverport )  -- connect
 	if res or ( err == "timeout" ) then
 		local ip, port = client:getsockname( )
-		local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx )
+		local interface = wrapclient( client, ip, serverport, listener, pattern, sslctx, extra )
 		debug( "new connection id:", interface.id )
 		return interface, err
 	else
diff --git a/net/server_select.lua b/net/server_select.lua
index 1a40a6d3..deb8fe48 100644
--- a/net/server_select.lua
+++ b/net/server_select.lua
@@ -264,7 +264,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- t
 	return handler
 end
 
-wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object
+wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx, extra ) -- this function wraps a client to a handler object
 
 	if socket:getfd() >= _maxfd then
 		out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent
@@ -314,6 +314,11 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
 
 	local handler = bufferqueue -- saves a table ^_^
 
+	handler.extra = extra
+	if extra then
+		handler.servername = extra.servername
+	end
+
 	handler.dispatch = function( )
 		return dispatch
 	end
@@ -624,6 +629,10 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
 				return nil, err -- fatal error
 			end
 
+			if socket.sni and self.servername then
+				socket:sni(self.servername);
+			end
+
 			socket:settimeout( 0 )
 
 			-- add the new socket to our system
@@ -977,8 +986,8 @@ end
 
 --// EXPERIMENTAL //--
 
-local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx )
-	local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
+local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx, extra )
+	local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx, extra)
 	if not handler then return nil, err end
 	_socketlist[ socket ] = handler
 	if not sslctx then
@@ -997,7 +1006,7 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx
 	return handler, socket
 end
 
-local addclient = function( address, port, listeners, pattern, sslctx, typ )
+local addclient = function( address, port, listeners, pattern, sslctx, typ, extra )
 	local err
 	if type( listeners ) ~= "table" then
 		err = "invalid listener table"
@@ -1034,7 +1043,7 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ )
 	client:settimeout( 0 )
 	local ok, err = client:setpeername( address, port )
 	if ok or err == "timeout" or err == "Operation already in progress" then
-		return wrapclient( client, address, port, listeners, pattern, sslctx )
+		return wrapclient( client, address, port, listeners, pattern, sslctx, extra )
 	else
 		return nil, err
 	end
-- 
cgit v1.2.3