diff --git a/changelog/914.bugfix.rst b/changelog/914.bugfix.rst new file mode 100644 index 0000000000..6dd6dcc4bf --- /dev/null +++ b/changelog/914.bugfix.rst @@ -0,0 +1 @@ +Fix :class:`ui.Modal` timeout issues with long-running callbacks, and multiple modals with the same user and ``custom_id``. diff --git a/disnake/ui/modal.py b/disnake/ui/modal.py index adf21ffa9c..7f0192c3b8 100644 --- a/disnake/ui/modal.py +++ b/disnake/ui/modal.py @@ -6,7 +6,8 @@ import os import sys import traceback -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, TypeVar, Union +from functools import partial +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, TypeVar, Union from ..enums import TextInputStyle from ..utils import MISSING @@ -38,14 +39,32 @@ class Modal: components: |components_type| The components to display in the modal. Up to 5 action rows. custom_id: :class:`str` - The custom ID of the modal. + The custom ID of the modal. This is usually not required. + If not given, then a unique one is generated for you. + + .. note:: + :class:`Modal`\\s are identified based on the user ID that triggered the + modal, and this ``custom_id``. + This can result in collisions when a user opens a modal with the same ``custom_id`` on + two separate devices, for example. + + To avoid such issues, consider not specifying a ``custom_id`` to use an automatically generated one, + or include a unique value in the custom ID (e.g. the original interaction ID). + timeout: :class:`float` The time to wait until the modal is removed from cache, if no interaction is made. Modals without timeouts are not supported, since there's no event for when a modal is closed. Defaults to 600 seconds. """ - __slots__ = ("title", "custom_id", "components", "timeout") + __slots__ = ( + "title", + "custom_id", + "components", + "timeout", + "__remove_callback", + "__timeout_handle", + ) def __init__( self, @@ -67,6 +86,11 @@ def __init__( self.components: List[ActionRow] = rows self.timeout: float = timeout + # function for the modal to remove itself from the store, if any + self.__remove_callback: Optional[Callable[[Modal], None]] = None + # timer handle for the scheduled timeout + self.__timeout_handle: Optional[asyncio.TimerHandle] = None + def __repr__(self) -> str: return ( f" None: except Exception as e: await self.on_error(e, interaction) finally: - # if the interaction was responded to (no matter if in the callback or error handler), - # the modal closed for the user and therefore can be removed from the store - if interaction.response._response_type is not None: - interaction._state._modal_store.remove_modal( - interaction.author.id, interaction.custom_id - ) + if interaction.response._response_type is None: + # If the interaction was not successfully responded to, the modal didn't close for the user. + # Since the timeout was already stopped at this point, restart it. + self._start_listening(self.__remove_callback) + else: + # Otherwise, the modal closed for the user; remove it from the store. + self._stop_listening() + + def _start_listening(self, remove_callback: Optional[Callable[[Modal], None]]) -> None: + self.__remove_callback = remove_callback + + loop = asyncio.get_running_loop() + if self.__timeout_handle is not None: + # shouldn't get here, but handled just in case + self.__timeout_handle.cancel() + + # start timeout + self.__timeout_handle = loop.call_later(self.timeout, self._dispatch_timeout) + + def _stop_listening(self) -> None: + # cancel timeout + if self.__timeout_handle is not None: + self.__timeout_handle.cancel() + self.__timeout_handle = None + + # remove modal from store + if self.__remove_callback is not None: + self.__remove_callback(self) + self.__remove_callback = None + + def _dispatch_timeout(self) -> None: + self._stop_listening() + asyncio.create_task(self.on_timeout(), name=f"disnake-ui-modal-timeout-{self.custom_id}") def dispatch(self, interaction: ModalInteraction) -> None: + # stop the timeout, but don't remove the modal from the store yet in case the + # response fails and the modal stays open + if self.__timeout_handle is not None: + self.__timeout_handle.cancel() + asyncio.create_task( self._scheduled_task(interaction), name=f"disnake-ui-modal-dispatch-{self.custom_id}" ) @@ -232,28 +288,22 @@ def __init__(self, state: ConnectionState) -> None: self._modals: Dict[Tuple[int, str], Modal] = {} def add_modal(self, user_id: int, modal: Modal) -> None: - loop = asyncio.get_running_loop() - self._modals[(user_id, modal.custom_id)] = modal - loop.create_task(self.handle_timeout(user_id, modal.custom_id, modal.timeout)) + key = (user_id, modal.custom_id) - def remove_modal(self, user_id: int, modal_custom_id: str) -> Modal: - return self._modals.pop((user_id, modal_custom_id)) + # if another modal with the same user+custom_id already exists, + # stop its timeout to avoid overlaps/collisions + if (existing := self._modals.get(key)) is not None: + existing._stop_listening() - async def handle_timeout(self, user_id: int, modal_custom_id: str, timeout: float) -> None: - # Waits for the timeout and then removes the modal from cache, this is done just in case - # the user closed the modal, as there isn't an event for that. + # start timeout, store modal + remove_callback = partial(self.remove_modal, user_id) + modal._start_listening(remove_callback) + self._modals[key] = modal - await asyncio.sleep(timeout) - try: - modal = self.remove_modal(user_id, modal_custom_id) - except KeyError: - # The modal has already been removed. - pass - else: - await modal.on_timeout() + def remove_modal(self, user_id: int, modal: Modal) -> None: + self._modals.pop((user_id, modal.custom_id), None) def dispatch(self, interaction: ModalInteraction) -> None: key = (interaction.author.id, interaction.custom_id) - modal = self._modals.get(key) - if modal is not None: + if (modal := self._modals.get(key)) is not None: modal.dispatch(interaction) diff --git a/examples/interactions/modal.py b/examples/interactions/modal.py index f271c82f4c..311b1d7d46 100644 --- a/examples/interactions/modal.py +++ b/examples/interactions/modal.py @@ -43,7 +43,7 @@ def __init__(self) -> None: max_length=1024, ), ] - super().__init__(title="Create Tag", custom_id="create_tag", components=components) + super().__init__(title="Create Tag", components=components) async def callback(self, inter: disnake.ModalInteraction) -> None: tag_name = inter.text_values["name"] diff --git a/test_bot/cogs/modals.py b/test_bot/cogs/modals.py index c5d514a25c..e988c88284 100644 --- a/test_bot/cogs/modals.py +++ b/test_bot/cogs/modals.py @@ -22,7 +22,7 @@ def __init__(self) -> None: style=TextInputStyle.paragraph, ), ] - super().__init__(title="Create Tag", custom_id="create_tag", components=components) + super().__init__(title="Create Tag", components=components) async def callback(self, inter: disnake.ModalInteraction[commands.Bot]) -> None: embed = disnake.Embed(title="Tag Creation")