-
Notifications
You must be signed in to change notification settings - Fork 197
/
Copy pathremove_repeat_sentences_mapper.py
73 lines (60 loc) · 2.78 KB
/
remove_repeat_sentences_mapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import regex as re
from ..base_op import OPERATORS, Mapper
def split_sentence(text):
text = re.sub('([.。!!?\?])([^’”])', r'\1\n\2', text) # noqa
text = re.sub('(\.{6})([^’”])', r'\1\n\2', text) # noqa
text = re.sub('(\…{2})([^’”])', r'\1\n\2', text) # noqa
text = re.sub('([.。!!?\?\.{6}\…{2}][’”])([^’”])', r'\1\n\2', text) # noqa
return text.split('\n')
@OPERATORS.register_module('remove_repeat_sentences_mapper')
class RemoveRepeatSentencesMapper(Mapper):
"""Mapper to remove repeat sentences in text samples."""
_batched_op = True
def __init__(self,
lowercase: bool = False,
ignore_special_character: bool = True,
min_repeat_sentence_length: int = 2,
*args,
**kwargs):
"""
Initialization method.
:param lowercase: Whether to convert sample text to lower case
:param ignore_special_character: Whether to ignore special
characters when judging repeated sentences. Special characters
are all characters except Chinese characters, letters and
numbers.
:param min_repeat_sentence_length: Sentences shorter than this
length will not be deduplicated. If ignore_special_character is
set to True, then special characters are not included in this
length.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.lowercase = lowercase
self.min_repeat_sentence_length = min_repeat_sentence_length
self.remove_regex = re.compile(r'[^a-zA-Z0-9\u4e00-\u9fa5\n\t ]'
) if ignore_special_character else None
def process_batched(self, samples):
for idx, text in enumerate(samples[self.text_key]):
lines = [e for e in text.split('\n')]
new_lines = []
hash_set = set([])
for line in lines:
new_sent = ''
if line:
sentences = split_sentence(line)
for sentence in sentences:
copy = sentence.strip()
if self.lowercase:
copy = copy.lower()
if self.remove_regex:
copy = self.remove_regex.sub('', copy)
if len(copy) < self.min_repeat_sentence_length:
new_sent += sentence
elif copy not in hash_set:
new_sent += sentence
hash_set.add(copy)
new_lines.append(new_sent)
samples[self.text_key][idx] = '\n'.join(new_lines)
return samples