Skip to content

Commit

Permalink
use proper relative matrix equality check in proptest tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pnevyk committed Nov 22, 2023
1 parent 0d4f8d3 commit 9a1fcf6
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 48 deletions.
1 change: 0 additions & 1 deletion faer-libs/faer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ polars = { version = "0.34", features = ["lazy", "parquet"] }
dbgf = "0.1.0"
proptest = "1.4.0"
matrixcompare = "0.3.0"
approx = "0.5.0"

[[example]]
name = "conversions"
Expand Down
8 changes: 4 additions & 4 deletions faer-libs/faer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4267,14 +4267,14 @@ mod tests {
fn prop_solve_lower_triangular((m, rhs) in mat_rhs(BlockStructure::UnitTriangularLower, prop_real())) {
let sol = m.solve_lower_triangular(rhs.clone());
let m_times_sol = m * sol;
prop_assert_matrix_eq!(m_times_sol, rhs);
prop_assert_matrix_eq!(m_times_sol, rhs, tol = 10.0 * f64::EPSILON);
}

#[test]
fn prop_qr_real(H in square_mat(prop_real())) {
let qr = H.qr();
prop_assert_matrix_eq!(qr.compute_q() * qr.compute_r(), &H);
prop_assert_matrix_eq!(qr.compute_thin_q() * qr.compute_thin_r(), &H);
prop_assert_matrix_eq!(qr.compute_q() * qr.compute_r(), &H, tol = 10.0 * f64::EPSILON);
prop_assert_matrix_eq!(qr.compute_thin_q() * qr.compute_thin_r(), &H, tol = 10.0 * f64::EPSILON);
}

#[test]
Expand All @@ -4290,7 +4290,7 @@ mod tests {
.copy_from(eigen.s_diagonal());
let u = eigen.u();

prop_assert_matrix_eq!(u * &s, &H * u);
prop_assert_matrix_eq!(u * &s, &H * u, tol = 10.0 * f64::EPSILON);
}
}
}
82 changes: 39 additions & 43 deletions faer-libs/faer/src/proptest_support.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::collections::HashSet;

use approx::{relative_eq, AbsDiffEq, RelativeEq};
use faer_core::{mul::triangular::BlockStructure, ComplexField, Mat};
use faer_core::{mul::triangular::BlockStructure, AsMatRef, ComplexField, Mat};
use matrixcompare::comparators::ElementwiseComparator;
use proptest::{
prelude::Rng,
Expand Down Expand Up @@ -557,20 +556,7 @@ where
}

pub struct ComplexFieldComparator<E: ComplexField> {
pub epsilon: E::Real,
pub max_relative: E::Real,
}

impl<E: ComplexField> Default for ComplexFieldComparator<E>
where
E::Real: RelativeEq<Epsilon = E::Real>,
{
fn default() -> Self {
Self {
epsilon: E::Real::default_epsilon().faer_sqrt(),
max_relative: E::Real::default_epsilon(),
}
}
pub tolerance: E::Real,
}

pub struct ComplexFieldError<E: ComplexField> {
Expand All @@ -588,51 +574,61 @@ where

impl<E: ComplexField> ElementwiseComparator<E> for ComplexFieldComparator<E>
where
E::Real: core::fmt::Display + RelativeEq<Epsilon = E::Real>,
E::Real: core::fmt::Display,
{
type Error = ComplexFieldError<E>;

fn compare(&self, x: &E, y: &E) -> Result<(), Self::Error> {
let real = relative_eq!(
&x.faer_real(),
&y.faer_real(),
epsilon = self.epsilon,
max_relative = self.max_relative
);
let imag = relative_eq!(
&x.faer_imag(),
&y.faer_imag(),
epsilon = self.epsilon,
max_relative = self.max_relative
);

if real && imag {
Ok(())
} else if !real {
let real_abs_diff: E::Real = x.faer_real().faer_sub(y.faer_real()).faer_abs();
let imag_abs_diff: E::Real = x.faer_imag().faer_sub(y.faer_imag()).faer_abs();

if real_abs_diff > self.tolerance {
Err(ComplexFieldError {
value: (x.faer_real().faer_sub(y.faer_real())).faer_abs(),
value: real_abs_diff,
})
} else {
} else if imag_abs_diff > self.tolerance {
Err(ComplexFieldError {
value: (x.faer_imag().faer_sub(y.faer_imag())).faer_abs(),
value: imag_abs_diff,
})
} else {
Ok(())
}
}

fn description(&self) -> String {
format!(
"|x - y| <= {} or |x - y| <= max(|x|, |y|) * {}",
self.epsilon, self.max_relative
)
format!("|x - y| <= {}", self.tolerance)
}
}

pub fn relative_epsilon<E>(x: impl AsMatRef<E>, y: impl AsMatRef<E>, threshold: E::Real) -> E::Real
where
E: ComplexField,
{
let x = x.as_mat_ref();
let y = y.as_mat_ref();

let dim_max = x.nrows().max(x.ncols()).max(y.nrows()).max(y.ncols());
let dim_max = E::Real::faer_from_f64(dim_max as f64);

let x_norm = x.norm_max();
let y_norm = y.norm_max();
let norm_max = if x_norm > y_norm { x_norm } else { y_norm };

threshold.faer_mul(dim_max).faer_mul(norm_max)
}

#[macro_export]
macro_rules! prop_assert_matrix_eq {
($x:expr, $y:expr) => {{
let comp = $crate::proptest_support::ComplexFieldComparator::default();
($x:expr, $y:expr, tol = $tol:expr) => {{
let x = $x;
let y = $y;
let x = x.as_mat_ref();
let y = y.as_mat_ref();

let tolerance = $crate::proptest_support::relative_epsilon(&x, &y, $tol);
let comp = $crate::proptest_support::ComplexFieldComparator { tolerance };

let result = ::matrixcompare::compare_matrices(&$x, &$y, &comp);
let result = ::matrixcompare::compare_matrices(&x, &y, &comp);
if let Err(failure) = result {
let message = format!(
"Comparison failure at {}:{}. Error:\n {}",
Expand Down

0 comments on commit 9a1fcf6

Please sign in to comment.