Skip to content

Commit

Permalink
Merge branch 'main' of github.com:LAPKB/PMcore
Browse files Browse the repository at this point in the history
  • Loading branch information
Siel committed Mar 16, 2024
2 parents bcf6004 + ba034a7 commit 3ba41a3
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 68 deletions.
5 changes: 4 additions & 1 deletion examples/theophylline/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ struct Model {
// This is a helper function to get the parameter value by name
impl Model {
pub fn get_param(&self, str: &str) -> f64 {
*self.params.get(str).expect(format!("Parameter {} not found", str).as_str())
*self
.params
.get(str)
.expect(format!("Parameter {} not found", str).as_str())
}
}

Expand Down
10 changes: 5 additions & 5 deletions src/logger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ pub fn setup_log(settings: &Settings, ui_tx: Option<UnboundedSender<Comm>>) {
let stdout_layer = match settings.config.tui {
false => {
let layer = fmt::layer()
.with_writer(std::io::stdout)
.with_ansi(true)
.with_target(false)
.with_timer(CompactTimestamp);
.with_writer(std::io::stdout)
.with_ansi(true)
.with_target(false)
.with_timer(CompactTimestamp);
Some(layer)
},
}
true => None,
};

Expand Down
80 changes: 18 additions & 62 deletions src/routines/evaluation/ipm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,113 +38,70 @@ type OneDimArray = ArrayBase<OwnedRepr<f64>, ndarray::Dim<[usize; 1]>>;
pub fn burke(
psi: &ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>,
) -> Result<(OneDimArray, f64), Box<dyn error::Error>> {
tracing::info!("Profiling {} subjects and {} spp", psi.nrows(), psi.ncols());
// trace_memory("start of burke");
// dbg!(psi.dim());
// let psi_clone = psi.clone();
// trace_memory("after cloning psi");
let psi = psi.mapv(|x| x.abs());
// trace_memory("new psi");
let (row, col) = psi.dim();
// if row>col {
// return Err("The matrix PSI has row>col".into());
// }
if psi.min()? < &0.0 {
return Err("PSI contains negative elements".into());
}
let ecol: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(col);
// trace_memory("after creating ecol");
// dbg!(ecol.dim());
let mut plam = psi.dot(&ecol);
// trace_memory("after creating plam");
// dbg!(plam.dim());

// if plam.min().unwrap() <= &1e-15 {
// return Err("The vector psi*e has a non-positive entry".into());
// }
let eps = 1e-8;
let mut sig = 0.;
let erow: ArrayBase<OwnedRepr<f64>, Dim<[usize; 1]>> = Array::ones(row);
// trace_memory("after creating erow");
// dbg!(erow.dim());

let mut lam = ecol.clone();
// trace_memory("after creating lam");
// dbg!(lam.dim());

let mut w = 1. / &plam;
// trace_memory("after creating w");
// dbg!(w.dim());

let mut ptw = psi.t().dot(&w);
// dbg!(&ptw);
// trace_memory("after creating ptw");
// dbg!(ptw.dim());

let shrink = 2. * *ptw.max().unwrap();
lam *= shrink;
plam *= shrink;
w /= shrink;
ptw /= shrink;
// dbg!(&w);
// dbg!(&plam);
// dbg!(&erow);

let mut y = &ecol - &ptw;
let mut r = &erow - &w * &plam;
// dbg!(&r);

let mut norm_r = norm_inf(r);
// dbg!(&y);
// dbg!(&r);
// dbg!(&norm_r);
// dbg!(&plam);

let sum_log_plam = plam.mapv(|x: f64| x.ln()).sum();
// dbg!(sum_log_plam);
let mut gap = (w.mapv(|x: f64| x.ln()).sum() + sum_log_plam).abs() / (1. + sum_log_plam);
// dbg!(gap);
let mut mu = lam.t().dot(&y) / col as f64;
// dbg!(mu);
// trace_memory("before the loop");
while mu > eps || norm_r > eps || gap > eps {
// log::info!("IPM cycle");
let smu = sig * mu;
let inner = &lam / &y; //divide(&lam, &y);
// dbg!(&inner);
// trace_memory("after creating inner");
// dbg!(inner.dim());
let w_plam = &plam / &w; //divide(&plam, &w);
// trace_memory("after creating w_plam");
// dbg!(&w_plam);
// dbg!(w_plam.dim());
// dbg!(psi.dim());
// dbg!(&Array2::from_diag(&inner).dim());
let inner = &lam / &y;

let w_plam = &plam / &w;

let mut psi_inner: Array2<f64> = psi.clone();
for (mut col, inner_val) in psi_inner.axis_iter_mut(Axis(1)).zip(&inner) {
col *= *inner_val;
}
let h = psi_inner.dot(&psi.t()) + Array2::from_diag(&w_plam);
// dbg!(&h);
// trace_memory("after creating h");
// dbg!(h.dim());

let uph = h.cholesky()?;
// trace_memory("after creating uph");
// dbg!(uph.dim());

let uph = uph.t();
let smuyinv = smu * (&ecol / &y);
let rhsdw = &erow / &w - (psi.dot(&smuyinv));
let a = rhsdw.clone().into_shape((rhsdw.len(), 1))?;
// dbg!(&rhsdw);
//todo: cleanup this aux variable
//dbg!(uph.t().is_triangular(linfa_linalg::triangular::UPLO::Upper));
// uph.solve_into(rhsdw);

let x = uph
.t()
.solve_triangular(&a, linfa_linalg::triangular::UPLO::Lower)?;
// dbg!(&x);

let dw_aux = uph.solve_triangular(&x, linfa_linalg::triangular::UPLO::Upper)?;

let dw = dw_aux.column(0);
let dy = -psi.t().dot(&dw);

let dlam = smuyinv - &lam - inner * &dy;
// dbg!(&dlam);

let mut alfpri = -1. / ((&dlam / &lam).min().unwrap().min(-0.5));
// dbg!(alfpri);

alfpri = (0.99995 * alfpri).min(1.0);

let mut alfdual = -1. / ((&dy / &y).min().unwrap().min(-0.5));
Expand Down Expand Up @@ -177,8 +134,7 @@ pub fn burke(
lam /= row as f64;
let obj = psi.dot(&lam).mapv(|x| x.ln()).sum();
lam = &lam / lam.sum();
// trace_memory("end of ipm");
// abort();

Ok((lam, obj))
}

Expand Down

0 comments on commit 3ba41a3

Please sign in to comment.