From acdc1b2babde16d00f32157cd8cda3c80fbfd050 Mon Sep 17 00:00:00 2001 From: wfondrie Date: Mon, 13 Jul 2020 14:40:27 -0700 Subject: [PATCH] Added ability to load Percolator model --- mokapot/model.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/mokapot/model.py b/mokapot/model.py index 84447d79..e2182035 100644 --- a/mokapot/model.py +++ b/mokapot/model.py @@ -12,6 +12,8 @@ class that contains this logic. A :py:class:`Model` instance can be import logging import pickle +import numpy as np +import pandas as pd import sklearn.base as base import sklearn.svm as svm import sklearn.model_selection as ms @@ -303,7 +305,11 @@ def save_model(model, out_file): def load_model(model_file): """ - Load a saved :py:class:`mokapot.model.Model` object. + Load a saved model for mokapot. + + The saved model can either be a saved :py:class:`mokapot.model.Model` + object or the output model weights from Percolator. In Percolator, + these can be obtained using the :code:`--weights` argument. Parameters ---------- @@ -315,8 +321,26 @@ def load_model(model_file): mokapot.model.Model The loaded :py:class:`mokapot.model.Model` object. """ - with open(model_file, "rb") as mod_in: - model = pickle.load(mod_in) + # Try a percolator model first: + try: + weights = pd.read_csv(model_file, sep="\t", nrows=2).loc[1, :] + logging.info("Loading the Percolator model.") + + weight_cols = [c for c in weights.index if c != "m0"] + model = Model(estimator=svm.LinearSVC(), scaler="as-is", + is_trained=True) + + weight_vals = weights.loc[weight_cols] + weight_vals = weight_vals[np.newaxis, :] + model.estimator.coef_ = weight_vals + model.estimator.intercept_ = weights.loc["m0"] + model.features = weight_cols + + # Then try loading it with pickle: + except UnicodeDecodeError: + logging.info("Loading mokapot model.") + with open(model_file, "rb") as mod_in: + model = pickle.load(mod_in) return model