From 65e2d646578c8ca08cc33f09d3a05177ee48c82b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Kijewski?= Date: Sat, 1 Feb 2025 21:52:27 +0100 Subject: [PATCH] derive: for `enum`s, let `Self` refer to the original type --- rinja_derive/src/generator/expr.rs | 11 ++++++++- rinja_derive/src/input.rs | 3 +++ rinja_derive/src/integration.rs | 10 +++++++- rinja_derive/src/lib.rs | 7 +++--- testing/tests/enum.rs | 39 ++++++++++++++++++++++++++++++ 5 files changed, 65 insertions(+), 5 deletions(-) diff --git a/rinja_derive/src/generator/expr.rs b/rinja_derive/src/generator/expr.rs index 9611bbfd..1a796bac 100644 --- a/rinja_derive/src/generator/expr.rs +++ b/rinja_derive/src/generator/expr.rs @@ -5,6 +5,7 @@ use parser::{ Attr, CharLit, CharPrefix, Expr, Filter, IntKind, Num, Span, StrLit, StrPrefix, Target, TyGenerics, WithSpan, }; +use quote::quote; use super::{ DisplayWrap, FILTER_SOURCE, Generator, LocalMeta, TargetIsize, TargetUsize, Writable, @@ -1057,9 +1058,17 @@ impl<'a> Generator<'a, '_> { } fn visit_path(&mut self, buf: &mut Buffer, path: &[&str]) -> DisplayWrap { - for (i, part) in path.iter().enumerate() { + for (i, part) in path.iter().copied().enumerate() { if i > 0 { buf.write("::"); + } else if let Some(enum_ast) = self.input.enum_ast { + if part == "Self" { + let this = &enum_ast.ident; + let (_, generics, _) = enum_ast.generics.split_for_impl(); + let generics = generics.as_turbofish(); + buf.write(quote!(#this #generics)); + continue; + } } buf.write(part); } diff --git a/rinja_derive/src/input.rs b/rinja_derive/src/input.rs index 0243133a..253840bc 100644 --- a/rinja_derive/src/input.rs +++ b/rinja_derive/src/input.rs @@ -19,6 +19,7 @@ use crate::{CompileError, FileInfo, MsgValidEscapers, OnceMap}; pub(crate) struct TemplateInput<'a> { pub(crate) ast: &'a syn::DeriveInput, + pub(crate) enum_ast: Option<&'a syn::DeriveInput>, pub(crate) config: &'a Config, pub(crate) syntax: &'a SyntaxAndCache<'a>, pub(crate) source: &'a Source, @@ -36,6 +37,7 @@ impl TemplateInput<'_> { /// `template()` attribute list fields. pub(crate) fn new<'n>( ast: &'n syn::DeriveInput, + enum_ast: Option<&'n syn::DeriveInput>, config: &'n Config, args: &'n TemplateArgs, ) -> Result, CompileError> { @@ -126,6 +128,7 @@ impl TemplateInput<'_> { Ok(TemplateInput { ast, + enum_ast, config, syntax, source, diff --git a/rinja_derive/src/integration.rs b/rinja_derive/src/integration.rs index fa38b2c0..182872b9 100644 --- a/rinja_derive/src/integration.rs +++ b/rinja_derive/src/integration.rs @@ -210,6 +210,12 @@ impl BufferFmt for Arguments<'_> { } } +impl BufferFmt for TokenStream { + fn append_to(&self, buf: &mut String) { + write!(buf, "{self}").unwrap(); + } +} + /// Similar to `write!(dest, "{src:?}")`, but only escapes the strictly needed characters, /// and without the surrounding `"…"` quotation marks. fn string_escape(dest: &mut String, src: &str) { @@ -271,7 +277,7 @@ pub(crate) fn build_template_enum( }; let var_ast = type_for_enum_variant(enum_ast, &generics, var); - buf.write(quote!(#var_ast).to_string()); + buf.write(quote!(#var_ast)); // not inherited: template, meta_docs, block, print if let Some(enum_args) = &mut enum_args { @@ -285,6 +291,7 @@ pub(crate) fn build_template_enum( let size_hint = biggest_size_hint.max(build_template_item( buf, &var_ast, + Some(enum_ast), &TemplateArgs::from_partial(&var_ast, Some(var_args))?, TmplKind::Variant, )?); @@ -302,6 +309,7 @@ pub(crate) fn build_template_enum( let size_hint = build_template_item( buf, enum_ast, + None, &TemplateArgs::from_partial(enum_ast, enum_args)?, TmplKind::Variant, )?; diff --git a/rinja_derive/src/lib.rs b/rinja_derive/src/lib.rs index fe1b2ece..0a0925fd 100644 --- a/rinja_derive/src/lib.rs +++ b/rinja_derive/src/lib.rs @@ -155,7 +155,7 @@ fn compile_error(msgs: impl Iterator, span: Span) -> TokenStream fn build_skeleton(buf: &mut Buffer, ast: &syn::DeriveInput) -> Result { let template_args = TemplateArgs::fallback(); let config = Config::new("", None, None, None)?; - let input = TemplateInput::new(ast, config, &template_args)?; + let input = TemplateInput::new(ast, None, config, &template_args)?; let mut contexts = HashMap::default(); let parsed = parser::Parsed::default(); contexts.insert(&input.path, Context::empty(&parsed)); @@ -177,7 +177,7 @@ pub(crate) fn build_template( let mut result = match AnyTemplateArgs::new(ast)? { AnyTemplateArgs::Struct(item) => { err_span = item.source.1.or(item.template_span); - build_template_item(buf, ast, &item, TmplKind::Struct) + build_template_item(buf, ast, None, &item, TmplKind::Struct) } AnyTemplateArgs::Enum { enum_args, @@ -203,6 +203,7 @@ pub(crate) fn build_template( fn build_template_item( buf: &mut Buffer, ast: &syn::DeriveInput, + enum_ast: Option<&syn::DeriveInput>, template_args: &TemplateArgs, tmpl_kind: TmplKind, ) -> Result { @@ -214,7 +215,7 @@ fn build_template_item( template_args.whitespace, template_args.config_span, )?; - let input = TemplateInput::new(ast, config, template_args)?; + let input = TemplateInput::new(ast, enum_ast, config, template_args)?; let mut templates = HashMap::default(); input.find_used_templates(&mut templates)?; diff --git a/testing/tests/enum.rs b/testing/tests/enum.rs index 0723af0e..af913b92 100644 --- a/testing/tests/enum.rs +++ b/testing/tests/enum.rs @@ -196,6 +196,45 @@ fn test_enum_blocks() { ); } +#[test] +fn associated_contants() { + #[derive(Template, Debug)] + #[template( + ext = "txt", + source = "\ + {% block a -%} {{ Self::CONST_A }} {{ self.0 }} {%- endblock %} + {% block b -%} {{ Self::CONST_B }} {{ self.0 }} {%- endblock %} + {% block c -%} {{ Self::func_c(self.0) }} {{ self.0 }} {%- endblock %} + " + )] + enum BlockEnum<'a, T: Display> { + #[template(block = "a")] + A(&'a str), + #[template(block = "b")] + B(T), + #[template(block = "c")] + C(&'a T), + } + + impl<'a, T: Display> BlockEnum<'a, T> { + const CONST_A: &'static str = ""; + const CONST_B: &'static str = ""; + + fn func_c(_: &'a T) -> &'static str { + "" + } + } + + let tmpl: BlockEnum<'_, X> = BlockEnum::A("hello"); + assert_eq!(tmpl.render().unwrap(), " hello"); + + let tmpl: BlockEnum<'_, X> = BlockEnum::B(X); + assert_eq!(tmpl.render().unwrap(), " X"); + + let tmpl: BlockEnum<'_, X> = BlockEnum::C(&X); + assert_eq!(tmpl.render().unwrap(), " X"); +} + #[derive(Debug)] struct X;