-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
83 lines (54 loc) · 2.03 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
from datetime import datetime
import pandas as pd
from rocket import Rocket
from oscnn import OSCNN
from fcn import FCN
from ml import CustomKFold
from loader import load_dataset
from constants import MODELS_PATH, PREDICTION_PATH
from ml import extract_features
def create_classifiers():
# Initialize the classifiers
rocket_classifier = Rocket()
OS_CNN_classifier = OSCNN()
FCN_classifier = FCN()
return [FCN_classifier, OS_CNN_classifier, rocket_classifier]
def main():
# Load the dataset
print(f"[{datetime.now()}] Loading dataset ...")
df = load_dataset()
print(f"[{datetime.now()}] Dataset loaded")
# Initialize the classifiers
classifiers = create_classifiers()
# Create predictions directory if it does not exist
os.makedirs(PREDICTION_PATH, exist_ok=True)
# initialize k-fold
kf = CustomKFold(df, k=5)
for clf in classifiers:
# initialize the results list
results = []
for fold, (_, df_test) in enumerate(kf.iter()):
# get the classifier name
clf_name = clf.__class__.__name__
# get the model
model_path = f"{MODELS_PATH}/{clf_name}_{fold}"
# load the best model
clf.load(model_path)
print(f"[{datetime.now()}] Model {model_path} loaded ...")
# extract features
x_test, _ = extract_features(df_test)
# predict
y_pred = clf.predict_proba(x_test)
print(f"[{datetime.now()}] Predictions completed ...")
# append results
res = df_test[['benchmark_id', 'no_fork','starts_at', 'y']].copy()
res.loc[:, 'y_pred'] = y_pred
results.append(res)
# concatenate results
results = pd.concat(results)
# save results
results.to_csv(f"{PREDICTION_PATH}/{clf_name}.csv", index=False)
print(f"[{datetime.now()}] Predictions saved to {PREDICTION_PATH}/{clf_name}.csv")
if __name__ == "__main__":
main()