-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathminbpe.py
325 lines (291 loc) · 13.6 KB
/
minbpe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
"""
以下代码复制于 https://github.com/karpathy/minbpe
我简化了以下,放在一个文件里方便演示。
Minimal (byte-level) Byte Pair Encoding tokenizer.
Algorithmically follows along the GPT tokenizer:
https://github.com/openai/gpt-2/blob/master/src/encoder.py
Unlike BasicTokenizer:
- RegexTokenizer handles an optional regex splitting pattern.
- RegexTokenizer handles optional special tokens.
"""
import regex as re
import unicodedata
# -----------------------------------------------------------------------------
# a few helper functions useful for both BasicTokenizer and RegexTokenizer
def get_stats(ids, counts=None):
"""
Given a list of integers, return a dictionary of counts of consecutive pairs
Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
Optionally allows to update an existing dictionary of counts
"""
counts = {} if counts is None else counts
for pair in zip(ids, ids[1:]): # iterate consecutive elements
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(ids, pair, idx):
"""
In the list of integers (ids), replace all consecutive occurrences
of pair with the new integer token idx
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
"""
newids = []
i = 0
while i < len(ids):
# if not at the very last position AND the pair matches, replace it
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
# first two helper functions...
def replace_control_characters(s: str) -> str:
# we don't want to print control characters
# which distort the output (e.g. \n or much worse)
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
# http://www.unicode.org/reports/tr44/#GC_Values_Table
chars = []
for ch in s:
if unicodedata.category(ch)[0] != "C":
chars.append(ch) # this character is ok
else:
chars.append(f"\\u{ord(ch):04x}") # escape
return "".join(chars)
def render_token(t: bytes) -> str:
# pretty print a token, escaping control characters
s = t.decode('utf-8', errors='replace')
s = replace_control_characters(s)
return s
# -----------------------------------------------------------------------------
# the base Tokenizer class
class Tokenizer:
"""Base class for Tokenizers"""
def __init__(self):
# default: vocab size of 256 (all bytes), no merges, no patterns
self.merges = {} # (int, int) -> int
self.pattern = "" # str
self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
self.vocab = self._build_vocab() # int -> bytes
def train(self, text, vocab_size, verbose=False):
# Tokenizer can train a vocabulary of size vocab_size from text
raise NotImplementedError
def encode(self, text):
# Tokenizer can encode a string into a list of integers
# 定义在RegexTokenizer类中
raise NotImplementedError
def decode(self, ids):
# Tokenizer can decode a list of integers into a string
# 定义在RegexTokenizer类中
raise NotImplementedError
def _build_vocab(self):
# vocab is simply and deterministically derived from merges
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
for special, idx in self.special_tokens.items():
vocab[idx] = special.encode("utf-8")
return vocab
def save(self, file_prefix):
"""
Saves two files: file_prefix.vocab and file_prefix.model
This is inspired (but not equivalent to!) sentencepiece's model saving:
- model file is the critical one, intended for load()
- vocab file is just a pretty printed version for human inspection only
"""
# write the model: to be used in load() later
model_file = file_prefix + ".model"
with open(model_file, 'w') as f:
# write the version, pattern and merges, that's all that's needed
f.write("minbpe v1\n")
f.write(f"{self.pattern}\n")
# write the special tokens, first the number of them, then each one
f.write(f"{len(self.special_tokens)}\n")
for special, idx in self.special_tokens.items():
f.write(f"{special} {idx}\n")
# the merges dict
for idx1, idx2 in self.merges:
f.write(f"{idx1} {idx2}\n")
# write the vocab: for the human to look at
vocab_file = file_prefix + ".vocab"
inverted_merges = {idx: pair for pair, idx in self.merges.items()}
with open(vocab_file, "w", encoding="utf-8") as f:
for idx, token in self.vocab.items():
# note: many tokens may be partial utf-8 sequences
# and cannot be decoded into valid strings. Here we're using
# errors='replace' to replace them with the replacement char �.
# this also means that we couldn't possibly use .vocab in load()
# because decoding in this way is a lossy operation!
s = render_token(token)
# find the children of this token, if any
if idx in inverted_merges:
# if this token has children, render it nicely as a merge
idx0, idx1 = inverted_merges[idx]
s0 = render_token(self.vocab[idx0])
s1 = render_token(self.vocab[idx1])
f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
else:
# otherwise this is leaf token, just print it
# (this should just be the first 256 tokens, the bytes)
f.write(f"[{s}] {idx}\n")
def load(self, model_file):
"""Inverse of save() but only for the model file"""
assert model_file.endswith(".model")
# read the model file
merges = {}
special_tokens = {}
idx = 256
with open(model_file, 'r', encoding="utf-8") as f:
# read the version
version = f.readline().strip()
assert version == "minbpe v1"
# read the pattern
self.pattern = f.readline().strip()
# read the special tokens
num_special = int(f.readline().strip())
for _ in range(num_special):
special, special_idx = f.readline().strip().split()
special_tokens[special] = int(special_idx)
# read the merges
for line in f:
idx1, idx2 = map(int, line.split())
merges[(idx1, idx2)] = idx
idx += 1
self.merges = merges
self.special_tokens = special_tokens
self.vocab = self._build_vocab()
# the main GPT text split patterns, see
# https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
class BPETokenizer(Tokenizer):
def __init__(self, pattern=None):
"""
- pattern: optional string to override the default (GPT-4 split pattern)
- special_tokens: str -> int dictionary of special tokens
example: {'<|endoftext|>': 100257}
"""
super().__init__()
self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
self.compiled_pattern = re.compile(self.pattern)
self.special_tokens = {}
self.inverse_special_tokens = {}
def train(self, text, vocab_size, verbose=False):
assert vocab_size >= 256
num_merges = vocab_size - 256
# split the text up into text chunks
text_chunks = re.findall(self.compiled_pattern, text)
# input text preprocessing
ids = [list(ch.encode("utf-8")) for ch in text_chunks]
# iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
for i in range(num_merges):
# count the number of times every consecutive pair appears
stats = {}
for chunk_ids in ids:
# passing in stats will update it in place, adding up counts
get_stats(chunk_ids, stats)
# find the pair with the highest count
pair = max(stats, key=stats.get)
# mint a new token: assign it the next available id
idx = 256 + i
# replace all occurrences of pair in ids with idx
ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
# save the merge
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# prints
if verbose:
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
# save class variables
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()
def register_special_tokens(self, special_tokens):
# special_tokens is a dictionary of str -> int
# example: {"<|endoftext|>": 100257}
self.special_tokens = special_tokens
self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
def decode(self, ids):
# given ids (list of integers), return Python string
part_bytes = []
for idx in ids:
if idx in self.vocab:
part_bytes.append(self.vocab[idx])
elif idx in self.inverse_special_tokens:
part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
else:
raise ValueError(f"invalid token id: {idx}")
text_bytes = b"".join(part_bytes)
text = text_bytes.decode("utf-8", errors="replace")
return text
def _encode_chunk(self, text_bytes):
# return the token ids
# let's begin. first, convert all bytes to integers in range 0..255
ids = list(text_bytes)
while len(ids) >= 2:
# find the pair with the lowest merge index
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# subtle: if there are no more merges available, the key will
# result in an inf for every single pair, and the min will be
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
ids = merge(ids, pair, idx)
return ids
def encode_ordinary(self, text):
"""Encoding that ignores any special tokens."""
# split text into chunks of text by categories defined in regex pattern
text_chunks = re.findall(self.compiled_pattern, text)
# all chunks of text are encoded separately, then results are joined
ids = []
for chunk in text_chunks:
chunk_bytes = chunk.encode("utf-8") # raw bytes
chunk_ids = self._encode_chunk(chunk_bytes)
ids.extend(chunk_ids)
return ids
def encode(self, text, allowed_special="none_raise"):
"""
Unlike encode_ordinary, this function handles special tokens.
allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
if none_raise, then an error is raised if any special token is encountered in text
this is the default tiktoken behavior right now as well
any other behavior is either annoying, or a major footgun
"""
# decode the user desire w.r.t. handling of special tokens
special = None
if allowed_special == "all":
special = self.special_tokens
elif allowed_special == "none":
special = {}
elif allowed_special == "none_raise":
special = {}
assert all(token not in text for token in self.special_tokens)
elif isinstance(allowed_special, set):
special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
else:
raise ValueError(f"allowed_special={allowed_special} not understood")
if not special:
# shortcut: if no special tokens, just use the ordinary encoding
return self.encode_ordinary(text)
# otherwise, we have to be careful with potential special tokens in text
# we handle special tokens by splitting the text
# based on the occurrence of any exact match with any of the special tokens
# we can use re.split for this. note that surrounding the pattern with ()
# makes it into a capturing group, so the special tokens will be included
special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
special_chunks = re.split(special_pattern, text)
# now all the special characters are separated from the rest of the text
# all chunks of text are encoded separately, then results are joined
ids = []
for part in special_chunks:
if part in special:
# this is a special token, encode it separately as a special case
ids.append(special[part])
else:
# this is an ordinary sequence, encode it normally
ids.extend(self.encode_ordinary(part))
return ids