aboutsummaryrefslogtreecommitdiffstats
path: root/util/set.lua
diff options
context:
space:
mode:
Diffstat (limited to 'util/set.lua')
-rw-r--r--util/set.lua101
1 files changed, 54 insertions, 47 deletions
diff --git a/util/set.lua b/util/set.lua
index 4be39c17..c136a522 100644
--- a/util/set.lua
+++ b/util/set.lua
@@ -10,59 +10,19 @@ local ipairs, pairs, setmetatable, next, tostring =
ipairs, pairs, setmetatable, next, tostring;
local t_concat = table.concat;
-module "set"
+local _ENV = nil;
local set_mt = {};
function set_mt.__call(set, _, k)
return next(set._items, k);
end
-function set_mt.__add(set1, set2)
- return _M.union(set1, set2);
-end
-function set_mt.__sub(set1, set2)
- return _M.difference(set1, set2);
-end
-function set_mt.__div(set, func)
- local new_set = _M.new();
- local items, new_items = set._items, new_set._items;
- for item in pairs(items) do
- local new_item = func(item);
- if new_item ~= nil then
- new_items[new_item] = true;
- end
- end
- return new_set;
-end
-function set_mt.__eq(set1, set2)
- set1, set2 = set1._items, set2._items;
- for item in pairs(set1) do
- if not set2[item] then
- return false;
- end
- end
-
- for item in pairs(set2) do
- if not set1[item] then
- return false;
- end
- end
-
- return true;
-end
-function set_mt.__tostring(set)
- local s, items = { }, set._items;
- for item in pairs(items) do
- s[#s+1] = tostring(item);
- end
- return t_concat(s, ", ");
-end
local items_mt = {};
function items_mt.__call(items, _, k)
return next(items, k);
end
-function new(list)
+local function new(list)
local items = setmetatable({}, items_mt);
local set = { _items = items };
@@ -116,7 +76,7 @@ function new(list)
return setmetatable(set, set_mt);
end
-function union(set1, set2)
+local function union(set1, set2)
local set = new();
local items = set._items;
@@ -131,7 +91,7 @@ function union(set1, set2)
return set;
end
-function difference(set1, set2)
+local function difference(set1, set2)
local set = new();
local items = set._items;
@@ -142,7 +102,7 @@ function difference(set1, set2)
return set;
end
-function intersection(set1, set2)
+local function intersection(set1, set2)
local set = new();
local items = set._items;
@@ -155,8 +115,55 @@ function intersection(set1, set2)
return set;
end
-function xor(set1, set2)
+local function xor(set1, set2)
return union(set1, set2) - intersection(set1, set2);
end
-return _M;
+function set_mt.__add(set1, set2)
+ return union(set1, set2);
+end
+function set_mt.__sub(set1, set2)
+ return difference(set1, set2);
+end
+function set_mt.__div(set, func)
+ local new_set = new();
+ local items, new_items = set._items, new_set._items;
+ for item in pairs(items) do
+ local new_item = func(item);
+ if new_item ~= nil then
+ new_items[new_item] = true;
+ end
+ end
+ return new_set;
+end
+function set_mt.__eq(set1, set2)
+ set1, set2 = set1._items, set2._items;
+ for item in pairs(set1) do
+ if not set2[item] then
+ return false;
+ end
+ end
+
+ for item in pairs(set2) do
+ if not set1[item] then
+ return false;
+ end
+ end
+
+ return true;
+end
+function set_mt.__tostring(set)
+ local s, items = { }, set._items;
+ for item in pairs(items) do
+ s[#s+1] = tostring(item);
+ end
+ return t_concat(s, ", ");
+end
+
+return {
+ new = new;
+ union = union;
+ difference = difference;
+ intersection = intersection;
+ xor = xor;
+};