Skip to content

Commit

Permalink
✨ support receive forward
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Oct 21, 2023
1 parent fed99c9 commit e85970f
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 55 deletions.
15 changes: 9 additions & 6 deletions src/nonebot_plugin_alconna/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import weakref
from weakref import ref
from types import FunctionType
from datetime import datetime, timedelta
from typing import Any, Union, Callable, ClassVar, Iterable, NoReturn, Protocol, TYPE_CHECKING
from functools import partial
from typing import TYPE_CHECKING, Any, Union, Callable, ClassVar, Iterable, NoReturn, Protocol

from nonebot.rule import Rule
from nonebot import get_driver
Expand Down Expand Up @@ -85,13 +85,13 @@ def _validate(target: Arg[Any], arg: MessageSegment):


class _method:
def __init__(self, func: Callable[..., Any]):
def __init__(self, func: FunctionType):
self.__func__ = func

def __get__(self, instance, owner):
if instance is None:
return partial(self.__func__, owner)
return partial(self.__func__, instance)
return self.__func__.__get__(owner, owner)
return self.__func__.__get__(instance, owner)


class AlconnaMatcher(Matcher):
Expand Down Expand Up @@ -137,14 +137,17 @@ def _decorator(func: T_Handler) -> T_Handler:
return _decorator

if TYPE_CHECKING:

@classmethod
def set_path_arg(cls_or_self, path: str, content: Any) -> None:
...

@classmethod
def get_path_arg(self, path: str, default: Any) -> Any:
...

else:

@_method
def set_path_arg(cls_or_self, path: str, content: Any) -> None:
"""设置一个 `got_path` 内容"""
Expand Down Expand Up @@ -680,7 +683,7 @@ async def handle(results: AlcExecResult):
if res := results.get(func.__name__):
if is_awaitable(res):
res = await res
if isinstance(res, (str, Message, MessageSegment, Segment, UniMessage)):
if isinstance(res, (str, Message, MessageSegment, Segment, UniMessage, UniMessageTemplate)):
await matcher.send(res, fallback=True)

return matcher
Expand Down
1 change: 0 additions & 1 deletion src/nonebot_plugin_alconna/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def matched(self) -> bool:


class CompConfig(TypedDict):
priority: NotRequired[int]
tab: NotRequired[str]
enter: NotRequired[str]
exit: NotRequired[str]
Expand Down
44 changes: 21 additions & 23 deletions src/nonebot_plugin_alconna/rule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import TYPE_CHECKING, List, Type, Union, Optional, cast
from typing import List, Type, Union, Optional, cast

from nonebot import get_driver
from nonebot.typing import T_State
Expand Down Expand Up @@ -89,39 +89,38 @@ def __init__(
self._session = None
self._future: asyncio.Future = asyncio.Future()
self._interface = CompSession(self.command)
self._waiter = None
self._waiter = on_message(
priority=0,
block=True,
rule=Rule(lambda: self._session is not None),
)
self._waiter.destroy()
if self.comp_config is not None:
_tab = self.comp_config.get("tab", ".tab")
_enter = self.comp_config.get("enter", ".enter")
_exit = self.comp_config.get("exit", ".exit")
_waiter = on_message(
priority=self.comp_config.get("priority", 0),
block=True,
rule=Rule(lambda: self._session is not None),
)
_waiter.destroy()

@_waiter.handle()
@self._waiter.handle()
async def _waiter_handle(_bot: Bot, _event: Event, content: Message = EventMessage()):
msg = str(content)
if msg.startswith(_exit):
if msg == _exit:
self._future.set_result(False)
await _waiter.finish()
await self._waiter.finish()
else:
self._future.set_result(None)
await _waiter.pause(
await self._waiter.pause(
lang.require("analyser", "param_unmatched").format(
target=msg.replace(_exit, "", 1)
)
)
elif msg.startswith(_enter):
if msg == _enter:
self._future.set_result(True)
await _waiter.finish()
await self._waiter.finish()
else:
self._future.set_result(None)
await _waiter.pause(
await self._waiter.pause(
lang.require("analyser", "param_unmatched").format(
target=msg.replace(_enter, "", 1)
)
Expand All @@ -132,20 +131,20 @@ async def _waiter_handle(_bot: Bot, _event: Event, content: Message = EventMessa
offset = int(offset)
except ValueError:
self._future.set_result(None)
await _waiter.pause(lang.require("analyser", "param_unmatched").format(target=offset))
await self._waiter.pause(
lang.require("analyser", "param_unmatched").format(target=offset)
)
else:
self._interface.tab(offset)
if self.comp_config is not None and self.comp_config.get("lite", False):
out = f"* {self._interface.current()}"
else:
out = "\n".join(self._interface.lines())
self._future.set_result(None)
await _waiter.pause(out)
await self._waiter.pause(out)
else:
self._future.set_result(content)
await _waiter.finish()

self._waiter = _waiter
await self._waiter.finish()

def __repr__(self) -> str:
return f"Alconna(command={self.command!r})"
Expand All @@ -165,8 +164,6 @@ async def handle(self, bot: Bot, event: Event, msg: Message):
if res:
return res
self._session = event.get_session_id()
if TYPE_CHECKING:
assert self._waiter is not None
self._waiter.permission = Permission(User.from_event(event))
matchers[self._waiter.priority].append(self._waiter)
res = Arparma(
Expand Down Expand Up @@ -248,9 +245,10 @@ async def __call__(self, event: Event, state: T_State, bot: Bot) -> bool:
exec_result = self.command.exec_result
for key, value in exec_result.items():
if is_awaitable(value):
exec_result[key] = await value
elif isinstance(value, (str, Message)):
exec_result[key] = await bot.send(event, value)
value = await value
if isinstance(value, (str, Message)):
value = await bot.send(event, value)
exec_result[key] = value
state[ALCONNA_EXEC_RESULT] = exec_result
state[ALCONNA_EXTENSION] = self.executor.context
return True
Expand Down
6 changes: 6 additions & 0 deletions src/nonebot_plugin_alconna/uniseg/adapters/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ def get_adapter(cls) -> str:
@export
async def text(self, seg: Text, bot: Bot) -> "MessageSegment":
ms = self.segment_class
if seg.style.startswith("markup"):
_style = seg.style.split(":", 1)[-1]
return ms.markup(seg.text, _style)
if seg.style.startswith("markdown"):
code_theme = seg.style.split(":", 1)[-1]
return ms.markdown(seg.text, code_theme)
return ms.text(seg.text)

@export
Expand Down
2 changes: 2 additions & 0 deletions src/nonebot_plugin_alconna/uniseg/adapters/kook.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def get_adapter(cls) -> str:
@export
async def text(self, seg: Text, bot: Bot) -> "MessageSegment":
ms = self.segment_class
if "markdown" in seg.style:
return ms.KMarkdown(seg.text)
return ms.text(seg.text)

@export
Expand Down
2 changes: 2 additions & 0 deletions src/nonebot_plugin_alconna/uniseg/adapters/qq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def get_message_type(self):
async def text(self, seg: Text, bot: Bot) -> "MessageSegment":
ms = self.segment_class

if seg.style == "markdown":
return ms.markdown(seg.text)
return ms.text(seg.text)

@export
Expand Down
Loading

0 comments on commit e85970f

Please sign in to comment.