diff --git a/macros/src/types/enum.rs b/macros/src/types/enum.rs index 4c73b756a..0cfcfd9b2 100644 --- a/macros/src/types/enum.rs +++ b/macros/src/types/enum.rs @@ -50,7 +50,11 @@ pub(crate) fn r#enum_def(s: &ItemEnum) -> syn::Result { Ok(DerivedTS { inline: quote!([#(#formatted_variants),*].join(" | ")), decl: quote!(format!("type {}{} = {};", #name, #generic_args, Self::inline())), - inline_flattened: None, + inline_flattened: Some( + quote!( + format!("({})", [#(#formatted_variants),*].join(" | ")) + ) + ), dependencies, name, export: enum_attr.export, @@ -130,7 +134,15 @@ fn format_variant( "{{ \"{}\": \"{}\", {} }}", #tag, #name, - #inline_flattened + // At this point inline_flattened looks like + // { /* ...data */ } + // + // To be flattened, an internally tagged enum must not be + // surrounded by braces, otherwise each variant will look like + // { "tag": "name", { /* ...data */ } } + // when we want it to look like + // { "tag": "name", /* ...data */ } + #inline_flattened.trim_matches(&['{', '}', ' ']) ) }, None => match &variant.fields { diff --git a/macros/src/types/named.rs b/macros/src/types/named.rs index 1a3249b42..3ec97d014 100644 --- a/macros/src/types/named.rs +++ b/macros/src/types/named.rs @@ -17,6 +17,7 @@ pub(crate) fn named( generics: &Generics, ) -> Result { let mut formatted_fields = Vec::new(); + let mut flattened_fields = Vec::new(); let mut dependencies = Dependencies::default(); if let Some(tag) = &attr.tag { let formatted = format!("{}: \"{}\",", tag, name); @@ -28,6 +29,7 @@ pub(crate) fn named( for field in &fields.named { format_field( &mut formatted_fields, + &mut flattened_fields, &mut dependencies, field, &attr.rename_all, @@ -36,17 +38,21 @@ pub(crate) fn named( } let fields = quote!(<[String]>::join(&[#(#formatted_fields),*], " ")); + let flattened = quote!(<[String]>::join(&[#(#flattened_fields),*], " & ")); let generic_args = format_generics(&mut dependencies, generics); + let inline = match (formatted_fields.len(), flattened_fields.len()) { + (0, 0) => quote!("{ }".to_owned()), + (_, 0) => quote!(format!("{{ {} }}", #fields)), + (0, 1) => quote!(#flattened.trim_matches(|c| c == '(' || c == ')').to_owned()), + (0, _) => quote!(#flattened), + (_, _) => quote!(format!("{{ {} }} & {}", #fields, #flattened)), + }; + Ok(DerivedTS { - inline: quote! { - format!( - "{{ {} }}", - #fields, - ) - }, + inline: quote!(#inline.replace(" } & { ", " ")), decl: quote!(format!("type {}{} = {}", #name, #generic_args, Self::inline())), - inline_flattened: Some(fields), + inline_flattened: Some(quote!(format!("{{ {} }}", #fields))), name: name.to_owned(), dependencies, export: attr.export, @@ -55,8 +61,18 @@ pub(crate) fn named( } // build an expresion which expands to a string, representing a single field of a struct. +// +// formatted_fields will contain all the fields that do not contain the flatten +// attribute, in the format +// key: type, +// +// flattened_fields will contain all the fields that contain the flatten attribute +// in their respective formats, which for a named struct is the same as formatted_fields, +// but for enums is +// ({ /* variant data */ } | { /* variant data */ }) fn format_field( formatted_fields: &mut Vec, + flattened_fields: &mut Vec, dependencies: &mut Dependencies, field: &Field, rename_all: &Option, @@ -88,7 +104,7 @@ fn format_field( _ => {} } - formatted_fields.push(quote!(<#ty as ts_rs::TS>::inline_flattened())); + flattened_fields.push(quote!(<#ty as ts_rs::TS>::inline_flattened())); dependencies.append_from(ty); return Ok(()); } diff --git a/ts-rs/tests/enum_flattening.rs b/ts-rs/tests/enum_flattening.rs new file mode 100644 index 000000000..c26ef059a --- /dev/null +++ b/ts-rs/tests/enum_flattening.rs @@ -0,0 +1,108 @@ +#[cfg(feature = "serde-compat")] +use serde::Serialize; +use ts_rs::TS; + +#[test] +fn externally_tagged() { + #[allow(dead_code)] + #[cfg_attr(feature = "serde-compat", derive(Serialize, TS))] + #[cfg_attr(not(feature = "serde-compat"), derive(TS))] + struct Foo { + qux: i32, + #[cfg_attr(feature = "serde-compat", serde(flatten))] + #[cfg_attr(not(feature = "serde-compat"), ts(flatten))] + baz: Bar, + biz: Option, + } + + #[cfg_attr(feature = "serde-compat", derive(Serialize, TS))] + #[cfg_attr(not(feature = "serde-compat"), derive(TS))] + #[allow(dead_code)] + enum Bar { + Baz { a: i32, a2: String }, + Biz { b: bool }, + Buz { c: String, d: Option }, + } + + assert_eq!( + Foo::inline(), + r#"{ qux: number, biz: string | null, } & ({ "Baz": { a: number, a2: string, } } | { "Biz": { b: boolean, } } | { "Buz": { c: string, d: number | null, } })"# + ) +} + +#[test] +#[cfg(feature = "serde-compat")] +fn adjacently_tagged() { + #[derive(Serialize, TS)] + struct Foo { + one: i32, + #[serde(flatten)] + baz: Bar, + qux: Option, + } + + #[derive(Serialize, TS)] + #[allow(dead_code)] + #[serde(tag = "type", content = "stuff")] + enum Bar { + Baz { a: i32, a2: String }, + Biz { b: bool }, + Buz { c: String, d: Option }, + } + + assert_eq!( + Foo::inline(), + r#"{ one: number, qux: string | null, } & ({ "type": "Baz", "stuff": { a: number, a2: string, } } | { "type": "Biz", "stuff": { b: boolean, } } | { "type": "Buz", "stuff": { c: string, d: number | null, } })"# + ) +} + +#[test] +#[cfg(feature = "serde-compat")] +fn internally_tagged() { + #[derive(Serialize, TS)] + struct Foo { + qux: Option, + + #[serde(flatten)] + baz: Bar, + } + + #[derive(Serialize, TS)] + #[allow(dead_code)] + #[serde(tag = "type")] + enum Bar { + Baz { a: i32, a2: String }, + Biz { b: bool }, + Buz { c: String, d: Option }, + } + + assert_eq!( + Foo::inline(), + r#"{ qux: string | null, } & ({ "type": "Baz", a: number, a2: string, } | { "type": "Biz", b: boolean, } | { "type": "Buz", c: string, d: number | null, })"# + ) +} + +#[test] +#[cfg(feature = "serde-compat")] +fn untagged() { + #[derive(Serialize, TS)] + struct Foo { + #[serde(flatten)] + baz: Bar, + } + + #[derive(Serialize, TS)] + #[allow(dead_code)] + #[serde(untagged)] + enum Bar { + Baz { a: i32, a2: String }, + Biz { b: bool }, + Buz { c: String }, + } + + assert_eq!( + Foo::inline(), + r#"{ a: number, a2: string, } | { b: boolean, } | { c: string, }"# + ) +} + diff --git a/ts-rs/tests/flatten.rs b/ts-rs/tests/flatten.rs index 837c9b65b..56212498f 100644 --- a/ts-rs/tests/flatten.rs +++ b/ts-rs/tests/flatten.rs @@ -26,6 +26,6 @@ struct C { fn test_def() { assert_eq!( C::inline(), - "{ b: { a: number, b: number, c: number, }, d: number, }" + "{ b: { c: number, a: number, b: number, }, d: number, }" ); }