Skip to content

Commit

Permalink
Add sys arg option to predict_nn_rf
Browse files Browse the repository at this point in the history
  • Loading branch information
gayaldassanayake committed Feb 8, 2022
1 parent e393a80 commit f4c7a70
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pipeline/predict/predict_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_sequence_lengths(out_path):
return len_df

def get_feature_data(out_path):
feature_df = pd.read_csv(os.path.join(out_path, 'predictions.csv'))
feature_df = pd.read_csv(os.path.join(out_path, 'nn_rf_predictions.csv'))
len_df = get_sequence_lengths(out_path)
df = pd.merge(feature_df, len_df, on='seq_id', how="left")
features = ["fragment_count", "kmer_plas_prob", "biomer_plas_prob", "length"]
Expand Down
16 changes: 4 additions & 12 deletions pipeline/predict/predict_nn_rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from scipy.interpolate import make_interp_spline
import numpy as np
import pandas as pd
from numpy.random import randint
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import os
import sys

from pipeline.model.NNModule import Model
from pipeline.constants import *
Expand Down Expand Up @@ -56,22 +55,18 @@ def read_kmer_file(seq_id, out_path):

def get_prediction(value, model):
yb = model(value)
# print(yb)
_, preds = torch.max(yb, dim=0)
# print(_.item(),preds.item())
return preds.item(), yb


def predict_kmer(sequence_id, out_path):
kmer_arrays = read_kmer_file(sequence_id, out_path)
# print(f'\nseq_id: {sequence_id} count: {len(kmer_arrays)}')
total = 0
tot_probs = torch.tensor([0.0, 0.0])
for m in kmer_arrays:
prediction, prob_tensor = get_prediction(torch.tensor(m), kmer_model)
tot_probs += prob_tensor
total += prediction
# print(f'kmer_plasmid_avg: {total/len(kmer_arrays)} prob_total = {tot_probs}')
probs_list = (tot_probs/(len(kmer_arrays))).tolist()
probs_list.insert(0, len(kmer_arrays))
probs_list.append(total/len(kmer_arrays))
Expand All @@ -90,15 +85,11 @@ def predict_nn_rf(out_path):
kmer_prediction = predict_kmer(seq, out_path)
full_prediction = [seq]
full_prediction.extend(kmer_prediction)
# print(full_prediction)
selected = sequence_df.loc[sequence_df['id'] == seq][features]
full_prediction.extend(
(biomer_model.predict_proba(selected))[0].tolist())
full_prediction.append((biomer_model.predict(selected)).item())
predictions.append(full_prediction)
# print(
# f"biomer result: {biomer_model.predict(selected)} biomer result probs: {biomer_model.predict_proba(selected)}")
# print(full_prediction)
print('Writing NN, RF predictions...')
predictions_df = pd.DataFrame(predictions, columns=['seq_id', 'fragment_count', 'kmer_chro_prob',
'kmer_plas_prob', 'kmer_prediction_avg', 'biomer_chro_prob', 'biomer_plas_prob', 'biomer_prediction'])
Expand All @@ -108,7 +99,7 @@ def predict_nn_rf(out_path):
predictions_df['product'] = (
predictions_df['kmer_plas_prob'] * predictions_df['biomer_plas_prob'])**0.5
predictions_df.to_csv(os.path.join(
out_path, 'predictions.csv'), index=False)
out_path, 'nn_rf_predictions.csv'), index=False)

print('Plotting graphs...')

Expand Down Expand Up @@ -141,4 +132,5 @@ def predict_nn_rf(out_path):


if __name__ == "__main__":
predict_nn_rf(all_results_path)
out_path = sys.argv[1]
predict_nn_rf(out_path)

0 comments on commit f4c7a70

Please sign in to comment.