diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 4efe0f92..ec234691 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -993,8 +993,10 @@ class DBSpec2Pep(Spec2Pep): """ Inherits Spec2Pep - Hijacks teacher-forcing implemented in Spec2Pep and uses it to predict scores between a spectra and associated peptide. - Input format is .mgf, with comma-separated targets and decoys in the SEQ field. Decoys should have a prefix of "decoy_". + Hijacks teacher-forcing implemented in Spec2Pep and + uses it to predict scores between a spectra and associated peptide. + Input format is .mgf, with comma-separated targets + and decoys in the SEQ field. Decoys should have a prefix of "decoy_". """ num_pairs = None # Modified to be predict_batch_size from config @@ -1120,22 +1122,29 @@ def _calc_match_score( batch_all_aa_scores: torch.Tensor, truth_aa_indicies: torch.Tensor ) -> List[float]: """ - Take in teacher-forced scoring of amino acids of the peptides (in a batch) and use the truth labels - to calculate a score between the input spectra and associated peptide. The score is the geometric + Take in teacher-forced scoring of amino acids + of the peptides (in a batch) and use the truth labels + to calculate a score between the input spectra and + associated peptide. The score is the geometric mean of the AA probabilities Parameters ---------- batch_all_aa_scores : torch.Tensor - Amino acid scores for all amino acids in the vocabulary for every prediction made to generate the associated peptide (for an entire batch) + Amino acid scores for all amino acids in + the vocabulary for every prediction made to generate + the associated peptide (for an entire batch) truth_aa_indicies : torch.Tensor - Indicies of the score for each actual amino acid in the peptide (for an entire batch) + Indicies of the score for each actual amino acid + in the peptide (for an entire batch) Returns ------- score : list[float], list[list[float]] - The score between the input spectra and associated peptide (for an entire batch) - a list of lists of per amino acid scores (for an entire batch) + The score between the input spectra and associated peptide + (for an entire batch) + a list of lists of per amino acid scores + (for an entire batch) """ # Remove trailing tokens from predictions, batch_all_aa_scores = batch_all_aa_scores[:, :-1] diff --git a/tests/test_integration.py b/tests/test_integration.py index e8654c68..3ad1a4f4 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -47,8 +47,9 @@ def test_annotate( == "SEQ=LESLIEK,PEPTIDEK,decoy_KEILSEL,decoy_KEDITEPP" ) assert ( - seq_lines[2].strip() - == "SEQ=+42.011LEM+15.995SLIM+15.995EK,+43.006PEN+0.984PTIQ+0.984DEK,decoy_-17.027KM+15.995EILSEL,decoy_+43.006-17.027KEDITEPP,decoy_KEDIQ+0.984TEPPQ+0.984" + seq_lines[2].strip() == "SEQ=+42.011LEM+15.995SLIM+15.995EK," + "+43.006PEN+0.984PTIQ+0.984DEK,decoy_-17.027KM+15.995EILSEL," + "decoy_+43.006-17.027KEDITEPP,decoy_KEDIQ+0.984TEPPQ+0.984" ) @@ -56,7 +57,7 @@ def test_db_search( mgf_small_unannotated, tide_dir_small, tiny_config, tmp_path, monkeypatch ): # Run a command: - monkeypatch.setattr(casanovo, "__version__", "4.1.1") + monkeypatch.setattr(casanovo, "__version__", "4.1.0") run = functools.partial( CliRunner().invoke, casanovo.main, catch_exceptions=False )