Skip to content

Commit

Permalink
Added ability to load Percolator model
Browse files Browse the repository at this point in the history
  • Loading branch information
wfondrie committed Jul 13, 2020
1 parent b843547 commit acdc1b2
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions mokapot/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class that contains this logic. A :py:class:`Model` instance can be
import logging
import pickle

import numpy as np
import pandas as pd
import sklearn.base as base
import sklearn.svm as svm
import sklearn.model_selection as ms
Expand Down Expand Up @@ -303,7 +305,11 @@ def save_model(model, out_file):

def load_model(model_file):
"""
Load a saved :py:class:`mokapot.model.Model` object.
Load a saved model for mokapot.
The saved model can either be a saved :py:class:`mokapot.model.Model`
object or the output model weights from Percolator. In Percolator,
these can be obtained using the :code:`--weights` argument.
Parameters
----------
Expand All @@ -315,8 +321,26 @@ def load_model(model_file):
mokapot.model.Model
The loaded :py:class:`mokapot.model.Model` object.
"""
with open(model_file, "rb") as mod_in:
model = pickle.load(mod_in)
# Try a percolator model first:
try:
weights = pd.read_csv(model_file, sep="\t", nrows=2).loc[1, :]
logging.info("Loading the Percolator model.")

weight_cols = [c for c in weights.index if c != "m0"]
model = Model(estimator=svm.LinearSVC(), scaler="as-is",
is_trained=True)

weight_vals = weights.loc[weight_cols]
weight_vals = weight_vals[np.newaxis, :]
model.estimator.coef_ = weight_vals
model.estimator.intercept_ = weights.loc["m0"]
model.features = weight_cols

# Then try loading it with pickle:
except UnicodeDecodeError:
logging.info("Loading mokapot model.")
with open(model_file, "rb") as mod_in:
model = pickle.load(mod_in)

return model

Expand Down

0 comments on commit acdc1b2

Please sign in to comment.