Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

313 save final model #340

Merged
merged 11 commits into from
Jun 28, 2024
15 changes: 10 additions & 5 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,22 @@ def __init__(
self.writer = None

# Configure checkpoints.
self.callbacks = [
ModelCheckpoint(
dirpath=config.model_save_folder_path,
save_on_train_epoch_end=True,
)
]

if config.save_top_k is not None:
self.callbacks = [
self.callbacks.append(
ModelCheckpoint(
dirpath=config.model_save_folder_path,
monitor="valid_CELoss",
mode="min",
save_top_k=config.save_top_k,
)
]
else:
self.callbacks = None
)

def __enter__(self):
"""Enter the context manager"""
Expand Down Expand Up @@ -187,7 +192,7 @@ def initialize_trainer(self, train: bool) -> None:
additional_cfg = dict(
devices=devices,
callbacks=self.callbacks,
enable_checkpointing=self.config.save_top_k is not None,
enable_checkpointing=True,
max_epochs=self.config.max_epochs,
num_sanity_val_steps=self.config.num_sanity_val_steps,
strategy=self._get_strategy(),
Expand Down
36 changes: 32 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import psims
import pytest
import yaml
import math
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
from pyteomics.mass import calculate_mass


Expand Down Expand Up @@ -184,10 +185,9 @@ def _create_mzml(peptides, mzml_file, random_state=42):
return mzml_file


@pytest.fixture
def tiny_config(tmp_path):
"""A config file for a tiny model."""
cfg = {
def _get_default_config(tmp_path):
"""Get default test config (dictionary)"""
return {
"n_head": 2,
"dim_feedforward": 10,
"n_layers": 1,
Expand Down Expand Up @@ -255,8 +255,36 @@ def tiny_config(tmp_path):
},
}


def _write_config_file(cfg, tmp_path):
"""Write config file to temp directory"""
cfg_file = tmp_path / "config.yml"
with cfg_file.open("w+") as out_file:
yaml.dump(cfg, out_file)

return cfg_file


@pytest.fixture
def tiny_config(tmp_path):
"""A config file for a tiny model."""
cfg = _get_default_config(tmp_path)
return _write_config_file(cfg, tmp_path)


@pytest.fixture
def tiny_config_interval_greater(tmp_path):
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
"""Config file where val_check interval is greater than the number of training steps"""
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
cfg = _get_default_config(tmp_path)
val_check_interval = 50
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
cfg["val_check_interval"] = val_check_interval
return _write_config_file(cfg, tmp_path)


@pytest.fixture
def tiny_config_not_factor(tmp_path):
"""Config file where val_check interval isn't a factor of the number of training steps"""
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
cfg = _get_default_config(tmp_path)
val_check_interval = 15
cfg["val_check_interval"] = val_check_interval
return _write_config_file(cfg, tmp_path)
46 changes: 45 additions & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@


def test_train_and_run(
mgf_small, mzml_small, tiny_config, tmp_path, monkeypatch
mgf_small,
mzml_small,
tiny_config,
tmp_path,
tiny_config_interval_greater,
tiny_config_not_factor,
monkeypatch,
):
# We can use this to explicitly test different versions.
monkeypatch.setattr(casanovo, "__version__", "3.0.1")
Expand Down Expand Up @@ -86,6 +92,44 @@ def test_train_and_run(
assert psms.loc[4, "sequence"] == "PEPTLDEK"
assert psms.loc[4, "spectra_ref"] == "ms_run[2]:scan=111"

# Test checkpoint saving when val_check_interval is greater than training steps
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
Path.unlink(model_file)
result = run(
[
"train",
"--validation_peak_path",
str(mgf_small),
"--config",
tiny_config_interval_greater,
"--output",
str(tmp_path / "train"),
str(mgf_small),
]
)

assert result.exit_code == 0
assert model_file.exists()

# Test checkpoint saving when val_check_interval is not a factor of training steps
Path.unlink(model_file)
validation_file = tmp_path / "epoch=14-step=15.ckpt"
result = run(
[
"train",
"--validation_peak_path",
str(mgf_small),
"--config",
tiny_config_not_factor,
"--output",
str(tmp_path / "train"),
str(mgf_small),
]
)

assert result.exit_code == 0
assert model_file.exists()
assert validation_file.exists()


def test_auxilliary_cli(tmp_path, monkeypatch):
"""Test the secondary CLI commands"""
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
def test_initialize_model(tmp_path, mgf_small):
"""Test initializing a new or existing model."""
config = Config()
config.model_save_folder_path = tmp_path
# No model filename given, so train from scratch.
ModelRunner(config=config).initialize_model(train=True)

Expand Down
Loading