diff --git a/qalib/translators/message_parsing.py b/qalib/translators/message_parsing.py index 173b750..2571e39 100644 --- a/qalib/translators/message_parsing.py +++ b/qalib/translators/message_parsing.py @@ -1,13 +1,26 @@ from __future__ import annotations -from typing import Dict, TypeVar, Iterable, List, Literal, Optional, Type, TypedDict, Union, cast, Callable +from functools import wraps +from typing import ( + Dict, + TypeVar, + Iterable, + List, + Literal, + Optional, + Type, + TypedDict, + Union, + cast, + Callable, +) import discord.partial_emoji import emoji from discord import ui, utils from typing_extensions import NotRequired, Concatenate, ParamSpec -from qalib.translators import Callback, M, N +from qalib.translators import I, Callback, M, N from qalib.translators.element.types.embed import Emoji P = ParamSpec("P") @@ -111,7 +124,9 @@ class TextInputComponent(TextInputRaw): callback: NotRequired[Callback] -def make_channel_types(channel_types: Iterable[ChannelType]) -> List[discord.ChannelType]: +def make_channel_types( + channel_types: Iterable[ChannelType], +) -> List[discord.ChannelType]: return [CHANNEL_TYPES[channel_type] for channel_type in channel_types] @@ -131,14 +146,43 @@ def make_emoji(raw_emoji: Optional[Union[str, Emoji]]) -> Optional[str]: if "id" not in raw_emoji: return emoji.emojize(":" + raw_emoji["name"] + ":") - string = f"a:{raw_emoji['name']}:" if raw_emoji.get("animated", False) else f":{raw_emoji['name']}:" + string = ( + f"a:{raw_emoji['name']}:" + if raw_emoji.get("animated", False) + else f":{raw_emoji['name']}:" + ) return string + str(raw_emoji["id"]) if "id" in raw_emoji else string +def _create_callback(callback: Callback[I]) -> Callback[I]: + """Wraps the callback to ensure that bound methods are still bound to their instance. + + meant to remedy the following + e.g.: + ```py + class Event: + async def on_button_click(self, item: I, interaction: discord.Interaction) -> None: + ... + + >>> event = Event() + >>> type(ui.Button.__name__, (ui.Button,), {"callback": event.on_button_click}) + ``` + + This would not work as the new method would lose the __self__ + + """ + + @wraps(callback) + async def wrapper(item: I, interaction: discord.Interaction) -> None: + await callback(item, interaction) + + return wrapper + + def create_button(component: ButtonComponent) -> ui.Button: button: Type[ui.Button] = ui.Button if "callback" in component: - callback = component["callback"] + callback = _create_callback(component["callback"]) button = cast( Type[ui.Button], type(ui.Button.__name__, (ui.Button,), {"callback": callback}), @@ -158,7 +202,7 @@ def create_button(component: ButtonComponent) -> ui.Button: def create_channel_select(**kwargs) -> ui.ChannelSelect: channel_select: Type[ui.ChannelSelect] = ui.ChannelSelect if kwargs.get("callback") is not None: - callback = kwargs["callback"] + callback = _create_callback(kwargs["callback"]) channel_select = cast( Type[ui.ChannelSelect], type( @@ -181,7 +225,7 @@ def create_channel_select(**kwargs) -> ui.ChannelSelect: def create_select(**kwargs) -> ui.Select: select: Type[ui.Select] = ui.Select if kwargs.get("callback") is not None: - callback = kwargs["callback"] + callback = _create_callback(kwargs["callback"]) select = cast( Type[ui.Select], type(ui.Select.__name__, (ui.Select,), {"callback": callback}), @@ -192,15 +236,19 @@ def create_select(**kwargs) -> ui.Select: placeholder=kwargs.get("placeholder"), min_values=int(kwargs.get("min_values", 1)), max_values=int(kwargs.get("max_values", 1)), - disabled=kwargs["disabled"].lower() == "true" if "disabled" in kwargs else False, + disabled=kwargs["disabled"].lower() == "true" + if "disabled" in kwargs + else False, options=kwargs.get("options", []), row=int(row) if (row := kwargs.get("row")) is not None else None, ) -def create_type_select(select: SelectTypes, **kwargs) -> Union[ui.RoleSelect, ui.UserSelect, ui.MentionableSelect]: +def create_type_select( + select: SelectTypes, **kwargs +) -> Union[ui.RoleSelect, ui.UserSelect, ui.MentionableSelect]: if kwargs.get("callback") is not None: - callback = kwargs["callback"] + callback = _create_callback(kwargs["callback"]) select = cast( SelectTypes, type(select.__name__, (select,), {"callback": callback}), @@ -219,7 +267,7 @@ def create_type_select(select: SelectTypes, **kwargs) -> Union[ui.RoleSelect, ui def create_text_input(text_input_component: TextInputComponent) -> ui.TextInput: text_input: Type[ui.TextInput] = ui.TextInput if "callback" in text_input_component: - callback = text_input_component["callback"] + callback = _create_callback(text_input_component["callback"]) text_input = cast( Type[ui.TextInput], type(