From 4ed151c5d0f7af5a3af41628bfc4340b8a4b5a47 Mon Sep 17 00:00:00 2001 From: rf_tar_railt <3165388245@qq.com> Date: Sun, 7 Jul 2024 20:23:01 +0800 Subject: [PATCH] :bookmark: version 0.49.0 --- src/nonebot_plugin_alconna/pattern.py | 142 +++++++++---------- src/nonebot_plugin_alconna/uniseg/message.py | 5 +- tests/test_satori.py | 4 +- 3 files changed, 70 insertions(+), 81 deletions(-) diff --git a/src/nonebot_plugin_alconna/pattern.py b/src/nonebot_plugin_alconna/pattern.py index 7517ae5..1f4d15b 100644 --- a/src/nonebot_plugin_alconna/pattern.py +++ b/src/nonebot_plugin_alconna/pattern.py @@ -1,6 +1,8 @@ -from typing import Any, Union, Literal, Optional, overload +from typing_extensions import deprecated +from typing import Any, Union, Generic, Literal, TypeVar, Callable, Optional -from nepattern import MatchMode, BasePattern, func +from tarina import lang +from nepattern import MatchMode, BasePattern, MatchFailed, func from .uniseg import segment @@ -19,24 +21,60 @@ Reference = BasePattern.of(segment.Reference) -@overload -def select( - seg: Union[type[segment.TS], BasePattern[segment.TS, segment.Segment, Any]], -) -> BasePattern[list[segment.TS], segment.Segment, Literal[MatchMode.TYPE_CONVERT]]: ... +TS = TypeVar("TS", bound=segment.Segment) +TS1 = TypeVar("TS1", bound=segment.Segment) +TS2 = TypeVar("TS2", bound=segment.Segment) -@overload -def select( - seg: Union[type[segment.TS], BasePattern[segment.TS, segment.Segment, Any]], index: int = 0 -) -> BasePattern[segment.TS, segment.Segment, Literal[MatchMode.TYPE_CONVERT]]: ... +class SelectPattern(BasePattern[list[TS], TS2, Literal[MatchMode.TYPE_CONVERT]], Generic[TS, TS2]): + def __init__( + self, + target: type[TS], + converter: Callable[[Any, TS2], Optional[list[TS]]], + ): + super().__init__( + mode=MatchMode.TYPE_CONVERT, + origin=list[target], + converter=converter, + alias=f"select({target.__name__})", + ) + self._accepts = (segment.Segment,) + + def match(self, input_: TS2): + if not isinstance(input_, self._accepts): + raise MatchFailed( + lang.require("nepattern", "type_error").format( + type=input_.__class__, target=input_, expected=self.alias + ) + ) + if (res := self.converter(self, input_)) is None: + raise MatchFailed(lang.require("nepattern", "content_error").format(target=input_, expected=self.alias)) + return res # type: ignore + + def nth(self, index: int): + return func.Index(self, index) + + @property + def first(self): + return func.Index(self, 0) + + @property + def last(self): + return func.Index(self, -1) + + def from_(self, seg: Union[type[TS1], BasePattern[TS1, segment.Segment, Any]]) -> "SelectPattern[TS, TS1]": + _self = self.copy() + if isinstance(seg, BasePattern): + _type = seg.origin + else: + _type = seg + _self._accepts = (_type,) + return _self # type: ignore def select( - seg: Union[type[segment.TS], BasePattern[segment.TS, segment.Segment, Any]], index: Optional[int] = None -) -> Union[ - BasePattern[list[segment.TS], segment.Segment, Literal[MatchMode.TYPE_CONVERT]], - BasePattern[segment.TS, segment.Segment, Literal[MatchMode.TYPE_CONVERT]], -]: + seg: Union[type[TS], BasePattern[TS, segment.Segment, Any]], +) -> SelectPattern[TS, segment.Segment]: if isinstance(seg, BasePattern): _type = seg.origin @@ -47,27 +85,7 @@ def query(segs: list[segment.Segment]): yield res.value() yield from query(s.children) - if index is None: - - def converter(self, _seg: segment.Segment): - results = [] - _res = seg.validate(_seg) - if _res.success: - results.append(_res.value()) - results.extend(query(_seg.children)) - if not results: - return None - return results - - return BasePattern( - mode=MatchMode.TYPE_CONVERT, - origin=list[segment.TS], - converter=converter, - accepts=segment.Segment, - alias=f"select({_type.__name__})", - ) - - def converter1(self, _seg: segment.Segment): + def converter(self, _seg: segment.Segment): results = [] _res = seg.validate(_seg) if _res.success: @@ -75,18 +93,10 @@ def converter1(self, _seg: segment.Segment): results.extend(query(_seg.children)) if not results: return None - return results[index] - - return BasePattern( - mode=MatchMode.TYPE_CONVERT, - origin=_type, - converter=converter1, - accepts=segment.Segment, - alias=f"select({_type.__name__})[{index}]", - ) + return results else: - _type = seg + _type: type[TS] = seg def query1(segs: list[segment.Segment]): for s in segs: @@ -94,53 +104,33 @@ def query1(segs: list[segment.Segment]): yield s yield from query1(s.children) - if index is None: - - def converter(self, _seg: segment.Segment): - results = [] - if isinstance(_seg, _type): - results.append(_seg) - results.extend(query1(_seg.children)) - if not results: - return None - return results - - return BasePattern( - mode=MatchMode.TYPE_CONVERT, - origin=list[segment.TS], - converter=converter, - accepts=segment.Segment, - alias=f"select({_type.__name__})", - ) - - def converter2(self, _seg: segment.Segment): + def converter(self, _seg: segment.Segment): results = [] if isinstance(_seg, _type): results.append(_seg) results.extend(query1(_seg.children)) if not results: return None - return results[index] + return results - return BasePattern( - mode=MatchMode.TYPE_CONVERT, - origin=_type, - converter=converter2, - accepts=segment.Segment, - alias=f"select({_type.__name__})[{index}]", - ) + return SelectPattern( + target=_type, + converter=converter, + ) +@deprecated("Use `select().first` instead.") def select_first( seg: Union[type[segment.TS], BasePattern[segment.TS, segment.Segment, Any]] ) -> BasePattern[segment.TS, segment.Segment, Literal[MatchMode.TYPE_CONVERT]]: - return select(seg, 0) + return select(seg).first +@deprecated("Use `select().last` instead.") def select_last( seg: Union[type[segment.TS], BasePattern[segment.TS, segment.Segment, Any]] ) -> BasePattern[segment.TS, segment.Segment, Literal[MatchMode.TYPE_CONVERT]]: - return select(seg, -1) + return select(seg).last patterns = { diff --git a/src/nonebot_plugin_alconna/uniseg/message.py b/src/nonebot_plugin_alconna/uniseg/message.py index a96e407..60d3df2 100644 --- a/src/nonebot_plugin_alconna/uniseg/message.py +++ b/src/nonebot_plugin_alconna/uniseg/message.py @@ -961,7 +961,7 @@ async def export( return await fn.export(self, bot, fallback) raise SerializeFailed(lang.require("nbp-uniseg", "unsupported").format(adapter=adapter)) except SerializeFailed: - if fallback: + if fallback and fallback != FallbackStrategy.forbid: return FallbackMessage(str(self)) raise @@ -974,10 +974,9 @@ def export_sync( """(实验性)同步方法地将 UniMessage 转换为指定适配器下的 Message""" coro = self.export(bot, fallback, adapter) try: - coro.send(None) + return coro.send(None) except StopIteration as e: return e.args[0] - raise SerializeFailed(lang.require("nbp-uniseg", "unsupported").format(adapter=adapter)) async def send( self, diff --git a/tests/test_satori.py b/tests/test_satori.py index a09618e..d69fb5c 100644 --- a/tests/test_satori.py +++ b/tests/test_satori.py @@ -10,7 +10,7 @@ def test_message_rollback(): - from nonebot_plugin_alconna import Image, select_first + from nonebot_plugin_alconna import Image, select text = """\ 捏 @@ -23,7 +23,7 @@ def test_message_rollback(): msg1 = Message.from_satori_element(parse(text1)) - alc = Alconna("捏", Args["img", Dot(select_first(Image), str, "url")]) + alc = Alconna("捏", Args["img", Dot(select(Image).first, str, "url")]) res = alc.parse(msg, {"$adapter.name": "Satori"}) assert res.matched