From ba034a76722d5f8661c2443cc932e8eabd215304 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Sat, 16 Mar 2024 08:19:08 +0100 Subject: [PATCH] Clean comments and format --- examples/theophylline/main.rs | 5 ++- src/logger.rs | 10 ++--- src/routines/evaluation/ipm.rs | 80 ++++++++-------------------------- src/routines/settings.rs | 2 +- 4 files changed, 28 insertions(+), 69 deletions(-) diff --git a/examples/theophylline/main.rs b/examples/theophylline/main.rs index c99d02feb..65ed022ee 100644 --- a/examples/theophylline/main.rs +++ b/examples/theophylline/main.rs @@ -26,7 +26,10 @@ struct Model { // This is a helper function to get the parameter value by name impl Model { pub fn get_param(&self, str: &str) -> f64 { - *self.params.get(str).expect(format!("Parameter {} not found", str).as_str()) + *self + .params + .get(str) + .expect(format!("Parameter {} not found", str).as_str()) } } diff --git a/src/logger.rs b/src/logger.rs index e2482fc07..67aaf939a 100644 --- a/src/logger.rs +++ b/src/logger.rs @@ -48,12 +48,12 @@ pub fn setup_log(settings: &Settings, ui_tx: Option>) { let stdout_layer = match settings.config.tui { false => { let layer = fmt::layer() - .with_writer(std::io::stdout) - .with_ansi(true) - .with_target(false) - .with_timer(CompactTimestamp); + .with_writer(std::io::stdout) + .with_ansi(true) + .with_target(false) + .with_timer(CompactTimestamp); Some(layer) - }, + } true => None, }; diff --git a/src/routines/evaluation/ipm.rs b/src/routines/evaluation/ipm.rs index 69fab9224..c6225c283 100644 --- a/src/routines/evaluation/ipm.rs +++ b/src/routines/evaluation/ipm.rs @@ -38,113 +38,70 @@ type OneDimArray = ArrayBase, ndarray::Dim<[usize; 1]>>; pub fn burke( psi: &ArrayBase, Dim<[usize; 2]>>, ) -> Result<(OneDimArray, f64), Box> { - tracing::info!("Profiling {} subjects and {} spp", psi.nrows(), psi.ncols()); - // trace_memory("start of burke"); - // dbg!(psi.dim()); - // let psi_clone = psi.clone(); - // trace_memory("after cloning psi"); let psi = psi.mapv(|x| x.abs()); - // trace_memory("new psi"); let (row, col) = psi.dim(); - // if row>col { - // return Err("The matrix PSI has row>col".into()); - // } if psi.min()? < &0.0 { return Err("PSI contains negative elements".into()); } let ecol: ArrayBase, Dim<[usize; 1]>> = Array::ones(col); - // trace_memory("after creating ecol"); - // dbg!(ecol.dim()); let mut plam = psi.dot(&ecol); - // trace_memory("after creating plam"); - // dbg!(plam.dim()); - - // if plam.min().unwrap() <= &1e-15 { - // return Err("The vector psi*e has a non-positive entry".into()); - // } let eps = 1e-8; let mut sig = 0.; let erow: ArrayBase, Dim<[usize; 1]>> = Array::ones(row); - // trace_memory("after creating erow"); - // dbg!(erow.dim()); + let mut lam = ecol.clone(); - // trace_memory("after creating lam"); - // dbg!(lam.dim()); + let mut w = 1. / &plam; - // trace_memory("after creating w"); - // dbg!(w.dim()); + let mut ptw = psi.t().dot(&w); - // dbg!(&ptw); - // trace_memory("after creating ptw"); - // dbg!(ptw.dim()); + let shrink = 2. * *ptw.max().unwrap(); lam *= shrink; plam *= shrink; w /= shrink; ptw /= shrink; - // dbg!(&w); - // dbg!(&plam); - // dbg!(&erow); + let mut y = &ecol - &ptw; let mut r = &erow - &w * &plam; - // dbg!(&r); + let mut norm_r = norm_inf(r); - // dbg!(&y); - // dbg!(&r); - // dbg!(&norm_r); - // dbg!(&plam); + let sum_log_plam = plam.mapv(|x: f64| x.ln()).sum(); - // dbg!(sum_log_plam); let mut gap = (w.mapv(|x: f64| x.ln()).sum() + sum_log_plam).abs() / (1. + sum_log_plam); - // dbg!(gap); let mut mu = lam.t().dot(&y) / col as f64; - // dbg!(mu); - // trace_memory("before the loop"); while mu > eps || norm_r > eps || gap > eps { // log::info!("IPM cycle"); let smu = sig * mu; - let inner = &lam / &y; //divide(&lam, &y); - // dbg!(&inner); - // trace_memory("after creating inner"); - // dbg!(inner.dim()); - let w_plam = &plam / &w; //divide(&plam, &w); - // trace_memory("after creating w_plam"); - // dbg!(&w_plam); - // dbg!(w_plam.dim()); - // dbg!(psi.dim()); - // dbg!(&Array2::from_diag(&inner).dim()); + let inner = &lam / &y; + + let w_plam = &plam / &w; + let mut psi_inner: Array2 = psi.clone(); for (mut col, inner_val) in psi_inner.axis_iter_mut(Axis(1)).zip(&inner) { col *= *inner_val; } let h = psi_inner.dot(&psi.t()) + Array2::from_diag(&w_plam); - // dbg!(&h); - // trace_memory("after creating h"); - // dbg!(h.dim()); + let uph = h.cholesky()?; - // trace_memory("after creating uph"); - // dbg!(uph.dim()); + let uph = uph.t(); let smuyinv = smu * (&ecol / &y); let rhsdw = &erow / &w - (psi.dot(&smuyinv)); let a = rhsdw.clone().into_shape((rhsdw.len(), 1))?; - // dbg!(&rhsdw); - //todo: cleanup this aux variable - //dbg!(uph.t().is_triangular(linfa_linalg::triangular::UPLO::Upper)); - // uph.solve_into(rhsdw); + let x = uph .t() .solve_triangular(&a, linfa_linalg::triangular::UPLO::Lower)?; - // dbg!(&x); + let dw_aux = uph.solve_triangular(&x, linfa_linalg::triangular::UPLO::Upper)?; let dw = dw_aux.column(0); let dy = -psi.t().dot(&dw); let dlam = smuyinv - &lam - inner * &dy; - // dbg!(&dlam); + let mut alfpri = -1. / ((&dlam / &lam).min().unwrap().min(-0.5)); - // dbg!(alfpri); + alfpri = (0.99995 * alfpri).min(1.0); let mut alfdual = -1. / ((&dy / &y).min().unwrap().min(-0.5)); @@ -177,8 +134,7 @@ pub fn burke( lam /= row as f64; let obj = psi.dot(&lam).mapv(|x| x.ln()).sum(); lam = &lam / lam.sum(); - // trace_memory("end of ipm"); - // abort(); + Ok((lam, obj)) } diff --git a/src/routines/settings.rs b/src/routines/settings.rs index 7506b31b8..5cdc0542d 100644 --- a/src/routines/settings.rs +++ b/src/routines/settings.rs @@ -252,4 +252,4 @@ fn default_10k() -> usize { fn default_cycles() -> usize { 100 -} \ No newline at end of file +}