Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix documentation of --output and updated output settings. #348

Closed
wants to merge 10 commits into from
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased]

### Changed

- Update description of the `output` command-line argument to reflect that this is the root of the output (mzTab, log) files.

### Fixed

- Precursor charges are now exported as integers instead of floats in the mzTab output file, in compliance with the mzTab specification.
Expand Down
70 changes: 59 additions & 11 deletions casanovo/casanovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,17 @@ def __init__(self, *args, **kwargs) -> None:
click.Option(
("-m", "--model"),
help="""
The model weights (.ckpt file). If not provided, Casanovo
will try to download the latest release.
The model weights (.ckpt file). If not provided, Casanovo will
try to download the latest release (during sequencing).
""",
type=click.Path(exists=True, dir_okay=False),
),
click.Option(
("-o", "--output"),
help="The mzTab file to which results will be written.",
help="The root file name to which results (i.e. the mzTab file "
"during sequencing, as well as the log file during all modes) "
"will be written. If not specified, a default timestamped file "
"name will be used.",
type=click.Path(dir_okay=False),
),
click.Option(
Expand All @@ -88,6 +91,15 @@ def __init__(self, *args, **kwargs) -> None:
),
default="info",
),
click.Option(
("-d", "--overwrite_output"),
help="""
Whether to overwrite sequencing output files (i.e. output .log and .mzTab files)
""",
is_flag=True,
show_default=True,
default=False,
),
]


Expand Down Expand Up @@ -127,13 +139,14 @@ def sequence(
config: Optional[str],
output: Optional[str],
verbosity: str,
overwrite_output: bool,
) -> None:
"""De novo sequence peptides from tandem mass spectra.

PEAK_PATH must be one or more mzMl, mzXML, or MGF files from which
to sequence peptides.
"""
output = setup_logging(output, verbosity)
output = setup_logging(output, verbosity, not overwrite_output)
config, model = setup_model(model, config, output, False)
with ModelRunner(config, model) as runner:
logger.info("Sequencing peptides from:")
Expand All @@ -158,13 +171,14 @@ def evaluate(
config: Optional[str],
output: Optional[str],
verbosity: str,
overwrite_output: bool,
) -> None:
"""Evaluate de novo peptide sequencing performance.

ANNOTATED_PEAK_PATH must be one or more annoated MGF files,
such as those provided by MassIVE-KB.
"""
output = setup_logging(output, verbosity)
output = setup_logging(output, verbosity, not overwrite_output)
config, model = setup_model(model, config, output, False)
with ModelRunner(config, model) as runner:
logger.info("Sequencing and evaluating peptides from:")
Expand Down Expand Up @@ -194,22 +208,39 @@ def evaluate(
multiple=True,
type=click.Path(exists=True, dir_okay=False),
)
@click.option(
"-r",
"--root_ckpt_name",
help="""
Root name for all model checkpoints saved during training,
i.e. if root is specified as `--root_ckpt_name foo` than all saved
checkpoint filenames will be formatted as `foo.epoch=2-step=150000.ckpt`.
If root is not specified the checkpoint filenames will instead be formatted
as `epoch=2-step=150000.ckpt`.
""",
required=False,
type=str,
)
def train(
train_peak_path: Tuple[str],
validation_peak_path: Tuple[str],
root_ckpt_name: Optional[str],
model: Optional[str],
config: Optional[str],
output: Optional[str],
verbosity: str,
overwrite_output: bool,
) -> None:
"""Train a Casanovo model on your own data.

TRAIN_PEAK_PATH must be one or more annoated MGF files, such as those
provided by MassIVE-KB, from which to train a new Casnovo model.
"""
output = setup_logging(output, verbosity)
output = setup_logging(output, verbosity, not overwrite_output)
config, model = setup_model(model, config, output, True)
with ModelRunner(config, model) as runner:
with ModelRunner(
config, model, root_checkpoint_name=root_ckpt_name
) as runner:
logger.info("Training a model from:")
for peak_file in train_peak_path:
logger.info(" %s", peak_file)
Expand Down Expand Up @@ -254,8 +285,7 @@ def configure(output: str) -> None:


def setup_logging(
output: Optional[str],
verbosity: str,
output: Optional[str], verbosity: str, check_overwrite: bool = False
) -> Path:
"""Set up the logger.

Expand All @@ -273,11 +303,25 @@ def setup_logging(
output : Path
The output file path.
"""
OUTPUT_SUFFIXES = [".log", ".mztab"]

if output is None:
output = f"casanovo_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"

output = Path(output).expanduser().resolve()

if check_overwrite:
for output_suffix in OUTPUT_SUFFIXES:
next_path = output.with_suffix(output.suffix + output_suffix)
if not next_path.is_file():
continue

raise FileExistsError(
f"Output file {next_path} already exists, existing output files "
f"can't be overwritten without setting the --overwrite_output "
f"flag"
)

logging_levels = {
"debug": logging.DEBUG,
"info": logging.INFO,
Expand All @@ -304,7 +348,9 @@ def setup_logging(
console_handler.setFormatter(console_formatter)
root_logger.addHandler(console_handler)
warnings_logger.addHandler(console_handler)
file_handler = logging.FileHandler(output.with_suffix(".log"))
file_handler = logging.FileHandler(
output.with_suffix(output.suffix + ".log")
)
file_handler.setFormatter(log_formatter)
root_logger.addHandler(file_handler)
warnings_logger.addHandler(file_handler)
Expand All @@ -329,7 +375,7 @@ def setup_model(
config: Optional[str],
output: Optional[Path],
is_train: bool,
) -> Config:
) -> Tuple[Config, str]:
"""Setup Casanovo for most commands.

Parameters
Expand All @@ -348,6 +394,8 @@ def setup_model(
------
config : Config
The parsed configuration
model : str
The name of the model weights.
"""
# Read parameters from the config file.
config = Config(config)
Expand Down
10 changes: 9 additions & 1 deletion casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
self,
config: Config,
model_filename: Optional[str] = None,
root_checkpoint_name: Optional[str] = None,
) -> None:
"""Initialize a ModelRunner"""
self.config = config
Expand All @@ -54,6 +55,10 @@ def __init__(
self.loaders = None
self.writer = None

checkpoint_filename = None
if root_checkpoint_name is not None:
checkpoint_filename = root_checkpoint_name + ".{epoch}-{step}"

# Configure checkpoints.
if config.save_top_k is not None:
self.callbacks = [
Expand All @@ -62,6 +67,7 @@ def __init__(
monitor="valid_CELoss",
mode="min",
save_top_k=config.save_top_k,
filename=checkpoint_filename,
)
]
else:
Expand Down Expand Up @@ -146,7 +152,9 @@ def predict(self, peak_path: Iterable[str], output: str) -> None:
-------
self
"""
self.writer = ms_io.MztabWriter(Path(output).with_suffix(".mztab"))
self.writer = ms_io.MztabWriter(
output.with_suffix(output.suffix + ".mztab")
)
self.writer.set_metadata(
self.config,
model=str(self.model_filename),
Expand Down
Loading