-
Notifications
You must be signed in to change notification settings - Fork 197
/
Copy pathextract_entity_attribute_mapper.py
200 lines (173 loc) · 8.18 KB
/
extract_entity_attribute_mapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import re
from typing import Dict, List, Optional
import numpy as np
from loguru import logger
from pydantic import PositiveInt
from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.model_utils import get_model, prepare_model
OP_NAME = 'extract_entity_attribute_mapper'
# TODO: LLM-based inference.
@TAGGING_OPS.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class ExtractEntityAttributeMapper(Mapper):
"""
Extract attributes for given entities from the text
"""
DEFAULT_SYSTEM_PROMPT_TEMPLATE = (
'给定一段文本,从文本中总结{entity}的{attribute},并且从原文摘录最能说明该{attribute}的代表性示例。\n'
'要求:\n'
'- 摘录的示例应该简短。\n'
'- 遵循如下的回复格式:\n'
'# {entity}\n'
'## {attribute}:\n'
'...\n'
'### 代表性示例摘录1:\n'
'```\n'
'...\n'
'```\n'
'### 代表性示例摘录2:\n'
'```\n'
'...\n'
'```\n'
'...\n')
DEFAULT_INPUT_TEMPLATE = '# 文本\n```\n{text}\n```\n'
DEFAULT_ATTR_PATTERN_TEMPLATE = r'\#\#\s*{attribute}:\s*(.*?)(?=\#\#\#|\Z)'
DEFAULT_DEMON_PATTERN = r'\#\#\#\s*代表性示例摘录(\d+):\s*```\s*(.*?)```\s*(?=\#\#\#|\Z)' # noqa: E501
def __init__(self,
api_model: str = 'gpt-4o',
query_entities: List[str] = [],
query_attributes: List[str] = [],
*,
entity_key: str = MetaKeys.main_entities,
attribute_key: str = MetaKeys.attributes,
attribute_desc_key: str = MetaKeys.attribute_descriptions,
support_text_key: str = MetaKeys.attribute_support_texts,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt_template: Optional[str] = None,
input_template: Optional[str] = None,
attr_pattern_template: Optional[str] = None,
demo_pattern: Optional[str] = None,
try_num: PositiveInt = 3,
drop_text: bool = False,
model_params: Dict = {},
sampling_params: Dict = {},
**kwargs):
"""
Initialization method.
:param api_model: API model name.
:param query_entities: Entity list to be queried.
:param query_attributes: Attribute list to be queried.
:param entity_key: The key name in the meta field to store the
given main entity for attribute extraction. It's "entity" in
default.
:param entity_attribute_key: The key name in the meta field to
store the given attribute to be extracted. It's "attribute"
in default.
:param attribute_desc_key: The key name in the meta field to store
the extracted attribute description. It's
"attribute_description" in default.
:param support_text_key: The key name in the meta field to store
the attribute support text extracted from the raw text.
It's "support_text" in default.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt_template: System prompt template for the
task. Need to be specified by given entity and attribute.
:param input_template: Template for building the model input.
:param attr_pattern_template: Pattern for parsing the attribute from
output. Need to be specified by given attribute.
:param: demo_pattern: Pattern for parsing the demonstraction from
output to support the attribute.
:param try_num: The number of retry attempts when there is an API
call error or output parsing error.
:param drop_text: If drop the text in the output.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)
self.query_entities = query_entities
self.query_attributes = query_attributes
self.entity_key = entity_key
self.attribute_key = attribute_key
self.attribute_desc_key = attribute_desc_key
self.support_text_key = support_text_key
self.system_prompt_template = system_prompt_template \
or self.DEFAULT_SYSTEM_PROMPT_TEMPLATE
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
self.attr_pattern_template = attr_pattern_template \
or self.DEFAULT_ATTR_PATTERN_TEMPLATE
self.demo_pattern = demo_pattern or self.DEFAULT_DEMON_PATTERN
self.sampling_params = sampling_params
self.model_key = prepare_model(model_type='api',
model=api_model,
endpoint=api_endpoint,
response_path=response_path,
**model_params)
self.try_num = try_num
self.drop_text = drop_text
def parse_output(self, raw_output, attribute_name):
attribute_pattern = self.attr_pattern_template.format(
attribute=attribute_name)
pattern = re.compile(attribute_pattern, re.VERBOSE | re.DOTALL)
matches = pattern.findall(raw_output)
if matches:
attribute = matches[0].strip()
else:
attribute = ''
pattern = re.compile(self.demo_pattern, re.VERBOSE | re.DOTALL)
matches = pattern.findall(raw_output)
demos = [demo.strip() for _, demo in matches if demo.strip()]
return attribute, demos
def _process_single_text(self, text='', rank=None):
client = get_model(self.model_key, rank=rank)
entities, attributes, descs, demo_lists = [], [], [], []
for entity in self.query_entities:
for attribute in self.query_attributes:
system_prompt = self.system_prompt_template.format(
entity=entity, attribute=attribute)
input_prompt = self.input_template.format(text=text)
messages = [{
'role': 'system',
'content': system_prompt
}, {
'role': 'user',
'content': input_prompt
}]
desc, demos = '', np.array([], dtype=str)
for _ in range(self.try_num):
try:
output = client(messages, **self.sampling_params)
cur_desc, cur_demos = self.parse_output(
output, attribute)
if cur_desc and len(cur_demos) > 0:
desc = cur_desc
demos = cur_demos
break
except Exception as e:
logger.warning(f'Exception: {e}')
entities.append(entity)
attributes.append(attribute)
descs.append(desc)
demo_lists.append(demos)
return entities, attributes, descs, demo_lists
def process_single(self, sample, rank=None):
# check if it's generated already
if set([
self.entity_key, self.attribute_key, self.attribute_desc_key,
self.support_text_key
]) <= set(sample[Fields.meta].keys()):
return sample
res = self._process_single_text(sample[self.text_key], rank=rank)
entities, attributes, descs, demo_lists = res
if self.drop_text:
sample.pop(self.text_key)
sample[Fields.meta][self.entity_key] = entities
sample[Fields.meta][self.attribute_key] = attributes
sample[Fields.meta][self.attribute_desc_key] = descs
sample[Fields.meta][self.support_text_key] = demo_lists
return sample