-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrewriter.py
61 lines (46 loc) · 1.93 KB
/
rewriter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import re
from functools import lru_cache
from typing import Dict, Tuple
from litellm import completion
@lru_cache(maxsize=1)
def load_prompts() -> Dict[str, str]:
prompts = {}
prompts_dir = "prompts"
for filename in os.listdir(prompts_dir):
if filename.endswith(".md"):
service_type = filename[:-3]
with open(os.path.join(prompts_dir, filename), "r", encoding="utf-8") as file:
prompts[service_type] = file.read()
return prompts
def remove_xml_tags(text: str) -> str:
return re.sub(r"<[^>]+>", "", text)
def rewrite_text(text: str, prompt: str, model: str) -> Tuple[str, str, str]:
response = completion(
model=model,
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": text},
],
max_tokens=4000,
)
full_response = response.choices[0].message.content.strip()
output_match = re.search(r"<output>(.*?)</output>", full_response, re.DOTALL)
explanation_match = re.search(r"<explanation>(.*?)(?:</explanation>|$)", full_response, re.DOTALL)
output = remove_xml_tags(output_match.group(1).strip()) if output_match else "未提供输出"
explanation = remove_xml_tags(explanation_match.group(1).strip()) if explanation_match else "未提供解释"
return output, explanation, full_response
def save_prompt(prompt_name: str, content: str) -> None:
prompts_dir = "prompts"
file_path = os.path.join(prompts_dir, f"{prompt_name}.md")
with open(file_path, "w", encoding="utf-8") as file:
file.write(content)
load_prompts.cache_clear()
def delete_prompt(prompt_name: str) -> None:
prompts_dir = "prompts"
file_path = os.path.join(prompts_dir, f"{prompt_name}.md")
if os.path.exists(file_path):
os.remove(file_path)
load_prompts.cache_clear()
else:
raise FileNotFoundError(f"Prompt 文件 {prompt_name}.md 不存在")