diff --git a/disnake/abc.py b/disnake/abc.py index 605bb725aa..b9a60f3ee5 100644 --- a/disnake/abc.py +++ b/disnake/abc.py @@ -390,7 +390,7 @@ async def _edit( if p_id is not None and (parent := self.guild.get_channel(p_id)): overwrites_payload = [c._asdict() for c in parent._overwrites] - if overwrites is not MISSING and overwrites is not None: + if overwrites not in (MISSING, None): overwrites_payload = [] for target, perm in overwrites.items(): if not isinstance(perm, PermissionOverwrite): diff --git a/disnake/activity.py b/disnake/activity.py index 92460cd35d..3c290edd17 100644 --- a/disnake/activity.py +++ b/disnake/activity.py @@ -921,7 +921,7 @@ def create_activity( elif game_type is ActivityType.listening and "sync_id" in data and "session_id" in data: activity = Spotify(**data) else: - activity = Activity(**data) + activity = Activity(**data) # type: ignore if isinstance(activity, (Activity, CustomActivity)) and activity.emoji and state: activity.emoji._state = state diff --git a/disnake/asset.py b/disnake/asset.py index fad72c79ce..bc7b505697 100644 --- a/disnake/asset.py +++ b/disnake/asset.py @@ -24,7 +24,7 @@ ValidAssetFormatTypes = Literal["webp", "jpeg", "jpg", "png", "gif"] AnyState = Union[ConnectionState, _WebhookState[BaseWebhook]] -AssetBytes = Union[bytes, "AssetMixin"] +AssetBytes = Union[utils._BytesLike, "AssetMixin"] VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"}) VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"} diff --git a/disnake/audit_logs.py b/disnake/audit_logs.py index e8ab022edf..256aaa04dc 100644 --- a/disnake/audit_logs.py +++ b/disnake/audit_logs.py @@ -245,7 +245,7 @@ def _transform_datetime(entry: AuditLogEntry, data: Optional[str]) -> Optional[d def _transform_privacy_level( - entry: AuditLogEntry, data: int + entry: AuditLogEntry, data: Optional[int] ) -> Optional[Union[enums.StagePrivacyLevel, enums.GuildScheduledEventPrivacyLevel]]: if data is None: return None diff --git a/disnake/channel.py b/disnake/channel.py index 7eef52b942..ffb11f2d2c 100644 --- a/disnake/channel.py +++ b/disnake/channel.py @@ -473,7 +473,7 @@ async def edit( overwrites=overwrites, flags=flags, reason=reason, - **kwargs, + **kwargs, # type: ignore ) if payload is not None: # the payload will always be the proper channel payload @@ -1628,7 +1628,7 @@ async def edit( slowmode_delay=slowmode_delay, flags=flags, reason=reason, - **kwargs, + **kwargs, # type: ignore ) if payload is not None: # the payload will always be the proper channel payload @@ -2453,7 +2453,7 @@ async def edit( flags=flags, slowmode_delay=slowmode_delay, reason=reason, - **kwargs, + **kwargs, # type: ignore ) if payload is not None: # the payload will always be the proper channel payload @@ -2946,7 +2946,7 @@ async def edit( overwrites=overwrites, flags=flags, reason=reason, - **kwargs, + **kwargs, # type: ignore ) if payload is not None: # the payload will always be the proper channel payload @@ -3619,7 +3619,7 @@ async def edit( default_sort_order=default_sort_order, default_layout=default_layout, reason=reason, - **kwargs, + **kwargs, # type: ignore ) if payload is not None: # the payload will always be the proper channel payload @@ -3994,7 +3994,7 @@ async def create_thread( stickers=stickers, ) - if auto_archive_duration is not None: + if auto_archive_duration not in (MISSING, None): auto_archive_duration = cast( "ThreadArchiveDurationLiteral", try_enum_to_int(auto_archive_duration) ) diff --git a/disnake/client.py b/disnake/client.py index f71842c7b3..b25b44cbd9 100644 --- a/disnake/client.py +++ b/disnake/client.py @@ -25,6 +25,7 @@ Optional, Sequence, Tuple, + TypedDict, TypeVar, Union, overload, @@ -79,6 +80,8 @@ from .widget import Widget if TYPE_CHECKING: + from typing_extensions import NotRequired + from .abc import GuildChannel, PrivateChannel, Snowflake, SnowflakeTime from .app_commands import APIApplicationCommand from .asset import AssetBytes @@ -207,6 +210,17 @@ class GatewayParams(NamedTuple): zlib: bool = True +# used for typing the ws parameter dict in the connect() loop +class _WebSocketParams(TypedDict): + initial: bool + shard_id: Optional[int] + gateway: Optional[str] + + sequence: NotRequired[Optional[int]] + resume: NotRequired[bool] + session: NotRequired[Optional[str]] + + class Client: """Represents a client connection that connects to Discord. This class is used to interact with the Discord WebSocket and API. @@ -1080,7 +1094,7 @@ async def connect( if not ignore_session_start_limit and self.session_start_limit.remaining == 0: raise SessionStartLimitReached(self.session_start_limit) - ws_params = { + ws_params: _WebSocketParams = { "initial": True, "shard_id": self.shard_id, "gateway": initial_gateway, @@ -1104,6 +1118,7 @@ async def connect( while True: await self.ws.poll_event() + except ReconnectWebSocket as e: _log.info("Got a request to %s the websocket.", e.op) self.dispatch("disconnect") @@ -1116,6 +1131,7 @@ async def connect( gateway=self.ws.resume_gateway if e.resume else initial_gateway, ) continue + except ( OSError, HTTPException, @@ -1196,7 +1212,8 @@ async def close(self) -> None: # if an error happens during disconnects, disregard it. pass - if self.ws is not None and self.ws.open: + # can be None if not connected + if self.ws is not None and self.ws.open: # pyright: ignore[reportUnnecessaryComparison] await self.ws.close(code=1000) await self.http.close() @@ -1874,16 +1891,15 @@ async def change_presence( await self.ws.change_presence(activity=activity, status=status_str) + activities = () if activity is None else (activity,) for guild in self._connection.guilds: me = guild.me - if me is None: + if me is None: # pyright: ignore[reportUnnecessaryComparison] + # may happen if guild is unavailable continue - if activity is not None: - me.activities = (activity,) # type: ignore - else: - me.activities = () - + # Member.activities is typehinted as Tuple[ActivityType, ...], we may be setting it as Tuple[BaseActivity, ...] + me.activities = activities # type: ignore me.status = status # Guild stuff diff --git a/disnake/components.py b/disnake/components.py index e6f3d14904..7614fd424b 100644 --- a/disnake/components.py +++ b/disnake/components.py @@ -9,6 +9,7 @@ Dict, Generic, List, + Literal, Optional, Tuple, Type, @@ -22,11 +23,12 @@ from .utils import MISSING, assert_never, get_slots if TYPE_CHECKING: - from typing_extensions import Self + from typing_extensions import Self, TypeAlias from .emoji import Emoji from .types.components import ( ActionRow as ActionRowPayload, + AnySelectMenu as AnySelectMenuPayload, BaseSelectMenu as BaseSelectMenuPayload, ButtonComponent as ButtonComponentPayload, ChannelSelectMenu as ChannelSelectMenuPayload, @@ -63,12 +65,16 @@ "MentionableSelectMenu", "ChannelSelectMenu", ] -MessageComponent = Union["Button", "AnySelectMenu"] -if TYPE_CHECKING: # TODO: remove when we add modal select support - from typing_extensions import TypeAlias +SelectMenuType = Literal[ + ComponentType.string_select, + ComponentType.user_select, + ComponentType.role_select, + ComponentType.mentionable_select, + ComponentType.channel_select, +] -# ModalComponent = Union["TextInput", "AnySelectMenu"] +MessageComponent = Union["Button", "AnySelectMenu"] ModalComponent: TypeAlias = "TextInput" NestedComponent = Union[MessageComponent, ModalComponent] @@ -131,8 +137,6 @@ class ActionRow(Component, Generic[ComponentT]): Attributes ---------- - type: :class:`ComponentType` - The type of component. children: List[Union[:class:`Button`, :class:`BaseSelectMenu`, :class:`TextInput`]] The children components that this holds, if any. """ @@ -142,10 +146,9 @@ class ActionRow(Component, Generic[ComponentT]): __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ def __init__(self, data: ActionRowPayload) -> None: - self.type: ComponentType = try_enum(ComponentType, data["type"]) - self.children: List[ComponentT] = [ - _component_factory(d) for d in data.get("components", []) - ] + self.type: Literal[ComponentType.action_row] = ComponentType.action_row + children = [_component_factory(d) for d in data.get("components", [])] + self.children: List[ComponentT] = children # type: ignore def to_dict(self) -> ActionRowPayload: return { @@ -195,7 +198,7 @@ class Button(Component): __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ def __init__(self, data: ButtonComponentPayload) -> None: - self.type: ComponentType = try_enum(ComponentType, data["type"]) + self.type: Literal[ComponentType.button] = ComponentType.button self.style: ButtonStyle = try_enum(ButtonStyle, data["style"]) self.custom_id: Optional[str] = data.get("custom_id") self.url: Optional[str] = data.get("url") @@ -209,7 +212,7 @@ def __init__(self, data: ButtonComponentPayload) -> None: def to_dict(self) -> ButtonComponentPayload: payload: ButtonComponentPayload = { - "type": 2, + "type": self.type.value, "style": self.style.value, "disabled": self.disabled, } @@ -273,8 +276,13 @@ class BaseSelectMenu(Component): __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ - def __init__(self, data: BaseSelectMenuPayload) -> None: - self.type: ComponentType = try_enum(ComponentType, data["type"]) + # n.b: ideally this would be `BaseSelectMenuPayload`, + # but pyright made TypedDict keys invariant and doesn't + # fully support readonly items yet (which would help avoid this) + def __init__(self, data: AnySelectMenuPayload) -> None: + component_type = try_enum(ComponentType, data["type"]) + self.type: SelectMenuType = component_type # type: ignore + self.custom_id: str = data["custom_id"] self.placeholder: Optional[str] = data.get("placeholder") self.min_values: int = data.get("min_values", 1) @@ -329,6 +337,7 @@ class StringSelectMenu(BaseSelectMenu): __slots__: Tuple[str, ...] = ("options",) __repr_info__: ClassVar[Tuple[str, ...]] = BaseSelectMenu.__repr_info__ + __slots__ + type: Literal[ComponentType.string_select] def __init__(self, data: StringSelectMenuPayload) -> None: super().__init__(data) @@ -372,6 +381,8 @@ class UserSelectMenu(BaseSelectMenu): __slots__: Tuple[str, ...] = () + type: Literal[ComponentType.user_select] + if TYPE_CHECKING: def to_dict(self) -> UserSelectMenuPayload: @@ -405,6 +416,8 @@ class RoleSelectMenu(BaseSelectMenu): __slots__: Tuple[str, ...] = () + type: Literal[ComponentType.role_select] + if TYPE_CHECKING: def to_dict(self) -> RoleSelectMenuPayload: @@ -438,6 +451,8 @@ class MentionableSelectMenu(BaseSelectMenu): __slots__: Tuple[str, ...] = () + type: Literal[ComponentType.mentionable_select] + if TYPE_CHECKING: def to_dict(self) -> MentionableSelectMenuPayload: @@ -475,6 +490,7 @@ class ChannelSelectMenu(BaseSelectMenu): __slots__: Tuple[str, ...] = ("channel_types",) __repr_info__: ClassVar[Tuple[str, ...]] = BaseSelectMenu.__repr_info__ + __slots__ + type: Literal[ComponentType.channel_select] def __init__(self, data: ChannelSelectMenuPayload) -> None: super().__init__(data) @@ -643,7 +659,7 @@ class TextInput(Component): def __init__(self, data: TextInputPayload) -> None: style = data.get("style", TextInputStyle.short.value) - self.type: ComponentType = try_enum(ComponentType, data["type"]) + self.type: Literal[ComponentType.text_input] = ComponentType.text_input self.custom_id: str = data["custom_id"] self.style: TextInputStyle = try_enum(TextInputStyle, style) self.label: Optional[str] = data.get("label") diff --git a/disnake/emoji.py b/disnake/emoji.py index 0f3d02c27d..2a24877b07 100644 --- a/disnake/emoji.py +++ b/disnake/emoji.py @@ -151,7 +151,7 @@ def roles(self) -> List[Role]: and count towards a separate limit of 25 emojis. """ guild = self.guild - if guild is None: + if guild is None: # pyright: ignore[reportUnnecessaryComparison] return [] return [role for role in guild.roles if self._roles.has(role.id)] diff --git a/disnake/enums.py b/disnake/enums.py index b4bf3d994d..cb603c5425 100644 --- a/disnake/enums.py +++ b/disnake/enums.py @@ -466,7 +466,7 @@ def category(self) -> Optional[AuditLogActionCategory]: @property def target_type(self) -> Optional[str]: v = self.value - if v == -1: + if v == -1: # pyright: ignore[reportUnnecessaryComparison] return "all" elif v < 10: return "guild" @@ -627,7 +627,7 @@ class ComponentType(Enum): action_row = 1 button = 2 string_select = 3 - select = string_select # backwards compatibility + select = 3 # backwards compatibility text_input = 4 user_select = 5 role_select = 6 diff --git a/disnake/ext/commands/base_core.py b/disnake/ext/commands/base_core.py index 7198394be8..3599ea0908 100644 --- a/disnake/ext/commands/base_core.py +++ b/disnake/ext/commands/base_core.py @@ -303,7 +303,7 @@ def _prepare_cooldowns(self, inter: ApplicationCommandInteraction) -> None: dt = inter.created_at current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() bucket = self._buckets.get_bucket(inter, current) # type: ignore - if bucket is not None: + if bucket is not None: # pyright: ignore[reportUnnecessaryComparison] retry_after = bucket.update_rate_limit(current) if retry_after: raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore diff --git a/disnake/ext/commands/bot_base.py b/disnake/ext/commands/bot_base.py index d55dc63490..1bba906c82 100644 --- a/disnake/ext/commands/bot_base.py +++ b/disnake/ext/commands/bot_base.py @@ -10,18 +10,7 @@ import sys import traceback import warnings -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterable, - List, - Optional, - Type, - TypeVar, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Type, TypeVar, Union import disnake @@ -414,7 +403,7 @@ def _remove_module_references(self, name: str) -> None: super()._remove_module_references(name) # remove all the commands from the module for cmd in self.all_commands.copy().values(): - if cmd.module is not None and _is_submodule(name, cmd.module): + if cmd.module and _is_submodule(name, cmd.module): if isinstance(cmd, GroupMixin): cmd.recursively_remove_all_commands() self.remove_command(cmd.name) @@ -513,7 +502,7 @@ class be provided, it must be similar enough to :class:`.Context`\'s ``cls`` parameter. """ view = StringView(message.content) - ctx = cast("CXT", cls(prefix=None, view=view, bot=self, message=message)) + ctx = cls(prefix=None, view=view, bot=self, message=message) if message.author.id == self.user.id: # type: ignore return ctx diff --git a/disnake/ext/commands/converter.py b/disnake/ext/commands/converter.py index 8bca2bd6dd..29672b2e54 100644 --- a/disnake/ext/commands/converter.py +++ b/disnake/ext/commands/converter.py @@ -1133,7 +1133,7 @@ def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]: raise TypeError("Greedy[...] expects a type or a Converter instance.") if converter in (str, type(None)) or origin is Greedy: - raise TypeError(f"Greedy[{converter.__name__}] is invalid.") # type: ignore + raise TypeError(f"Greedy[{converter.__name__}] is invalid.") if origin is Union and type(None) in args: raise TypeError(f"Greedy[{converter!r}] is invalid.") @@ -1161,7 +1161,7 @@ def get_converter(param: inspect.Parameter) -> Any: return converter -_GenericAlias = type(List[T]) +_GenericAlias = type(List[Any]) def is_generic_type(tp: Any, *, _GenericAlias: Type = _GenericAlias) -> bool: @@ -1222,7 +1222,7 @@ async def _actual_conversion( raise ConversionError(converter, exc) from exc try: - return converter(argument) + return converter(argument) # type: ignore except CommandError: raise except Exception as exc: diff --git a/disnake/ext/commands/cooldowns.py b/disnake/ext/commands/cooldowns.py index 4268f76fff..354754550a 100644 --- a/disnake/ext/commands/cooldowns.py +++ b/disnake/ext/commands/cooldowns.py @@ -228,7 +228,7 @@ def get_bucket(self, message: Message, current: Optional[float] = None) -> Coold key = self._bucket_key(message) if key not in self._cache: bucket = self.create_bucket(message) - if bucket is not None: + if bucket is not None: # pyright: ignore[reportUnnecessaryComparison] self._cache[key] = bucket else: bucket = self._cache[key] diff --git a/disnake/ext/commands/core.py b/disnake/ext/commands/core.py index fda34b5a95..2ddcb10075 100644 --- a/disnake/ext/commands/core.py +++ b/disnake/ext/commands/core.py @@ -755,7 +755,7 @@ def _prepare_cooldowns(self, ctx: Context) -> None: dt = ctx.message.edited_at or ctx.message.created_at current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() bucket = self._buckets.get_bucket(ctx.message, current) - if bucket is not None: + if bucket is not None: # pyright: ignore[reportUnnecessaryComparison] retry_after = bucket.update_rate_limit(current) if retry_after: raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore @@ -1718,7 +1718,7 @@ def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: decorator.predicate = predicate else: - @functools.wraps(predicate) + @functools.wraps(predicate) # type: ignore async def wrapper(ctx): return predicate(ctx) # type: ignore diff --git a/disnake/ext/commands/help.py b/disnake/ext/commands/help.py index 483d4f4bd2..ecd3988b86 100644 --- a/disnake/ext/commands/help.py +++ b/disnake/ext/commands/help.py @@ -368,7 +368,11 @@ def invoked_with(self): """ command_name = self._command_impl.name ctx = self.context - if ctx is None or ctx.command is None or ctx.command.qualified_name != command_name: + if ( + ctx is disnake.utils.MISSING + or ctx.command is None + or ctx.command.qualified_name != command_name + ): return command_name return ctx.invoked_with diff --git a/disnake/ext/commands/params.py b/disnake/ext/commands/params.py index 95679ed802..9114b8b353 100644 --- a/disnake/ext/commands/params.py +++ b/disnake/ext/commands/params.py @@ -85,9 +85,6 @@ if sys.version_info >= (3, 10): from types import EllipsisType, UnionType -elif TYPE_CHECKING: - UnionType = object() - EllipsisType = ellipsis # noqa: F821 else: UnionType = object() EllipsisType = type(Ellipsis) @@ -543,7 +540,7 @@ def __init__( self.max_length = max_length self.large = large - def copy(self) -> ParamInfo: + def copy(self) -> Self: # n. b. this method needs to be manually updated when a new attribute is added. cls = self.__class__ ins = cls.__new__(cls) @@ -1339,7 +1336,7 @@ def option_enum( choices = choices or kwargs first, *_ = choices.values() - return Enum("", choices, type=type(first)) + return Enum("", choices, type=type(first)) # type: ignore class ConverterMethod(classmethod): diff --git a/disnake/ext/commands/slash_core.py b/disnake/ext/commands/slash_core.py index 1b318a21d0..4652c552f8 100644 --- a/disnake/ext/commands/slash_core.py +++ b/disnake/ext/commands/slash_core.py @@ -666,7 +666,7 @@ async def _call_relevant_autocompleter(self, inter: ApplicationCommandInteractio group = self.children.get(chain[0]) if not isinstance(group, SubCommandGroup): raise AssertionError("the first subcommand is not a SubCommandGroup instance") - subcmd = group.children.get(chain[1]) if group is not None else None + subcmd = group.children.get(chain[1]) else: raise ValueError("Command chain is too long") @@ -695,7 +695,7 @@ async def invoke_children(self, inter: ApplicationCommandInteraction) -> None: group = self.children.get(chain[0]) if not isinstance(group, SubCommandGroup): raise AssertionError("the first subcommand is not a SubCommandGroup instance") - subcmd = group.children.get(chain[1]) if group is not None else None + subcmd = group.children.get(chain[1]) else: raise ValueError("Command chain is too long") diff --git a/disnake/ext/tasks/__init__.py b/disnake/ext/tasks/__init__.py index 1c23e0e912..6532c3d088 100644 --- a/disnake/ext/tasks/__init__.py +++ b/disnake/ext/tasks/__init__.py @@ -708,7 +708,7 @@ class Object(Protocol[T_co, P]): def __new__(cls) -> T_co: ... - def __init__(*args: P.args, **kwargs: P.kwargs) -> None: + def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None: ... @@ -734,7 +734,7 @@ def loop( def loop( - cls: Type[Object[L_co, Concatenate[LF, P]]] = Loop[LF], + cls: Type[Object[L_co, Concatenate[LF, P]]] = Loop[Any], **kwargs: Any, ) -> Callable[[LF], L_co]: """A decorator that schedules a task in the background for you with diff --git a/disnake/gateway.py b/disnake/gateway.py index 2081493509..cd0cb6d44a 100644 --- a/disnake/gateway.py +++ b/disnake/gateway.py @@ -274,7 +274,7 @@ async def close(self, *, code: int = 4000, message: bytes = b"") -> bool: class HeartbeatWebSocket(Protocol): - HEARTBEAT: Final[Literal[1, 3]] # type: ignore + HEARTBEAT: Final[Literal[1, 3]] thread_id: int loop: asyncio.AbstractEventLoop diff --git a/disnake/guild.py b/disnake/guild.py index 3927992fb5..ba140f2298 100644 --- a/disnake/guild.py +++ b/disnake/guild.py @@ -3136,10 +3136,6 @@ async def integrations(self) -> List[Integration]: def convert(d): factory, _ = _integration_factory(d["type"]) - if factory is None: - raise InvalidData( - "Unknown integration type {type!r} for integration ID {id}".format_map(d) - ) return factory(guild=self, data=d) return [convert(d) for d in data] diff --git a/disnake/http.py b/disnake/http.py index f8c4b44694..06b3801861 100644 --- a/disnake/http.py +++ b/disnake/http.py @@ -248,19 +248,18 @@ def recreate(self) -> None: ) async def ws_connect(self, url: str, *, compress: int = 0) -> aiohttp.ClientWebSocketResponse: - kwargs = { - "proxy_auth": self.proxy_auth, - "proxy": self.proxy, - "max_msg_size": 0, - "timeout": 30.0, - "autoclose": False, - "headers": { + return await self.__session.ws_connect( + url, + proxy_auth=self.proxy_auth, + proxy=self.proxy, + max_msg_size=0, + timeout=30.0, + autoclose=False, + headers={ "User-Agent": self.user_agent, }, - "compress": compress, - } - - return await self.__session.ws_connect(url, **kwargs) + compress=compress, + ) async def request( self, @@ -276,9 +275,7 @@ async def request( lock = self._locks.get(bucket) if lock is None: - lock = asyncio.Lock() - if bucket is not None: - self._locks[bucket] = lock + self._locks[bucket] = lock = asyncio.Lock() # header creation headers: Dict[str, str] = { diff --git a/disnake/i18n.py b/disnake/i18n.py index 344787ad5b..c2781a9eb8 100644 --- a/disnake/i18n.py +++ b/disnake/i18n.py @@ -409,7 +409,7 @@ def _load_file(self, path: Path) -> None: except Exception as e: raise RuntimeError(f"Unable to load '{path}': {e}") from e - def _load_dict(self, data: Dict[str, str], locale: str) -> None: + def _load_dict(self, data: Dict[str, Optional[str]], locale: str) -> None: if not isinstance(data, dict) or not all( o is None or isinstance(o, str) for o in data.values() ): diff --git a/disnake/interactions/base.py b/disnake/interactions/base.py index bdcbe3cae2..01637be96a 100644 --- a/disnake/interactions/base.py +++ b/disnake/interactions/base.py @@ -1855,7 +1855,7 @@ def __init__( guild and guild.get_channel_or_thread(channel_id) or factory( - guild=guild_fallback, # type: ignore + guild=guild_fallback, state=state, data=channel, # type: ignore ) diff --git a/disnake/iterators.py b/disnake/iterators.py index ea8347effd..f7d694598a 100644 --- a/disnake/iterators.py +++ b/disnake/iterators.py @@ -106,7 +106,7 @@ def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]: def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]: return _MappedAsyncIterator(self, func) - def filter(self, predicate: _Func[T, bool]) -> _FilteredAsyncIterator[T]: + def filter(self, predicate: Optional[_Func[T, bool]]) -> _FilteredAsyncIterator[T]: return _FilteredAsyncIterator(self, predicate) async def flatten(self) -> List[T]: @@ -152,11 +152,11 @@ async def next(self) -> OT: class _FilteredAsyncIterator(_AsyncIterator[T]): - def __init__(self, iterator: _AsyncIterator[T], predicate: _Func[T, bool]) -> None: + def __init__(self, iterator: _AsyncIterator[T], predicate: Optional[_Func[T, bool]]) -> None: self.iterator = iterator if predicate is None: - predicate = lambda x: bool(x) + predicate = bool # similar to the `filter` builtin, a `None` filter drops falsy items self.predicate: _Func[T, bool] = predicate @@ -626,8 +626,8 @@ async def _fill(self) -> None: } for element in entries: - # TODO: remove this if statement later - if element["action_type"] is None: + # https://github.com/discord/discord-api-docs/issues/5055#issuecomment-1266363766 + if element["action_type"] is None: # pyright: ignore[reportUnnecessaryComparison] continue await self.entries.put( diff --git a/disnake/message.py b/disnake/message.py index 92aba532c7..e3967e1160 100644 --- a/disnake/message.py +++ b/disnake/message.py @@ -658,13 +658,14 @@ def __repr__(self) -> str: return f"" def to_dict(self) -> MessageReferencePayload: - result: MessageReferencePayload = {"channel_id": self.channel_id} + result: MessageReferencePayload = { + "channel_id": self.channel_id, + "fail_if_not_exists": self.fail_if_not_exists, + } if self.message_id is not None: result["message_id"] = self.message_id if self.guild_id is not None: result["guild_id"] = self.guild_id - if self.fail_if_not_exists is not None: - result["fail_if_not_exists"] = self.fail_if_not_exists return result to_message_reference_dict = to_dict diff --git a/disnake/shard.py b/disnake/shard.py index 102c66e4ae..a82ae13efd 100644 --- a/disnake/shard.py +++ b/disnake/shard.py @@ -589,7 +589,8 @@ async def change_presence( activities = () if activity is None else (activity,) for guild in guilds: me = guild.me - if me is None: + if me is None: # pyright: ignore[reportUnnecessaryComparison] + # may happen if guild is unavailable continue # Member.activities is typehinted as Tuple[ActivityType, ...], we may be setting it as Tuple[BaseActivity, ...] diff --git a/disnake/state.py b/disnake/state.py index ca915aa33f..714a92759b 100644 --- a/disnake/state.py +++ b/disnake/state.py @@ -25,7 +25,6 @@ Tuple, TypeVar, Union, - cast, overload, ) @@ -600,7 +599,6 @@ def _get_guild_channel( if channel is None: if "author" in data: # MessagePayload - data = cast("MessagePayload", data) user_id = int(data["author"]["id"]) else: # TypingStartEvent @@ -637,8 +635,6 @@ async def query_members( ): guild_id = guild.id ws = self._get_websocket(guild_id) - if ws is None: - raise RuntimeError("Somehow do not have a websocket for this guild_id") request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache) self._chunk_requests[request.nonce] = request @@ -1796,6 +1792,8 @@ def parse_voice_server_update(self, data: gateway.VoiceServerUpdateEvent) -> Non logging_coroutine(coro, info="Voice Protocol voice server update handler") ) + # FIXME: this should be refactored. The `GroupChannel` path will never be hit, + # `raw.timestamp` exists so no need to parse it twice, and `.get_user` should be used before falling back def parse_typing_start(self, data: gateway.TypingStartEvent) -> None: channel, guild = self._get_guild_channel(data) raw = RawTypingEvent(data) @@ -1810,7 +1808,7 @@ def parse_typing_start(self, data: gateway.TypingStartEvent) -> None: self.dispatch("raw_typing", raw) - if channel is not None: + if channel is not None: # pyright: ignore[reportUnnecessaryComparison] member = None if raw.member is not None: member = raw.member diff --git a/disnake/types/audit_log.py b/disnake/types/audit_log.py index d3b3a5484f..f9640b3ad9 100644 --- a/disnake/types/audit_log.py +++ b/disnake/types/audit_log.py @@ -103,8 +103,8 @@ class _AuditLogChange_Str(TypedDict): "permissions", "tags", ] - new_value: str - old_value: str + new_value: NotRequired[str] + old_value: NotRequired[str] class _AuditLogChange_AssetHash(TypedDict): @@ -116,8 +116,8 @@ class _AuditLogChange_AssetHash(TypedDict): "avatar_hash", "asset", ] - new_value: str - old_value: str + new_value: NotRequired[str] + old_value: NotRequired[str] class _AuditLogChange_Snowflake(TypedDict): @@ -134,8 +134,8 @@ class _AuditLogChange_Snowflake(TypedDict): "inviter_id", "guild_id", ] - new_value: Snowflake - old_value: Snowflake + new_value: NotRequired[Snowflake] + old_value: NotRequired[Snowflake] class _AuditLogChange_Bool(TypedDict): @@ -157,8 +157,8 @@ class _AuditLogChange_Bool(TypedDict): "premium_progress_bar_enabled", "enabled", ] - new_value: bool - old_value: bool + new_value: NotRequired[bool] + old_value: NotRequired[bool] class _AuditLogChange_Int(TypedDict): @@ -175,104 +175,104 @@ class _AuditLogChange_Int(TypedDict): "auto_archive_duration", "default_auto_archive_duration", ] - new_value: int - old_value: int + new_value: NotRequired[int] + old_value: NotRequired[int] class _AuditLogChange_ListSnowflake(TypedDict): key: Literal["exempt_roles", "exempt_channels"] - new_value: List[Snowflake] - old_value: List[Snowflake] + new_value: NotRequired[List[Snowflake]] + old_value: NotRequired[List[Snowflake]] class _AuditLogChange_ListRole(TypedDict): key: Literal["$add", "$remove"] - new_value: List[Role] - old_value: List[Role] + new_value: NotRequired[List[Role]] + old_value: NotRequired[List[Role]] class _AuditLogChange_MFALevel(TypedDict): key: Literal["mfa_level"] - new_value: MFALevel - old_value: MFALevel + new_value: NotRequired[MFALevel] + old_value: NotRequired[MFALevel] class _AuditLogChange_VerificationLevel(TypedDict): key: Literal["verification_level"] - new_value: VerificationLevel - old_value: VerificationLevel + new_value: NotRequired[VerificationLevel] + old_value: NotRequired[VerificationLevel] class _AuditLogChange_ExplicitContentFilter(TypedDict): key: Literal["explicit_content_filter"] - new_value: ExplicitContentFilterLevel - old_value: ExplicitContentFilterLevel + new_value: NotRequired[ExplicitContentFilterLevel] + old_value: NotRequired[ExplicitContentFilterLevel] class _AuditLogChange_DefaultMessageNotificationLevel(TypedDict): key: Literal["default_message_notifications"] - new_value: DefaultMessageNotificationLevel - old_value: DefaultMessageNotificationLevel + new_value: NotRequired[DefaultMessageNotificationLevel] + old_value: NotRequired[DefaultMessageNotificationLevel] class _AuditLogChange_ChannelType(TypedDict): key: Literal["type"] - new_value: ChannelType - old_value: ChannelType + new_value: NotRequired[ChannelType] + old_value: NotRequired[ChannelType] class _AuditLogChange_IntegrationExpireBehaviour(TypedDict): key: Literal["expire_behavior"] - new_value: IntegrationExpireBehavior - old_value: IntegrationExpireBehavior + new_value: NotRequired[IntegrationExpireBehavior] + old_value: NotRequired[IntegrationExpireBehavior] class _AuditLogChange_VideoQualityMode(TypedDict): key: Literal["video_quality_mode"] - new_value: VideoQualityMode - old_value: VideoQualityMode + new_value: NotRequired[VideoQualityMode] + old_value: NotRequired[VideoQualityMode] class _AuditLogChange_Overwrites(TypedDict): key: Literal["permission_overwrites"] - new_value: List[PermissionOverwrite] - old_value: List[PermissionOverwrite] + new_value: NotRequired[List[PermissionOverwrite]] + old_value: NotRequired[List[PermissionOverwrite]] class _AuditLogChange_Datetime(TypedDict): key: Literal["communication_disabled_until"] - new_value: datetime.datetime - old_value: datetime.datetime + new_value: NotRequired[datetime.datetime] + old_value: NotRequired[datetime.datetime] class _AuditLogChange_ApplicationCommandPermissions(TypedDict): key: str - new_value: ApplicationCommandPermissions - old_value: ApplicationCommandPermissions + new_value: NotRequired[ApplicationCommandPermissions] + old_value: NotRequired[ApplicationCommandPermissions] class _AuditLogChange_AutoModTriggerType(TypedDict): key: Literal["trigger_type"] - new_value: AutoModTriggerType - old_value: AutoModTriggerType + new_value: NotRequired[AutoModTriggerType] + old_value: NotRequired[AutoModTriggerType] class _AuditLogChange_AutoModEventType(TypedDict): key: Literal["event_type"] - new_value: AutoModEventType - old_value: AutoModEventType + new_value: NotRequired[AutoModEventType] + old_value: NotRequired[AutoModEventType] class _AuditLogChange_AutoModActions(TypedDict): key: Literal["actions"] - new_value: List[AutoModAction] - old_value: List[AutoModAction] + new_value: NotRequired[List[AutoModAction]] + old_value: NotRequired[List[AutoModAction]] class _AuditLogChange_AutoModTriggerMetadata(TypedDict): key: Literal["trigger_metadata"] - new_value: AutoModTriggerMetadata - old_value: AutoModTriggerMetadata + new_value: NotRequired[AutoModTriggerMetadata] + old_value: NotRequired[AutoModTriggerMetadata] AuditLogChange = Union[ diff --git a/disnake/types/automod.py b/disnake/types/automod.py index 156952d092..f7ac372e5e 100644 --- a/disnake/types/automod.py +++ b/disnake/types/automod.py @@ -8,9 +8,9 @@ from .snowflake import Snowflake, SnowflakeList -AutoModTriggerType = Literal[1, 2, 3, 4, 5] +AutoModTriggerType = Literal[1, 3, 4, 5] AutoModEventType = Literal[1] -AutoModActionType = Literal[1, 2] +AutoModActionType = Literal[1, 2, 3] AutoModPresetType = Literal[1, 2, 3] diff --git a/disnake/types/template.py b/disnake/types/template.py index ddb2c26cb7..e0008659aa 100644 --- a/disnake/types/template.py +++ b/disnake/types/template.py @@ -20,7 +20,7 @@ class Template(TypedDict): description: Optional[str] usage_count: int creator_id: Snowflake - creator: User + creator: Optional[User] # unsure when this can be null, but the spec says so created_at: str updated_at: str source_guild_id: Snowflake diff --git a/disnake/ui/action_row.py b/disnake/ui/action_row.py index fe7244a776..21ea01cb74 100644 --- a/disnake/ui/action_row.py +++ b/disnake/ui/action_row.py @@ -159,7 +159,8 @@ def __init__(self: ActionRow[ModalUIComponent], *components: ModalUIComponent) - def __init__(self: ActionRow[StrictUIComponentT], *components: StrictUIComponentT) -> None: ... - def __init__(self, *components: UIComponentT) -> None: + # n.b. this should be `*components: UIComponentT`, but pyright does not like it + def __init__(self, *components: Union[MessageUIComponent, ModalUIComponent]) -> None: self._children: List[UIComponentT] = [] for component in components: @@ -167,7 +168,7 @@ def __init__(self, *components: UIComponentT) -> None: raise TypeError( f"components should be of type WrappedComponent, got {type(component).__name__}." ) - self.append_item(component) + self.append_item(component) # type: ignore def __repr__(self) -> str: return f"" diff --git a/disnake/ui/button.py b/disnake/ui/button.py index d5e1fc7708..a961ba29ab 100644 --- a/disnake/ui/button.py +++ b/disnake/ui/button.py @@ -275,7 +275,7 @@ def button( def button( - cls: Type[Object[B_co, P]] = Button[Any], **kwargs: Any + cls: Type[Object[B_co, ...]] = Button[Any], **kwargs: Any ) -> Callable[[ItemCallbackType[B_co]], DecoratedItem[B_co]]: """A decorator that attaches a button to a component. diff --git a/disnake/ui/item.py b/disnake/ui/item.py index 971ca8dcb3..464eb4d588 100644 --- a/disnake/ui/item.py +++ b/disnake/ui/item.py @@ -184,5 +184,5 @@ class Object(Protocol[T_co, P]): def __new__(cls) -> T_co: ... - def __init__(*args: P.args, **kwargs: P.kwargs) -> None: + def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None: ... diff --git a/disnake/ui/modal.py b/disnake/ui/modal.py index a7a5503a28..adf21ffa9c 100644 --- a/disnake/ui/modal.py +++ b/disnake/ui/modal.py @@ -55,7 +55,7 @@ def __init__( custom_id: str = MISSING, timeout: float = 600, ) -> None: - if timeout is None: + if timeout is None: # pyright: ignore[reportUnnecessaryComparison] raise ValueError("Timeout may not be None") rows = components_to_rows(components) diff --git a/disnake/ui/select/channel.py b/disnake/ui/select/channel.py index a98472b547..57dd9cfbe9 100644 --- a/disnake/ui/select/channel.py +++ b/disnake/ui/select/channel.py @@ -168,7 +168,7 @@ def channel_select( def channel_select( - cls: Type[Object[S_co, P]] = ChannelSelect[Any], **kwargs: Any + cls: Type[Object[S_co, ...]] = ChannelSelect[Any], **kwargs: Any ) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: """A decorator that attaches a channel select menu to a component. diff --git a/disnake/ui/select/mentionable.py b/disnake/ui/select/mentionable.py index 4f0d591201..860903f7f1 100644 --- a/disnake/ui/select/mentionable.py +++ b/disnake/ui/select/mentionable.py @@ -144,7 +144,7 @@ def mentionable_select( def mentionable_select( - cls: Type[Object[S_co, P]] = MentionableSelect[Any], **kwargs: Any + cls: Type[Object[S_co, ...]] = MentionableSelect[Any], **kwargs: Any ) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: """A decorator that attaches a mentionable (user/member/role) select menu to a component. diff --git a/disnake/ui/select/role.py b/disnake/ui/select/role.py index 69b1bcaa57..fe2da2f97a 100644 --- a/disnake/ui/select/role.py +++ b/disnake/ui/select/role.py @@ -142,7 +142,7 @@ def role_select( def role_select( - cls: Type[Object[S_co, P]] = RoleSelect[Any], **kwargs: Any + cls: Type[Object[S_co, ...]] = RoleSelect[Any], **kwargs: Any ) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: """A decorator that attaches a role select menu to a component. diff --git a/disnake/ui/select/string.py b/disnake/ui/select/string.py index d38c9ea6ba..3eeedc1f22 100644 --- a/disnake/ui/select/string.py +++ b/disnake/ui/select/string.py @@ -268,7 +268,7 @@ def string_select( def string_select( - cls: Type[Object[S_co, P]] = StringSelect[Any], **kwargs: Any + cls: Type[Object[S_co, ...]] = StringSelect[Any], **kwargs: Any ) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: """A decorator that attaches a string select menu to a component. diff --git a/disnake/ui/select/user.py b/disnake/ui/select/user.py index 179b9d6c74..4868894a83 100644 --- a/disnake/ui/select/user.py +++ b/disnake/ui/select/user.py @@ -143,7 +143,7 @@ def user_select( def user_select( - cls: Type[Object[S_co, P]] = UserSelect[Any], **kwargs: Any + cls: Type[Object[S_co, ...]] = UserSelect[Any], **kwargs: Any ) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: """A decorator that attaches a user select menu to a component. diff --git a/disnake/utils.py b/disnake/utils.py index a74d50ab94..6fa6ae82d7 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -134,6 +134,7 @@ class _RequestLike(Protocol): V = TypeVar("V") T_co = TypeVar("T_co", covariant=True) _Iter = Union[Iterator[T], AsyncIterator[T]] +_BytesLike = Union[bytes, bytearray, memoryview] class CachedSlotProperty(Generic[T, T_co]): @@ -489,7 +490,7 @@ def _maybe_cast(value: V, converter: Callable[[V], T], default: T = None) -> Opt } -def _get_mime_type_for_image(data: bytes) -> str: +def _get_mime_type_for_image(data: _BytesLike) -> str: if data[0:8] == b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A": return "image/png" elif data[0:3] == b"\xff\xd8\xff" or data[6:10] in (b"JFIF", b"Exif"): @@ -502,14 +503,14 @@ def _get_mime_type_for_image(data: bytes) -> str: raise ValueError("Unsupported image type given") -def _bytes_to_base64_data(data: bytes) -> str: +def _bytes_to_base64_data(data: _BytesLike) -> str: fmt = "data:{mime};base64,{data}" mime = _get_mime_type_for_image(data) b64 = b64encode(data).decode("ascii") return fmt.format(mime=mime, data=b64) -def _get_extension_for_image(data: bytes) -> Optional[str]: +def _get_extension_for_image(data: _BytesLike) -> Optional[str]: try: mime_type = _get_mime_type_for_image(data) except ValueError: @@ -538,7 +539,7 @@ async def _assetbytes_to_base64_data(data: Optional[AssetBytes]) -> Optional[str if HAS_ORJSON: def _to_json(obj: Any) -> str: - return orjson.dumps(obj).decode("utf-8") + return orjson.dumps(obj).decode("utf-8") # type: ignore _from_json = orjson.loads # type: ignore @@ -571,7 +572,8 @@ async def maybe_coroutine( return value # type: ignore # typeguard doesn't narrow in the negative case -async def async_all(gen: Iterable[Union[Awaitable[bool], bool]], *, check=_isawaitable) -> bool: +async def async_all(gen: Iterable[Union[Awaitable[bool], bool]]) -> bool: + check = _isawaitable for elem in gen: if check(elem): elem = await elem diff --git a/disnake/voice_client.py b/disnake/voice_client.py index 52750ecebd..a6cc13e0ba 100644 --- a/disnake/voice_client.py +++ b/disnake/voice_client.py @@ -279,7 +279,7 @@ async def on_voice_server_update(self, data: VoiceServerUpdateEvent) -> None: self.server_id = int(data["guild_id"]) endpoint = data.get("endpoint") - if endpoint is None or self.token is None: + if endpoint is None or not self.token: _log.warning( "Awaiting endpoint... This requires waiting. " "If timeout occurred considering raising the timeout and reconnecting." diff --git a/docs/extensions/builder.py b/docs/extensions/builder.py index 5133af0f85..61e366d2ca 100644 --- a/docs/extensions/builder.py +++ b/docs/extensions/builder.py @@ -65,7 +65,7 @@ def disable_mathjax(app: Sphinx, config: Config) -> None: # inspired by https://github.com/readthedocs/sphinx-hoverxref/blob/003b84fee48262f1a969c8143e63c177bd98aa26/hoverxref/extension.py#L151 for listener in app.events.listeners.get("html-page-context", []): - module_name = inspect.getmodule(listener.handler).__name__ # type: ignore + module_name = inspect.getmodule(listener.handler).__name__ if module_name == "sphinx.ext.mathjax": app.disconnect(listener.id) diff --git a/examples/basic_voice.py b/examples/basic_voice.py index 6d224b21e5..45046c780f 100644 --- a/examples/basic_voice.py +++ b/examples/basic_voice.py @@ -33,8 +33,6 @@ "source_address": "0.0.0.0", # bind to ipv4 since ipv6 addresses cause issues sometimes } -ffmpeg_options = {"options": "-vn"} - ytdl = youtube_dl.YoutubeDL(ytdl_format_options) @@ -59,7 +57,7 @@ async def from_url( filename = data["url"] if stream else ytdl.prepare_filename(data) - return cls(disnake.FFmpegPCMAudio(filename, **ffmpeg_options), data=data) + return cls(disnake.FFmpegPCMAudio(filename, options="-vn"), data=data) class Music(commands.Cog): diff --git a/examples/interactions/injections.py b/examples/interactions/injections.py index 27576d60bc..30c7554dd6 100644 --- a/examples/interactions/injections.py +++ b/examples/interactions/injections.py @@ -114,7 +114,7 @@ async def get_game_user( if user is None: return await db.get_game_user(id=inter.author.id) - game_user: GameUser = await db.search_game_user(username=user, server=server) + game_user: Optional[GameUser] = await db.search_game_user(username=user, server=server) if game_user is None: raise commands.CommandError(f"User with username {user!r} could not be found") diff --git a/examples/interactions/modal.py b/examples/interactions/modal.py index 00b2364789..f271c82f4c 100644 --- a/examples/interactions/modal.py +++ b/examples/interactions/modal.py @@ -2,6 +2,8 @@ """An example demonstrating two methods of sending modals and handling modal responses.""" +# pyright: reportUnknownLambdaType=false + import asyncio import os diff --git a/pyproject.toml b/pyproject.toml index 4756d55c4a..73f3ad1b9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,7 @@ codemod = [ ] typing = [ # this is not pyright itself, but the python wrapper - "pyright==1.1.291", + "pyright==1.1.336", "typing-extensions~=4.8.0", # only used for type-checking, version does not matter "pytz", diff --git a/test_bot/cogs/modals.py b/test_bot/cogs/modals.py index 13c84bddf2..c5d514a25c 100644 --- a/test_bot/cogs/modals.py +++ b/test_bot/cogs/modals.py @@ -65,7 +65,7 @@ async def create_tag_low(self, inter: disnake.AppCmdInter[commands.Bot]) -> None modal_inter: disnake.ModalInteraction = await self.bot.wait_for( "modal_submit", - check=lambda i: i.custom_id == "create_tag2" and i.author.id == inter.author.id, + check=lambda i: i.custom_id == "create_tag2" and i.author.id == inter.author.id, # type: ignore # unknown parameter type ) embed = disnake.Embed(title="Tag Creation") diff --git a/tests/ui/test_decorators.py b/tests/ui/test_decorators.py index 5fab1bb787..e9c3680873 100644 --- a/tests/ui/test_decorators.py +++ b/tests/ui/test_decorators.py @@ -30,16 +30,16 @@ def __init__(self, *, param: float = 42.0) -> None: class TestDecorator: def test_default(self) -> None: - with create_callback(ui.Button) as func: + with create_callback(ui.Button[ui.View]) as func: res = ui.button(custom_id="123")(func) - assert_type(res, ui.item.DecoratedItem[ui.Button]) + assert_type(res, ui.item.DecoratedItem[ui.Button[ui.View]]) assert func.__discord_ui_model_type__ is ui.Button assert func.__discord_ui_model_kwargs__ == {"custom_id": "123"} - with create_callback(ui.StringSelect) as func: + with create_callback(ui.StringSelect[ui.View]) as func: res = ui.string_select(custom_id="123")(func) - assert_type(res, ui.item.DecoratedItem[ui.StringSelect]) + assert_type(res, ui.item.DecoratedItem[ui.StringSelect[ui.View]]) assert func.__discord_ui_model_type__ is ui.StringSelect assert func.__discord_ui_model_kwargs__ == {"custom_id": "123"}