aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/connect.lua45
-rw-r--r--net/resolvers/basic.lua160
-rw-r--r--net/resolvers/service.lua81
-rw-r--r--net/server.lua7
-rw-r--r--net/server_epoll.lua41
-rw-r--r--net/server_event.lua34
-rw-r--r--net/server_select.lua26
-rw-r--r--net/tls_luasec.lua89
8 files changed, 364 insertions, 119 deletions
diff --git a/net/connect.lua b/net/connect.lua
index 4b602be4..d85afcff 100644
--- a/net/connect.lua
+++ b/net/connect.lua
@@ -1,6 +1,7 @@
local server = require "net.server";
local log = require "util.logger".init("net.connect");
local new_id = require "util.id".short;
+local timer = require "util.timer";
-- TODO #1246 Happy Eyeballs
-- FIXME RFC 6724
@@ -28,16 +29,17 @@ local pending_connection_listeners = {};
local function attempt_connection(p)
p:log("debug", "Checking for targets...");
- if p.conn then
- pending_connections_map[p.conn] = nil;
- p.conn = nil;
- end
- p.target_resolver:next(function (conn_type, ip, port, extra)
+ p.target_resolver:next(function (conn_type, ip, port, extra, more_targets_available)
if not conn_type then
-- No more targets to try
p:log("debug", "No more connection targets to try", p.target_resolver.last_error);
- if p.listeners.onfail then
- p.listeners.onfail(p.data, p.last_error or p.target_resolver.last_error or "unable to resolve service");
+ if next(p.conns) == nil then
+ p:log("debug", "No more targets, no pending connections. Connection failed.");
+ if p.listeners.onfail then
+ p.listeners.onfail(p.data, p.last_error or p.target_resolver.last_error or "unable to resolve service");
+ end
+ else
+ p:log("debug", "One or more connection attempts are still pending. Waiting for now.");
end
return;
end
@@ -49,8 +51,16 @@ local function attempt_connection(p)
p.last_error = err or "unknown reason";
return attempt_connection(p);
end
- p.conn = conn;
+ p.conns[conn] = true;
pending_connections_map[conn] = p;
+ if more_targets_available then
+ timer.add_task(0.250, function ()
+ if not p.connected then
+ p:log("debug", "Still not connected, making parallel connection attempt...");
+ attempt_connection(p);
+ end
+ end);
+ end
end);
end
@@ -62,6 +72,13 @@ function pending_connection_listeners.onconnect(conn)
return;
end
pending_connections_map[conn] = nil;
+ if p.connected then
+ -- We already succeeded in connecting
+ p.conns[conn] = nil;
+ conn:close();
+ return;
+ end
+ p.connected = true;
p:log("debug", "Successfully connected");
conn:setlistener(p.listeners, p.data);
return p.listeners.onconnect(conn);
@@ -73,9 +90,18 @@ function pending_connection_listeners.ondisconnect(conn, reason)
log("warn", "Failed connection, but unexpected!");
return;
end
+ p.conns[conn] = nil;
+ pending_connections_map[conn] = nil;
p.last_error = reason or "unknown reason";
p:log("debug", "Connection attempt failed: %s", p.last_error);
- attempt_connection(p);
+ if p.connected then
+ p:log("debug", "Connection already established, ignoring failure");
+ elseif next(p.conns) == nil then
+ p:log("debug", "No pending connection attempts, and not yet connected");
+ attempt_connection(p);
+ else
+ p:log("debug", "Other attempts are still pending, ignoring failure");
+ end
end
local function connect(target_resolver, listeners, options, data)
@@ -85,6 +111,7 @@ local function connect(target_resolver, listeners, options, data)
listeners = assert(listeners);
options = options or {};
data = data;
+ conns = {};
}, pending_connection_mt);
p:log("debug", "Starting connection process");
diff --git a/net/resolvers/basic.lua b/net/resolvers/basic.lua
index 305bce76..15338ff4 100644
--- a/net/resolvers/basic.lua
+++ b/net/resolvers/basic.lua
@@ -2,13 +2,59 @@ local adns = require "net.adns";
local inet_pton = require "util.net".pton;
local inet_ntop = require "util.net".ntop;
local idna_to_ascii = require "util.encodings".idna.to_ascii;
-local unpack = table.unpack or unpack; -- luacheck: ignore 113
+local promise = require "util.promise";
+local t_move = require "util.table".move;
local methods = {};
local resolver_mt = { __index = methods };
-- FIXME RFC 6724
+local function do_dns_lookup(self, dns_resolver, record_type, name)
+ return promise.new(function (resolve, reject)
+ local ipv = (record_type == "A" and "4") or (record_type == "AAAA" and "6") or nil;
+ if ipv and self.extra["use_ipv"..ipv] == false then
+ return reject(("IPv%s disabled - %s lookup skipped"):format(ipv, record_type));
+ elseif record_type == "TLSA" and self.extra.use_dane ~= true then
+ return reject("DANE disabled - TLSA lookup skipped");
+ end
+ dns_resolver:lookup(function (answer, err)
+ if not answer then
+ return reject(err);
+ elseif answer.bogus then
+ return reject(("Validation error in %s lookup"):format(record_type));
+ elseif answer.status and #answer == 0 then
+ return reject(("%s in %s lookup"):format(answer.status, record_type));
+ end
+
+ local targets = { secure = answer.secure };
+ for _, record in ipairs(answer) do
+ if ipv then
+ table.insert(targets, { self.conn_type..ipv, record[record_type:lower()], self.port, self.extra });
+ else
+ table.insert(targets, record[record_type:lower()]);
+ end
+ end
+ return resolve(targets);
+ end, name, record_type, "IN");
+ end);
+end
+
+local function merge_targets(ipv4_targets, ipv6_targets)
+ local result = { secure = ipv4_targets.secure and ipv6_targets.secure };
+ local common_length = math.min(#ipv4_targets, #ipv6_targets);
+ for i = 1, common_length do
+ table.insert(result, ipv6_targets[i]);
+ table.insert(result, ipv4_targets[i]);
+ end
+ if common_length < #ipv4_targets then
+ t_move(ipv4_targets, common_length+1, #ipv4_targets, common_length+1, result);
+ elseif common_length < #ipv6_targets then
+ t_move(ipv6_targets, common_length+1, #ipv6_targets, common_length+1, result);
+ end
+ return result;
+end
+
-- Find the next target to connect to, and
-- pass it to cb()
function methods:next(cb)
@@ -18,7 +64,7 @@ function methods:next(cb)
return;
end
local next_target = table.remove(self.targets, 1);
- cb(unpack(next_target, 1, 4));
+ cb(next_target[1], next_target[2], next_target[3], next_target[4], not not self.targets[1]);
return;
end
@@ -28,91 +74,45 @@ function methods:next(cb)
return;
end
- local secure = true;
- local tlsa = {};
- local targets = {};
- local n = 3;
- local function ready()
- n = n - 1;
- if n > 0 then return; end
- self.targets = targets;
+ -- Resolve DNS to target list
+ local dns_resolver = adns.resolver();
+
+ local dns_lookups = {
+ ipv4 = do_dns_lookup(self, dns_resolver, "A", self.hostname);
+ ipv6 = do_dns_lookup(self, dns_resolver, "AAAA", self.hostname);
+ tlsa = do_dns_lookup(self, dns_resolver, "TLSA", ("_%d._%s.%s"):format(self.port, self.conn_type, self.hostname));
+ };
+
+ promise.all_settled(dns_lookups):next(function (dns_results)
+ -- Combine targets, assign to self.targets, self:next(cb)
+ local have_ipv4 = dns_results.ipv4.status == "fulfilled";
+ local have_ipv6 = dns_results.ipv6.status == "fulfilled";
+
+ if have_ipv4 and have_ipv6 then
+ self.targets = merge_targets(dns_results.ipv4.value, dns_results.ipv6.value);
+ elseif have_ipv4 then
+ self.targets = dns_results.ipv4.value;
+ elseif have_ipv6 then
+ self.targets = dns_results.ipv6.value;
+ else
+ self.targets = {};
+ end
+
if self.extra and self.extra.use_dane then
- if secure and tlsa[1] then
- self.extra.tlsa = tlsa;
+ if self.targets.secure and dns_results.tlsa.status == "fulfilled" then
+ self.extra.tlsa = dns_results.tlsa.value;
self.extra.dane_hostname = self.hostname;
else
self.extra.tlsa = nil;
self.extra.dane_hostname = nil;
end
end
- self:next(cb);
- end
- -- Resolve DNS to target list
- local dns_resolver = adns.resolver();
-
- if not self.extra or self.extra.use_ipv4 ~= false then
- dns_resolver:lookup(function (answer, err)
- if answer then
- secure = secure and answer.secure;
- for _, record in ipairs(answer) do
- table.insert(targets, { self.conn_type.."4", record.a, self.port, self.extra });
- end
- if answer.bogus then
- self.last_error = "Validation error in A lookup";
- elseif answer.status then
- self.last_error = answer.status .. " in A lookup";
- end
- else
- self.last_error = err;
- end
- ready();
- end, self.hostname, "A", "IN");
- else
- ready();
- end
-
- if not self.extra or self.extra.use_ipv6 ~= false then
- dns_resolver:lookup(function (answer, err)
- if answer then
- secure = secure and answer.secure;
- for _, record in ipairs(answer) do
- table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra });
- end
- if answer.bogus then
- self.last_error = "Validation error in AAAA lookup";
- elseif answer.status then
- self.last_error = answer.status .. " in AAAA lookup";
- end
- else
- self.last_error = err;
- end
- ready();
- end, self.hostname, "AAAA", "IN");
- else
- ready();
- end
-
- if self.extra and self.extra.use_dane == true then
- dns_resolver:lookup(function (answer, err)
- if answer then
- secure = secure and answer.secure;
- for _, record in ipairs(answer) do
- table.insert(tlsa, record.tlsa);
- end
- if answer.bogus then
- self.last_error = "Validation error in TLSA lookup";
- elseif answer.status then
- self.last_error = answer.status .. " in TLSA lookup";
- end
- else
- self.last_error = err;
- end
- ready();
- end, ("_%d._tcp.%s"):format(self.port, self.hostname), "TLSA", "IN");
- else
- ready();
- end
+ self:next(cb);
+ end):catch(function (err)
+ self.last_error = err;
+ self.targets = {};
+ end);
end
local function new(hostname, port, conn_type, extra)
@@ -137,7 +137,7 @@ local function new(hostname, port, conn_type, extra)
hostname = ascii_host;
port = port;
conn_type = conn_type;
- extra = extra;
+ extra = extra or {};
targets = targets;
}, resolver_mt);
end
diff --git a/net/resolvers/service.lua b/net/resolvers/service.lua
index 3810cac8..a7ce76a3 100644
--- a/net/resolvers/service.lua
+++ b/net/resolvers/service.lua
@@ -2,23 +2,78 @@ local adns = require "net.adns";
local basic = require "net.resolvers.basic";
local inet_pton = require "util.net".pton;
local idna_to_ascii = require "util.encodings".idna.to_ascii;
-local unpack = table.unpack or unpack; -- luacheck: ignore 113
local methods = {};
local resolver_mt = { __index = methods };
+local function new_target_selector(rrset)
+ local rr_count = rrset and #rrset;
+ if not rr_count or rr_count == 0 then
+ rrset = nil;
+ else
+ table.sort(rrset, function (a, b) return a.srv.priority < b.srv.priority end);
+ end
+ local rrset_pos = 1;
+ local priority_bucket, bucket_total_weight, bucket_len, bucket_used;
+ return function ()
+ if not rrset then return; end
+
+ if not priority_bucket or bucket_used >= bucket_len then
+ if rrset_pos > rr_count then return; end -- Used up all records
+
+ -- Going to start on a new priority now. Gather up all the next
+ -- records with the same priority and add them to priority_bucket
+ priority_bucket, bucket_total_weight, bucket_len, bucket_used = {}, 0, 0, 0;
+ local current_priority;
+ repeat
+ local curr_record = rrset[rrset_pos].srv;
+ if not current_priority then
+ current_priority = curr_record.priority;
+ elseif current_priority ~= curr_record.priority then
+ break;
+ end
+ table.insert(priority_bucket, curr_record);
+ bucket_total_weight = bucket_total_weight + curr_record.weight;
+ bucket_len = bucket_len + 1;
+ rrset_pos = rrset_pos + 1;
+ until rrset_pos > rr_count;
+ end
+
+ bucket_used = bucket_used + 1;
+ local n, running_total = math.random(0, bucket_total_weight), 0;
+ local target_record;
+ for i = 1, bucket_len do
+ local candidate = priority_bucket[i];
+ if candidate then
+ running_total = running_total + candidate.weight;
+ if running_total >= n then
+ target_record = candidate;
+ bucket_total_weight = bucket_total_weight - candidate.weight;
+ priority_bucket[i] = nil;
+ break;
+ end
+ end
+ end
+ return target_record;
+ end;
+end
+
-- Find the next target to connect to, and
-- pass it to cb()
function methods:next(cb)
- if self.targets then
- if not self.resolver then
- if #self.targets == 0 then
+ if self.resolver or self._get_next_target then
+ if not self.resolver then -- Do we have a basic resolver currently?
+ -- We don't, so fetch a new SRV target, create a new basic resolver for it
+ local next_srv_target = self._get_next_target and self._get_next_target();
+ if not next_srv_target then
+ -- No more SRV targets left
cb(nil);
return;
end
- local next_target = table.remove(self.targets, 1);
- self.resolver = basic.new(unpack(next_target, 1, 4));
+ -- Create a new basic resolver for this SRV target
+ self.resolver = basic.new(next_srv_target.target, next_srv_target.port, self.conn_type, self.extra);
end
+ -- Look up the next (basic) target from the current target's resolver
self.resolver:next(function (...)
if self.resolver then
self.last_error = self.resolver.last_error;
@@ -31,6 +86,9 @@ function methods:next(cb)
end
end);
return;
+ elseif self.in_progress then
+ cb(nil);
+ return;
end
if not self.hostname then
@@ -39,9 +97,9 @@ function methods:next(cb)
return;
end
- local targets = {};
+ self.in_progress = true;
+
local function ready()
- self.targets = targets;
self:next(cb);
end
@@ -63,7 +121,7 @@ function methods:next(cb)
if #answer == 0 then
if self.extra and self.extra.default_port then
- table.insert(targets, { self.hostname, self.extra.default_port, self.conn_type, self.extra });
+ self.resolver = basic.new(self.hostname, self.extra.default_port, self.conn_type, self.extra);
else
self.last_error = "zero SRV records found";
end
@@ -77,10 +135,7 @@ function methods:next(cb)
return;
end
- table.sort(answer, function (a, b) return a.srv.priority < b.srv.priority end);
- for _, record in ipairs(answer) do
- table.insert(targets, { record.srv.target, record.srv.port, self.conn_type, self.extra });
- end
+ self._get_next_target = new_target_selector(answer);
else
self.last_error = err;
end
diff --git a/net/server.lua b/net/server.lua
index 0696fd52..72272bef 100644
--- a/net/server.lua
+++ b/net/server.lua
@@ -118,6 +118,13 @@ if prosody and set_config then
prosody.events.add_handler("config-reloaded", load_config);
end
+local tls_builder = server.tls_builder;
+-- resolving the basedir here avoids util.sslconfig depending on
+-- prosody.paths.config
+function server.tls_builder()
+ return tls_builder(prosody.paths.config or "")
+end
+
-- require "net.server" shall now forever return this,
-- ie. server_select or server_event as chosen above.
return server;
diff --git a/net/server_epoll.lua b/net/server_epoll.lua
index fa275d71..8e75e072 100644
--- a/net/server_epoll.lua
+++ b/net/server_epoll.lua
@@ -18,7 +18,6 @@ local traceback = debug.traceback;
local logger = require "util.logger";
local log = logger.init("server_epoll");
local socket = require "socket";
-local luasec = require "ssl";
local realtime = require "util.time".now;
local monotonic = require "util.time".monotonic;
local indexedbheap = require "util.indexedbheap";
@@ -28,6 +27,8 @@ local inet_pton = inet.pton;
local _SOCKETINVALID = socket._SOCKETINVALID or -1;
local new_id = require "util.id".short;
local xpcall = require "util.xpcall".xpcall;
+local sslconfig = require "util.sslconfig";
+local tls_impl = require "net.tls_luasec";
local poller = require "util.poll"
local EEXIST = poller.EEXIST;
@@ -614,6 +615,30 @@ function interface:set_sslctx(sslctx)
self._sslctx = sslctx;
end
+function interface:sslctx()
+ return self.tls_ctx
+end
+
+function interface:ssl_info()
+ local sock = self.conn;
+ return sock.info and sock:info();
+end
+
+function interface:ssl_peercertificate()
+ local sock = self.conn;
+ return sock.getpeercertificate and sock:getpeercertificate();
+end
+
+function interface:ssl_peerverification()
+ local sock = self.conn;
+ return sock.getpeerverification and sock:getpeerverification();
+end
+
+function interface:ssl_peerfinished()
+ local sock = self.conn;
+ return sock.getpeerfinished and sock:getpeerfinished();
+end
+
function interface:starttls(tls_ctx)
if tls_ctx then self.tls_ctx = tls_ctx; end
self.starttls = false;
@@ -641,11 +666,7 @@ function interface:inittls(tls_ctx, now)
self.starttls = false;
self:debug("Starting TLS now");
self:updatenames(); -- Can't getpeer/sockname after wrap()
- local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx);
- if not ok then
- conn, err = ok, conn;
- self:debug("Failed to initialize TLS: %s", err);
- end
+ local conn, err = self.tls_ctx:wrap(self.conn);
if not conn then
self:on("disconnect", err);
self:destroy();
@@ -656,8 +677,8 @@ function interface:inittls(tls_ctx, now)
if conn.sni then
if self.servername then
conn:sni(self.servername);
- elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
- conn:sni(self._server.hosts, true);
+ elseif next(self.tls_ctx._sni_contexts) ~= nil then
+ conn:sni(self.tls_ctx._sni_contexts, true);
end
end
if self.extra and self.extra.tlsa and conn.settlsa then
@@ -1085,6 +1106,10 @@ return {
cfg = setmetatable(newconfig, default_config);
end;
+ tls_builder = function(basedir)
+ return sslconfig._new(tls_impl.new_context, basedir)
+ end,
+
-- libevent emulation
event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 };
addevent = function (fd, mode, callback)
diff --git a/net/server_event.lua b/net/server_event.lua
index c30181b8..313ba981 100644
--- a/net/server_event.lua
+++ b/net/server_event.lua
@@ -47,11 +47,13 @@ local s_sub = string.sub
local coroutine_wrap = coroutine.wrap
local coroutine_yield = coroutine.yield
-local has_luasec, ssl = pcall ( require , "ssl" )
+local has_luasec = pcall ( require , "ssl" )
local socket = require "socket"
local levent = require "luaevent.core"
local inet = require "util.net";
local inet_pton = inet.pton;
+local sslconfig = require "util.sslconfig";
+local tls_impl = require "net.tls_luasec";
local socket_gettime = socket.gettime
@@ -153,7 +155,7 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed
_ = self.eventwrite and self.eventwrite:close( )
self.eventread, self.eventwrite = nil, nil
local err
- self.conn, err = ssl.wrap( self.conn, self._sslctx )
+ self.conn, err = self._sslctx:wrap(self.conn)
if err then
self.fatalerror = err
self.conn = nil -- cannot be used anymore
@@ -168,8 +170,8 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed
if self.conn.sni then
if self.servername then
self.conn:sni(self.servername);
- elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
- self.conn:sni(self._server.hosts, true);
+ elseif next(self._sslctx._sni_contexts) ~= nil then
+ self.conn:sni(self._sslctx._sni_contexts, true);
end
end
@@ -274,6 +276,26 @@ function interface_mt:pause()
return self:_lock(self.nointerface, true, self.nowriting);
end
+function interface_mt:sslctx()
+ return self._sslctx
+end
+
+function interface_mt:ssl_info()
+ return self.conn.info and self.conn:info()
+end
+
+function interface_mt:ssl_peercertificate()
+ return self.conn.getpeercertificate and self.conn:getpeercertificate()
+end
+
+function interface_mt:ssl_peerverification()
+ return self.conn.getpeerverification and self.conn:getpeerverification()
+end
+
+function interface_mt:ssl_peerfinished()
+ return self.conn.getpeerfinished and self.conn:getpeerfinished()
+end
+
function interface_mt:resume()
self:_lock(self.nointerface, false, self.nowriting);
if self.readcallback and not self.eventread then
@@ -924,6 +946,10 @@ return {
add_task = add_task,
watchfd = watchfd,
+ tls_builder = function(basedir)
+ return sslconfig._new(tls_impl.new_context, basedir)
+ end,
+
__NAME = SCRIPT_NAME,
__DATE = LAST_MODIFIED,
__AUTHOR = SCRIPT_AUTHOR,
diff --git a/net/server_select.lua b/net/server_select.lua
index eea850ce..80f5f590 100644
--- a/net/server_select.lua
+++ b/net/server_select.lua
@@ -47,15 +47,15 @@ local coroutine_yield = coroutine.yield
--// extern libs //--
-local has_luasec, luasec = pcall ( require , "ssl" )
local luasocket = use "socket" or require "socket"
local luasocket_gettime = luasocket.gettime
local inet = require "util.net";
local inet_pton = inet.pton;
+local sslconfig = require "util.sslconfig";
+local has_luasec, tls_impl = pcall(require, "net.tls_luasec");
--// extern lib methods //--
-local ssl_wrap = ( has_luasec and luasec.wrap )
local socket_bind = luasocket.bind
local socket_select = luasocket.select
@@ -359,6 +359,18 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
handler.sslctx = function ( )
return sslctx
end
+ handler.ssl_info = function( )
+ return socket.info and socket:info()
+ end
+ handler.ssl_peercertificate = function( )
+ return socket.getpeercertificate and socket:getpeercertificate()
+ end
+ handler.ssl_peerverification = function( )
+ return socket.getpeerverification and socket:getpeerverification()
+ end
+ handler.ssl_peerfinished = function( )
+ return socket.getpeerfinished and socket:getpeerfinished()
+ end
handler.send = function( _, data, i, j )
return send( socket, data, i, j )
end
@@ -652,7 +664,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
end
out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
local oldsocket, err = socket
- socket, err = ssl_wrap( socket, sslctx ) -- wrap socket
+ socket, err = sslctx:wrap(socket) -- wrap socket
if not socket then
out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
@@ -662,8 +674,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
if socket.sni then
if self.servername then
socket:sni(self.servername);
- elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
- socket:sni(self.server().hosts, true);
+ elseif next(sslctx._sni_contexts) ~= nil then
+ socket:sni(sslctx._sni_contexts, true);
end
end
@@ -1169,4 +1181,8 @@ return {
removeserver = removeserver,
get_backend = get_backend,
changesettings = changesettings,
+
+ tls_builder = function(basedir)
+ return sslconfig._new(tls_impl.new_context, basedir)
+ end,
}
diff --git a/net/tls_luasec.lua b/net/tls_luasec.lua
new file mode 100644
index 00000000..2bedb5ab
--- /dev/null
+++ b/net/tls_luasec.lua
@@ -0,0 +1,89 @@
+-- Prosody IM
+-- Copyright (C) 2021 Prosody folks
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+--[[
+This file provides a shim abstraction over LuaSec, consolidating some code
+which was previously spread between net.server backends, portmanager and
+certmanager.
+
+The goal is to provide a more or less well-defined API on top of LuaSec which
+abstracts away some of the things which are not needed and simplifies usage of
+commonly used things (such as SNI contexts). Eventually, network backends
+which do not rely on LuaSocket+LuaSec should be able to provide *this* API
+instead of having to mimic LuaSec.
+]]
+local ssl = require "ssl";
+local ssl_newcontext = ssl.newcontext;
+local ssl_context = ssl.context or require "ssl.context";
+local io_open = io.open;
+
+local context_api = {};
+local context_mt = {__index = context_api};
+
+function context_api:set_sni_host(host, cert, key)
+ local ctx, err = self._builder:clone():apply({
+ certificate = cert,
+ key = key,
+ }):build();
+ if not ctx then
+ return false, err
+ end
+
+ self._sni_contexts[host] = ctx._inner
+
+ return true, nil
+end
+
+function context_api:remove_sni_host(host)
+ self._sni_contexts[host] = nil
+end
+
+function context_api:wrap(sock)
+ local ok, conn, err = pcall(ssl.wrap, sock, self._inner);
+ if not ok then
+ return nil, err
+ end
+ return conn, nil
+end
+
+local function new_context(cfg, builder)
+ -- LuaSec expects dhparam to be a callback that takes two arguments.
+ -- We ignore those because it is mostly used for having a separate
+ -- set of params for EXPORT ciphers, which we don't have by default.
+ if type(cfg.dhparam) == "string" then
+ local f, err = io_open(cfg.dhparam);
+ if not f then return nil, "Could not open DH parameters: "..err end
+ local dhparam = f:read("*a");
+ f:close();
+ cfg.dhparam = function() return dhparam; end
+ end
+
+ local inner, err = ssl_newcontext(cfg);
+ if not inner then
+ return nil, err
+ end
+
+ -- COMPAT Older LuaSec ignores the cipher list from the config, so we have to take care
+ -- of it ourselves (W/A for #x)
+ if inner and cfg.ciphers then
+ local success;
+ success, err = ssl_context.setcipher(inner, cfg.ciphers);
+ if not success then
+ return nil, err
+ end
+ end
+
+ return setmetatable({
+ _inner = inner,
+ _builder = builder,
+ _sni_contexts = {},
+ }, context_mt), nil
+end
+
+return {
+ new_context = new_context,
+};