Skip to content

Commit

Permalink
Finished implementation of outputs_folder
Browse files Browse the repository at this point in the history
Improtant: MetaWriter and meta_rust.csv are no longer used. Instead, cycles.csv now includes a column "converged" with the boolean flag for convergence. To faciliate this, the ordering of logging has been changed, and break has been replaced with a stop-flag.
  • Loading branch information
mhovd committed Mar 30, 2024
1 parent 8973ffb commit 90e45af
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 110 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.rs text eol=lf
48 changes: 27 additions & 21 deletions src/algorithms/npag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ where
_ => panic!("Error type not supported"),
},
converged: false,
cycle_log: CycleLog::new(&settings.random.names()),
cycle_log: CycleLog::new(&settings),
cache: settings.config.cache,
tx,
settings,
Expand Down Expand Up @@ -273,26 +273,12 @@ where

self.optim_gamma();

let state = NPCycle {
cycle: self.cycle,
objf: -2. * self.objf,
delta_objf: (self.last_objf - self.objf).abs(),
nspp: self.theta.shape()[0],
theta: self.theta.clone(),
gamlam: self.gamma,
};

// Log relevant cycle information
tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
tracing::debug!("Support points: {}", self.theta.shape()[0]);
tracing::debug!("Gamma = {:.4}", self.gamma);
tracing::debug!("EPS = {:.4}", self.eps);

match &self.tx {
Some(tx) => tx.send(Comm::NPCycle(state.clone())).unwrap(),
None => (),
}

