Skip to content

Commit

Permalink
Generalize lhs scalar ops to more combos of types
Browse files Browse the repository at this point in the history
This doesn't have a noticeable impact on the results of the
`scalar_add_2` and `scalar_add_strided_2` benchmarks.
  • Loading branch information
jturner314 committed Feb 15, 2020
1 parent 073bc0e commit 19e35d3
Showing 1 changed file with 60 additions and 74 deletions.
134 changes: 60 additions & 74 deletions src/impl_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,56 +166,42 @@ impl<'a, A, S, D, B, C> $trt<B> for &'a ArrayBase<S, D>
);
);

// Pick the expression $a for commutative and $b for ordered binop
macro_rules! if_commutative {
(Commute { $a:expr } or { $b:expr }) => {
$a
};
(Ordered { $a:expr } or { $b:expr }) => {
$b
};
}

macro_rules! impl_scalar_lhs_op {
// $commutative flag. Reuse the self + scalar impl if we can.
// We can do this safely since these are the primitive numeric types
($scalar:ty, $commutative:ident, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => (
// these have no doc -- they are not visible in rustdoc
// Perform elementwise
// between the scalar `self` and array `rhs`,
// and return the result (based on `self`).
impl<S, D> $trt<ArrayBase<S, D>> for $scalar
where S: DataOwned<Elem=$scalar> + DataMut,
D: Dimension,
($scalar:ty, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => (
/// Perform elementwise
#[doc=$doc]
/// between the scalar `self` and array `rhs`,
/// and return the result (based on `self`).
impl<A, S, D> $trt<ArrayBase<S, D>> for $scalar
where
$scalar: Clone + $trt<A, Output=A>,
A: Clone,
S: DataOwned<Elem=A> + DataMut,
D: Dimension,
{
type Output = ArrayBase<S, D>;
fn $mth(self, rhs: ArrayBase<S, D>) -> ArrayBase<S, D> {
if_commutative!($commutative {
rhs.$mth(self)
} or {{
let mut rhs = rhs;
rhs.unordered_foreach_mut(move |elt| {
*elt = self $operator *elt;
});
rhs
}})
fn $mth(self, mut rhs: ArrayBase<S, D>) -> ArrayBase<S, D> {
rhs.unordered_foreach_mut(move |elt| {
*elt = self.clone() $operator elt.clone();
});
rhs
}
}

// Perform elementwise
// between the scalar `self` and array `rhs`,
// and return the result as a new `Array`.
impl<'a, S, D> $trt<&'a ArrayBase<S, D>> for $scalar
where S: Data<Elem=$scalar>,
D: Dimension,
/// Perform elementwise
#[doc=$doc]
/// between the scalar `self` and array `rhs`,
/// and return the result as a new `Array`.
impl<'a, A, S, D, B> $trt<&'a ArrayBase<S, D>> for $scalar
where
$scalar: Clone + $trt<A, Output=B>,
A: Clone,
S: Data<Elem=A>,
D: Dimension,
{
type Output = Array<$scalar, D>;
fn $mth(self, rhs: &ArrayBase<S, D>) -> Array<$scalar, D> {
if_commutative!($commutative {
rhs.$mth(self)
} or {
self.$mth(rhs.to_owned())
})
type Output = Array<B, D>;
fn $mth(self, rhs: &ArrayBase<S, D>) -> Array<B, D> {
rhs.map(move |elt| self.clone() $operator elt.clone())
}
}
);
Expand All @@ -241,16 +227,16 @@ mod arithmetic_ops {

macro_rules! all_scalar_ops {
($int_scalar:ty) => (
impl_scalar_lhs_op!($int_scalar, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!($int_scalar, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!($int_scalar, Commute, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!($int_scalar, Ordered, /, Div, div, "division");
impl_scalar_lhs_op!($int_scalar, Ordered, %, Rem, rem, "remainder");
impl_scalar_lhs_op!($int_scalar, Commute, &, BitAnd, bitand, "bit and");
impl_scalar_lhs_op!($int_scalar, Commute, |, BitOr, bitor, "bit or");
impl_scalar_lhs_op!($int_scalar, Commute, ^, BitXor, bitxor, "bit xor");
impl_scalar_lhs_op!($int_scalar, Ordered, <<, Shl, shl, "left shift");
impl_scalar_lhs_op!($int_scalar, Ordered, >>, Shr, shr, "right shift");
impl_scalar_lhs_op!($int_scalar, +, Add, add, "addition");
impl_scalar_lhs_op!($int_scalar, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!($int_scalar, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!($int_scalar, /, Div, div, "division");
impl_scalar_lhs_op!($int_scalar, %, Rem, rem, "remainder");
impl_scalar_lhs_op!($int_scalar, &, BitAnd, bitand, "bit and");
impl_scalar_lhs_op!($int_scalar, |, BitOr, bitor, "bit or");
impl_scalar_lhs_op!($int_scalar, ^, BitXor, bitxor, "bit xor");
impl_scalar_lhs_op!($int_scalar, <<, Shl, shl, "left shift");
impl_scalar_lhs_op!($int_scalar, >>, Shr, shr, "right shift");
);
}
all_scalar_ops!(i8);
Expand All @@ -264,31 +250,31 @@ mod arithmetic_ops {
all_scalar_ops!(i128);
all_scalar_ops!(u128);

impl_scalar_lhs_op!(bool, Commute, &, BitAnd, bitand, "bit and");
impl_scalar_lhs_op!(bool, Commute, |, BitOr, bitor, "bit or");
impl_scalar_lhs_op!(bool, Commute, ^, BitXor, bitxor, "bit xor");
impl_scalar_lhs_op!(bool, &, BitAnd, bitand, "bit and");
impl_scalar_lhs_op!(bool, |, BitOr, bitor, "bit or");
impl_scalar_lhs_op!(bool, ^, BitXor, bitxor, "bit xor");

impl_scalar_lhs_op!(f32, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(f32, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(f32, Commute, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(f32, Ordered, /, Div, div, "division");
impl_scalar_lhs_op!(f32, Ordered, %, Rem, rem, "remainder");
impl_scalar_lhs_op!(f32, +, Add, add, "addition");
impl_scalar_lhs_op!(f32, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(f32, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(f32, /, Div, div, "division");
impl_scalar_lhs_op!(f32, %, Rem, rem, "remainder");

impl_scalar_lhs_op!(f64, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(f64, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(f64, Commute, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(f64, Ordered, /, Div, div, "division");
impl_scalar_lhs_op!(f64, Ordered, %, Rem, rem, "remainder");
impl_scalar_lhs_op!(f64, +, Add, add, "addition");
impl_scalar_lhs_op!(f64, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(f64, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(f64, /, Div, div, "division");
impl_scalar_lhs_op!(f64, %, Rem, rem, "remainder");

impl_scalar_lhs_op!(Complex<f32>, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(Complex<f32>, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(Complex<f32>, Commute, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(Complex<f32>, Ordered, /, Div, div, "division");
impl_scalar_lhs_op!(Complex<f32>, +, Add, add, "addition");
impl_scalar_lhs_op!(Complex<f32>, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(Complex<f32>, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(Complex<f32>, /, Div, div, "division");

impl_scalar_lhs_op!(Complex<f64>, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(Complex<f64>, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(Complex<f64>, Commute, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(Complex<f64>, Ordered, /, Div, div, "division");
impl_scalar_lhs_op!(Complex<f64>, +, Add, add, "addition");
impl_scalar_lhs_op!(Complex<f64>, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(Complex<f64>, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(Complex<f64>, /, Div, div, "division");

impl<A, S, D> Neg for ArrayBase<S, D>
where
Expand Down

0 comments on commit 19e35d3

Please sign in to comment.