diff --git a/assyst-core/src/assyst.rs b/assyst-core/src/assyst.rs index 1e1cf9b..d632fda 100644 --- a/assyst-core/src/assyst.rs +++ b/assyst-core/src/assyst.rs @@ -10,7 +10,10 @@ use assyst_common::metrics_handler::MetricsHandler; use assyst_common::pipe::CACHE_PIPE_PATH; use assyst_database::DatabaseHandler; use std::sync::{Arc, Mutex}; +use twilight_http::client::InteractionClient; use twilight_http::Client as HttpClient; +use twilight_model::id::marker::ApplicationMarker; +use twilight_model::id::Id; pub type ThreadSafeAssyst = Arc; @@ -29,6 +32,8 @@ pub struct Assyst { /// HTTP client for Discord. Handles all HTTP requests to Discord, storing stateful information /// about current ratelimits. pub http_client: Arc, + /// Interaction client for handling Discord interations (i.e., slash commands). + pub application_id: Id, /// List of the current premim users of Assyst. pub premium_users: Arc>>, /// Metrics handler for Prometheus, rate trackers etc. @@ -52,11 +57,13 @@ impl Assyst { let database_handler = Arc::new(DatabaseHandler::new(CONFIG.database.to_url(), CONFIG.database.to_url_safe()).await?); let premium_users = Arc::new(Mutex::new(vec![])); + let current_application = http_client.current_user_application().await?.model().await?; Ok(Assyst { persistent_cache_handler: PersistentCacheHandler::new(CACHE_PIPE_PATH), database_handler: database_handler.clone(), http_client: http_client.clone(), + application_id: current_application.id, premium_users: premium_users.clone(), metrics_handler: Arc::new(MetricsHandler::new(database_handler.clone())?), reqwest_client: reqwest::Client::new(), @@ -77,4 +84,8 @@ impl Assyst { pub fn update_premium_user_list(&self, patrons: Vec) { *self.premium_users.lock().unwrap() = patrons; } + + pub fn interaction_client(&self) -> InteractionClient { + self.http_client.interaction(self.application_id) + } } diff --git a/assyst-core/src/command/group.rs b/assyst-core/src/command/group.rs index 70a60af..77eebcc 100644 --- a/assyst-core/src/command/group.rs +++ b/assyst-core/src/command/group.rs @@ -79,6 +79,31 @@ macro_rules! define_commandgroup { crate::command::group::find_subcommand(sub, Self::SUBCOMMANDS) } + fn as_interaction_command(&self) -> twilight_model::application::command::Command { + let meta = self.metadata(); + + twilight_model::application::command::Command { + application_id: None, + default_member_permissions: None, + description: meta.description.to_owned(), + description_localizations: None, + // TODO: set based on if dms are allowed + // TODO: update to `contexts` once this is required + // (see https://discord.com/developers/docs/interactions/application-commands#create-global-application-command) + dm_permission: Some(false), + guild_id: None, + id: None, + kind: twilight_model::application::command::CommandType::ChatInput, + // todo: handle properly + name: "".to_owned(), + name_localizations: None, + nsfw: Some(meta.age_restricted), + // TODO: set options properly + options: vec![], + version: twilight_model::id::Id::new(1), + } + } + async fn execute(&self, ctxt: CommandCtxt<'_>) -> Result<(), crate::command::ExecutionError> { #![allow(unreachable_code)] match crate::command::group::execute_subcommand(ctxt.fork(), Self::SUBCOMMANDS).await { diff --git a/assyst-core/src/command/mod.rs b/assyst-core/src/command/mod.rs index aa8b5d6..2f3d0cc 100644 --- a/assyst-core/src/command/mod.rs +++ b/assyst-core/src/command/mod.rs @@ -79,6 +79,7 @@ impl Display for Availability { } } +#[derive(Debug)] pub struct CommandMetadata { pub name: &'static str, pub aliases: &'static [&'static str], @@ -93,7 +94,7 @@ pub struct CommandMetadata { pub age_restricted: bool, } -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub enum Category { Fun, Makesweet, @@ -155,6 +156,9 @@ pub trait Command { /// Tries to find a subcommand given a name, provided that `self` is a command group fn subcommand(&self, s: &str) -> Option; + /// Creates an interaction command for subitting for Discord on startup + fn as_interaction_command(&self) -> twilight_model::application::command::Command; + /// Parses arguments and executes the command. async fn execute(&self, ctxt: CommandCtxt<'_>) -> Result<(), ExecutionError>; } @@ -208,6 +212,8 @@ impl<'a> CommandCtxt<'a> { let builder = builder.into(); match self.data.source { Source::Gateway => gateway_reply::reply(self, builder).await, + // TODO: reply properly + Source::Interaction => todo!(), } } @@ -323,7 +329,7 @@ pub async fn check_metadata( .command_ratelimits .insert(id, metadata.name, Instant::now()); - if metadata.send_processing { + if metadata.send_processing && ctxt.data.source == Source::Gateway { if let Err(e) = ctxt.reply("Processing...").await { return Err(ExecutionError::Command(e)); } diff --git a/assyst-core/src/command/registry.rs b/assyst-core/src/command/registry.rs index 75a1d6f..36c890c 100644 --- a/assyst-core/src/command/registry.rs +++ b/assyst-core/src/command/registry.rs @@ -2,7 +2,9 @@ use std::collections::HashMap; use std::sync::OnceLock; use tracing::info; +use twilight_model::application::command::Command as InteractionCommand; +use crate::assyst::ThreadSafeAssyst; use crate::command::CommandMetadata; use super::{misc, services, wsi, TCommand}; @@ -58,3 +60,29 @@ pub fn get_or_init_commands() -> &'static HashMap<&'static str, TCommand> { pub fn find_command_by_name(name: &str) -> Option { get_or_init_commands().get(name).copied() } + +pub async fn register_interaction_commands(assyst: ThreadSafeAssyst) -> anyhow::Result> { + // todo: dont register aliases + let commands = get_or_init_commands() + .iter() + .map(|x| { + let mut interaction_command = x.1.as_interaction_command(); + interaction_command.name = if interaction_command.name.is_empty() { + "".to_owned() + } else { + x.0.to_owned().to_owned() + }; + interaction_command + }) + .filter(|y| !y.name.is_empty()) + .collect::>(); + + let response = assyst + .interaction_client() + .set_global_commands(&commands) + .await? + .model() + .await?; + + Ok(response) +} diff --git a/assyst-core/src/command/source.rs b/assyst-core/src/command/source.rs index 0f3f865..5be23f0 100644 --- a/assyst-core/src/command/source.rs +++ b/assyst-core/src/command/source.rs @@ -1,4 +1,5 @@ -#[derive(Clone)] +#[derive(Clone, Eq, PartialEq)] pub enum Source { Gateway, + Interaction, } diff --git a/assyst-core/src/main.rs b/assyst-core/src/main.rs index 7d79d9e..1ac0414 100644 --- a/assyst-core/src/main.rs +++ b/assyst-core/src/main.rs @@ -18,6 +18,7 @@ use assyst_common::config::CONFIG; use assyst_common::pipe::{Pipe, GATEWAY_PIPE_PATH}; use assyst_common::util::tracing_init; use assyst_common::{err, ok_or_break}; +use command::registry::register_interaction_commands; use gateway_handler::handle_raw_event; use gateway_handler::incoming_event::IncomingEvent; use rest::web_media_download::get_web_download_api_urls; @@ -136,41 +137,49 @@ async fn main() { ) .await; - info!("Connecting to assyst-gateway pipe at {}", GATEWAY_PIPE_PATH); - loop { - let mut gateway_pipe = Pipe::poll_connect(GATEWAY_PIPE_PATH, None).await.unwrap(); - info!("Connected to assyst-gateway pipe at {}", GATEWAY_PIPE_PATH); + info!("Registering interaction commands"); + register_interaction_commands(assyst.clone()).await.unwrap(); + spawn(async move { + info!("Connecting to assyst-gateway pipe at {}", GATEWAY_PIPE_PATH); loop { - // break if read fails because it means broken pipe - // we need to re-poll the pipe to get a new connection - let event = ok_or_break!(gateway_pipe.read_string().await); - trace!("got event: {}", event); - - let parsed_event = twilight_gateway::parse( - event, - EventTypeFlags::GUILD_CREATE - | EventTypeFlags::GUILD_DELETE - | EventTypeFlags::MESSAGE_CREATE - | EventTypeFlags::MESSAGE_DELETE - | EventTypeFlags::MESSAGE_UPDATE - | EventTypeFlags::READY, - ) - .ok() - .flatten(); - - if let Some(parsed_event) = parsed_event { - let try_incoming_event: Result = parsed_event.try_into(); - if let Ok(incoming_event) = try_incoming_event { - assyst.metrics_handler.add_event(); - let assyst_c = assyst.clone(); - spawn(async move { handle_raw_event(assyst_c.clone(), incoming_event).await }); + let mut gateway_pipe = Pipe::poll_connect(GATEWAY_PIPE_PATH, None).await.unwrap(); + info!("Connected to assyst-gateway pipe at {}", GATEWAY_PIPE_PATH); + + loop { + // break if read fails because it means broken pipe + // we need to re-poll the pipe to get a new connection + let event = ok_or_break!(gateway_pipe.read_string().await); + trace!("got event: {}", event); + + let parsed_event = twilight_gateway::parse( + event, + EventTypeFlags::GUILD_CREATE + | EventTypeFlags::GUILD_DELETE + | EventTypeFlags::MESSAGE_CREATE + | EventTypeFlags::MESSAGE_DELETE + | EventTypeFlags::MESSAGE_UPDATE + | EventTypeFlags::READY, + ) + .ok() + .flatten(); + + if let Some(parsed_event) = parsed_event { + let try_incoming_event: Result = parsed_event.try_into(); + if let Ok(incoming_event) = try_incoming_event { + assyst.metrics_handler.add_event(); + let assyst_c = assyst.clone(); + spawn(async move { handle_raw_event(assyst_c.clone(), incoming_event).await }); + } } } + + err!("Connection to assyst-gateway lost, attempting reconnection"); } + }); - err!("Connection to assyst-gateway lost, attempting reconnection"); - } + // todo: connect to slash client + loop {} } #[cfg(test)] diff --git a/assyst-proc-macro/src/lib.rs b/assyst-proc-macro/src/lib.rs index 451c0f7..c5e602a 100644 --- a/assyst-proc-macro/src/lib.rs +++ b/assyst-proc-macro/src/lib.rs @@ -134,6 +134,30 @@ pub fn command(attrs: TokenStream, func: TokenStream) -> TokenStream { None } + fn as_interaction_command(&self) -> twilight_model::application::command::Command { + let meta = self.metadata(); + + twilight_model::application::command::Command { + application_id: None, + default_member_permissions: None, + description: meta.description.to_owned(), + description_localizations: None, + // TODO: set based on if dms are allowed + // TODO: update to `contexts` once this is required + // (see https://discord.com/developers/docs/interactions/application-commands#create-global-application-command) + dm_permission: Some(false), + guild_id: None, + id: None, + kind: twilight_model::application::command::CommandType::ChatInput, + name: meta.name.to_owned(), + name_localizations: None, + nsfw: Some(meta.age_restricted), + // TODO: set options properly + options: vec![], + version: twilight_model::id::Id::new(1), + } + } + async fn execute( &self, mut ctxt: diff --git a/assyst-slash-client/Cargo.toml b/assyst-slash-client/Cargo.toml index 3c4bcb1..3b312f6 100644 --- a/assyst-slash-client/Cargo.toml +++ b/assyst-slash-client/Cargo.toml @@ -13,3 +13,4 @@ anyhow = "1.0.75" toml = "0.8.8" futures-util = { version = "0.3", default-features = false } tracing = "0.1.37" +assyst-common = { path = "../assyst-common" } diff --git a/assyst-slash-client/src/main.rs b/assyst-slash-client/src/main.rs index 59f28c6..961338a 100644 --- a/assyst-slash-client/src/main.rs +++ b/assyst-slash-client/src/main.rs @@ -7,6 +7,7 @@ )] use std::sync::Arc; +use assyst_common::config::CONFIG; use command::Cmd; use context::{Context, InnerContext}; use response::ResponseBuilder; @@ -32,15 +33,12 @@ pub mod utils; #[tokio::main] async fn main() -> anyhow::Result<()> { - let cfg_file = std::fs::read_to_string("Config.toml").expect("missing Config.toml"); - let cfg: Cfg = toml::from_str(&cfg_file).expect("error parsing TOML"); - - let client = Arc::new(Client::new(cfg.token.clone())); - let config = twilight_gateway::Config::new(cfg.token.clone(), Intents::empty()); + let client = Arc::new(Client::new(CONFIG.authentication.discord_token.clone())); + let config = twilight_gateway::Config::new(CONFIG.authentication.discord_token.clone(), Intents::empty()); let application_id = client.current_user_application().await?.model().await?.id; let interactions = client.interaction(application_id); - let ctx = Context::new(client.clone(), application_id, cfg); + let ctx = Context::new(client.clone(), application_id, ()); let cmds = vec![ping(&ctx)]; @@ -63,10 +61,10 @@ async fn main() -> anyhow::Result<()> { .collect::>() .join(", "); if let Some(g) = k { - interactions.set_guild_commands(g, c).await?; + interactions.set_guild_commands(g, &[]).await.unwrap(); println!("\x1b[1;32mRegister\x1b[0m Guild [\x1b[33m{g}\x1b[0m] with [{names}]"); } else { - interactions.set_global_commands(c).await?; + interactions.set_global_commands(c).await.unwrap(); println!("\x1b[1;32mRegister\x1b[0m global commands with [{names}]"); } } @@ -96,12 +94,12 @@ pub async fn runner(mut shard: Shard, tx: UnboundedSender) { } #[must_use] -pub fn ping(ctx: &Context) -> Cmd { +pub fn ping(ctx: &Context<()>) -> Cmd<()> { Cmd::new(Box::new(ctx.clone())) .name("test") .chat_input() .description("waow") - .guild_id(ctx.data.guild_id) + .guild_id(1099115731301449758) .respond_with(|ctx| { Box::pin(async move { ctx.respond(ResponseBuilder::channel_message_with_source().content("ok!"))