diff options
-rw-r--r-- | plugins/mod_external_services.lua | 44 |
1 files changed, 21 insertions, 23 deletions
diff --git a/plugins/mod_external_services.lua b/plugins/mod_external_services.lua index 6fafdb1f..871c275f 100644 --- a/plugins/mod_external_services.lua +++ b/plugins/mod_external_services.lua @@ -5,6 +5,7 @@ local hashes = require "util.hashes"; local st = require "util.stanza"; local jid = require "util.jid"; local array = require "util.array"; +local set = require "util.set"; local default_host = module:get_option_string("external_service_host", module.host); local default_port = module:get_option_number("external_service_port"); @@ -186,22 +187,18 @@ local function handle_credentials(event) return item.restricted; end) - local requested_credentials = {}; + local requested_credentials = set.new(); for service in action:childtags("service") do if not service.attr.type or not service.attr.host then origin.send(st.error_reply(stanza, "modify", "bad-request")); return true; end - table.insert(requested_credentials, { - type = service.attr.type; - host = service.attr.host; - port = tonumber(service.attr.port); - }); + requested_credentials:add(string.format("%s:%s:%d", service.attr.type, service.attr.host, + tonumber(service.attr.port) or 0)); end setmetatable(services, services_mt); - setmetatable(requested_credentials, services_mt); module:fire_event("external_service/credentials", { origin = origin; @@ -211,22 +208,23 @@ local function handle_credentials(event) services = services; }); - for req_srv in action:childtags("service") do - for _, srv in ipairs(services) do - if srv.type == req_srv.attr.type and srv.host == req_srv.attr.host - and not req_srv.attr.port or srv.port == tonumber(req_srv.attr.port) then - reply:tag("service", { - type = srv.type; - transport = srv.transport; - host = srv.host; - port = srv.port and string.format("%d", srv.port) or nil; - username = srv.username; - password = srv.password; - expires = srv.expires and dt.datetime(srv.expires) or nil; - restricted = srv.restricted and "1" or nil; - }):up(); - end - end + services:filter(function (srv) + local port_key = string.format("%s:%s:%d", srv.type, srv.host, srv.port or 0); + local portless_key = string.format("%s:%s:%d", srv.type, srv.host, 0); + return requested_credentials:contains(port_key) or requested_credentials:contains(portless_key); + end); + + for _, srv in ipairs(services) do + reply:tag("service", { + type = srv.type; + transport = srv.transport; + host = srv.host; + port = srv.port and string.format("%d", srv.port) or nil; + username = srv.username; + password = srv.password; + expires = srv.expires and dt.datetime(srv.expires) or nil; + restricted = srv.restricted and "1" or nil; + }):up(); end origin.send(reply); |