Skip to content

Commit

Permalink
🐛 fix bound method as callback
Browse files Browse the repository at this point in the history
  • Loading branch information
YousefEZ committed Sep 25, 2024
1 parent 5086bef commit c6e9a63
Showing 1 changed file with 59 additions and 11 deletions.
70 changes: 59 additions & 11 deletions qalib/translators/message_parsing.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
from __future__ import annotations

from typing import Dict, TypeVar, Iterable, List, Literal, Optional, Type, TypedDict, Union, cast, Callable
from functools import wraps
from typing import (
Dict,
TypeVar,
Iterable,
List,
Literal,
Optional,
Type,
TypedDict,
Union,
cast,
Callable,
)

import discord.partial_emoji
import emoji
from discord import ui, utils
from typing_extensions import NotRequired, Concatenate, ParamSpec

from qalib.translators import Callback, M, N
from qalib.translators import I, Callback, M, N
from qalib.translators.element.types.embed import Emoji

P = ParamSpec("P")
Expand Down Expand Up @@ -111,7 +124,9 @@ class TextInputComponent(TextInputRaw):
callback: NotRequired[Callback]


def make_channel_types(channel_types: Iterable[ChannelType]) -> List[discord.ChannelType]:
def make_channel_types(
channel_types: Iterable[ChannelType],
) -> List[discord.ChannelType]:
return [CHANNEL_TYPES[channel_type] for channel_type in channel_types]


Expand All @@ -131,14 +146,43 @@ def make_emoji(raw_emoji: Optional[Union[str, Emoji]]) -> Optional[str]:
if "id" not in raw_emoji:
return emoji.emojize(":" + raw_emoji["name"] + ":")

string = f"a:{raw_emoji['name']}:" if raw_emoji.get("animated", False) else f":{raw_emoji['name']}:"
string = (
f"a:{raw_emoji['name']}:"
if raw_emoji.get("animated", False)
else f":{raw_emoji['name']}:"
)
return string + str(raw_emoji["id"]) if "id" in raw_emoji else string


def _create_callback(callback: Callback[I]) -> Callback[I]:
"""Wraps the callback to ensure that bound methods are still bound to their instance.
meant to remedy the following
e.g.:
```py
class Event:
async def on_button_click(self, item: I, interaction: discord.Interaction) -> None:
...
>>> event = Event()
>>> type(ui.Button.__name__, (ui.Button,), {"callback": event.on_button_click})
```
This would not work as the new method would lose the __self__
"""

@wraps(callback)
async def wrapper(item: I, interaction: discord.Interaction) -> None:
await callback(item, interaction)

return wrapper


def create_button(component: ButtonComponent) -> ui.Button:
button: Type[ui.Button] = ui.Button
if "callback" in component:
callback = component["callback"]
callback = _create_callback(component["callback"])
button = cast(
Type[ui.Button],
type(ui.Button.__name__, (ui.Button,), {"callback": callback}),
Expand All @@ -158,7 +202,7 @@ def create_button(component: ButtonComponent) -> ui.Button:
def create_channel_select(**kwargs) -> ui.ChannelSelect:
channel_select: Type[ui.ChannelSelect] = ui.ChannelSelect
if kwargs.get("callback") is not None:
callback = kwargs["callback"]
callback = _create_callback(kwargs["callback"])
channel_select = cast(
Type[ui.ChannelSelect],
type(
Expand All @@ -181,7 +225,7 @@ def create_channel_select(**kwargs) -> ui.ChannelSelect:
def create_select(**kwargs) -> ui.Select:
select: Type[ui.Select] = ui.Select
if kwargs.get("callback") is not None:
callback = kwargs["callback"]
callback = _create_callback(kwargs["callback"])
select = cast(
Type[ui.Select],
type(ui.Select.__name__, (ui.Select,), {"callback": callback}),
Expand All @@ -192,15 +236,19 @@ def create_select(**kwargs) -> ui.Select:
placeholder=kwargs.get("placeholder"),
min_values=int(kwargs.get("min_values", 1)),
max_values=int(kwargs.get("max_values", 1)),
disabled=kwargs["disabled"].lower() == "true" if "disabled" in kwargs else False,
disabled=kwargs["disabled"].lower() == "true"
if "disabled" in kwargs
else False,
options=kwargs.get("options", []),
row=int(row) if (row := kwargs.get("row")) is not None else None,
)


def create_type_select(select: SelectTypes, **kwargs) -> Union[ui.RoleSelect, ui.UserSelect, ui.MentionableSelect]:
def create_type_select(
select: SelectTypes, **kwargs
) -> Union[ui.RoleSelect, ui.UserSelect, ui.MentionableSelect]:
if kwargs.get("callback") is not None:
callback = kwargs["callback"]
callback = _create_callback(kwargs["callback"])
select = cast(
SelectTypes,
type(select.__name__, (select,), {"callback": callback}),
Expand All @@ -219,7 +267,7 @@ def create_type_select(select: SelectTypes, **kwargs) -> Union[ui.RoleSelect, ui
def create_text_input(text_input_component: TextInputComponent) -> ui.TextInput:
text_input: Type[ui.TextInput] = ui.TextInput
if "callback" in text_input_component:
callback = text_input_component["callback"]
callback = _create_callback(text_input_component["callback"])
text_input = cast(
Type[ui.TextInput],
type(
Expand Down

0 comments on commit c6e9a63

Please sign in to comment.