-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathranking_agent.py
117 lines (105 loc) · 4.7 KB
/
ranking_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import logging
from typing import Dict, List, Optional, Union
from erniebot_agent.chat_models.erniebot import BaseERNIEBot
from erniebot_agent.memory import HumanMessage, Message, SystemMessage
from erniebot_agent.prompt import PromptTemplate
from tools.utils import JsonUtil, ReportCallbackHandler
logger = logging.getLogger(__name__)
MAX_RETRY = 10
TOKEN_MAX_LENGTH = 4200
def get_markdown_check_prompt(report):
prompt_markdow_str = """
现在给你1篇报告,你需要判断报告是不是markdown格式,并给出理由。你需要输出判断理由以及判断结果,判断结果是报告是markdown形式或者报告不是markdown格式
你的输出结果应该是个json形式,包括两个键值,一个是"reason",一个是"accept",如果你认为报告是markdown形式,则"accept"取值为true,如果你认为报告不是markdown形式,则"accept"取值为false,
你需要判断报告是不是markdown格式,并给出理由
{"reason":...,"accept":...}
报告:{{report}}
"""
prompt_markdow = PromptTemplate(prompt_markdow_str, input_variables=["report"])
return prompt_markdow.format(report=report)
class RankingAgent(JsonUtil):
DEFAULT_SYSTEM_MESSAGE = """你是一个排序助手,你的任务就是对给定的内容和query的相关性进行排序."""
def __init__(
self,
name: str,
ranking_tool,
llm: BaseERNIEBot,
llm_long: BaseERNIEBot,
system_message: Optional[SystemMessage] = None,
callbacks=None,
is_reset=False,
) -> None:
self.name = name
self.system_message = (
system_message.content
if system_message is not None
else self.DEFAULT_SYSTEM_MESSAGE
)
self.llm = llm
self.llm_long = llm_long
self.ranking_tool = ranking_tool
self.is_reset = is_reset
if callbacks is None:
self._callback_manager = ReportCallbackHandler()
else:
self._callback_manager = callbacks
async def run(self, list_reports: List[Union[str, Dict]], query: str):
await self._callback_manager.on_run_start(
agent=self, agent_name=self.name, prompt=query
)
agent_resp = await self._run(query=query, list_reports=list_reports)
await self._callback_manager.on_run_end(agent=self, response=agent_resp)
return agent_resp
async def _run(self, query: str, list_reports: List[Union[str, dict]]):
await self._callback_manager.on_tool_start(
agent=self, tool=self.ranking_tool, input_args=list_reports
)
reports = []
for item in list_reports:
if isinstance(item, dict):
format_check = await self.check_format(item["report"])
else:
format_check = await self.check_format(item)
if format_check:
reports.append(item)
if len(reports) == 0:
if self.is_reset:
logger.info("所有的report都不是markdown格式,重新生成report")
return [], None
else:
reports = list_reports
response = await self.ranking_tool(reports, query)
await self._callback_manager.on_tool_end(
agent=self, tool=self.ranking_tool, response=response
)
return reports, response
async def check_format(self, report: str):
retry_count = 0
while True:
try:
content = get_markdown_check_prompt(report)
messages: List[Message] = [HumanMessage(content=content)]
if len(content) < TOKEN_MAX_LENGTH:
response = await self.llm.chat(messages=messages, temperature=0.001)
else:
response = await self.llm_long.chat(
messages=messages, temperature=0.001
)
result = response.content
logger.info(f"check report format result: {result}")
result_dict = self.parse_json(result)
if result_dict["accept"] is True or result_dict["accept"] == "true":
return True
elif result_dict["accept"] is False or result_dict["accept"] == "false":
return False
except Exception as e:
await self._callback_manager.on_tool_error(
self, tool=self.ranking_tool, error=e
)
logger.error(e)
retry_count += 1
if retry_count > MAX_RETRY:
raise Exception(
f"Failed to check report format after {MAX_RETRY} times."
)
continue