Skip to content

Commit

Permalink
wip: make brew generic again
Browse files Browse the repository at this point in the history
  • Loading branch information
jspaezp committed Dec 6, 2024
1 parent ce53dee commit d6f58ac
Show file tree
Hide file tree
Showing 10 changed files with 351 additions and 162 deletions.
5 changes: 3 additions & 2 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ Check these if you believe they are true
- [ ] Was it tested locally? (`make test`)
- [ ] Has the changelog been updated?
- [ ] Have the docs been updated?
- [ ] The level of testing this PR includes is appropriate
- [ ] The level of testing this PR includes is appropriate.
- [ ] No features have been dropped (or note which ones and why).

### Reviewers

Expand All @@ -24,4 +25,4 @@ Check these if you believe they are true

#### Notes

(FILL ME IN, Optional) Additional context to this PR.
(FILL ME IN, Optional) Additional context to this PR.
23 changes: 16 additions & 7 deletions mokapot/brew.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
LinearPsmDataset,
calibrate_scores,
update_labels,
OnDiskPsmDataset,
PsmDataset,
)
from mokapot.model import PercolatorModel, Model
from mokapot.parsers.pin import parse_in_chunks
Expand All @@ -32,7 +32,7 @@
# Functions -------------------------------------------------------------------
@typechecked
def brew(
datasets: list[OnDiskPsmDataset],
datasets: list[PsmDataset],
model: None | Model | list[Model] = None,
test_fdr: float = 0.01,
folds: int = 3,
Expand Down Expand Up @@ -135,7 +135,16 @@ def brew(
test_folds_idx = [dataset._split(folds, rng) for dataset in datasets]

# If trained models are provided, use them as-is.
try:
# If the model is not iterable, it means that a single model is pased, thus
# It is more of a template for training. (or was generated within this
# code, thus "None" was passed)
is_mod_iterable = hasattr(model, "__iter__")
if is_mod_iterable:
# Q: Is this branch ever used?
# JSPP 2024-12-06 I think it makes sense to split this function
# To remve the trained case ... which adds a lot of clutter.
# Furthermore, that function can fall back to this one if its
# Not actually trained.
fitted = [[m, False] for m in model if m.is_trained]

if len(model) != folds:
Expand All @@ -149,7 +158,7 @@ def brew(
"One or more of the provided models was not previously trained"
)

except TypeError:
else:
train_sets = list(
make_train_sets(
test_idx=test_folds_idx,
Expand Down Expand Up @@ -353,7 +362,7 @@ def make_train_sets(test_idx, subset_max_train, data_size, rng):

@typechecked
def _create_linear_dataset(
dataset: OnDiskPsmDataset, psms: pd.DataFrame, enforce_checks: bool = True
dataset: PsmDataset, psms: pd.DataFrame, enforce_checks: bool = True
):
utils.convert_targets_column(
data=psms, target_column=dataset.target_column
Expand Down Expand Up @@ -395,7 +404,7 @@ def predict_fold(
@typechecked
def _predict(
models_idx: list,
datasets: Iterable[OnDiskPsmDataset],
datasets: Iterable[PsmDataset],
models: Iterable[Model],
test_fdr: float,
max_workers: int,
Expand Down Expand Up @@ -488,7 +497,7 @@ def _predict(

@typechecked
def _predict_with_ensemble(
dataset: OnDiskPsmDataset, models: Iterable[Model], max_workers
dataset: PsmDataset, models: Iterable[Model], max_workers
):
"""
Return the new scores for the dataset using ensemble of all trained models
Expand Down
16 changes: 10 additions & 6 deletions mokapot/confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from mokapot.column_defs import get_standard_column_name
from mokapot.constants import CONFIDENCE_CHUNK_SIZE
from mokapot.dataset import OnDiskPsmDataset
from mokapot.dataset import PsmDataset, OptionalColumns
from mokapot.peps import (
peps_from_scores,
TDHistData,
Expand Down Expand Up @@ -62,7 +62,7 @@ class Confidence(object):

def __init__(
self,
dataset: OnDiskPsmDataset,
dataset: PsmDataset,
levels: list[str],
level_paths: dict[str, Path],
out_writers: dict[str, Sequence[TabularDataWriter]],
Expand Down Expand Up @@ -105,6 +105,7 @@ def __init__(
Save decoys confidence estimates as well?
"""

self.dataset = dataset
self._score_column = "score"
self._target_column = dataset.target_column
self._protein_column = "proteinIds"
Expand All @@ -115,7 +116,7 @@ def __init__(
self.do_rollup = do_rollup

if proteins:
self.write_protein_level_data(level_paths, proteins, rng)
self._write_protein_level_data(level_paths, proteins, rng)

self._assign_confidence(
levels=levels,
Expand Down Expand Up @@ -197,7 +198,7 @@ def _assign_confidence(

level_path.unlink(missing_ok=True)

def write_protein_level_data(self, level_paths, proteins, rng):
def _write_protein_level_data(self, level_paths, proteins, rng):
psms = TabularDataReader.from_path(level_paths["psms"]).read()
proteins = picked_protein(
psms,
Expand All @@ -218,6 +219,9 @@ def write_protein_level_data(self, level_paths, proteins, rng):
protein_writer.write(proteins)
LOGGER.info("\t- Found %i unique protein groups.", len(proteins))

def get_optional_columns(self) -> OptionalColumns:
return self.dataset.get_optional_columns()

def to_flashlfq(self, out_file="mokapot.flashlfq.txt"):
"""Save confidenct peptides for quantification with FlashLFQ."""
return to_flashlfq(self, out_file)
Expand All @@ -226,7 +230,7 @@ def to_flashlfq(self, out_file="mokapot.flashlfq.txt"):
# Functions -------------------------------------------------------------------
@typechecked
def assign_confidence(
datasets: list[OnDiskPsmDataset],
datasets: list[PsmDataset],
scores_list: list[np.ndarray[float]],
max_workers: int = 1,
eval_fdr=0.01,
Expand Down Expand Up @@ -542,7 +546,7 @@ def hash_data_row(data_row):
@contextmanager
@typechecked
def create_sorted_file_reader(
dataset: OnDiskPsmDataset,
dataset: PsmDataset,
score_reader: TabularDataReader,
dest_dir: Path,
file_prefix: str,
Expand Down
Loading

0 comments on commit d6f58ac

Please sign in to comment.