Skip to content

Commit

Permalink
fix(interactions): improve fallbacks in resolved data (#646)
Browse files Browse the repository at this point in the history
Co-authored-by: arl <genericusername414+git@gmail.com>
  • Loading branch information
shiftinv and onerandomusername committed Jul 21, 2022
1 parent 3099051 commit 2a175d1
Showing 1 changed file with 52 additions and 24 deletions.
76 changes: 52 additions & 24 deletions disnake/interactions/application_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,23 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, TypeVar, Union, cast

from .. import utils
from ..channel import _threaded_channel_factory
from ..enums import ApplicationCommandType, Locale, OptionType, try_enum
from ..channel import (
CategoryChannel,
ForumChannel,
PartialMessageable,
StageChannel,
TextChannel,
VoiceChannel,
_threaded_guild_channel_factory,
)
from ..enums import ApplicationCommandType, ChannelType, Locale, OptionType, try_enum
from ..guild import Guild
from ..member import Member
from ..message import Attachment, Message
from ..object import Object
from ..role import Role
from ..user import User
from .base import Interaction
Expand All @@ -55,14 +64,7 @@
MISSING = utils.MISSING

if TYPE_CHECKING:
from ..channel import (
CategoryChannel,
ForumChannel,
PartialMessageable,
StageChannel,
TextChannel,
VoiceChannel,
)
from ..abc import MessageableChannel
from ..ext.commands import InvokableApplicationCommand
from ..state import ConnectionState
from ..threads import Thread
Expand Down Expand Up @@ -141,7 +143,7 @@ class ApplicationCommandInteraction(Interaction):
def __init__(self, *, data: ApplicationCommandInteractionPayload, state: ConnectionState):
super().__init__(data=data, state=state)
self.data = ApplicationCommandInteractionData(
data=data["data"], state=state, guild=self.guild
data=data["data"], state=state, guild_id=self.guild_id
)
self.application_command: InvokableApplicationCommand = MISSING
self.command_failed: bool = False
Expand Down Expand Up @@ -237,14 +239,14 @@ def __init__(
*,
data: ApplicationCommandInteractionDataPayload,
state: ConnectionState,
guild: Optional[Guild],
guild_id: Optional[int],
):
super().__init__(data)
self.id: int = int(data["id"])
self.name: str = data["name"]
self.type: ApplicationCommandType = try_enum(ApplicationCommandType, data["type"])
self.resolved = ApplicationCommandInteractionDataResolved(
data=data.get("resolved", {}), state=state, guild=guild
data=data.get("resolved", {}), state=state, guild_id=guild_id
)
self.target_id: Optional[int] = utils._get_as_snowflake(data, "target_id")
self.target: Optional[Union[User, Member, Message]] = self.resolved.get(self.target_id) # type: ignore
Expand Down Expand Up @@ -393,7 +395,7 @@ def __init__(
*,
data: ApplicationCommandInteractionDataResolvedPayload,
state: ConnectionState,
guild: Optional[Guild],
guild_id: Optional[int],
):
data = data or {}
super().__init__(data)
Expand All @@ -412,6 +414,14 @@ def __init__(
messages = data.get("messages", {})
attachments = data.get("attachments", {})

guild: Optional[Guild] = None
# `guild_fallback` is only used in guild contexts, so this `MISSING` value should never be used.
# We need to define it anyway to satisfy the typechecker.
guild_fallback: Union[Guild, Object] = MISSING
if guild_id is not None:
guild = state._get_guild(guild_id)
guild_fallback = guild or Object(id=guild_id)

for str_id, user in users.items():
user_id = int(str_id)
member = members.get(str_id)
Expand All @@ -422,32 +432,50 @@ def __init__(
or Member(
data=member,
user_data=user,
guild=guild, # type: ignore
guild=guild_fallback, # type: ignore
state=state,
)
)
else:
self.users[user_id] = User(state=state, data=user)

for str_id, role in roles.items():
self.roles[int(str_id)] = Role(guild=guild, state=state, data=role) # type: ignore
self.roles[int(str_id)] = Role(
guild=guild_fallback, # type: ignore
state=state,
data=role,
)

for str_id, channel in channels.items():
factory, _ = _threaded_channel_factory(channel["type"])
channel_id = int(str_id)
factory, _ = _threaded_guild_channel_factory(channel["type"])
if factory:
channel["position"] = 0 # type: ignore
self.channels[int(str_id)] = ( # type: ignore
self.channels[channel_id] = (
guild
and guild.get_channel(int(str_id))
or factory(guild=guild, state=state, data=channel) # type: ignore
and guild.get_channel(channel_id)
or factory(
guild=guild_fallback, # type: ignore
state=state,
data=channel, # type: ignore
)
)
else:
self.channels[channel_id] = PartialMessageable(
state=state, id=channel_id, type=try_enum(ChannelType, channel["type"])
)

for str_id, message in messages.items():
channel_id = int(message["channel_id"])
channel = guild.get_channel(channel_id) if guild else None
channel = cast(
"Optional[MessageableChannel]",
(guild and guild.get_channel(channel_id) or state.get_channel(channel_id)),
)
if channel is None:
channel = state.get_channel(channel_id)
self.messages[int(str_id)] = Message(state=state, channel=channel, data=message) # type: ignore
# The channel is not part of `resolved.channels`,
# so we need to fall back to partials here.
channel = PartialMessageable(state=state, id=channel_id, type=None)
self.messages[int(str_id)] = Message(state=state, channel=channel, data=message)

for str_id, attachment in attachments.items():
self.attachments[int(str_id)] = Attachment(data=attachment, state=state)
Expand Down

0 comments on commit 2a175d1

Please sign in to comment.