aboutsummaryrefslogtreecommitdiffstats
path: root/net
diff options
context:
space:
mode:
authorKim Alvefur <zash@zash.se>2018-05-09 16:15:40 +0200
committerKim Alvefur <zash@zash.se>2018-05-09 16:15:40 +0200
commita247edeac984c22ac4419eb03b5bbbdcd9a2206f (patch)
treeea76a040fa49fccc3bc0669effd781d0382c34cd /net
parentd6ed959fd38c1f32aa41ba42db0061531b4703fe (diff)
downloadprosody-a247edeac984c22ac4419eb03b5bbbdcd9a2206f.tar.gz
prosody-a247edeac984c22ac4419eb03b5bbbdcd9a2206f.zip
net.server: Add watchfd, a simple API for watching file descriptors
Diffstat (limited to 'net')
-rw-r--r--net/server_epoll.lua21
-rw-r--r--net/server_event.lua29
-rw-r--r--net/server_select.lua43
3 files changed, 93 insertions, 0 deletions
diff --git a/net/server_epoll.lua b/net/server_epoll.lua
index 564444f8..0881f797 100644
--- a/net/server_epoll.lua
+++ b/net/server_epoll.lua
@@ -15,6 +15,7 @@ local t_concat = table.concat;
local setmetatable = setmetatable;
local tostring = tostring;
local pcall = pcall;
+local type = type;
local next = next;
local pairs = pairs;
local log = require "util.logger".init("server_epoll");
@@ -586,6 +587,25 @@ local function addclient(addr, port, listeners, pattern, tls)
return client, conn;
end
+local function watchfd(fd, onreadable, onwriteable)
+ local conn = setmetatable({
+ conn = fd;
+ onreadable = onreadable;
+ onwriteable = onwriteable;
+ close = function (self)
+ self:setflags(false, false);
+ end
+ }, interface_mt);
+ if type(fd) == "number" then
+ conn.getfd = function ()
+ return fd;
+ end;
+ -- Otherwise it'll need to be something LuaSocket-compatible
+ end
+ conn:setflags(onreadable, onwriteable);
+ return conn;
+end;
+
-- Dump all data from one connection into another
local function link(from, to)
from.listeners = setmetatable({
@@ -663,6 +683,7 @@ return {
closeall = closeall;
setquitting = setquitting;
wrapclient = wrapclient;
+ watchfd = watchfd;
link = link;
set_config = function (newconfig)
cfg = setmetatable(newconfig, default_config);
diff --git a/net/server_event.lua b/net/server_event.lua
index d40b388f..3e949092 100644
--- a/net/server_event.lua
+++ b/net/server_event.lua
@@ -834,6 +834,34 @@ local function add_task(delay, callback)
return event_handle;
end
+local function watchfd(fd, onreadable, onwriteable)
+ local handle = {};
+ function handle:setflags(r,w)
+ if r ~= nil then
+ if r and not self.wantread then
+ self.wantread = base:addevent(fd, EV_READ, function ()
+ onreadable(self);
+ end);
+ elseif not r and self.wantread then
+ self.wantread:close();
+ self.wantread = nil;
+ end
+ end
+ if w ~= nil then
+ if w and not self.wantwrite then
+ self.wantwrite = base:addevent(fd, EV_WRITE, function ()
+ onwriteable(self);
+ end);
+ elseif not r and self.wantread then
+ self.wantwrite:close();
+ self.wantwrite = nil;
+ end
+ end
+ end
+ handle:setflags(onreadable, onwriteable);
+ return handle;
+end
+
return {
cfg = cfg,
base = base,
@@ -850,6 +878,7 @@ return {
get_backend = get_backend,
hook_signal = hook_signal,
add_task = add_task,
+ watchfd = watchfd,
__NAME = SCRIPT_NAME,
__DATE = LAST_MODIFIED,
diff --git a/net/server_select.lua b/net/server_select.lua
index cfd08f37..3b83bb6d 100644
--- a/net/server_select.lua
+++ b/net/server_select.lua
@@ -1034,6 +1034,48 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ )
end
end
+local closewatcher = function (handler)
+ local socket = handler.conn;
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
+ _readlistlen = removesocket( _readlist, socket, _readlistlen )
+ _socketlist[ socket ] = nil
+end;
+
+local addremove = function (handler, read, send)
+ local socket = handler.conn
+ _socketlist[ socket ] = handler
+ if read ~= nil then
+ if read then
+ _readlistlen = addsocket( _readlist, socket, _readlistlen )
+ else
+ _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
+ end
+ end
+ if send ~= nil then
+ if send then
+ _sendlistlen = addsocket( _sendlist, socket, _sendlistlen )
+ else
+ _readlistlen = removesocket( _readlist, socket, _readlistlen )
+ end
+ end
+end
+
+local watchfd = function ( fd, onreadable, onwriteable )
+ local socket = fd
+ if type(fd) == "number" then
+ socket = { getfd = function () return fd; end }
+ end
+ local handler = {
+ conn = socket;
+ readbuffer = onreadable or id;
+ sendbuffer = onwriteable or id;
+ close = closewatcher;
+ setflags = addremove;
+ };
+ addremove( handler, onreadable, onwriteable )
+ return handler
+end
+
----------------------------------// BEGIN //--
use "setmetatable" ( _socketlist, { __mode = "k" } )
@@ -1058,6 +1100,7 @@ return {
addclient = addclient,
wrapclient = wrapclient,
+ watchfd = watchfd,
loop = loop,
link = link,