-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDisjointSet.pyx
65 lines (49 loc) · 1.67 KB
/
DisjointSet.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import itertools
import numpy as np
cimport numpy as np
def grouper(iterable, n, fillvalue=None):
"""From python documentation iter tools:
https://docs.python.org/3/library/itertools.html"""
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
args = [iter(iterable)] * n
return itertools.zip_longest(*args, fillvalue=fillvalue)
cdef class DisjointSets:
cdef readonly np.ndarray sets
cdef __cinit__(self, int size):
self.sets = np.ones(size, dtype=np.int32) * -1
cdef void reset(self):
self.sets = np.ones(len(self.sets), dtype=np.int32) * -1
cdef union_group(self, elems:'groupable iterable'):
if len(elems) == 1:
return
for i, j in grouper(elems, 2, None):
self.union(i, j)
self.union_group(elems[::2])
cdef void union(self, a, b):
if a is None or b is None:
return
repA = self.find(a)
repB = self.find(b)
if repA == repB:
return
if self.sets[repA] <= self.sets[repB]:
self.sets[repA] += self.sets[repB]
self.sets[repB] = repA
else:
self.sets[repB] += self.sets[repA]
self.sets[repA] = repB
cdef void find(self, elem):
if self.sets[elem] < 0:
return elem
i = self.find(self.sets[elem])
self.sets[elem] = i
return i
cdef void __setitem__(self, key, val):
self.sets[key] = val
cdef void __len__(self):
return len(self.sets)
cdef void copy(self):
ret = DisjointSets(len(self.sets))
ret.sets = np.copy(self.sets)
return ret