Skip to content

Commit

Permalink
add option to discard evaluated states when the walker error exceeds …
Browse files Browse the repository at this point in the history
…a given threshold. This is complementary to the walker reset
  • Loading branch information
svandenhaute committed Jan 24, 2024
1 parent 4b8459a commit 1bf5fbe
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 170 deletions.
4 changes: 2 additions & 2 deletions psiflow/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,15 +447,15 @@ def assign_identifiers(
identifier += 1
for atoms in data: # assign those which were not yet assigned
if ("identifier" not in atoms.info) and atoms.reference_status:
state, identifier = _assign_identifier(atoms, identifier)
state, identifier = _assign_identifier(atoms, False, identifier)
states.append(state)
else:
states.append(atoms)
write_dataset(states, outputs=[outputs[0]])
return identifier
else:
for atoms in data:
state, identifier = _assign_identifier(atoms, identifier)
state, identifier = _assign_identifier(atoms, False, identifier)
states.append(state)
write_dataset(states, outputs=[outputs[0]])
return identifier
Expand Down
5 changes: 5 additions & 0 deletions psiflow/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ class SequentialLearning(BaseLearning):
temperature_ramp: Optional[tuple[float, float, int]] = None
niterations: int = 10
error_thresholds_for_reset: tuple[float, float] = (10, 200)
error_thresholds_for_discard: tuple[float, float] = (20, 500)

def update_walkers(self, walkers: list[BaseWalker], initialize=False):
if self.temperature_ramp is not None:
Expand Down Expand Up @@ -235,6 +236,7 @@ def run(
walkers,
self.identifier,
self.error_thresholds_for_reset,
self.error_thresholds_for_discard,
self.metrics,
)
assert new_data.length().result() > 0, "no new states were generated!"
Expand Down Expand Up @@ -289,6 +291,7 @@ def run(
self.identifier,
self.nstates_per_iteration,
self.error_thresholds_for_reset,
self.error_thresholds_for_discard,
self.metrics,
)
assert new_data.length().result() > 0, "no new states were generated!"
Expand Down Expand Up @@ -318,6 +321,7 @@ class IncrementalLearning(BaseLearning):
cv_delta: Optional[float] = None
niterations: int = 10
error_thresholds_for_reset: tuple[float, float] = (10, 200)
error_thresholds_for_discard: tuple[float, float] = (10, 200)

def update_walkers(self, walkers: list[BaseWalker], initialize=False):
for walker in walkers: # may not all contain bias
Expand Down Expand Up @@ -378,6 +382,7 @@ def run(
walkers,
self.identifier,
self.error_thresholds_for_reset,
self.error_thresholds_for_discard,
self.metrics,
)
assert new_data.length().result() > 0, "no new states were generated!"
Expand Down
17 changes: 14 additions & 3 deletions psiflow/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import typeguard
import wandb
from parsl.app.app import python_app
from parsl.app.app import join_app, python_app
from parsl.data_provider.files import File

import psiflow
Expand Down Expand Up @@ -53,6 +53,7 @@ def _save_walker_logs(data: dict[str, list], path: Path) -> str:
field_names = [
"walker_index",
"counter",
"is_discarded",
"is_reset",
"e_rsme",
"f_rmse",
Expand Down Expand Up @@ -113,6 +114,7 @@ def _log_walker(
walker_index: int,
evaluated_state: FlowAtoms,
error: tuple[Optional[float], Optional[float]],
discard: bool,
condition: bool,
identifier: int,
disagreement: Optional[float] = None,
Expand All @@ -123,6 +125,7 @@ def _log_walker(
data = {}
data["walker_index"] = walker_index
data["counter"] = metadata["counter"]
data["is_discarded"] = discard
data["is_reset"] = condition
data["e_rmse"] = error[0]
data["f_rmse"] = error[1]
Expand Down Expand Up @@ -484,6 +487,7 @@ def log_walker(
metadata,
state,
error,
discard,
condition,
identifier,
disagreement=None,
Expand All @@ -496,6 +500,7 @@ def log_walker(
i,
state,
error,
discard,
condition,
identifier,
disagreement,
Expand All @@ -522,20 +527,26 @@ def save(
model: Optional[BaseModel] = None,
dataset: Optional[Dataset] = None,
):
@join_app
def log_string(s: str) -> None:
logger.info(s)

path = Path(path)
if not path.exists():
path.mkdir()
walker_logs = None
dataset_log = None
if len(self.walker_logs) > 0:
walker_logs = gather_walker_logs(*self.walker_logs)
save_walker_logs(walker_logs, path / "walkers.log")
walker_logs_str = save_walker_logs(walker_logs, path / "walkers.log")
log_string(walker_logs_str)
self.walker_logs = []
if model is not None:
assert dataset is not None
inputs = [dataset.data_future, model.evaluate(dataset).data_future]
dataset_log = log_dataset(inputs=inputs)
save_dataset_log(dataset_log, path / "dataset.log")
dataset_log_str = save_dataset_log(dataset_log, path / "dataset.log")
log_string(dataset_log_str)
if self.wandb_group is not None:
# typically needs a result() from caller
return to_wandb( # noqa: F841
Expand Down
2 changes: 1 addition & 1 deletion psiflow/models/_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class MACEConfig:
save_cpu: bool = True
clip_grad: Optional[float] = 10
wandb: bool = True
wandb_project: str = "psiflow"
wandb_project: Optional[str] = "psiflow"
wandb_group: Optional[str] = None
wandb_name: str = "mace_training"
wandb_log_hypers: list = field(
Expand Down
Loading

0 comments on commit 1bf5fbe

Please sign in to comment.