diff --git a/utils/maybe_async/README.md b/utils/maybe_async/README.md index 117c70362..0dd895d60 100644 --- a/utils/maybe_async/README.md +++ b/utils/maybe_async/README.md @@ -70,6 +70,58 @@ async fn world() -> String { } ``` +## maybe_async_trait + +The `maybe_async_trait` macro can be applied to traits, and it will conditionally add the `async` keyword to trait methods annotated with `#[maybe_async]`, depending on the async feature being enabled. It also applies `#[async_trait::async_trait(?Send)]` to the trait or impl block when the async feature is on. + +For example: + +```rust +// Adding `maybe_async_trait` to a trait definition +#[maybe_async_trait] +trait ExampleTrait { + #[maybe_async] + fn hello_world(&self); + + fn get_hello(&self) -> String; +} + +// Adding `maybe_async_trait` to an implementation of the trait +#[maybe_async_trait] +impl ExampleTrait for MyStruct { + #[maybe_async] + fn hello_world(&self) { + // ... + } + + fn get_hello(&self) -> String { + // ... + } +} +``` + +When `async` is set, it gets transformed into: + +```rust +#[async_trait::async_trait(?Send)] +trait ExampleTrait { + async fn hello_world(&self); + + fn get_hello(&self) -> String; +} + +#[async_trait::async_trait(?Send)] +impl ExampleTrait for MyStruct { + async fn hello_world(&self) { + // ... + } + + fn get_hello(&self) -> String { + // ... + } +} +``` + ## License This project is [MIT licensed](../../LICENSE). diff --git a/utils/maybe_async/src/lib.rs b/utils/maybe_async/src/lib.rs index c9eb3a056..7904bdeda 100644 --- a/utils/maybe_async/src/lib.rs +++ b/utils/maybe_async/src/lib.rs @@ -5,7 +5,7 @@ use proc_macro::TokenStream; use quote::quote; -use syn::{parse_macro_input, Expr, ItemFn, TraitItemFn}; +use syn::{parse_macro_input, Expr, ImplItem, ItemFn, ItemImpl, ItemTrait, TraitItem, TraitItemFn}; /// Parses a function (regular or trait) and conditionally adds the `async` keyword depending on /// the `async` feature flag being enabled. @@ -67,6 +67,129 @@ pub fn maybe_async(_attr: TokenStream, input: TokenStream) -> TokenStream { } } +/// Parses a trait or an `impl` block and conditionally adds the `async` keyword to methods that +/// are annotated with `#[maybe_async]`, depending on the `async` feature flag being enabled. +/// Additionally, if applied to a trait definition or impl block, it will add +/// `#[async_trait::async_trait(?Send)]` to the it. +/// +/// For example, given the following trait definition: +/// ```ignore +/// #[maybe_async_trait] +/// trait ExampleTrait { +/// #[maybe_async] +/// fn hello_world(&self); +/// +/// fn get_hello(&self) -> String; +/// } +/// ``` +/// +/// And the following implementation: +/// ```ignore +/// #[maybe_async_trait] +/// impl ExampleTrait for MyStruct { +/// #[maybe_async] +/// fn hello_world(&self) { +/// // ... +/// } +/// +/// fn get_hello(&self) -> String { +/// // ... +/// } +/// } +/// ``` +/// +/// When the `async` feature is enabled, this will be transformed into: +/// ```ignore +/// #[async_trait::async_trait(?Send)] +/// trait ExampleTrait { +/// async fn hello_world(&self); +/// +/// fn get_hello(&self) -> String; +/// } +/// +/// #[async_trait::async_trait(?Send)] +/// impl ExampleTrait for MyStruct { +/// async fn hello_world(&self) { +/// // ... +/// } +/// +/// fn get_hello(&self) -> String { +/// // ... +/// } +/// } +/// ``` +/// +/// When the `async` feature is disabled, the code remains unchanged, and neither the `async` +/// keyword nor the `#[async_trait::async_trait(?Send)]` attribute is applied. +#[proc_macro_attribute] +pub fn maybe_async_trait(_attr: TokenStream, input: TokenStream) -> TokenStream { + // Try parsing the input as a trait definition + if let Ok(trait_item) = syn::parse::(input.clone()) { + let output = if cfg!(feature = "async") { + let mut async_trait = trait_item; + + for item in &mut async_trait.items { + if let TraitItem::Fn(method) = item { + // Remove the #[maybe_async] and make method async + method.attrs.retain(|attr| { + if attr.path().is_ident("maybe_async") { + method.sig.asyncness = Some(syn::token::Async::default()); + false + } else { + true + } + }); + } + } + + quote! { + #[async_trait::async_trait(?Send)] + #async_trait + } + } else { + quote! { + #trait_item + } + }; + + return output.into(); + } + // Check if it is an Impl block + else if let Ok(mut impl_item) = syn::parse::(input.clone()) { + let output = if cfg!(feature = "async") { + for item in &mut impl_item.items { + if let ImplItem::Fn(method) = item { + // Remove #[maybe_async] and make method async + method.attrs.retain(|attr| { + if attr.path().is_ident("maybe_async") { + method.sig.asyncness = Some(syn::token::Async::default()); + false // Remove the attribute + } else { + true // Keep other attributes + } + }); + } + } + quote! { + #[async_trait::async_trait(?Send)] + #impl_item + } + } else { + quote! { + #[cfg(not(feature = "async"))] + #impl_item + } + }; + + return output.into(); + } + + // If input is neither a trait nor an impl block, emit a compile-time error + return quote! { + compile_error!("`maybe_async_trait` can only be applied to trait definitions and trait impl blocks"); + }.into(); +} + /// Parses an expression and conditionally adds the `.await` keyword at the end of it depending on /// the `async` feature flag being enabled. ///