aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
Diffstat (limited to 'net')
-rw-r--r--net/server_epoll.lua35
1 files changed, 21 insertions, 14 deletions
diff --git a/net/server_epoll.lua b/net/server_epoll.lua
index 9516ab4d..a9440258 100644
--- a/net/server_epoll.lua
+++ b/net/server_epoll.lua
@@ -440,15 +440,30 @@ end
function interface:starttls(tls_ctx)
if tls_ctx then self.tls_ctx = tls_ctx; end
+ self.starttls = false;
if self.writebuffer and self.writebuffer[1] then
log("debug", "Start TLS on %s after write", self);
self.ondrain = interface.starttls;
- self.starttls = false;
self:set(nil, true); -- make sure wantwrite is set
else
+ if self.ondrain == interface.starttls then
+ self.ondrain = nil;
+ end
+ self.onwritable = interface.tlshandskake;
+ self.onreadable = interface.tlshandskake;
+ self:set(true, true);
+ log("debug", "Prepare to start TLS on %s", self);
+ end
+end
+
+function interface:tlshandskake()
+ self:setwritetimeout(false);
+ self:setreadtimeout(false);
+ if not self._tls then
+ self._tls = true;
log("debug", "Start TLS on %s now", self);
self:del();
- local conn, err = luasec.wrap(self.conn, tls_ctx or self.tls_ctx);
+ local conn, err = luasec.wrap(self.conn, self.tls_ctx);
if not conn then
self:on("disconnect", err);
self:destroy();
@@ -456,22 +471,17 @@ function interface:starttls(tls_ctx)
end
conn:settimeout(0);
self.conn = conn;
+ self:on("starttls");
self.ondrain = nil;
self.onwritable = interface.tlshandskake;
self.onreadable = interface.tlshandskake;
return self:init();
end
-end
-
-function interface:tlshandskake()
- self:setwritetimeout(false);
- self:setreadtimeout(false);
local ok, err = self.conn:dohandshake();
if ok then
log("debug", "TLS handshake on %s complete", self);
self.onwritable = nil;
self.onreadable = nil;
- self._tls = true;
self:on("status", "ssl-handshake-complete");
self:setwritetimeout();
self:set(true, true);
@@ -529,10 +539,9 @@ function interface:onacceptable()
end
local client = wrapsocket(conn, self, nil, self.listeners);
log("debug", "New connection %s", tostring(client));
+ client:init();
if self.tls_direct then
client:starttls(self.tls_ctx);
- else
- client:init();
end
end
@@ -600,10 +609,9 @@ local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx)
if not client.peername then
client.peername, client.peerport = addr, port;
end
+ client:init();
if tls_ctx then
client:starttls(tls_ctx);
- else
- client:init();
end
return client;
end
@@ -615,10 +623,9 @@ local function addclient(addr, port, listeners, read_size, tls_ctx)
conn:settimeout(0);
conn:connect(addr, port);
local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx)
+ client:init();
if tls_ctx then
client:starttls(tls_ctx);
- else
- client:init();
end
return client, conn;
end