Skip to content

Commit

Permalink
✨ update custom DI
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Oct 22, 2023
1 parent 3acccc8 commit 92b5644
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 30 deletions.
5 changes: 3 additions & 2 deletions example/plugins/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,15 +389,16 @@ async def statis_h():
sources = [cmd.meta.extra["matcher.source"] for cmd in cmds]
await statis.finish(UniMessage(f"sources: {sources}"))


alc = Alconna(
"添加教师",
Arg("name", str, Field(completion=lambda: "请输入姓名")),
Arg("phone", int, Field(completion=lambda: "请输入手机号"))
Arg("phone", int, Field(completion=lambda: "请输入手机号")),
)

cmd = on_alconna(alc, comp_config={})


@cmd.handle()
async def handle(name: str, phone: int):
await cmd.finish(f"姓名:{name}\n手机号:{phone}")
await cmd.finish(f"姓名:{name}\n手机号:{phone}")
1 change: 1 addition & 0 deletions src/nonebot_plugin_alconna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from .params import match_path as match_path
from .uniseg import UniMessage as UniMessage
from .extension import Extension as Extension
from .extension import Interface as Interface
from .matcher import funcommand as funcommand
from .matcher import on_alconna as on_alconna
from .tools import image_fetch as image_fetch
Expand Down
29 changes: 20 additions & 9 deletions src/nonebot_plugin_alconna/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import functools
import importlib as imp
from weakref import finalize
from dataclasses import dataclass
from typing_extensions import Self
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Union, Literal, TypeVar, Any
from typing import TYPE_CHECKING, Any, Union, Generic, Literal, TypeVar

from tarina import lang
from nonebot.typing import T_State
Expand All @@ -19,11 +20,21 @@

OutputType = Literal["help", "shortcut", "completion"]
TM = TypeVar("TM", bound=Union[str, Message, UniMessage])
TE = TypeVar("TE", bound=Event)

if TYPE_CHECKING:
from .rule import AlconnaRule


@dataclass
class Interface(Generic[TE]):
event: TE
state: T_State
name: str
annotation: Any
default: Any


class Extension(metaclass=ABCMeta):
_overrides: dict[str, bool]

