Skip to content

Commit

Permalink
🔖 version 0.49.0
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Jul 7, 2024
1 parent 4839746 commit 4ed151c
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 81 deletions.
142 changes: 66 additions & 76 deletions src/nonebot_plugin_alconna/pattern.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any, Union, Literal, Optional, overload
from typing_extensions import deprecated
from typing import Any, Union, Generic, Literal, TypeVar, Callable, Optional

from nepattern import MatchMode, BasePattern, func
from tarina import lang
from nepattern import MatchMode, BasePattern, MatchFailed, func

from .uniseg import segment

Expand All @@ -19,24 +21,60 @@
Reference = BasePattern.of(segment.Reference)


@overload
def select(
seg: Union[type[segment.TS], BasePattern[segment.TS, segment.Segment, Any]],
) -> BasePattern[list[segment.TS], segment.Segment, Literal[MatchMode.TYPE_CONVERT]]: ...
TS = TypeVar("TS", bound=segment.Segment)
TS1 = TypeVar("TS1", bound=segment.Segment)
TS2 = TypeVar("TS2", bound=segment.Segment)


@overload
def select(
seg: Union[type[segment.TS], BasePattern[segment.TS, segment.Segment, Any]], index: int = 0
) -> BasePattern[segment.TS, segment.Segment, Literal[MatchMode.TYPE_CONVERT]]: ...
class SelectPattern(BasePattern[list[TS], TS2, Literal[MatchMode.TYPE_CONVERT]], Generic[TS, TS2]):
def __init__(
self,
target: type[TS],
converter: Callable[[Any, TS2], Optional[list[TS]]],
):
super().__init__(
mode=MatchMode.TYPE_CONVERT,
origin=list[target],
converter=converter,
alias=f"select({target.__name__})",
)
self._accepts = (segment.Segment,)

def match(self, input_: TS2):
if not isinstance(input_, self._accepts):
raise MatchFailed(
lang.require("nepattern", "type_error").format(
type=input_.__class__, target=input_, expected=self.alias
)
)
if (res := self.converter(self, input_)) is None:
raise MatchFailed(lang.require("nepattern", "content_error").format(target=input_, expected=self.alias))
return res # type: ignore

def nth(self, index: int):
return func.Index(self, index)

@property
def first(self):
return func.Index(self, 0)

@property
def last(self):
return func.Index(self, -1)

def from_(self, seg: Union[type[TS1], BasePattern[TS1, segment.Segment, Any]]) -> "SelectPattern[TS, TS1]":
_self = self.copy()
if isinstance(seg, BasePattern):
_type = seg.origin
else:
_type = seg
_self._accepts = (_type,)
return _self # type: ignore


def select(
seg: Union[type[segment.TS], BasePattern[segment.TS, segment.Segment, Any]], index: Optional[int] = None
) -> Union[
BasePattern[list[segment.TS], segment.Segment, Literal[MatchMode.TYPE_CONVERT]],
BasePattern[segment.TS, segment.Segment, Literal[MatchMode.TYPE_CONVERT]],
]:
seg: Union[type[TS], BasePattern[TS, segment.Segment, Any]],
) -> SelectPattern[TS, segment.Segment]:
if isinstance(seg, BasePattern):
_type = seg.origin

Expand All @@ -47,100 +85,52 @@ def query(segs: list[segment.Segment]):
yield res.value()
yield from query(s.children)

if index is None:

def converter(self, _seg: segment.Segment):
results = []
_res = seg.validate(_seg)
if _res.success:
results.append(_res.value())
results.extend(query(_seg.children))
if not results:
return None
return results

return BasePattern(
mode=MatchMode.TYPE_CONVERT,
origin=list[segment.TS],
converter=converter,
accepts=segment.Segment,
alias=f"select({_type.__name__})",
)

def converter1(self, _seg: segment.Segment):
def converter(self, _seg: segment.Segment):
results = []
_res = seg.validate(_seg)
if _res.success:
results.append(_res.value())
results.extend(query(_seg.children))
if not results:
return None
return results[index]

return BasePattern(
mode=MatchMode.TYPE_CONVERT,
origin=_type,
converter=converter1,
accepts=segment.Segment,
alias=f"select({_type.__name__})[{index}]",
)
return results

else:
_type = seg
_type: type[TS] = seg

def query1(segs: list[segment.Segment]):
for s in segs:
if isinstance(s, _type):
yield s
yield from query1(s.children)

if index is None:

def converter(self, _seg: segment.Segment):
results = []
if isinstance(_seg, _type):
results.append(_seg)
results.extend(query1(_seg.children))
if not results:
return None
return results

return BasePattern(
mode=MatchMode.TYPE_CONVERT,
origin=list[segment.TS],
converter=converter,
accepts=segment.Segment,
alias=f"select({_type.__name__})",
)

def converter2(self, _seg: segment.Segment):
def converter(self, _seg: segment.Segment):
results = []
if isinstance(_seg, _type):
results.append(_seg)
results.extend(query1(_seg.children))
if not results:
return None
return results[index]
return results

return BasePattern(
mode=MatchMode.TYPE_CONVERT,
origin=_type,
converter=converter2,
accepts=segment.Segment,
alias=f"select({_type.__name__})[{index}]",
)
return SelectPattern(
target=_type,
converter=converter,
)


@deprecated("Use `select().first` instead.")
def select_first(
seg: Union[type[segment.TS], BasePattern[segment.TS, segment.Segment, Any]]
) -> BasePattern[segment.TS, segment.Segment, Literal[MatchMode.TYPE_CONVERT]]:
return select(seg, 0)
return select(seg).first


@deprecated("Use `select().last` instead.")
def select_last(
seg: Union[type[segment.TS], BasePattern[segment.TS, segment.Segment, Any]]
) -> BasePattern[segment.TS, segment.Segment, Literal[MatchMode.TYPE_CONVERT]]:
return select(seg, -1)
return select(seg).last


patterns = {
Expand Down
5 changes: 2 additions & 3 deletions src/nonebot_plugin_alconna/uniseg/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ async def export(
return await fn.export(self, bot, fallback)
raise SerializeFailed(lang.require("nbp-uniseg", "unsupported").format(adapter=adapter))
except SerializeFailed:
if fallback:
if fallback and fallback != FallbackStrategy.forbid:
return FallbackMessage(str(self))
raise

Expand All @@ -974,10 +974,9 @@ def export_sync(
"""(实验性)同步方法地将 UniMessage 转换为指定适配器下的 Message"""
coro = self.export(bot, fallback, adapter)
try:
coro.send(None)
return coro.send(None)
except StopIteration as e:
return e.args[0]
raise SerializeFailed(lang.require("nbp-uniseg", "unsupported").format(adapter=adapter))

async def send(
self,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_satori.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def test_message_rollback():
from nonebot_plugin_alconna import Image, select_first
from nonebot_plugin_alconna import Image, select

text = """\
捏<chronocat:marketface tab-id="237834" face-id="a651cf5813ba41587b22d273682e01ae" key="e08787120cade0a5">
Expand All @@ -23,7 +23,7 @@ def test_message_rollback():

msg1 = Message.from_satori_element(parse(text1))

alc = Alconna("捏", Args["img", Dot(select_first(Image), str, "url")])
alc = Alconna("捏", Args["img", Dot(select(Image).first, str, "url")])

res = alc.parse(msg, {"$adapter.name": "Satori"})
assert res.matched
Expand Down

0 comments on commit 4ed151c

Please sign in to comment.