aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
authorKim Alvefur <zash@zash.se>2018-10-11 15:48:30 +0200
committerKim Alvefur <zash@zash.se>2018-10-11 15:48:30 +0200
commit7953747a89caaf00c4587a3d238329a2ce992390 (patch)
tree1f23a6726a51d1b6c2dc854dd34812f3e6ed656c /net
parent9a1027d82045261d48515b8aa5de606cb02b49b8 (diff)
downloadprosody-7953747a89caaf00c4587a3d238329a2ce992390.tar.gz
prosody-7953747a89caaf00c4587a3d238329a2ce992390.zip
net.server: Require IP address as argument to addclient (no DNS names)
The net.connect API should be used to resolve DNS names first
Diffstat (limited to 'net')
-rw-r--r--net/server_epoll.lua19
-rw-r--r--net/server_event.lua17
-rw-r--r--net/server_select.lua17
3 files changed, 34 insertions, 19 deletions
diff --git a/net/server_epoll.lua b/net/server_epoll.lua
index 7a860878..b5053d22 100644
--- a/net/server_epoll.lua
+++ b/net/server_epoll.lua
@@ -21,6 +21,8 @@ local socket = require "socket";
local luasec = require "ssl";
local gettime = require "util.time".now;
local createtable = require "util.table".create;
+local inet = require "util.net";
+local inet_pton = inet.pton;
local _SOCKETINVALID = socket._SOCKETINVALID or -1;
local poll = require "util.poll".new();
@@ -614,7 +616,8 @@ local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx)
if not client.peername then
client.peername, client.peerport = addr, port;
end
- client:init();
+ local ok, err = client:init();
+ if not ok then return ok, err; end
if tls_ctx then
client:starttls(tls_ctx);
end
@@ -623,12 +626,20 @@ end
-- New outgoing TCP connection
local function addclient(addr, port, listeners, read_size, tls_ctx)
- local conn, err = socket.tcp();
- if not conn then return conn, err; end
+ local n = inet_pton(addr);
+ if not n then return nil, "invalid-ip"; end
+ local create
+ if #n == 16 then
+ create = socket.tcp6 or socket.tcp;
+ else
+ create = socket.tcp4 or socket.tcp;
+ end
+ local conn, err = create();
conn:settimeout(0);
conn:connect(addr, port);
local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx)
- client:init();
+ local ok, err = client:init();
+ if not ok then return ok, err; end
if tls_ctx then
client:starttls(tls_ctx);
end
diff --git a/net/server_event.lua b/net/server_event.lua
index 9fde558e..5a4dc1a4 100644
--- a/net/server_event.lua
+++ b/net/server_event.lua
@@ -50,9 +50,10 @@ local coroutine_yield = coroutine.yield
local has_luasec, ssl = pcall ( require , "ssl" )
local socket = require "socket"
local levent = require "luaevent.core"
+local inet = require "util.net";
+local inet_pton = inet.pton;
local socket_gettime = socket.gettime
-local getaddrinfo = socket.dns.getaddrinfo
local log = require ("util.logger").init("socket")
@@ -728,15 +729,15 @@ local function addclient( addr, serverport, listener, pattern, sslctx, typ )
return nil, "luasec not found"
end
if not typ then
- local addrinfo, err = getaddrinfo(addr)
- if not addrinfo then return nil, err end
- if addrinfo[1] and addrinfo[1].family == "inet6" then
- typ = "tcp6"
- else
- typ = "tcp"
+ local n = inet_pton(addr);
+ if not n then return nil, "invalid-ip"; end
+ if #n == 16 then
+ typ = "tcp6";
+ elseif #n == 4 then
+ typ = "tcp4";
end
end
- local create = socket[typ]
+ local create = socket[typ] or socket.tcp;
if type( create ) ~= "function" then
return nil, "invalid socket type"
end
diff --git a/net/server_select.lua b/net/server_select.lua
index cd97f0a6..c5f9004c 100644
--- a/net/server_select.lua
+++ b/net/server_select.lua
@@ -50,7 +50,8 @@ local coroutine_yield = coroutine.yield
local has_luasec, luasec = pcall ( require , "ssl" )
local luasocket = use "socket" or require "socket"
local luasocket_gettime = luasocket.gettime
-local getaddrinfo = luasocket.dns.getaddrinfo
+local inet = require "util.net";
+local inet_pton = inet.pton;
--// extern lib methods //--
@@ -1007,14 +1008,16 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ )
elseif sslctx and not has_luasec then
err = "luasec not found"
end
- if getaddrinfo and not typ then
- local addrinfo, err = getaddrinfo(address)
- if not addrinfo then return nil, err end
- if addrinfo[1] and addrinfo[1].family == "inet6" then
- typ = "tcp6"
+ if not typ then
+ local n = inet_pton(addr);
+ if not n then return nil, "invalid-ip"; end
+ if #n == 16 then
+ typ = "tcp6";
+ elseif #n == 4 then
+ typ = "tcp4";
end
end
- local create = luasocket[typ or "tcp"]
+ local create = luasocket[typ] or luasocket.tcp;
if type( create ) ~= "function" then
err = "invalid socket type"
end