From 975ab20d813e0dcb8d06976ec70cc8cf4c462a6b Mon Sep 17 00:00:00 2001 From: Markus Date: Sun, 23 Feb 2025 14:36:09 +0100 Subject: [PATCH] chore: Update dependencies faer-ext has been held back, as it uses an older version of faer, which causes an import collision. --- Cargo.toml | 10 +++++----- src/routines/evaluation/ipm.rs | 2 +- src/routines/evaluation/qr.rs | 14 ++++++++------ src/routines/optimization/mod.rs | 12 ++++++------ 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 12fb8037d..251ba5570 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,12 +16,12 @@ exclude = [".github/*", ".vscode/*"] [dependencies] csv = "1.2.1" -ndarray = { version = "0.15.6", features = ["rayon"] } +ndarray = { version = "0.16.1", features = ["rayon"] } serde = "1.0.188" serde_json = "1.0.66" sobol_burley = "0.5.0" -ndarray-stats = "0.5.1" -linfa-linalg = "0.1.0" +ndarray-stats = "0.6.0" +linfa-linalg = "0.2.0" argmin = { version = "0.10.0", features = [] } argmin-math = { version = "0.4.0", features = ["ndarray_v0_15-nolinalg"] } tracing = "0.1.40" @@ -31,8 +31,8 @@ tracing-subscriber = { version = "0.3.17", features = [ "time", ] } config = { version = "0.15", features = ["preserve_order"] } -faer = "0.19.3" -faer-ext = { version = "0.2.0", features = ["nalgebra", "ndarray"] } +faer = "0.20.2" +faer-ext = { version = "0.4.1", features = ["nalgebra", "ndarray"] } pharmsol = "0.7.6" rand = "0.9.0" anyhow = "1.0.86" diff --git a/src/routines/evaluation/ipm.rs b/src/routines/evaluation/ipm.rs index 2554d63a8..febfb6a9e 100644 --- a/src/routines/evaluation/ipm.rs +++ b/src/routines/evaluation/ipm.rs @@ -97,7 +97,7 @@ pub fn burke( let uph = uph.t(); let smuyinv = smu * (&ecol / &y); let rhsdw = &erow / &w - (psi.dot(&smuyinv)); - let a = rhsdw.clone().into_shape((rhsdw.len(), 1))?; + let a = rhsdw.clone().into_shape_with_order((rhsdw.len(), 1))?; let x = uph .t() diff --git a/src/routines/evaluation/qr.rs b/src/routines/evaluation/qr.rs index bda68f3b8..e8b75b502 100644 --- a/src/routines/evaluation/qr.rs +++ b/src/routines/evaluation/qr.rs @@ -1,18 +1,20 @@ -// use faer::{FaerMat, IntoFaer, IntoNdarray}; use faer_ext::*; use ndarray::parallel::prelude::*; use ndarray::{Array2, Axis}; +use faer::MatRef; + pub fn calculate_r(x: &Array2) -> (Array2, Vec) { let mut n_x = x.clone(); n_x.axis_iter_mut(Axis(0)) .into_par_iter() .for_each(|mut row| row /= row.sum()); - let mat_x = n_x.view().into_faer(); - let qr = mat_x.col_piv_qr(); - let r_mat = qr.compute_r(); + let mat_x: MatRef<'_, f64> = n_x.view().into_faer(); + let qr: faer::sparse::solvers::ColPivQr = mat_x.col_piv_qr(); + let r_mat: faer::Mat = qr.compute_r(); let (forward, _inverse) = qr.col_permutation().arrays(); - let r = r_mat.as_ref().into_ndarray().to_owned(); - let perm = Vec::from(forward); + let r: ndarray::ArrayBase, ndarray::Dim<[usize; 2]>> = + r_mat.as_ref().into_ndarray().to_owned(); + let perm: Vec = Vec::from(forward); (r, perm) } diff --git a/src/routines/optimization/mod.rs b/src/routines/optimization/mod.rs index 973c6df42..e7566bd1e 100644 --- a/src/routines/optimization/mod.rs +++ b/src/routines/optimization/mod.rs @@ -17,10 +17,10 @@ pub struct SppOptimizer<'a, E: Equation> { } impl<'a, E: Equation> CostFunction for SppOptimizer<'a, E> { - type Param = Array1; + type Param = Vec; type Output = f64; fn cost(&self, spp: &Self::Param) -> Result { - let theta = spp.to_owned().insert_axis(Axis(0)); + let theta = Array1::from(spp.clone()).insert_axis(Axis(0)); let psi = psi(self.equation, self.data, &theta, self.sig, false, false); @@ -53,17 +53,17 @@ impl<'a, E: Equation> SppOptimizer<'a, E> { } } pub fn optimize_point(self, spp: Array1) -> Result, Error> { - let simplex = create_initial_simplex(&spp); + let simplex = create_initial_simplex(&spp.to_vec()); let solver = NelderMead::new(simplex).with_sd_tolerance(1e-2)?; let res = Executor::new(self, solver) .configure(|state| state.max_iters(5)) // .add_observer(SlogLogger::term(), ObserverMode::Always) .run()?; - Ok(res.state.best_param.unwrap()) + Ok(Array1::from(res.state.best_param.unwrap())) } } -fn create_initial_simplex(initial_point: &Array1) -> Vec> { +fn create_initial_simplex(initial_point: &[f64]) -> Vec> { let num_dimensions = initial_point.len(); let perturbation_percentage = 0.008; @@ -71,7 +71,7 @@ fn create_initial_simplex(initial_point: &Array1) -> Vec> { let mut vertices = Vec::new(); // Add the initial point to the vertices - vertices.push(initial_point.to_owned()); + vertices.push(initial_point.to_vec()); // Calculate perturbation values for each component for i in 0..num_dimensions {