From 8a62a14e5da8c27202cc11a37243569ab8962f76 Mon Sep 17 00:00:00 2001
From: Kim Alvefur <zash@zash.se>
Date: Thu, 11 Oct 2018 15:48:30 +0200
Subject: net.server: Require IP address as argument to addclient (no DNS
 names)

The net.connect API should be used to resolve DNS names first
---
 net/server_epoll.lua  | 19 +++++++++++++++----
 net/server_event.lua  | 17 +++++++++--------
 net/server_select.lua | 17 ++++++++++-------
 3 files changed, 34 insertions(+), 19 deletions(-)

(limited to 'net')

diff --git a/net/server_epoll.lua b/net/server_epoll.lua
index 7a860878..b5053d22 100644
--- a/net/server_epoll.lua
+++ b/net/server_epoll.lua
@@ -21,6 +21,8 @@ local socket = require "socket";
 local luasec = require "ssl";
 local gettime = require "util.time".now;
 local createtable = require "util.table".create;
+local inet = require "util.net";
+local inet_pton = inet.pton;
 local _SOCKETINVALID = socket._SOCKETINVALID or -1;
 
 local poll = require "util.poll".new();
@@ -614,7 +616,8 @@ local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx)
 	if not client.peername then
 		client.peername, client.peerport = addr, port;
 	end
-	client:init();
+	local ok, err = client:init();
+	if not ok then return ok, err; end
 	if tls_ctx then
 		client:starttls(tls_ctx);
 	end
@@ -623,12 +626,20 @@ end
 
 -- New outgoing TCP connection
 local function addclient(addr, port, listeners, read_size, tls_ctx)
-	local conn, err = socket.tcp();
-	if not conn then return conn, err; end
+	local n = inet_pton(addr);
+	if not n then return nil, "invalid-ip"; end
+	local create
+	if #n == 16 then
+		create = socket.tcp6 or socket.tcp;
+	else
+		create = socket.tcp4 or socket.tcp;
+	end
+	local conn, err = create();
 	conn:settimeout(0);
 	conn:connect(addr, port);
 	local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx)
-	client:init();
+	local ok, err = client:init();
+	if not ok then return ok, err; end
 	if tls_ctx then
 		client:starttls(tls_ctx);
 	end
diff --git a/net/server_event.lua b/net/server_event.lua
index 9fde558e..5a4dc1a4 100644
--- a/net/server_event.lua
+++ b/net/server_event.lua
@@ -50,9 +50,10 @@ local coroutine_yield = coroutine.yield
 local has_luasec, ssl = pcall ( require , "ssl" )
 local socket = require "socket"
 local levent = require "luaevent.core"
+local inet = require "util.net";
+local inet_pton = inet.pton;
 
 local socket_gettime = socket.gettime
-local getaddrinfo = socket.dns.getaddrinfo
 
 local log = require ("util.logger").init("socket")
 
@@ -728,15 +729,15 @@ local function addclient( addr, serverport, listener, pattern, sslctx, typ )
 		return nil, "luasec not found"
 	end
 	if not typ then
-		local addrinfo, err = getaddrinfo(addr)
-		if not addrinfo then return nil, err end
-		if addrinfo[1] and addrinfo[1].family == "inet6" then
-			typ = "tcp6"
-		else
-			typ = "tcp"
+		local n = inet_pton(addr);
+		if not n then return nil, "invalid-ip"; end
+		if #n == 16 then
+			typ = "tcp6";
+		elseif #n == 4 then
+			typ = "tcp4";
 		end
 	end
-	local create = socket[typ]
+	local create = socket[typ] or socket.tcp;
 	if type( create ) ~= "function"  then
 		return nil, "invalid socket type"
 	end
diff --git a/net/server_select.lua b/net/server_select.lua
index cd97f0a6..c5f9004c 100644
--- a/net/server_select.lua
+++ b/net/server_select.lua
@@ -50,7 +50,8 @@ local coroutine_yield = coroutine.yield
 local has_luasec, luasec = pcall ( require , "ssl" )
 local luasocket = use "socket" or require "socket"
 local luasocket_gettime = luasocket.gettime
-local getaddrinfo = luasocket.dns.getaddrinfo
+local inet = require "util.net";
+local inet_pton = inet.pton;
 
 --// extern lib methods //--
 
@@ -1007,14 +1008,16 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ )
 	elseif sslctx and not has_luasec then
 		err = "luasec not found"
 	end
-	if getaddrinfo and not typ then
-		local addrinfo, err = getaddrinfo(address)
-		if not addrinfo then return nil, err end
-		if addrinfo[1] and addrinfo[1].family == "inet6" then
-			typ = "tcp6"
+	if not typ then
+		local n = inet_pton(addr);
+		if not n then return nil, "invalid-ip"; end
+		if #n == 16 then
+			typ = "tcp6";
+		elseif #n == 4 then
+			typ = "tcp4";
 		end
 	end
-	local create = luasocket[typ or "tcp"]
+	local create = luasocket[typ] or luasocket.tcp;
 	if type( create ) ~= "function"  then
 		err = "invalid socket type"
 	end
-- 
cgit v1.2.3