// Increasing objf signals instability or model misspecification.
if self.last_objf > self.objf {
tracing::warn!(
Expand All @@ -305,10 +291,8 @@ where
self.w = self.lambda.clone();
let pyl = self.psi.dot(&self.w);

self.cycle_log
.push_and_write(state, self.settings.config.output);

// Stop if we have reached convergence criteria
let mut stop = false;
if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E {
self.eps /= 2.;
if self.eps <= THETA_E {
Expand All @@ -318,7 +302,7 @@ where
"The run converged with the following criteria: Log-Likelihood"
);
self.converged = true;
break;
stop = true;
} else {
self.f0 = self.f1;
self.eps = 0.2;
Expand All @@ -328,18 +312,40 @@ where
if self.eps <= THETA_E {
tracing::info!("The run converged with the following criteria: Eps");
self.converged = true;
break;
stop = true;
}

// Stop if we have reached maximum number of cycles
if self.cycle >= self.settings.config.cycles {
tracing::warn!("Maximum number of cycles reached");
break;
stop = true;
}

// Stop if stopfile exists
if std::path::Path::new("stop").exists() {
tracing::warn!("Stopfile detected - breaking");
stop = true;
}

let state = NPCycle {
cycle: self.cycle,
objf: -2. * self.objf,
delta_objf: (self.last_objf - self.objf).abs(),
nspp: self.theta.shape()[0],
theta: self.theta.clone(),
gamlam: self.gamma,
converged: self.converged,
};

match &self.tx {
Some(tx) => tx.send(Comm::NPCycle(state.clone())).unwrap(),
None => (),
}

self.cycle_log
.push_and_write(state, self.settings.config.output);

if stop {
break;
}

Expand Down
74 changes: 38 additions & 36 deletions src/algorithms/npod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ where
_ => panic!("Error type not supported"),
},
converged: false,
cycle_log: CycleLog::new(&settings.random.names()),
cycle_log: CycleLog::new(&settings),
cache: settings.config.cache,
tx,
settings,
Expand Down Expand Up @@ -263,25 +263,6 @@ where

self.optim_gamma();

let state = NPCycle {
cycle: self.cycle,
objf: -2. * self.objf,
delta_objf: (self.last_objf - self.objf).abs(),
nspp: self.theta.shape()[0],
theta: self.theta.clone(),
gamlam: self.gamma,
};

// Log relevant cycle information
tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
tracing::debug!("Support points: {}", self.theta.shape()[0]);
tracing::debug!("Gamma = {:.4}", self.gamma);

match &self.tx {
Some(tx) => tx.send(Comm::NPCycle(state.clone())).unwrap(),
None => (),
}

// Increasing objf signals instability or model misspecification.
if self.last_objf > self.objf {
tracing::warn!(
Expand All @@ -293,35 +274,62 @@ where

self.w = self.lambda.clone();

// Perform checks for convergence or termination
let mut stop = false;
// Stop if objective function convergence is reached
if (self.last_objf - self.objf).abs() <= THETA_F {
tracing::info!("Objective function convergence reached");
self.converged = true;
break;
stop = true;
}
// Stop if we have reached maximum number of cycles
if self.cycle >= self.settings.config.cycles {
tracing::warn!("Maximum number of cycles reached");
break;
stop = true;
}

// Stop if stopfile exists
if std::path::Path::new("stop").exists() {
tracing::warn!("Stopfile detected - breaking");
stop = true;
}

// Create a new NPCycle state and log it
let state = NPCycle {
cycle: self.cycle,
objf: -2. * self.objf,
delta_objf: (self.last_objf - self.objf).abs(),
nspp: self.theta.shape()[0],
theta: self.theta.clone(),
gamlam: self.gamma,
converged: self.converged,
};

// Log relevant cycle information
tracing::info!("Objective function = {:.4}", -2.0 * self.objf);
tracing::debug!("Support points: {}", self.theta.shape()[0]);
tracing::debug!("Gamma = {:.4}", self.gamma);

match &self.tx {
Some(tx) => tx.send(Comm::NPCycle(state.clone())).unwrap(),
None => (),
}

self.cycle_log
.push_and_write(state, self.settings.config.output);

if stop {
break;
}
let pyl = self.psi.dot(&self.w);

// Add new point to theta based on the optimization of the D function
// If no stop signal, add new point to theta based on the optimization of the D function
let pyl = self.psi.dot(&self.w);
let sigma = ErrorPoly {
c: self.c,
gl: self.gamma,
e_type: &self.error_type,
};
// for spp in self.theta.clone().rows() {
// let optimizer = SppOptimizer::new(&self.engine, &self.scenarios, &sigma, &pyl);
// let candidate_point = optimizer.optimize_point(spp.to_owned()).unwrap();
// prune(&mut self.theta, candidate_point, &self.ranges, THETA_D);
// }

let mut candididate_points: Vec<Array1<f64>> = Vec::default();
for spp in self.theta.clone().rows() {
candididate_points.push(spp.to_owned());
Expand All @@ -335,14 +343,8 @@ where
prune(&mut self.theta, cp, &self.ranges, THETA_D);
}

//TODO: the cycle might break before reaching this point
self.cycle_log
.push_and_write(state, self.settings.config.output);

// Increment the cycle count and prepare for the next cycle
self.cycle += 1;

// log::info!("cycle: {}, objf: {}", self.cycle, self.objf);
// dbg!((self.last_objf - self.objf).abs());
}

self.to_npresult()
Expand Down
79 changes: 26 additions & 53 deletions src/routines/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,9 @@ impl NPResult {
self.write_posterior();
self.write_obs();
self.write_pred(engine, idelta, tad);
self.write_meta();
}
}

// Writes meta_rust.csv
pub fn write_meta(&self) {
let mut meta_writer = MetaWriter::new();
meta_writer.write(self.converged, self.cycles);
}

/// Writes theta, which containts the population support points and their associated probabilities
/// Each row is one support point, the last column being probability
pub fn write_theta(&self) {
Expand Down Expand Up @@ -279,32 +272,38 @@ pub struct CycleLog {
cycle_writer: CycleWriter,
}
impl CycleLog {
pub fn new(par_names: &[String]) -> Self {
let cycle_writer = CycleWriter::new("cycles.csv", par_names.to_vec());
pub fn new(settings: &Settings) -> Self {
let cycle_writer = CycleWriter::new(settings);
Self {
cycles: Vec::new(),
cycle_writer,
}
}
pub fn push_and_write(&mut self, npcycle: NPCycle, write_ouput: bool) {
if write_ouput {
self.cycle_writer
.write(npcycle.cycle, npcycle.objf, npcycle.gamlam, &npcycle.theta);
self.cycle_writer.write(
npcycle.cycle,
npcycle.converged,
npcycle.objf,
npcycle.gamlam,
&npcycle.theta,
);
self.cycle_writer.flush();
}
self.cycles.push(npcycle);
}
}

/// Defines the result objects from a run
/// An [NPResult] contains the necessary information to generate predictions and summary statistics
/// An [NPCycle] contains summary of a cycle
/// It holds the following information:
/// - `cycle`: The cycle number
/// - `objf`: The objective function value
/// - `gamlam`: The assay noise parameter, either gamma or lambda
/// - `theta`: The support points and their associated probabilities
/// - `nspp`: The number of support points
/// - `delta_objf`: The change in objective function value from last cycle
/// - `converged`: Whether the algorithm has reached convergence
#[derive(Debug, Clone)]
pub struct NPCycle {
pub cycle: usize,
Expand All @@ -313,6 +312,7 @@ pub struct NPCycle {
pub theta: Array2<f64>,
pub nspp: usize,
pub delta_objf: f64,
pub converged: bool,
}
impl NPCycle {
pub fn new() -> Self {
Expand All @@ -323,6 +323,7 @@ impl NPCycle {
theta: Array2::default((0, 0)),
nspp: 0,
delta_objf: 0.0,
converged: false,
}
}
}
Expand All @@ -339,16 +340,18 @@ pub struct CycleWriter {
}

impl CycleWriter {
pub fn new(file_path: &str, parameter_names: Vec<String>) -> CycleWriter {
let file = File::create(file_path).unwrap();
pub fn new(settings: &Settings) -> CycleWriter {
let file = create_output_file(settings, "cycles.csv").unwrap();
let mut writer = WriterBuilder::new().has_headers(false).from_writer(file);

// Write headers
writer.write_field("cycle").unwrap();
writer.write_field("converged").unwrap();
writer.write_field("neg2ll").unwrap();
writer.write_field("gamlam").unwrap();
writer.write_field("nspp").unwrap();

let parameter_names = settings.random.names();
for param_name in &parameter_names {
writer.write_field(format!("{}.mean", param_name)).unwrap();
writer
Expand All @@ -362,8 +365,16 @@ impl CycleWriter {
CycleWriter { writer }
}

pub fn write(&mut self, cycle: usize, objf: f64, gamma: f64, theta: &Array2<f64>) {
pub fn write(
&mut self,
cycle: usize,
converged: bool,
objf: f64,
gamma: f64,
theta: &Array2<f64>,
) {
self.writer.write_field(format!("{}", cycle)).unwrap();
self.writer.write_field(format!("{}", converged)).unwrap();
self.writer.write_field(format!("{}", objf)).unwrap();
self.writer.write_field(format!("{}", gamma)).unwrap();
self.writer
Expand Down Expand Up @@ -396,44 +407,6 @@ impl CycleWriter {
}
}

// Meta
#[derive(Debug)]
pub struct MetaWriter {
writer: csv::Writer<File>,
}

impl Default for MetaWriter {
fn default() -> Self {
Self::new()
}
}

impl MetaWriter {
pub fn new() -> MetaWriter {
let meta_file = File::create("meta_rust.csv").unwrap();
let mut meta_writer = WriterBuilder::new()
.has_headers(false)
.from_writer(meta_file);
meta_writer.write_field("converged").unwrap();
meta_writer.write_field("ncycles").unwrap();
meta_writer.write_record(None::<&[u8]>).unwrap();
MetaWriter {
writer: meta_writer,
}
}

pub fn write(&mut self, converged: bool, cycle: usize) {
self.writer.write_field(converged.to_string()).unwrap();
self.writer.write_field(format!("{}", cycle)).unwrap();
self.writer.write_record(None::<&[u8]>).unwrap();
self.flush();
}

fn flush(&mut self) {
self.writer.flush().unwrap();
}
}

pub fn posterior(psi: &Array2<f64>, w: &Array1<f64>) -> Array2<f64> {
let py = psi.dot(w);
let mut post: Array2<f64> = Array2::zeros((psi.nrows(), psi.ncols()));
Expand Down

0 comments on commit 90e45af

Please sign in to comment.