From 19e35d3b9e58555f27b170920ce5ce9c553ce5b2 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sat, 15 Feb 2020 15:15:06 -0500 Subject: [PATCH] Generalize lhs scalar ops to more combos of types This doesn't have a noticeable impact on the results of the `scalar_add_2` and `scalar_add_strided_2` benchmarks. --- src/impl_ops.rs | 134 ++++++++++++++++++++++-------------------------- 1 file changed, 60 insertions(+), 74 deletions(-) diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 5e57d610e..8999999f1 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -166,56 +166,42 @@ impl<'a, A, S, D, B, C> $trt for &'a ArrayBase ); ); -// Pick the expression $a for commutative and $b for ordered binop -macro_rules! if_commutative { - (Commute { $a:expr } or { $b:expr }) => { - $a - }; - (Ordered { $a:expr } or { $b:expr }) => { - $b - }; -} - macro_rules! impl_scalar_lhs_op { - // $commutative flag. Reuse the self + scalar impl if we can. - // We can do this safely since these are the primitive numeric types - ($scalar:ty, $commutative:ident, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => ( -// these have no doc -- they are not visible in rustdoc -// Perform elementwise -// between the scalar `self` and array `rhs`, -// and return the result (based on `self`). -impl $trt> for $scalar - where S: DataOwned + DataMut, - D: Dimension, + ($scalar:ty, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => ( +/// Perform elementwise +#[doc=$doc] +/// between the scalar `self` and array `rhs`, +/// and return the result (based on `self`). +impl $trt> for $scalar +where + $scalar: Clone + $trt, + A: Clone, + S: DataOwned + DataMut, + D: Dimension, { type Output = ArrayBase; - fn $mth(self, rhs: ArrayBase) -> ArrayBase { - if_commutative!($commutative { - rhs.$mth(self) - } or {{ - let mut rhs = rhs; - rhs.unordered_foreach_mut(move |elt| { - *elt = self $operator *elt; - }); - rhs - }}) + fn $mth(self, mut rhs: ArrayBase) -> ArrayBase { + rhs.unordered_foreach_mut(move |elt| { + *elt = self.clone() $operator elt.clone(); + }); + rhs } } -// Perform elementwise -// between the scalar `self` and array `rhs`, -// and return the result as a new `Array`. -impl<'a, S, D> $trt<&'a ArrayBase> for $scalar - where S: Data, - D: Dimension, +/// Perform elementwise +#[doc=$doc] +/// between the scalar `self` and array `rhs`, +/// and return the result as a new `Array`. +impl<'a, A, S, D, B> $trt<&'a ArrayBase> for $scalar +where + $scalar: Clone + $trt, + A: Clone, + S: Data, + D: Dimension, { - type Output = Array<$scalar, D>; - fn $mth(self, rhs: &ArrayBase) -> Array<$scalar, D> { - if_commutative!($commutative { - rhs.$mth(self) - } or { - self.$mth(rhs.to_owned()) - }) + type Output = Array; + fn $mth(self, rhs: &ArrayBase) -> Array { + rhs.map(move |elt| self.clone() $operator elt.clone()) } } ); @@ -241,16 +227,16 @@ mod arithmetic_ops { macro_rules! all_scalar_ops { ($int_scalar:ty) => ( - impl_scalar_lhs_op!($int_scalar, Commute, +, Add, add, "addition"); - impl_scalar_lhs_op!($int_scalar, Ordered, -, Sub, sub, "subtraction"); - impl_scalar_lhs_op!($int_scalar, Commute, *, Mul, mul, "multiplication"); - impl_scalar_lhs_op!($int_scalar, Ordered, /, Div, div, "division"); - impl_scalar_lhs_op!($int_scalar, Ordered, %, Rem, rem, "remainder"); - impl_scalar_lhs_op!($int_scalar, Commute, &, BitAnd, bitand, "bit and"); - impl_scalar_lhs_op!($int_scalar, Commute, |, BitOr, bitor, "bit or"); - impl_scalar_lhs_op!($int_scalar, Commute, ^, BitXor, bitxor, "bit xor"); - impl_scalar_lhs_op!($int_scalar, Ordered, <<, Shl, shl, "left shift"); - impl_scalar_lhs_op!($int_scalar, Ordered, >>, Shr, shr, "right shift"); + impl_scalar_lhs_op!($int_scalar, +, Add, add, "addition"); + impl_scalar_lhs_op!($int_scalar, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!($int_scalar, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!($int_scalar, /, Div, div, "division"); + impl_scalar_lhs_op!($int_scalar, %, Rem, rem, "remainder"); + impl_scalar_lhs_op!($int_scalar, &, BitAnd, bitand, "bit and"); + impl_scalar_lhs_op!($int_scalar, |, BitOr, bitor, "bit or"); + impl_scalar_lhs_op!($int_scalar, ^, BitXor, bitxor, "bit xor"); + impl_scalar_lhs_op!($int_scalar, <<, Shl, shl, "left shift"); + impl_scalar_lhs_op!($int_scalar, >>, Shr, shr, "right shift"); ); } all_scalar_ops!(i8); @@ -264,31 +250,31 @@ mod arithmetic_ops { all_scalar_ops!(i128); all_scalar_ops!(u128); - impl_scalar_lhs_op!(bool, Commute, &, BitAnd, bitand, "bit and"); - impl_scalar_lhs_op!(bool, Commute, |, BitOr, bitor, "bit or"); - impl_scalar_lhs_op!(bool, Commute, ^, BitXor, bitxor, "bit xor"); + impl_scalar_lhs_op!(bool, &, BitAnd, bitand, "bit and"); + impl_scalar_lhs_op!(bool, |, BitOr, bitor, "bit or"); + impl_scalar_lhs_op!(bool, ^, BitXor, bitxor, "bit xor"); - impl_scalar_lhs_op!(f32, Commute, +, Add, add, "addition"); - impl_scalar_lhs_op!(f32, Ordered, -, Sub, sub, "subtraction"); - impl_scalar_lhs_op!(f32, Commute, *, Mul, mul, "multiplication"); - impl_scalar_lhs_op!(f32, Ordered, /, Div, div, "division"); - impl_scalar_lhs_op!(f32, Ordered, %, Rem, rem, "remainder"); + impl_scalar_lhs_op!(f32, +, Add, add, "addition"); + impl_scalar_lhs_op!(f32, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!(f32, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!(f32, /, Div, div, "division"); + impl_scalar_lhs_op!(f32, %, Rem, rem, "remainder"); - impl_scalar_lhs_op!(f64, Commute, +, Add, add, "addition"); - impl_scalar_lhs_op!(f64, Ordered, -, Sub, sub, "subtraction"); - impl_scalar_lhs_op!(f64, Commute, *, Mul, mul, "multiplication"); - impl_scalar_lhs_op!(f64, Ordered, /, Div, div, "division"); - impl_scalar_lhs_op!(f64, Ordered, %, Rem, rem, "remainder"); + impl_scalar_lhs_op!(f64, +, Add, add, "addition"); + impl_scalar_lhs_op!(f64, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!(f64, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!(f64, /, Div, div, "division"); + impl_scalar_lhs_op!(f64, %, Rem, rem, "remainder"); - impl_scalar_lhs_op!(Complex, Commute, +, Add, add, "addition"); - impl_scalar_lhs_op!(Complex, Ordered, -, Sub, sub, "subtraction"); - impl_scalar_lhs_op!(Complex, Commute, *, Mul, mul, "multiplication"); - impl_scalar_lhs_op!(Complex, Ordered, /, Div, div, "division"); + impl_scalar_lhs_op!(Complex, +, Add, add, "addition"); + impl_scalar_lhs_op!(Complex, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!(Complex, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!(Complex, /, Div, div, "division"); - impl_scalar_lhs_op!(Complex, Commute, +, Add, add, "addition"); - impl_scalar_lhs_op!(Complex, Ordered, -, Sub, sub, "subtraction"); - impl_scalar_lhs_op!(Complex, Commute, *, Mul, mul, "multiplication"); - impl_scalar_lhs_op!(Complex, Ordered, /, Div, div, "division"); + impl_scalar_lhs_op!(Complex, +, Add, add, "addition"); + impl_scalar_lhs_op!(Complex, -, Sub, sub, "subtraction"); + impl_scalar_lhs_op!(Complex, *, Mul, mul, "multiplication"); + impl_scalar_lhs_op!(Complex, /, Div, div, "division"); impl Neg for ArrayBase where