Skip to content

Commit

Permalink
fix line lengths and modify test
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunAnanth2003 committed May 7, 2024
1 parent 81aa073 commit 0ecbd80
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
25 changes: 17 additions & 8 deletions casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,8 +993,10 @@ class DBSpec2Pep(Spec2Pep):
"""
Inherits Spec2Pep
Hijacks teacher-forcing implemented in Spec2Pep and uses it to predict scores between a spectra and associated peptide.
Input format is .mgf, with comma-separated targets and decoys in the SEQ field. Decoys should have a prefix of "decoy_".
Hijacks teacher-forcing implemented in Spec2Pep and
uses it to predict scores between a spectra and associated peptide.
Input format is .mgf, with comma-separated targets
and decoys in the SEQ field. Decoys should have a prefix of "decoy_".
"""

num_pairs = None # Modified to be predict_batch_size from config
Expand Down Expand Up @@ -1120,22 +1122,29 @@ def _calc_match_score(
batch_all_aa_scores: torch.Tensor, truth_aa_indicies: torch.Tensor
) -> List[float]:
"""
Take in teacher-forced scoring of amino acids of the peptides (in a batch) and use the truth labels
to calculate a score between the input spectra and associated peptide. The score is the geometric
Take in teacher-forced scoring of amino acids
of the peptides (in a batch) and use the truth labels
to calculate a score between the input spectra and
associated peptide. The score is the geometric
mean of the AA probabilities
Parameters
----------
batch_all_aa_scores : torch.Tensor
Amino acid scores for all amino acids in the vocabulary for every prediction made to generate the associated peptide (for an entire batch)
Amino acid scores for all amino acids in
the vocabulary for every prediction made to generate
the associated peptide (for an entire batch)
truth_aa_indicies : torch.Tensor
Indicies of the score for each actual amino acid in the peptide (for an entire batch)
Indicies of the score for each actual amino acid
in the peptide (for an entire batch)
Returns
-------
score : list[float], list[list[float]]
The score between the input spectra and associated peptide (for an entire batch)
a list of lists of per amino acid scores (for an entire batch)
The score between the input spectra and associated peptide
(for an entire batch)
a list of lists of per amino acid scores
(for an entire batch)
"""
# Remove trailing tokens from predictions,
batch_all_aa_scores = batch_all_aa_scores[:, :-1]
Expand Down
7 changes: 4 additions & 3 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,17 @@ def test_annotate(
== "SEQ=LESLIEK,PEPTIDEK,decoy_KEILSEL,decoy_KEDITEPP"
)
assert (
seq_lines[2].strip()
== "SEQ=+42.011LEM+15.995SLIM+15.995EK,+43.006PEN+0.984PTIQ+0.984DEK,decoy_-17.027KM+15.995EILSEL,decoy_+43.006-17.027KEDITEPP,decoy_KEDIQ+0.984TEPPQ+0.984"
seq_lines[2].strip() == "SEQ=+42.011LEM+15.995SLIM+15.995EK,"
"+43.006PEN+0.984PTIQ+0.984DEK,decoy_-17.027KM+15.995EILSEL,"
"decoy_+43.006-17.027KEDITEPP,decoy_KEDIQ+0.984TEPPQ+0.984"
)


def test_db_search(
mgf_small_unannotated, tide_dir_small, tiny_config, tmp_path, monkeypatch
):
# Run a command:
monkeypatch.setattr(casanovo, "__version__", "4.1.1")
monkeypatch.setattr(casanovo, "__version__", "4.1.0")
run = functools.partial(
CliRunner().invoke, casanovo.main, catch_exceptions=False
)
Expand Down

0 comments on commit 0ecbd80

Please sign in to comment.