Expand Down Expand Up @@ -73,7 +84,7 @@ async def message_provider(
return None
return msg

async def receive_wrapper(self, bot: Bot, event: Event, receive: TM) -> TM:
async def receive_wrapper(self, bot: Bot, event: Event, command: Alconna, receive: TM) -> TM:
"""接收消息后的钩子函数。"""
return receive

Expand All @@ -85,7 +96,7 @@ async def send_wrapper(self, bot: Bot, event: Event, send: TM) -> TM:
"""发送消息前的钩子函数。"""
return send

async def catch(self, state: T_State, name: str, annotation: Any, default: Any, **kwargs: Any) -> Any:
async def catch(self, interface: Interface) -> Any:
"""自定义依赖注入处理函数。"""
pass

Expand Down Expand Up @@ -135,9 +146,9 @@ def __init__(
self.extensions = [
ext
for ext in self.extensions
if ext.id not in self._excludes and ext.__class__ not in self._excludes and (
not (ns := ext.namespace) or ns == rule.command.namespace
)
if ext.id not in self._excludes
and ext.__class__ not in self._excludes
and (not (ns := ext.namespace) or ns == rule.command.namespace)
]
self.context: list[Extension] = []
self._rule = rule
Expand Down Expand Up @@ -193,7 +204,7 @@ async def receive_wrapper(self, bot: Bot, event: Event, receive: TM) -> TM:
res = receive
for ext in self.context:
if ext._overrides["receive_wrapper"]:
res = await ext.receive_wrapper(bot, event, res)
res = await ext.receive_wrapper(bot, event, self._rule.command, res)
return res

async def parse_wrapper(self, bot: Bot, state: T_State, event: Event, res: Arparma) -> None:
Expand All @@ -212,10 +223,10 @@ async def send_wrapper(self, bot: Bot, event: Event, send: TM) -> TM:
res = await ext.send_wrapper(bot, event, res)
return res

async def catch(self, state: T_State, name: str, annotation: Any, default: Any, **kwargs: Any):
async def catch(self, event: Event, state: T_State, name: str, annotation: Any, default: Any):
for ext in self.context:
if ext._overrides["catch"]:
res = await ext.catch(state, name, annotation, default, **kwargs)
res = await ext.catch(Interface(event, state, name, annotation, default))
if res is None:
continue
return res
Expand Down
7 changes: 5 additions & 2 deletions src/nonebot_plugin_alconna/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from nonebot.internal.params import DefaultParam
from tarina import lang, is_awaitable, run_always_await
from arclet.alconna.tools import AlconnaFormat, AlconnaString
from nonebot.plugin.on import store_matcher, get_matcher_source
from arclet.alconna.tools.construct import FuncMounter, MountConfig
from arclet.alconna import Arg, Args, Alconna, ShortcutArgs, command_manager
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_PermissionChecker
from nonebot.exception import PausedException, FinishedException, RejectedException
from nonebot.plugin.on import store_matcher, get_matcher_source
from nonebot.internal.adapter import Bot, Event, Message, MessageSegment, MessageTemplate
from nonebot.matcher import Matcher, matchers, current_bot, current_event, current_matcher

Expand Down Expand Up @@ -588,7 +588,10 @@ def on_alconna(
use_cmd_sep,
)
executor = cast(ExtensionExecutor, list(_rule.checkers)[0].call.executor) # type: ignore
AlconnaMatcher.HANDLER_PARAM_TYPES = Matcher.HANDLER_PARAM_TYPES[:-1] + (AlconnaParam.new(executor), DefaultParam)
AlconnaMatcher.HANDLER_PARAM_TYPES = Matcher.HANDLER_PARAM_TYPES[:-1] + (
AlconnaParam.new(executor),
DefaultParam,
)
matcher: type[AlconnaMatcher] = AlconnaMatcher.new(
"",
rule & _rule,
Expand Down
11 changes: 6 additions & 5 deletions src/nonebot_plugin_alconna/params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
from typing_extensions import Annotated, get_args
from typing import Any, Dict, Type, Tuple, Union, Literal, TypeVar, Optional, overload, ClassVar
from typing import Any, Dict, Type, Tuple, Union, Literal, TypeVar, ClassVar, Optional, overload

from nonebot.typing import T_State
from tarina import run_always_await
Expand All @@ -14,9 +14,9 @@
from arclet.alconna.builtin import generate_duplication
from arclet.alconna import Empty, Alconna, Arparma, Duplication

from .extension import Extension, ExtensionExecutor
from .typings import CHECK, MIDDLEWARE
from .model import T, Match, Query, CommandResult
from .extension import Extension, ExtensionExecutor
from .consts import ALCONNA_RESULT, ALCONNA_ARG_KEY, ALCONNA_EXTENSION, ALCONNA_EXEC_RESULT

T_Duplication = TypeVar("T_Duplication", bound=Duplication)
Expand Down Expand Up @@ -211,6 +211,7 @@ class AlconnaParam(Param):
本注入解析事件响应器操作 `AlconnaMatcher` 的响应函数内所需参数。
"""

executor: ClassVar[ExtensionExecutor]

def __repr__(self) -> str:
Expand Down Expand Up @@ -249,10 +250,10 @@ def _check_param(
return cls(param.default, type=Query)
return cls(param.default, name=param.name, type=param.annotation, validate=True)

async def _solve(self, matcher: Matcher, state: T_State, **kwargs: Any) -> Any:
async def _solve(self, matcher: Matcher, event: Event, state: T_State, **kwargs: Any) -> Any:
t = self.extra["type"]
if ALCONNA_RESULT not in state:
ext_res = await self.executor.catch(state, self.extra["name"], t, self.default, **kwargs)
ext_res = await self.executor.catch(event, state, self.extra["name"], t, self.default)
if ext_res is not Undefined:
return ext_res
return self.default if self.default not in (..., Empty) else Undefined
Expand Down Expand Up @@ -288,7 +289,7 @@ async def _solve(self, matcher: Matcher, state: T_State, **kwargs: Any) -> Any:
return state[key]
if self.extra["name"] in res.result.all_matched_args:
return res.result.all_matched_args[self.extra["name"]]
ext_res = await self.executor.catch(state, self.extra["name"], t, self.default, **kwargs)
ext_res = await self.executor.catch(event, state, self.extra["name"], t, self.default)
if ext_res is not Undefined:
return ext_res
return self.default if self.default not in (..., Empty) else Undefined
Expand Down
21 changes: 9 additions & 12 deletions tests/test_extension.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import pytest
from nonebug import App
from nonebot import get_adapter
from arclet.alconna import Alconna, Args
from arclet.alconna import Args, Alconna
from nonebot.adapters.onebot.v11 import Bot, Adapter, Message

from tests.fake import fake_group_message_event_v11


@pytest.mark.asyncio()
async def test_extension(app: App):
from nonebot_plugin_alconna import Extension, on_alconna
from nonebot.adapters.onebot.v11 import MessageEvent

from nonebot_plugin_alconna import Extension, Interface, on_alconna

class DemoExtension(Extension):
@property
Expand All @@ -20,19 +22,14 @@ def priority(self) -> int:
def id(self) -> str:
return "demo"

async def catch(self, state, name, annotation, default, **kwargs):
if annotation is str:
async def catch(self, interface: Interface[MessageEvent]):
if interface.annotation is str:
return {
"hello": "Hello!",
"world": "World!",
}.get(name, name)

add = on_alconna(
Alconna(
"add", Args["a", float]["b", float]
),
extensions=[DemoExtension]
)
}.get(interface.name, interface.name)

add = on_alconna(Alconna("add", Args["a", float]["b", float]), extensions=[DemoExtension])

@add.handle()
async def h(a: float, b: float, hello: str, world: str, test: str):
Expand Down

0 comments on commit 92b5644

Please sign in to comment.