-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpolish_agent.py
173 lines (161 loc) · 7 KB
/
polish_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import logging
import time
from typing import Any, List, Optional
from erniebot_agent.chat_models.erniebot import BaseERNIEBot
from erniebot_agent.memory import HumanMessage, Message, SystemMessage
from erniebot_agent.prompt import PromptTemplate
from tools.semantic_citation_tool import SemanticCitationTool
from tools.utils import (JsonUtil, ReportCallbackHandler, add_citation,
write_md_to_pdf)
logger = logging.getLogger(__name__)
TOKEN_MAX_LENGTH = 4200
class PolishAgent(JsonUtil):
DEFAULT_SYSTEM_MESSAGE = "你是一个报告润色助手,你的主要工作是报告进行内容上的润色"
template_abstract = """
请你总结报告并给出报告的摘要和关键词,摘要在100-200字之间,关键词不超过5个词。
你需要输出一个json形式的字符串,内容为{"abstract":...,"keywords":...}。
现在给你报告的内容:
{{report}}"""
template_polish = """你的任务是扩写和润色相关内容,
你需要把相关内容扩写到300-400字之间,扩写的内容必须与给出的内容相关。
下面给出内容:
{{content}}
扩写并润色内容为:"""
def __init__(
self,
name: str,
llm: BaseERNIEBot,
llm_long: BaseERNIEBot,
citation_tool: SemanticCitationTool,
embeddings: Any,
citation_index_name: str,
dir_path: str,
report_type: str,
build_index_function: Any,
search_tool: Any,
system_message: Optional[SystemMessage] = None,
callbacks=None,
):
self.name = name
self.llm = llm
self.llm_long = llm_long
self.report_type = report_type
self.dir_path = dir_path
self.embeddings = embeddings
self.citation_tool = citation_tool
self.citation_index_name = citation_index_name
self.system_message = (
system_message.content
if system_message is not None
else self.DEFAULT_SYSTEM_MESSAGE
)
self.prompt_template_abstract = PromptTemplate(
template=self.template_abstract, input_variables=["report"]
)
self.prompt_template_polish = PromptTemplate(
template=self.template_polish, input_variables=["content"]
)
self.build_index_function = build_index_function
self.search_tool = search_tool
if callbacks is None:
self._callback_manager = ReportCallbackHandler()
else:
self._callback_manager = callbacks
async def run(self, report: str, summarize=None):
await self._callback_manager.on_run_start(agent=self, prompt=report)
agent_resp = await self._run(report, summarize)
await self._callback_manager.on_run_end(agent=self, response=agent_resp)
return agent_resp
async def add_abstract(self, report: str):
while True:
try:
content = self.prompt_template_abstract.format(report=report)
messages: List[Message] = [HumanMessage(content)]
if len(content) > TOKEN_MAX_LENGTH:
reponse = await self.llm_long.chat(messages)
else:
reponse = await self.llm.chat(messages)
res = reponse.content
abstract_json = self.parse_json(res)
abstract = abstract_json["abstract"]
key = abstract_json["keywords"]
if type(key) is list:
key = ",".join(key)
return abstract, key
except Exception as e:
await self._callback_manager.on_llm_error(self, self.llm, e)
continue
async def polish_paragraph(self, report: str, abstract: str, key: str):
report_list = [item for item in report.split("\n\n") if item.strip() != ""]
if "#" in report_list[0]:
paragraphs = [report_list[0]]
if "##" in report_list[1]:
paragraphs.append("**摘要** " + abstract)
paragraphs.append("**关键词** " + key)
content = ""
for item in report_list[1:]:
# paragraphs
if "#" not in item:
content += item + "\n"
# Title
else:
# Not to polish
if len(content) > 300:
paragraphs.append(content)
# Polishing
elif len(content) > 0:
content = self.prompt_template_polish.format(content=content)
messages: List[Message] = [HumanMessage(content)]
try:
reponse = await self.llm.chat(messages)
except Exception as e:
await self._callback_manager.on_llm_error(self, self.llm, e)
time.sleep(0.5)
reponse = await self.llm.chat(messages)
paragraphs.append(reponse.content)
content = ""
# Add title to
paragraphs.append(item)
# The last paragraph
if len(content) > 0:
content = self.prompt_template_polish.format(content=content)
messages = [HumanMessage(content)]
try:
reponse = await self.llm.chat(messages)
except Exception as e:
await self._callback_manager.on_llm_error(self, self.llm, e)
time.sleep(0.5)
reponse = await self.llm.chat(messages)
paragraphs.append(reponse.content)
# Generate Citations
final_report = "\n\n".join(paragraphs)
else:
logging.error("Report format error, unable to add abstract and keywords")
final_report = report
return final_report
async def _run(self, report, summarize=None):
abstract, key = await self.add_abstract(report)
final_report = await self.polish_paragraph(report, abstract, key)
await self._callback_manager.on_tool_start(
self, tool=self.citation_tool, input_args=final_report
)
if summarize is not None:
citation_search = add_citation(
summarize,
self.citation_index_name,
self.embeddings,
self.build_index_function,
self.search_tool,
)
final_report, path = await self.citation_tool(
report=final_report,
agent_name=self.name,
report_type=self.report_type,
dir_path=self.dir_path,
citation_faiss_research=citation_search,
)
path = write_md_to_pdf(self.report_type, self.dir_path, final_report)
await self._callback_manager.on_tool_end(
self, tool=self.citation_tool, response={"report": final_report}
)
return final_report, path