Skip to content

Commit

Permalink
Support non-literal discriminant values
Browse files Browse the repository at this point in the history
While preserving favouring literal values where possible.
  • Loading branch information
illicitonion committed Jan 14, 2023
1 parent 7acc582 commit 927f00a
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 56 deletions.
13 changes: 13 additions & 0 deletions num_enum/tests/try_build/compile_fail/alternative_exprs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
const THREE: u8 = 3;

#[derive(num_enum::TryFromPrimitive)]
#[repr(i8)]
enum Numbers {
Zero = 0,
#[num_enum(alternatives = [-1, 2, THREE])]
One = 1,
}

fn main() {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: Only literals are allowed as num_enum alternate values
--> tests/try_build/compile_fail/alternative_exprs.rs:7:39
|
7 | #[num_enum(alternatives = [-1, 2, THREE])]
| ^^^^^
98 changes: 86 additions & 12 deletions num_enum/tests/try_from_primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,79 @@ fn wrong_order() {
assert_eq!(four, Ok(Enum::Four));
}

#[cfg(feature = "complex-expression")]
#[test]
fn negative_values() {
#[derive(Debug, Eq, PartialEq, TryFromPrimitive)]
#[repr(i8)]
enum Enum {
MinusTwo = -2,
MinusOne = -1,
Zero = 0,
One = 1,
Two = 2,
}

let minus_two: Result<Enum, _> = (-2i8).try_into();
assert_eq!(minus_two, Ok(Enum::MinusTwo));

let minus_one: Result<Enum, _> = (-1i8).try_into();
assert_eq!(minus_one, Ok(Enum::MinusOne));

let zero: Result<Enum, _> = 0i8.try_into();
assert_eq!(zero, Ok(Enum::Zero));

let one: Result<Enum, _> = 1i8.try_into();
assert_eq!(one, Ok(Enum::One));

let two: Result<Enum, _> = 2i8.try_into();
assert_eq!(two, Ok(Enum::Two));
}

#[test]
fn discriminant_expressions() {
const ONE: u8 = 1;

#[derive(Debug, Eq, PartialEq, TryFromPrimitive)]
#[repr(u8)]
enum Enum {
Zero,
One = ONE,
Two,
Four = 4u8,
Five,
Six = ONE + ONE + 2u8 + 2,
}

let zero: Result<Enum, _> = 0u8.try_into();
assert_eq!(zero, Ok(Enum::Zero));

let one: Result<Enum, _> = 1u8.try_into();
assert_eq!(one, Ok(Enum::One));

let two: Result<Enum, _> = 2u8.try_into();
assert_eq!(two, Ok(Enum::Two));

let three: Result<Enum, _> = 3u8.try_into();
assert_eq!(
three.unwrap_err().to_string(),
"No discriminant in enum `Enum` matches the value `3`",
);

let four: Result<Enum, _> = 4u8.try_into();
assert_eq!(four, Ok(Enum::Four));

let five: Result<Enum, _> = 5u8.try_into();
assert_eq!(five, Ok(Enum::Five));

let six: Result<Enum, _> = 6u8.try_into();
assert_eq!(six, Ok(Enum::Six));
}

#[cfg(feature = "complex-expressions")]
mod complex {
use num_enum::TryFromPrimitive;
use std::convert::TryInto;

const ONE: u8 = 1;

#[derive(Debug, Eq, PartialEq, TryFromPrimitive)]
Expand Down Expand Up @@ -261,26 +332,29 @@ fn error_variant_is_allowed() {
#[test]
fn alternative_values() {
#[derive(Debug, Eq, PartialEq, TryFromPrimitive)]
#[repr(u8)]
#[repr(i8)]
enum Enum {
Zero = 0,
#[num_enum(alternatives = [2, 3])]
OneTwoOrThree = 1,
#[num_enum(alternatives = [-1, 2, 3])]
OneTwoThreeOrMinusOne = 1,
}

let zero: Result<Enum, _> = 0u8.try_into();
let minus_one: Result<Enum, _> = (-1i8).try_into();
assert_eq!(minus_one, Ok(Enum::OneTwoThreeOrMinusOne));

let zero: Result<Enum, _> = 0i8.try_into();
assert_eq!(zero, Ok(Enum::Zero));

let one: Result<Enum, _> = 1u8.try_into();
assert_eq!(one, Ok(Enum::OneTwoOrThree));
let one: Result<Enum, _> = 1i8.try_into();
assert_eq!(one, Ok(Enum::OneTwoThreeOrMinusOne));

let two: Result<Enum, _> = 2u8.try_into();
assert_eq!(two, Ok(Enum::OneTwoOrThree));
let two: Result<Enum, _> = 2i8.try_into();
assert_eq!(two, Ok(Enum::OneTwoThreeOrMinusOne));

let three: Result<Enum, _> = 3u8.try_into();
assert_eq!(three, Ok(Enum::OneTwoOrThree));
let three: Result<Enum, _> = 3i8.try_into();
assert_eq!(three, Ok(Enum::OneTwoThreeOrMinusOne));

let four: Result<Enum, _> = 4u8.try_into();
let four: Result<Enum, _> = 4i8.try_into();
assert_eq!(
four.unwrap_err().to_string(),
"No discriminant in enum `Enum` matches the value `4`"
Expand Down
148 changes: 104 additions & 44 deletions num_enum_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use syn::{
parse::{Parse, ParseStream},
parse_macro_input, parse_quote,
spanned::Spanned,
Attribute, Data, DeriveInput, Error, Expr, Fields, Ident, Lit, LitInt, LitStr, Meta, Result,
Attribute, Data, DeriveInput, Error, Expr, ExprUnary, Fields, Ident, Lit, LitInt, LitStr, Meta,
Result, UnOp,
};

macro_rules! die {
Expand All @@ -32,14 +33,32 @@ fn literal(i: i128) -> Expr {
}
}

fn expr_to_int(val_exp: &Expr) -> Result<i128> {
Ok(match val_exp {
Expr::Lit(ref val_exp_lit) => match val_exp_lit.lit {
Lit::Int(ref lit_int) => lit_int.base10_parse()?,
_ => die!(val_exp => "Expected integer"),
},
_ => die!(val_exp => "Expected literal"),
})
enum DiscriminantValue {
Literal(i128),
Expr(Expr),
}

fn parse_discriminant(val_exp: &Expr) -> Result<DiscriminantValue> {
match val_exp {
Expr::Lit(ref val_exp_lit) => {
if let Lit::Int(ref lit_int) = val_exp_lit.lit {
return Ok(DiscriminantValue::Literal(lit_int.base10_parse()?));
}
}
Expr::Unary(ExprUnary {
op: UnOp::Neg(..),
expr,
..
}) => {
if let Expr::Lit(ref val_exp_lit) = **expr {
if let Lit::Int(ref lit_int) = val_exp_lit.lit {
return Ok(DiscriminantValue::Literal(-lit_int.base10_parse()?));
}
}
}
_ => {}
}
Ok(DiscriminantValue::Expr(val_exp.clone()))
}

mod kw {
Expand Down Expand Up @@ -307,7 +326,7 @@ impl Parse for EnumInfo {
let mut has_catch_all_variant: bool = false;

// Vec to keep track of the used discriminants and alt values.
let mut val_set: BTreeSet<i128> = BTreeSet::new();
let mut discriminant_int_val_set = BTreeSet::new();

let mut next_discriminant = literal(0);
for variant in data.variants.into_iter() {
Expand All @@ -319,7 +338,7 @@ impl Parse for EnumInfo {
};

let mut attr_spans: AttributeSpans = Default::default();
let mut alternative_values: Vec<Expr> = vec![];
let mut raw_alternative_values: Vec<Expr> = vec![];
// Keep the attribute around for better error reporting.
let mut alt_attr_ref: Vec<&Attribute> = vec![];

Expand Down Expand Up @@ -398,7 +417,7 @@ impl Parse for EnumInfo {
}
NumEnumVariantAttributeItem::Alternatives(alternatives) => {
attr_spans.alternatives.push(alternatives.span());
alternative_values.extend(alternatives.expressions);
raw_alternative_values.extend(alternatives.expressions);
alt_attr_ref.push(attribute);
}
}
Expand All @@ -422,75 +441,116 @@ impl Parse for EnumInfo {
}
}

let canonical_value = discriminant;
let canonical_value_int = expr_to_int(&canonical_value)?;
let discriminant_value = parse_discriminant(&discriminant)?;

// Check for collision.
if val_set.contains(&canonical_value_int) {
die!(ident => format!("The discriminant '{}' collides with a value attributed to a previous variant", canonical_value_int))
// We can't do const evaluation, or even compare arbitrary Exprs,
// so unfortunately we can't check for duplicates.
// That's not the end of the world, just we'll end up with compile errors for
// matches with duplicate branches in generated code instead of nice friendly error messages.
if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
if discriminant_int_val_set.contains(&canonical_value_int) {
die!(ident => format!("The discriminant '{}' collides with a value attributed to a previous variant", canonical_value_int))
}
}

// Deal with the alternative values.
let alt_val = alternative_values
let alternate_values = raw_alternative_values
.iter()
.map(expr_to_int)
.map(parse_discriminant)
.collect::<Result<Vec<_>>>()?;

debug_assert_eq!(alt_val.len(), alternative_values.len());

if !alt_val.is_empty() {
let mut alt_val_sorted = alt_val.clone();
alt_val_sorted.sort_unstable();
let alt_val_sorted = alt_val_sorted;

// check if the current discriminant is not in the alternative values.
if let Some(i) = alt_val.iter().position(|&x| x == canonical_value_int) {
die!(&alternative_values[i] => format!("'{}' in the alternative values is already attributed as the discriminant of this variant", canonical_value_int));
debug_assert_eq!(alternate_values.len(), raw_alternative_values.len());

if !alternate_values.is_empty() {
let mut sorted_alternate_int_values = alternate_values
.into_iter()
.map(|v| {
match v {
DiscriminantValue::Literal(value) => Ok(value),
DiscriminantValue::Expr(expr) => {
// We can't do uniqueness checking on non-literals, so we don't allow them as alternate values.
// We could probably allow them, but there doesn't seem to be much of a use-case,
// and it's easier to give good error messages about duplicate values this way,
// rather than rustc errors on conflicting match branches.
die!(expr => format!("Only literals are allowed as num_enum alternate values"))
},
}
})
.collect::<Result<Vec<i128>>>()?;
sorted_alternate_int_values.sort_unstable();
let sorted_alternate_int_values = sorted_alternate_int_values;

// Check if the current discriminant is not in the alternative values.
if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
if let Ok(index) =
sorted_alternate_int_values.binary_search(&canonical_value_int)
{
die!(&raw_alternative_values[index] => format!("'{}' in the alternative values is already attributed as the discriminant of this variant", canonical_value_int));
}
}

// Search for duplicates, the vec is sorted. Warn about them.
if (1..alt_val_sorted.len()).any(|i| alt_val_sorted[i] == alt_val_sorted[i - 1])
{
if (1..sorted_alternate_int_values.len()).any(|i| {
sorted_alternate_int_values[i] == sorted_alternate_int_values[i - 1]
}) {
let attr = *alt_attr_ref.last().unwrap();
die!(attr => "There is duplication in the alternative values");
}
// Search if those alt_val where already attributed.
// (The val_set is BTreeSet, and iter().next_back() is the is the maximum in the set.)
if let Some(last_upper_val) = val_set.iter().next_back() {
if alt_val_sorted.first().unwrap() <= last_upper_val {
for (i, val) in alt_val_sorted.iter().enumerate() {
if val_set.contains(val) {
die!(&alternative_values[i] => format!("'{}' in the alternative values is already attributed to a previous variant", val));
// Search if those discriminant_int_val_set where already attributed.
// (discriminant_int_val_set is BTreeSet, and iter().next_back() is the is the maximum in the set.)
if let Some(last_upper_val) = discriminant_int_val_set.iter().next_back() {
if sorted_alternate_int_values.first().unwrap() <= last_upper_val {
for (i, val) in sorted_alternate_int_values.iter().enumerate() {
if discriminant_int_val_set.contains(val) {
die!(&raw_alternative_values[i] => format!("'{}' in the alternative values is already attributed to a previous variant", val));
}
}
}
}

// Reconstruct the alternative_values vec of Expr but sorted.
alternative_values = alt_val_sorted
raw_alternative_values = sorted_alternate_int_values
.iter()
.map(|val| literal(val.to_owned()))
.collect();

// Add the alternative values to the the set to keep track.
val_set.extend(alt_val_sorted);
discriminant_int_val_set.extend(sorted_alternate_int_values);
}

// Add the current discriminant to the the set to keep track.
let newly_inserted = val_set.insert(canonical_value_int);
debug_assert!(newly_inserted);
if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
discriminant_int_val_set.insert(canonical_value_int);
}

variants.push(VariantInfo {
ident,
attr_spans,
is_default,
is_catch_all,
canonical_value,
alternative_values,
canonical_value: discriminant,
alternative_values: raw_alternative_values,
});

// Get the next value for the discriminant.
next_discriminant = literal(canonical_value_int + 1);
next_discriminant = match discriminant_value {
DiscriminantValue::Literal(int_value) => {
if int_value >= -1 {
literal(int_value + 1)
} else {
let value = literal((int_value + 1).abs());
parse_quote! {
-#value
}
}
}
DiscriminantValue::Expr(expr) => {
parse_quote! {
#repr::wrapping_add(#expr, 1)
}
}
}
}

EnumInfo {
Expand Down

0 comments on commit 927f00a

Please sign in to comment.