Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support non-literal discriminant values #96

Merged
merged 2 commits into from
Jan 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions num_enum/tests/try_build/compile_fail/alternative_exprs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
const THREE: u8 = 3;

#[derive(num_enum::TryFromPrimitive)]
#[repr(i8)]
enum Numbers {
Zero = 0,
#[num_enum(alternatives = [-1, 2, THREE])]
One = 1,
}

fn main() {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: Only literals are allowed as num_enum alternate values
--> tests/try_build/compile_fail/alternative_exprs.rs:7:39
|
7 | #[num_enum(alternatives = [-1, 2, THREE])]
| ^^^^^
98 changes: 86 additions & 12 deletions num_enum/tests/try_from_primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,79 @@ fn wrong_order() {
assert_eq!(four, Ok(Enum::Four));
}

#[cfg(feature = "complex-expression")]
#[test]
fn negative_values() {
#[derive(Debug, Eq, PartialEq, TryFromPrimitive)]
#[repr(i8)]
enum Enum {
MinusTwo = -2,
MinusOne = -1,
Zero = 0,
One = 1,
Two = 2,
}

let minus_two: Result<Enum, _> = (-2i8).try_into();
assert_eq!(minus_two, Ok(Enum::MinusTwo));

let minus_one: Result<Enum, _> = (-1i8).try_into();
assert_eq!(minus_one, Ok(Enum::MinusOne));

let zero: Result<Enum, _> = 0i8.try_into();
assert_eq!(zero, Ok(Enum::Zero));

let one: Result<Enum, _> = 1i8.try_into();
assert_eq!(one, Ok(Enum::One));

let two: Result<Enum, _> = 2i8.try_into();
assert_eq!(two, Ok(Enum::Two));
}

#[test]
fn discriminant_expressions() {
const ONE: u8 = 1;

#[derive(Debug, Eq, PartialEq, TryFromPrimitive)]
#[repr(u8)]
enum Enum {
Zero,
One = ONE,
Two,
Four = 4u8,
Five,
Six = ONE + ONE + 2u8 + 2,
}

let zero: Result<Enum, _> = 0u8.try_into();
assert_eq!(zero, Ok(Enum::Zero));

let one: Result<Enum, _> = 1u8.try_into();
assert_eq!(one, Ok(Enum::One));

let two: Result<Enum, _> = 2u8.try_into();
assert_eq!(two, Ok(Enum::Two));

let three: Result<Enum, _> = 3u8.try_into();
assert_eq!(
three.unwrap_err().to_string(),
"No discriminant in enum `Enum` matches the value `3`",
);

let four: Result<Enum, _> = 4u8.try_into();
assert_eq!(four, Ok(Enum::Four));

let five: Result<Enum, _> = 5u8.try_into();
assert_eq!(five, Ok(Enum::Five));

let six: Result<Enum, _> = 6u8.try_into();
assert_eq!(six, Ok(Enum::Six));
}

#[cfg(feature = "complex-expressions")]
mod complex {
use num_enum::TryFromPrimitive;
use std::convert::TryInto;

const ONE: u8 = 1;

#[derive(Debug, Eq, PartialEq, TryFromPrimitive)]
Expand Down Expand Up @@ -261,26 +332,29 @@ fn error_variant_is_allowed() {
#[test]
fn alternative_values() {
#[derive(Debug, Eq, PartialEq, TryFromPrimitive)]
#[repr(u8)]
#[repr(i8)]
enum Enum {
Zero = 0,
#[num_enum(alternatives = [2, 3])]
OneTwoOrThree = 1,
#[num_enum(alternatives = [-1, 2, 3])]
OneTwoThreeOrMinusOne = 1,
}

let zero: Result<Enum, _> = 0u8.try_into();
let minus_one: Result<Enum, _> = (-1i8).try_into();
assert_eq!(minus_one, Ok(Enum::OneTwoThreeOrMinusOne));

let zero: Result<Enum, _> = 0i8.try_into();
assert_eq!(zero, Ok(Enum::Zero));

let one: Result<Enum, _> = 1u8.try_into();
assert_eq!(one, Ok(Enum::OneTwoOrThree));
let one: Result<Enum, _> = 1i8.try_into();
assert_eq!(one, Ok(Enum::OneTwoThreeOrMinusOne));

let two: Result<Enum, _> = 2u8.try_into();
assert_eq!(two, Ok(Enum::OneTwoOrThree));
let two: Result<Enum, _> = 2i8.try_into();
assert_eq!(two, Ok(Enum::OneTwoThreeOrMinusOne));

let three: Result<Enum, _> = 3u8.try_into();
assert_eq!(three, Ok(Enum::OneTwoOrThree));
let three: Result<Enum, _> = 3i8.try_into();
assert_eq!(three, Ok(Enum::OneTwoThreeOrMinusOne));

let four: Result<Enum, _> = 4u8.try_into();
let four: Result<Enum, _> = 4i8.try_into();
assert_eq!(
four.unwrap_err().to_string(),
"No discriminant in enum `Enum` matches the value `4`"
Expand Down
148 changes: 104 additions & 44 deletions num_enum_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use syn::{
parse::{Parse, ParseStream},
parse_macro_input, parse_quote,
spanned::Spanned,
Attribute, Data, DeriveInput, Error, Expr, Fields, Ident, Lit, LitInt, LitStr, Meta, Result,
Attribute, Data, DeriveInput, Error, Expr, ExprUnary, Fields, Ident, Lit, LitInt, LitStr, Meta,
Result, UnOp,
};

macro_rules! die {
Expand All @@ -32,14 +33,32 @@ fn literal(i: i128) -> Expr {
}
}

fn expr_to_int(val_exp: &Expr) -> Result<i128> {
Ok(match val_exp {
Expr::Lit(ref val_exp_lit) => match val_exp_lit.lit {
Lit::Int(ref lit_int) => lit_int.base10_parse()?,
_ => die!(val_exp => "Expected integer"),
},
_ => die!(val_exp => "Expected literal"),
})
enum DiscriminantValue {
Literal(i128),
Expr(Expr),
}

fn parse_discriminant(val_exp: &Expr) -> Result<DiscriminantValue> {
match val_exp {
Expr::Lit(ref val_exp_lit) => {
if let Lit::Int(ref lit_int) = val_exp_lit.lit {
return Ok(DiscriminantValue::Literal(lit_int.base10_parse()?));
}
}
Expr::Unary(ExprUnary {
op: UnOp::Neg(..),
expr,
..
}) => {
if let Expr::Lit(ref val_exp_lit) = **expr {
if let Lit::Int(ref lit_int) = val_exp_lit.lit {
return Ok(DiscriminantValue::Literal(-lit_int.base10_parse()?));
}
}
}
_ => {}
}
Ok(DiscriminantValue::Expr(val_exp.clone()))
}

mod kw {
Expand Down Expand Up @@ -307,7 +326,7 @@ impl Parse for EnumInfo {
let mut has_catch_all_variant: bool = false;

// Vec to keep track of the used discriminants and alt values.
let mut val_set: BTreeSet<i128> = BTreeSet::new();
let mut discriminant_int_val_set = BTreeSet::new();

let mut next_discriminant = literal(0);
for variant in data.variants.into_iter() {
Expand All @@ -319,7 +338,7 @@ impl Parse for EnumInfo {
};

let mut attr_spans: AttributeSpans = Default::default();
let mut alternative_values: Vec<Expr> = vec![];
let mut raw_alternative_values: Vec<Expr> = vec![];
// Keep the attribute around for better error reporting.
let mut alt_attr_ref: Vec<&Attribute> = vec![];

Expand Down Expand Up @@ -398,7 +417,7 @@ impl Parse for EnumInfo {
}
NumEnumVariantAttributeItem::Alternatives(alternatives) => {
attr_spans.alternatives.push(alternatives.span());
alternative_values.extend(alternatives.expressions);
raw_alternative_values.extend(alternatives.expressions);
alt_attr_ref.push(attribute);
}
}
Expand All @@ -422,75 +441,116 @@ impl Parse for EnumInfo {
}
}

let canonical_value = discriminant;
let canonical_value_int = expr_to_int(&canonical_value)?;
let discriminant_value = parse_discriminant(&discriminant)?;

// Check for collision.
if val_set.contains(&canonical_value_int) {
die!(ident => format!("The discriminant '{}' collides with a value attributed to a previous variant", canonical_value_int))
// We can't do const evaluation, or even compare arbitrary Exprs,
// so unfortunately we can't check for duplicates.
// That's not the end of the world, just we'll end up with compile errors for
// matches with duplicate branches in generated code instead of nice friendly error messages.
if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
if discriminant_int_val_set.contains(&canonical_value_int) {
die!(ident => format!("The discriminant '{}' collides with a value attributed to a previous variant", canonical_value_int))
}
}

// Deal with the alternative values.
let alt_val = alternative_values
let alternate_values = raw_alternative_values
.iter()
.map(expr_to_int)
.map(parse_discriminant)
.collect::<Result<Vec<_>>>()?;

debug_assert_eq!(alt_val.len(), alternative_values.len());

if !alt_val.is_empty() {
let mut alt_val_sorted = alt_val.clone();
alt_val_sorted.sort_unstable();
let alt_val_sorted = alt_val_sorted;

// check if the current discriminant is not in the alternative values.
if let Some(i) = alt_val.iter().position(|&x| x == canonical_value_int) {
die!(&alternative_values[i] => format!("'{}' in the alternative values is already attributed as the discriminant of this variant", canonical_value_int));
debug_assert_eq!(alternate_values.len(), raw_alternative_values.len());

if !alternate_values.is_empty() {
let mut sorted_alternate_int_values = alternate_values
.into_iter()
.map(|v| {
match v {
DiscriminantValue::Literal(value) => Ok(value),
DiscriminantValue::Expr(expr) => {
// We can't do uniqueness checking on non-literals, so we don't allow them as alternate values.
// We could probably allow them, but there doesn't seem to be much of a use-case,
// and it's easier to give good error messages about duplicate values this way,
// rather than rustc errors on conflicting match branches.
die!(expr => format!("Only literals are allowed as num_enum alternate values"))
},
}
})
.collect::<Result<Vec<i128>>>()?;
sorted_alternate_int_values.sort_unstable();
let sorted_alternate_int_values = sorted_alternate_int_values;

// Check if the current discriminant is not in the alternative values.
if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
if let Ok(index) =
sorted_alternate_int_values.binary_search(&canonical_value_int)
{
die!(&raw_alternative_values[index] => format!("'{}' in the alternative values is already attributed as the discriminant of this variant", canonical_value_int));
}
}

// Search for duplicates, the vec is sorted. Warn about them.
if (1..alt_val_sorted.len()).any(|i| alt_val_sorted[i] == alt_val_sorted[i - 1])
{
if (1..sorted_alternate_int_values.len()).any(|i| {
sorted_alternate_int_values[i] == sorted_alternate_int_values[i - 1]
}) {
let attr = *alt_attr_ref.last().unwrap();
die!(attr => "There is duplication in the alternative values");
}
// Search if those alt_val where already attributed.
// (The val_set is BTreeSet, and iter().next_back() is the is the maximum in the set.)
if let Some(last_upper_val) = val_set.iter().next_back() {
if alt_val_sorted.first().unwrap() <= last_upper_val {
for (i, val) in alt_val_sorted.iter().enumerate() {
if val_set.contains(val) {
die!(&alternative_values[i] => format!("'{}' in the alternative values is already attributed to a previous variant", val));
// Search if those discriminant_int_val_set where already attributed.
// (discriminant_int_val_set is BTreeSet, and iter().next_back() is the is the maximum in the set.)
if let Some(last_upper_val) = discriminant_int_val_set.iter().next_back() {
if sorted_alternate_int_values.first().unwrap() <= last_upper_val {
for (i, val) in sorted_alternate_int_values.iter().enumerate() {
if discriminant_int_val_set.contains(val) {
die!(&raw_alternative_values[i] => format!("'{}' in the alternative values is already attributed to a previous variant", val));
}
}
}
}

// Reconstruct the alternative_values vec of Expr but sorted.
alternative_values = alt_val_sorted
raw_alternative_values = sorted_alternate_int_values
.iter()
.map(|val| literal(val.to_owned()))
.collect();

// Add the alternative values to the the set to keep track.
val_set.extend(alt_val_sorted);
discriminant_int_val_set.extend(sorted_alternate_int_values);
}

// Add the current discriminant to the the set to keep track.
let newly_inserted = val_set.insert(canonical_value_int);
debug_assert!(newly_inserted);
if let DiscriminantValue::Literal(canonical_value_int) = discriminant_value {
discriminant_int_val_set.insert(canonical_value_int);
}

variants.push(VariantInfo {
ident,
attr_spans,
is_default,
is_catch_all,
canonical_value,
alternative_values,
canonical_value: discriminant,
alternative_values: raw_alternative_values,
});

// Get the next value for the discriminant.
next_discriminant = literal(canonical_value_int + 1);
next_discriminant = match discriminant_value {
DiscriminantValue::Literal(int_value) => {
if int_value >= -1 {
literal(int_value + 1)
} else {
let value = literal((int_value + 1).abs());
parse_quote! {
-#value
}
}
}
DiscriminantValue::Expr(expr) => {
parse_quote! {
#repr::wrapping_add(#expr, 1)
}
}
}
}

EnumInfo {
Expand Down