From 9291209e822354ad532160e6c66c58aeb1a3db5a Mon Sep 17 00:00:00 2001 From: Kim Alvefur Date: Fri, 14 Sep 2018 01:34:38 +0200 Subject: net.server_epoll: Delay wrapping sockets in TLS until just before first handshake --- net/server_epoll.lua | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/net/server_epoll.lua b/net/server_epoll.lua index 9516ab4d..a9440258 100644 --- a/net/server_epoll.lua +++ b/net/server_epoll.lua @@ -440,15 +440,30 @@ end function interface:starttls(tls_ctx) if tls_ctx then self.tls_ctx = tls_ctx; end + self.starttls = false; if self.writebuffer and self.writebuffer[1] then log("debug", "Start TLS on %s after write", self); self.ondrain = interface.starttls; - self.starttls = false; self:set(nil, true); -- make sure wantwrite is set else + if self.ondrain == interface.starttls then + self.ondrain = nil; + end + self.onwritable = interface.tlshandskake; + self.onreadable = interface.tlshandskake; + self:set(true, true); + log("debug", "Prepare to start TLS on %s", self); + end +end + +function interface:tlshandskake() + self:setwritetimeout(false); + self:setreadtimeout(false); + if not self._tls then + self._tls = true; log("debug", "Start TLS on %s now", self); self:del(); - local conn, err = luasec.wrap(self.conn, tls_ctx or self.tls_ctx); + local conn, err = luasec.wrap(self.conn, self.tls_ctx); if not conn then self:on("disconnect", err); self:destroy(); @@ -456,22 +471,17 @@ function interface:starttls(tls_ctx) end conn:settimeout(0); self.conn = conn; + self:on("starttls"); self.ondrain = nil; self.onwritable = interface.tlshandskake; self.onreadable = interface.tlshandskake; return self:init(); end -end - -function interface:tlshandskake() - self:setwritetimeout(false); - self:setreadtimeout(false); local ok, err = self.conn:dohandshake(); if ok then log("debug", "TLS handshake on %s complete", self); self.onwritable = nil; self.onreadable = nil; - self._tls = true; self:on("status", "ssl-handshake-complete"); self:setwritetimeout(); self:set(true, true); @@ -529,10 +539,9 @@ function interface:onacceptable() end local client = wrapsocket(conn, self, nil, self.listeners); log("debug", "New connection %s", tostring(client)); + client:init(); if self.tls_direct then client:starttls(self.tls_ctx); - else - client:init(); end end @@ -600,10 +609,9 @@ local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx) if not client.peername then client.peername, client.peerport = addr, port; end + client:init(); if tls_ctx then client:starttls(tls_ctx); - else - client:init(); end return client; end @@ -615,10 +623,9 @@ local function addclient(addr, port, listeners, read_size, tls_ctx) conn:settimeout(0); conn:connect(addr, port); local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx) + client:init(); if tls_ctx then client:starttls(tls_ctx); - else - client:init(); end return client, conn; end -- cgit v1.2.3