From 5a6f4b9e1c1c7fade620a9c703f9efbac44eb31f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bryan=E4=B8=8D=E5=8F=AF=E6=80=9D=E8=AE=AE?= Date: Thu, 11 Jan 2024 11:52:07 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20Feature:=20=E5=B8=A6=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E7=9A=84=20`RegexStr()`=20(#2499)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/params.py | 45 ++++++++++++++++++++++++++---- tests/plugins/param/param_state.py | 9 ++++-- tests/test_param.py | 4 ++- 3 files changed, 50 insertions(+), 8 deletions(-) diff --git a/nonebot/params.py b/nonebot/params.py index fc5990b9b248..9f3c60cda8e3 100644 --- a/nonebot/params.py +++ b/nonebot/params.py @@ -5,7 +5,18 @@ description: nonebot.params 模块 """ -from typing import Any, Dict, List, Match, Tuple, Union, Optional +from typing import ( + Any, + Dict, + List, + Match, + Tuple, + Union, + Literal, + Callable, + Optional, + overload, +) from nonebot.typing import T_State from nonebot.matcher import Matcher @@ -147,13 +158,37 @@ def RegexMatched() -> Match[str]: return Depends(_regex_matched, use_cache=False) -def _regex_str(state: T_State) -> str: - return _regex_matched(state).group() +def _regex_str( + groups: Tuple[Union[str, int], ...] +) -> Callable[[T_State], Union[str, Tuple[Union[str, Any], ...], Any]]: + def _regex_str_dependency( + state: T_State, + ) -> Union[str, Tuple[Union[str, Any], ...], Any]: + return _regex_matched(state).group(*groups) + + return _regex_str_dependency + + +@overload +def RegexStr(__group: Literal[0] = 0) -> str: + ... + + +@overload +def RegexStr(__group: Union[str, int]) -> Union[str, Any]: + ... + + +@overload +def RegexStr( + __group1: Union[str, int], __group2: Union[str, int], *groups: Union[str, int] +) -> Tuple[Union[str, Any], ...]: + ... -def RegexStr() -> str: +def RegexStr(*groups: Union[str, int]) -> Union[str, Tuple[Union[str, Any], ...], Any]: """正则匹配结果文本""" - return Depends(_regex_str, use_cache=False) + return Depends(_regex_str(groups), use_cache=False) def _regex_group(state: T_State) -> Tuple[Any, ...]: diff --git a/tests/plugins/param/param_state.py b/tests/plugins/param/param_state.py index 06731adac5bd..f513ecd67045 100644 --- a/tests/plugins/param/param_state.py +++ b/tests/plugins/param/param_state.py @@ -77,8 +77,13 @@ async def regex_matched(regex_matched: Match[str] = RegexMatched()) -> Match[str return regex_matched -async def regex_str(regex_str: str = RegexStr()) -> str: - return regex_str +async def regex_str( + entire: str = RegexStr(), + type_: str = RegexStr("type"), + second: str = RegexStr(2), + groups: Tuple[str, ...] = RegexStr(1, "arg"), +) -> Tuple[str, str, str, Tuple[str, ...]]: + return entire, type_, second, groups async def startswith(startswith: str = Startswith()) -> str: diff --git a/tests/test_param.py b/tests/test_param.py index 3bbf70aec81f..8a5323d26577 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -361,7 +361,9 @@ async def test_state(app: App): regex_str, allow_types=[StateParam, DependParam] ) as ctx: ctx.pass_params(state=fake_state) - ctx.should_return("[cq:test,arg=value]") + ctx.should_return( + ("[cq:test,arg=value]", "test", "arg=value", ("test", "arg=value")) + ) async with app.test_dependent( regex_group, allow_types=[StateParam, DependParam]