-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbktree.py
121 lines (107 loc) · 3.46 KB
/
bktree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
class BKTree:
"""
def __init__(self, distfn, words):
# Create a new BK-tree from the given distance function and
# words.
#
# Arguments:
#
# distfn: a binary function that returns the distance between
# two words. Return value is a non-negative integer. the
# distance function must be a metric space.
#
# words: an iterable. produces values that can be passed to
# distfn
self.distfn = distfn
it = iter(words)
root = it.next()
self.tree = (root, {})
for i in it:
self._add_word(self.tree, i)
"""
def __init__(self, distfn, words):
self.distfn = distfn
root = words[0]
self.tree = (root, {})
for i in words[1:]:
self._add_word(self.tree, i)
def _add_word(self, parent, word):
pword, children = parent
d = self.distfn(word, pword)
if d in children:
self._add_word(children[d], word)
else:
children[d] = (word, {})
def query(self, word, n):
"""
Return all words in the tree that are within a distance of `n'
from `word`.
Arguments:
word: a word to query on
n: a non-negative integer that specifies the allowed distance
from the query word.
Return value is a list of tuples (distance, word), sorted in
ascending order of distance.
"""
def rec(parent):
pword, children = parent
d = self.distfn(word, pword)
results = []
if d <= n:
results.append((d, pword))
for i in range(d - n, d + n + 1):
child = children.get(i)
if child is not None:
results.extend(rec(child))
return results
# sort by distance
return sorted(rec(self.tree))
def brute_query(word, words, distfn, n):
"""A brute force distance query
Arguments:
word: the word to query for
words: a iterable that produces words to test
distfn: a binary function that returns the distance between a
`word' and an item in `words'.
n: an integer that specifies the distance of a matching word
"""
return [i for i in words
if distfn(i, word) <= n]
def maxdepth(tree, c=0):
_, children = tree
if len(children):
return max(maxdepth(i, c + 1) for i in children.values())
else:
return c
def levenshtein(s, t):
m, n = len(s), len(t)
d = [range(n + 1)]
d += [[i] for i in range(1, m + 1)]
for i in range(0, m):
for j in range(0, n):
cost = 1
if s[i] == t[j]: cost = 0
d[i + 1].append(min(d[i][j + 1] + 1, # deletion
d[i + 1][j] + 1, # insertion
d[i][j] + cost) # substitution
)
return d[m][n]
def list_words(dictfile):
words = []
with open(dictfile, "r", encoding="utf-8") as f:
for line in f.readlines():
line = line.strip()
words.append(line)
return words
def timeof(fn, *args):
import time
t = time.time()
res = fn(*args)
print("time: ", (time.time() - t))
return res
if __name__ == "__main__":
tree = BKTree(levenshtein, list_words('vocab.txt'))
dist = 1
for i in ["mil", "1mil", "1stqth", "houu",'1cle','11:11']:
w = set(tree.query(i, dist))
print(w)