diff --git a/disnake/interactions/application_command.py b/disnake/interactions/application_command.py index fb9ecdc832..16c0317a1d 100644 --- a/disnake/interactions/application_command.py +++ b/disnake/interactions/application_command.py @@ -22,14 +22,23 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, TypeVar, Union, cast from .. import utils -from ..channel import _threaded_channel_factory -from ..enums import ApplicationCommandType, Locale, OptionType, try_enum +from ..channel import ( + CategoryChannel, + ForumChannel, + PartialMessageable, + StageChannel, + TextChannel, + VoiceChannel, + _threaded_guild_channel_factory, +) +from ..enums import ApplicationCommandType, ChannelType, Locale, OptionType, try_enum from ..guild import Guild from ..member import Member from ..message import Attachment, Message +from ..object import Object from ..role import Role from ..user import User from .base import Interaction @@ -55,14 +64,7 @@ MISSING = utils.MISSING if TYPE_CHECKING: - from ..channel import ( - CategoryChannel, - ForumChannel, - PartialMessageable, - StageChannel, - TextChannel, - VoiceChannel, - ) + from ..abc import MessageableChannel from ..ext.commands import InvokableApplicationCommand from ..state import ConnectionState from ..threads import Thread @@ -141,7 +143,7 @@ class ApplicationCommandInteraction(Interaction): def __init__(self, *, data: ApplicationCommandInteractionPayload, state: ConnectionState): super().__init__(data=data, state=state) self.data = ApplicationCommandInteractionData( - data=data["data"], state=state, guild=self.guild + data=data["data"], state=state, guild_id=self.guild_id ) self.application_command: InvokableApplicationCommand = MISSING self.command_failed: bool = False @@ -237,14 +239,14 @@ def __init__( *, data: ApplicationCommandInteractionDataPayload, state: ConnectionState, - guild: Optional[Guild], + guild_id: Optional[int], ): super().__init__(data) self.id: int = int(data["id"]) self.name: str = data["name"] self.type: ApplicationCommandType = try_enum(ApplicationCommandType, data["type"]) self.resolved = ApplicationCommandInteractionDataResolved( - data=data.get("resolved", {}), state=state, guild=guild + data=data.get("resolved", {}), state=state, guild_id=guild_id ) self.target_id: Optional[int] = utils._get_as_snowflake(data, "target_id") self.target: Optional[Union[User, Member, Message]] = self.resolved.get(self.target_id) # type: ignore @@ -393,7 +395,7 @@ def __init__( *, data: ApplicationCommandInteractionDataResolvedPayload, state: ConnectionState, - guild: Optional[Guild], + guild_id: Optional[int], ): data = data or {} super().__init__(data) @@ -412,6 +414,14 @@ def __init__( messages = data.get("messages", {}) attachments = data.get("attachments", {}) + guild: Optional[Guild] = None + # `guild_fallback` is only used in guild contexts, so this `MISSING` value should never be used. + # We need to define it anyway to satisfy the typechecker. + guild_fallback: Union[Guild, Object] = MISSING + if guild_id is not None: + guild = state._get_guild(guild_id) + guild_fallback = guild or Object(id=guild_id) + for str_id, user in users.items(): user_id = int(str_id) member = members.get(str_id) @@ -422,7 +432,7 @@ def __init__( or Member( data=member, user_data=user, - guild=guild, # type: ignore + guild=guild_fallback, # type: ignore state=state, ) ) @@ -430,24 +440,42 @@ def __init__( self.users[user_id] = User(state=state, data=user) for str_id, role in roles.items(): - self.roles[int(str_id)] = Role(guild=guild, state=state, data=role) # type: ignore + self.roles[int(str_id)] = Role( + guild=guild_fallback, # type: ignore + state=state, + data=role, + ) for str_id, channel in channels.items(): - factory, _ = _threaded_channel_factory(channel["type"]) + channel_id = int(str_id) + factory, _ = _threaded_guild_channel_factory(channel["type"]) if factory: channel["position"] = 0 # type: ignore - self.channels[int(str_id)] = ( # type: ignore + self.channels[channel_id] = ( guild - and guild.get_channel(int(str_id)) - or factory(guild=guild, state=state, data=channel) # type: ignore + and guild.get_channel(channel_id) + or factory( + guild=guild_fallback, # type: ignore + state=state, + data=channel, # type: ignore + ) + ) + else: + self.channels[channel_id] = PartialMessageable( + state=state, id=channel_id, type=try_enum(ChannelType, channel["type"]) ) for str_id, message in messages.items(): channel_id = int(message["channel_id"]) - channel = guild.get_channel(channel_id) if guild else None + channel = cast( + "Optional[MessageableChannel]", + (guild and guild.get_channel(channel_id) or state.get_channel(channel_id)), + ) if channel is None: - channel = state.get_channel(channel_id) - self.messages[int(str_id)] = Message(state=state, channel=channel, data=message) # type: ignore + # The channel is not part of `resolved.channels`, + # so we need to fall back to partials here. + channel = PartialMessageable(state=state, id=channel_id, type=None) + self.messages[int(str_id)] = Message(state=state, channel=channel, data=message) for str_id, attachment in attachments.items(): self.attachments[int(str_id)] = Attachment(data=attachment, state=state)