generated from A-kirami/nonebot-plugin-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcontext.py
176 lines (143 loc) · 6.03 KB
/
context.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import ast
import inspect
from copy import deepcopy
from typing import Any, ClassVar, Self, cast
import anyio
import anyio.abc
import nonebot
from nonebot.adapters import Bot, Event, Message
from nonebot.internal.matcher import current_bot, current_event
from nonebot.utils import escape_tag
from nonebot_plugin_alconna.uniseg import Image, UniMessage
from nonebot_plugin_session import Session, SessionIdType, extract_session
from .exception import BotEventMismatch, SessionNotInitialized
from .interface import API, Buffer, default_context, get_api_class
from .typings import T_Context, T_Executor
logger = nonebot.logger.opt(colors=True)
EXECUTOR_FUNCTION = """\
last_exc, __exception__ = __exception__, (None, None)
async def __executor__():
try:
...
except BaseException as e:
global __exception__
__exception__ = (e, __import__("traceback").format_exc())
finally:
globals().update({
k: v for k, v in dict(locals()).items()
if not k.startswith("__") and not k.endswith("__")
})
"""
class Context:
__ua2session: ClassVar[dict[tuple[str, str], Session]] = {}
__contexts: ClassVar[dict[str, Self]] = {}
uin: str
ctx: T_Context
lock: anyio.Lock
cancel_scope: anyio.CancelScope | None
def __init__(self, uin: str) -> None:
self.uin = uin
self.ctx = deepcopy(default_context)
self.lock = anyio.Lock()
self.cancel_scope = None
@classmethod
def _session2uin(cls, session: Session | Event | str) -> str:
if isinstance(session, Event):
if current_event.get() is not session:
raise BotEventMismatch
key = (session.get_user_id(), current_bot.get().type)
if key not in cls.__ua2session:
raise SessionNotInitialized(key=key)
session = cls.__ua2session[key]
elif isinstance(session, str):
key = (session, current_bot.get().type)
if key not in cls.__ua2session:
raise SessionNotInitialized(key=key)
session = cls.__ua2session[key]
key = (session.id1 or "", session.bot_type)
if key not in cls.__ua2session:
cls.__ua2session[key] = session.model_copy()
return session.get_id(SessionIdType.USER).replace(" ", "_")
@classmethod
def get_context(cls, session: Session | Event | str) -> Self:
uin = cls._session2uin(session)
if uin not in cls.__contexts:
logger.debug(f"为用户 <y>{uin}</y> 创建 Context")
cls.__contexts[uin] = cls(uin)
return cls.__contexts[uin]
def _solve_code(self, raw_code: str, api: API) -> T_Executor:
assert self.lock.locked(), "`Context._solve_code` called without lock"
parsed = ast.parse(EXECUTOR_FUNCTION, mode="exec")
func_def = cast(ast.AsyncFunctionDef, parsed.body[1])
cast(ast.Try, func_def.body[0]).body[:] = [
ast.Global(names=list(self.ctx)),
*ast.parse(raw_code, mode="exec").body,
]
solved = ast.unparse(parsed)
code = compile(solved, f"<executor_{self.uin}>", "exec")
# 包装为异步函数
exec(code, self.ctx, self.ctx) # noqa: S102
executor = self.ctx.pop(func_def.name)
if inspect.isasyncgenfunction(executor):
_executor = executor
async def executor() -> None:
try:
async for value in _executor():
await api.feedback(repr(value))
except BaseException as err:
import traceback
self.ctx["last_exc"] = self.ctx["__exception__"]
self.ctx["__exception__"] = (err, traceback.format_exc())
return executor
@classmethod
async def execute(cls, bot: Bot, event: Event, code: str) -> None:
session = extract_session(bot, event)
uin = cls._session2uin(session)
self = cls.get_context(session)
api_class = get_api_class(bot)
colored_uin = f"<y>{escape_tag(uin)}</y>"
# 执行代码时加锁,避免出现多段代码分别读写变量
async with self.lock, api_class(bot, event, session, self.ctx) as api:
executor = self._solve_code(code, api)
escaped = escape_tag(repr(executor))
logger.debug(f"为用户 {colored_uin} 创建 executor: {escaped}")
self.cancel_scope = anyio.CancelScope()
with self.cancel_scope:
result = await executor()
self.cancel_scope = None
if buf := Buffer.get(uin).read().rstrip("\n"):
logger.debug(f"用户 {colored_uin} 清空缓冲:")
logger.opt(raw=True).debug(buf)
await UniMessage.text(buf).send()
if result is not None:
result = repr(result)
logger.debug(f"用户 {colored_uin} 输出返回值: {escape_tag(result)}")
await UniMessage.text(result).send()
# 处理异常
if exc := self.ctx.setdefault("__exception__", (None, None))[0]:
raise cast(Exception, exc)
def cancel(self) -> bool:
if self.cancel_scope is None:
return False
self.cancel_scope.cancel()
return True
def set_value(self, varname: str, value: Any) -> None:
if value is not None:
self.ctx[varname] = value
elif varname in self.ctx:
del self.ctx[varname]
def set_gem(self, msg: Message) -> None:
self.set_value("gem", msg)
def set_gurl(self, msg: UniMessage[Image] | Image) -> None:
url: str | None = None
if isinstance(msg, UniMessage) and msg.has(Image):
url = msg[Image, 0].url
elif isinstance(msg, Image):
url = msg.url
self.set_value("gurl", url)
def __getitem__(self, key: str, /) -> Any:
return self.ctx[key]
def __setitem__(self, key: str, value: Any, /) -> None:
self.ctx[key] = value
def __delitem__(self, key: str, /) -> None:
del self.ctx[key]