aboutsummaryrefslogtreecommitdiffstats
path: root/plugins/mod_net_multiplex.lua
blob: 3f5ee54d850d8a7dd11e97be3e740fd2d1823626 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
module:set_global();

local array = require "prosody.util.array";
local max_buffer_len = module:get_option_integer("multiplex_buffer_size", 1024, 1);
local default_mode = module:get_option_integer("network_default_read_size", 4096, 0);

local portmanager = require "prosody.core.portmanager";

local available_services = {};
local service_by_protocol = {};
local available_protocols = array();

local function add_service(service)
	local multiplex_pattern = service.multiplex and service.multiplex.pattern;
	local protocol_name = service.multiplex and service.multiplex.protocol;
	if protocol_name then
		module:log("debug", "Adding multiplex service %q with protocol %q", service.name, protocol_name);
		service_by_protocol[protocol_name] = service;
		available_protocols:push(protocol_name);
	end
	if multiplex_pattern then
		module:log("debug", "Adding multiplex service %q with pattern %q", service.name, multiplex_pattern);
		available_services[service] = multiplex_pattern;
	elseif not protocol_name then
		module:log("debug", "Service %q is not multiplex-capable", service.name);
	end
end
module:hook("service-added", function (event) add_service(event.service); end);
module:hook("service-removed", function (event)
	available_services[event.service] = nil;
	if event.service.multiplex and event.service.multiplex.protocol then
		available_protocols:filter(function (p) return p ~= event.service.multiplex.protocol end);
		service_by_protocol[event.service.multiplex.protocol] = nil;
	end
end);

for _, services in pairs(portmanager.get_registered_services()) do
	for _, service in ipairs(services) do
		add_service(service);
	end
end

local buffers = {};

local listener = { default_mode = max_buffer_len };

function listener.onconnect(conn)
	local sock = conn:socket();
	if sock.getalpn then
		local selected_proto = sock:getalpn();
		local service = service_by_protocol[selected_proto];
		if service then
			module:log("debug", "Routing incoming connection to %s based on ALPN %q", service.name, selected_proto);
			local next_listener = service.listener;
			conn:setlistener(next_listener);
			conn:set_mode(next_listener.default_mode or default_mode);
			local onconnect = next_listener.onconnect;
			if onconnect then return onconnect(conn) end
		end
	end
end

function listener.onincoming(conn, data)
	if not data then return; end
	local buf = buffers[conn];
	buf = buf and buf..data or data;
	for service, multiplex_pattern in pairs(available_services) do
		if buf:match(multiplex_pattern) then
			module:log("debug", "Routing incoming connection to %s", service.name);
			local next_listener = service.listener;
			conn:setlistener(next_listener);
			conn:set_mode(next_listener.default_mode or default_mode);
			local onconnect = next_listener.onconnect;
			if onconnect then onconnect(conn) end
			return next_listener.onincoming(conn, buf);
		end
	end
	if #buf > max_buffer_len then -- Give up
		conn:close();
	else
		buffers[conn] = buf;
	end
end

function listener.ondisconnect(conn)
	buffers[conn] = nil; -- warn if no buffer?
end

listener.ondetach = listener.ondisconnect;

module:provides("net", {
	name = "multiplex";
	config_prefix = "";
	listener = listener;
});

module:provides("net", {
	name = "multiplex_ssl";
	config_prefix = "ssl";
	encryption = "ssl";
	ssl_config = {
		alpn = function ()
			return available_protocols;
		end;
	};
	listener = listener;
});