aboutsummaryrefslogtreecommitdiffstats
path: root/util/set.lua
blob: bb318adf5290a207d41c3c7f0a2144acf4dd2dcd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
local ipairs, pairs, setmetatable, next, tostring = 
      ipairs, pairs, setmetatable, next, tostring;
local t_concat = table.concat;

module "set"

local set_mt = {};
function set_mt.__call(set, _, k)
	return next(set._items, k);
end
function set_mt.__add(set1, set2)
	return _M.union(set1, set2);
end
function set_mt.__sub(set1, set2)
	return _M.difference(set1, set2);
end
function set_mt.__div(set, func)
	local new_set, new_items = _M.new();
	local items, new_items = set._items, new_set._items;
	for item in pairs(items) do
		if func(item) then
			new_items[item] = true;
		end
	end
	return new_set;
end
function set_mt.__eq(set1, set2)
	local set1, set2 = set1._items, set2._items;
	for item in pairs(set1) do
		if not set2[item] then
			return false;
		end
	end
	
	for item in pairs(set2) do
		if not set1[item] then
			return false;
		end
	end
	
	return true;
end
function set_mt.__tostring(set)
	local s, items = { }, set._items;
	for item in pairs(items) do
		s[#s+1] = tostring(item);
	end
	return t_concat(s, ", ");
end

local items_mt = {};
function items_mt.__call(items, _, k)
	return next(items, k);
end

function new(list)
	local items = setmetatable({}, items_mt);
	local set = { _items = items };
	
	function set:add(item)
		items[item] = true;
	end
	
	function set:contains(item)
		return items[item];
	end
	
	function set:items()
		return items;
	end
	
	function set:remove(item)
		items[item] = nil;
	end
	
	function set:add_list(list)
		for _, item in ipairs(list) do
			items[item] = true;
		end
	end
	
	function set:include(otherset)
		for item in pairs(otherset) do
			items[item] = true;
		end
	end

	function set:exclude(otherset)
		for item in pairs(otherset) do
			items[item] = nil;
		end
	end
	
	function set:empty()
		return not next(items);
	end
	
	if list then
		set:add_list(list);
	end
	
	return setmetatable(set, set_mt);
end

function union(set1, set2)
	local set = new();
	local items = set._items;
	
	for item in pairs(set1._items) do
		items[item] = true;
	end

	for item in pairs(set2._items) do
		items[item] = true;
	end
	
	return set;
end

function difference(set1, set2)
	local set = new();
	local items = set._items;
	
	for item in pairs(set1._items) do
		items[item] = (not set2._items[item]) or nil;
	end

	return set;
end

function intersection(set1, set2)
	local set = new();
	local items = set._items;
	
	set1, set2 = set1._items, set2._items;
	
	for item in pairs(set1) do
		items[item] = (not not set2[item]) or nil;
	end
	
	return set;
end

return _M;