Skip to content

Commit

Permalink
begin preparations for slash command support
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacherr committed Jun 9, 2024
1 parent 0c566c7 commit 7de59be
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 42 deletions.
11 changes: 11 additions & 0 deletions assyst-core/src/assyst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ use assyst_common::metrics_handler::MetricsHandler;
use assyst_common::pipe::CACHE_PIPE_PATH;
use assyst_database::DatabaseHandler;
use std::sync::{Arc, Mutex};
use twilight_http::client::InteractionClient;
use twilight_http::Client as HttpClient;
use twilight_model::id::marker::ApplicationMarker;
use twilight_model::id::Id;

pub type ThreadSafeAssyst = Arc<Assyst>;

Expand All @@ -29,6 +32,8 @@ pub struct Assyst {
/// HTTP client for Discord. Handles all HTTP requests to Discord, storing stateful information
/// about current ratelimits.
pub http_client: Arc<HttpClient>,
/// Interaction client for handling Discord interations (i.e., slash commands).
pub application_id: Id<ApplicationMarker>,
/// List of the current premim users of Assyst.
pub premium_users: Arc<Mutex<Vec<Patron>>>,
/// Metrics handler for Prometheus, rate trackers etc.
Expand All @@ -52,11 +57,13 @@ impl Assyst {
let database_handler =
Arc::new(DatabaseHandler::new(CONFIG.database.to_url(), CONFIG.database.to_url_safe()).await?);
let premium_users = Arc::new(Mutex::new(vec![]));
let current_application = http_client.current_user_application().await?.model().await?;

Ok(Assyst {
persistent_cache_handler: PersistentCacheHandler::new(CACHE_PIPE_PATH),
database_handler: database_handler.clone(),
http_client: http_client.clone(),
application_id: current_application.id,
premium_users: premium_users.clone(),
metrics_handler: Arc::new(MetricsHandler::new(database_handler.clone())?),
reqwest_client: reqwest::Client::new(),
Expand All @@ -77,4 +84,8 @@ impl Assyst {
pub fn update_premium_user_list(&self, patrons: Vec<Patron>) {
*self.premium_users.lock().unwrap() = patrons;
}

pub fn interaction_client(&self) -> InteractionClient {
self.http_client.interaction(self.application_id)
}
}
25 changes: 25 additions & 0 deletions assyst-core/src/command/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,31 @@ macro_rules! define_commandgroup {
crate::command::group::find_subcommand(sub, Self::SUBCOMMANDS)
}

fn as_interaction_command(&self) -> twilight_model::application::command::Command {
let meta = self.metadata();

twilight_model::application::command::Command {
application_id: None,
default_member_permissions: None,
description: meta.description.to_owned(),
description_localizations: None,
// TODO: set based on if dms are allowed
// TODO: update to `contexts` once this is required
// (see https://discord.com/developers/docs/interactions/application-commands#create-global-application-command)
dm_permission: Some(false),
guild_id: None,
id: None,
kind: twilight_model::application::command::CommandType::ChatInput,
// todo: handle properly
name: "".to_owned(),
name_localizations: None,
nsfw: Some(meta.age_restricted),
// TODO: set options properly
options: vec![],
version: twilight_model::id::Id::new(1),
}
}

async fn execute(&self, ctxt: CommandCtxt<'_>) -> Result<(), crate::command::ExecutionError> {
#![allow(unreachable_code)]
match crate::command::group::execute_subcommand(ctxt.fork(), Self::SUBCOMMANDS).await {
Expand Down
10 changes: 8 additions & 2 deletions assyst-core/src/command/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl Display for Availability {
}
}

#[derive(Debug)]
pub struct CommandMetadata {
pub name: &'static str,
pub aliases: &'static [&'static str],
Expand All @@ -93,7 +94,7 @@ pub struct CommandMetadata {
pub age_restricted: bool,
}

#[derive(Clone, PartialEq, Eq, Hash)]
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum Category {
Fun,
Makesweet,
Expand Down Expand Up @@ -155,6 +156,9 @@ pub trait Command {
/// Tries to find a subcommand given a name, provided that `self` is a command group
fn subcommand(&self, s: &str) -> Option<TCommand>;

/// Creates an interaction command for subitting for Discord on startup
fn as_interaction_command(&self) -> twilight_model::application::command::Command;

/// Parses arguments and executes the command.
async fn execute(&self, ctxt: CommandCtxt<'_>) -> Result<(), ExecutionError>;
}
Expand Down Expand Up @@ -208,6 +212,8 @@ impl<'a> CommandCtxt<'a> {
let builder = builder.into();
match self.data.source {
Source::Gateway => gateway_reply::reply(self, builder).await,
// TODO: reply properly
Source::Interaction => todo!(),
}
}

Expand Down Expand Up @@ -323,7 +329,7 @@ pub async fn check_metadata(
.command_ratelimits
.insert(id, metadata.name, Instant::now());

if metadata.send_processing {
if metadata.send_processing && ctxt.data.source == Source::Gateway {
if let Err(e) = ctxt.reply("Processing...").await {
return Err(ExecutionError::Command(e));
}
Expand Down
28 changes: 28 additions & 0 deletions assyst-core/src/command/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::collections::HashMap;
use std::sync::OnceLock;

use tracing::info;
use twilight_model::application::command::Command as InteractionCommand;

use crate::assyst::ThreadSafeAssyst;
use crate::command::CommandMetadata;

use super::{misc, services, wsi, TCommand};
Expand Down Expand Up @@ -58,3 +60,29 @@ pub fn get_or_init_commands() -> &'static HashMap<&'static str, TCommand> {
pub fn find_command_by_name(name: &str) -> Option<TCommand> {
get_or_init_commands().get(name).copied()
}

pub async fn register_interaction_commands(assyst: ThreadSafeAssyst) -> anyhow::Result<Vec<InteractionCommand>> {
// todo: dont register aliases
let commands = get_or_init_commands()
.iter()
.map(|x| {
let mut interaction_command = x.1.as_interaction_command();
interaction_command.name = if interaction_command.name.is_empty() {
"".to_owned()
} else {
x.0.to_owned().to_owned()
};
interaction_command
})
.filter(|y| !y.name.is_empty())
.collect::<Vec<_>>();

let response = assyst
.interaction_client()
.set_global_commands(&commands)
.await?
.model()
.await?;

Ok(response)
}
3 changes: 2 additions & 1 deletion assyst-core/src/command/source.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#[derive(Clone)]
#[derive(Clone, Eq, PartialEq)]
pub enum Source {
Gateway,
Interaction,
}
67 changes: 38 additions & 29 deletions assyst-core/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use assyst_common::config::CONFIG;
use assyst_common::pipe::{Pipe, GATEWAY_PIPE_PATH};
use assyst_common::util::tracing_init;
use assyst_common::{err, ok_or_break};
use command::registry::register_interaction_commands;
use gateway_handler::handle_raw_event;
use gateway_handler::incoming_event::IncomingEvent;
use rest::web_media_download::get_web_download_api_urls;
Expand Down Expand Up @@ -136,41 +137,49 @@ async fn main() {
)
.await;

info!("Connecting to assyst-gateway pipe at {}", GATEWAY_PIPE_PATH);
loop {
let mut gateway_pipe = Pipe::poll_connect(GATEWAY_PIPE_PATH, None).await.unwrap();
info!("Connected to assyst-gateway pipe at {}", GATEWAY_PIPE_PATH);
info!("Registering interaction commands");
register_interaction_commands(assyst.clone()).await.unwrap();

spawn(async move {
info!("Connecting to assyst-gateway pipe at {}", GATEWAY_PIPE_PATH);
loop {
// break if read fails because it means broken pipe
// we need to re-poll the pipe to get a new connection
let event = ok_or_break!(gateway_pipe.read_string().await);
trace!("got event: {}", event);

let parsed_event = twilight_gateway::parse(
event,
EventTypeFlags::GUILD_CREATE
| EventTypeFlags::GUILD_DELETE
| EventTypeFlags::MESSAGE_CREATE
| EventTypeFlags::MESSAGE_DELETE
| EventTypeFlags::MESSAGE_UPDATE
| EventTypeFlags::READY,
)
.ok()
.flatten();

if let Some(parsed_event) = parsed_event {
let try_incoming_event: Result<IncomingEvent, _> = parsed_event.try_into();
if let Ok(incoming_event) = try_incoming_event {
assyst.metrics_handler.add_event();
let assyst_c = assyst.clone();
spawn(async move { handle_raw_event(assyst_c.clone(), incoming_event).await });
let mut gateway_pipe = Pipe::poll_connect(GATEWAY_PIPE_PATH, None).await.unwrap();
info!("Connected to assyst-gateway pipe at {}", GATEWAY_PIPE_PATH);

loop {
// break if read fails because it means broken pipe
// we need to re-poll the pipe to get a new connection
let event = ok_or_break!(gateway_pipe.read_string().await);
trace!("got event: {}", event);

let parsed_event = twilight_gateway::parse(
event,
EventTypeFlags::GUILD_CREATE
| EventTypeFlags::GUILD_DELETE
| EventTypeFlags::MESSAGE_CREATE
| EventTypeFlags::MESSAGE_DELETE
| EventTypeFlags::MESSAGE_UPDATE
| EventTypeFlags::READY,
)
.ok()
.flatten();

if let Some(parsed_event) = parsed_event {
let try_incoming_event: Result<IncomingEvent, _> = parsed_event.try_into();
if let Ok(incoming_event) = try_incoming_event {
assyst.metrics_handler.add_event();
let assyst_c = assyst.clone();
spawn(async move { handle_raw_event(assyst_c.clone(), incoming_event).await });
}
}
}

err!("Connection to assyst-gateway lost, attempting reconnection");
}
});

err!("Connection to assyst-gateway lost, attempting reconnection");
}
// todo: connect to slash client
loop {}
}

#[cfg(test)]
Expand Down
24 changes: 24 additions & 0 deletions assyst-proc-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,30 @@ pub fn command(attrs: TokenStream, func: TokenStream) -> TokenStream {
None
}

fn as_interaction_command(&self) -> twilight_model::application::command::Command {
let meta = self.metadata();

twilight_model::application::command::Command {
application_id: None,
default_member_permissions: None,
description: meta.description.to_owned(),
description_localizations: None,
// TODO: set based on if dms are allowed
// TODO: update to `contexts` once this is required
// (see https://discord.com/developers/docs/interactions/application-commands#create-global-application-command)
dm_permission: Some(false),
guild_id: None,
id: None,
kind: twilight_model::application::command::CommandType::ChatInput,
name: meta.name.to_owned(),
name_localizations: None,
nsfw: Some(meta.age_restricted),
// TODO: set options properly
options: vec![],
version: twilight_model::id::Id::new(1),
}
}

async fn execute(
&self,
mut ctxt:
Expand Down
1 change: 1 addition & 0 deletions assyst-slash-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ anyhow = "1.0.75"
toml = "0.8.8"
futures-util = { version = "0.3", default-features = false }
tracing = "0.1.37"
assyst-common = { path = "../assyst-common" }
18 changes: 8 additions & 10 deletions assyst-slash-client/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)]
use std::sync::Arc;

use assyst_common::config::CONFIG;
use command::Cmd;
use context::{Context, InnerContext};
use response::ResponseBuilder;
Expand All @@ -32,15 +33,12 @@ pub mod utils;

#[tokio::main]
async fn main() -> anyhow::Result<()> {
let cfg_file = std::fs::read_to_string("Config.toml").expect("missing Config.toml");
let cfg: Cfg = toml::from_str(&cfg_file).expect("error parsing TOML");

let client = Arc::new(Client::new(cfg.token.clone()));
let config = twilight_gateway::Config::new(cfg.token.clone(), Intents::empty());
let client = Arc::new(Client::new(CONFIG.authentication.discord_token.clone()));
let config = twilight_gateway::Config::new(CONFIG.authentication.discord_token.clone(), Intents::empty());
let application_id = client.current_user_application().await?.model().await?.id;
let interactions = client.interaction(application_id);

let ctx = Context::new(client.clone(), application_id, cfg);
let ctx = Context::new(client.clone(), application_id, ());

let cmds = vec![ping(&ctx)];

Expand All @@ -63,10 +61,10 @@ async fn main() -> anyhow::Result<()> {
.collect::<Vec<_>>()
.join(", ");
if let Some(g) = k {
interactions.set_guild_commands(g, c).await?;
interactions.set_guild_commands(g, &[]).await.unwrap();
println!("\x1b[1;32mRegister\x1b[0m Guild [\x1b[33m{g}\x1b[0m] with [{names}]");
} else {
interactions.set_global_commands(c).await?;
interactions.set_global_commands(c).await.unwrap();
println!("\x1b[1;32mRegister\x1b[0m global commands with [{names}]");
}
}
Expand Down Expand Up @@ -96,12 +94,12 @@ pub async fn runner(mut shard: Shard, tx: UnboundedSender<Event>) {
}

#[must_use]
pub fn ping(ctx: &Context<Cfg>) -> Cmd<Cfg> {
pub fn ping(ctx: &Context<()>) -> Cmd<()> {
Cmd::new(Box::new(ctx.clone()))
.name("test")
.chat_input()
.description("waow")
.guild_id(ctx.data.guild_id)
.guild_id(1099115731301449758)
.respond_with(|ctx| {
Box::pin(async move {
ctx.respond(ResponseBuilder::channel_message_with_source().content("ok!"))
Expand Down

0 comments on commit 7de59be

Please sign in to comment.