Skip to content

Commit

Permalink
early stop mechanism when there is no improvement at all in last steps
Browse files Browse the repository at this point in the history
  • Loading branch information
tjiagoM committed Mar 17, 2020
1 parent 064f77e commit 7a17633
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 138 deletions.
32 changes: 18 additions & 14 deletions 01_pyg_hcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import random
from sys import exit
from collections import deque

import numpy as np
import torch
Expand Down Expand Up @@ -137,14 +138,15 @@ def get_array_data(data_fold, num_nodes=50):

return np.array(tmp_array), np.array(tmp_y)


if __name__ == '__main__':

#import warnings
# import warnings

#warnings.filterwarnings("ignore")
# warnings.filterwarnings("ignore")
torch.manual_seed(1)
# torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.benchmark = False
np.random.seed(1111)
random.seed(1111)
torch.cuda.manual_seed_all(1111)
Expand All @@ -158,7 +160,7 @@ def get_array_data(data_fold, num_nodes=50):
parser.add_argument("--activation", default='relu')
parser.add_argument("--threshold", default=5, type=int)
parser.add_argument("--num_nodes", default=50, type=int)
parser.add_argument("--num_epochs", default=20, type=int)
parser.add_argument("--num_epochs", default=100, type=int)
parser.add_argument("--batch_size", default=150, type=int)
parser.add_argument("--add_gcn", type=bool, default=False) # to make true just include flag with 1
parser.add_argument("--add_gat", type=bool, default=False) # to make true just include flag with 1
Expand All @@ -173,6 +175,7 @@ def get_array_data(data_fold, num_nodes=50):
parser.add_argument("--analysis_type", default='spatiotemporal')
parser.add_argument("--time_length", type=int)
parser.add_argument("--encoding_strategy", default='none')
parser.add_argument("--early_stop_steps", default=30, type=int)

args = parser.parse_args()

Expand Down Expand Up @@ -202,9 +205,7 @@ def get_array_data(data_fold, num_nodes=50):
TIME_LENGTH = args.time_length
TS_SPIT_NUM = int(4800 / TIME_LENGTH)
ENCODING_STRATEGY = EncodingStrategy(args.encoding_strategy)

if NUM_NODES == 300 and CHANNELS_CONV > 1:
BATCH_SIZE = int(BATCH_SIZE / 3)
EARLY_STOP_STEPS = args.early_stop_steps

#if CONV_STRATEGY != ConvStrategy.TCN_ENTIRE:
# print("Setting to deterministic runs")
Expand Down Expand Up @@ -306,11 +307,9 @@ def get_array_data(data_fold, num_nodes=50):
# best_metric = -100
# best_params = None
best_model_name_outer_fold_auc = None
best_model_name_outer_fold_acc = None
best_model_name_outer_fold_loss = None
best_outer_metric_loss = 1000
best_outer_metric_auc = -1000
best_outer_metric_acc = -1000
for params in grid:
print("For ", params)

Expand All @@ -327,7 +326,7 @@ def get_array_data(data_fold, num_nodes=50):
merged_labels_inner,
groups=[data.hcp_id.item() for data in X_train_out])
model_with_sigmoid = True
metrics = ['acc', 'f1', 'auc', 'loss']
metrics = ['auc', 'loss']

# This for-cycle will only be executed once (for now)
for inner_train_index, inner_val_index in skf_inner_generator:
Expand Down Expand Up @@ -366,10 +365,6 @@ def get_array_data(data_fold, num_nodes=50):
THRESHOLD, BATCH_SIZE, REMOVE_NODES, NUM_NODES, CONN_TYPE,
NORMALISATION, ANALYSIS_TYPE,
m)
# If there is one of the metrics saved, then I assume this inner part was already calculated
if os.path.isfile(model_names[metrics[0]]):
print("Saved model exists, thus skipping this search...")
break # break because I'm in the "inner" fold, which is being done only once

X_train_in = X_train_out[torch.tensor(inner_train_index)]
X_val_in = X_train_out[torch.tensor(inner_val_index)]
Expand Down Expand Up @@ -412,6 +407,9 @@ def get_array_data(data_fold, num_nodes=50):
best_metrics_fold[m] = 1000
else:
best_metrics_fold[m] = -1000
# Only for loss
last_losses_val = deque([1000 for _ in range(EARLY_STOP_STEPS)], maxlen=EARLY_STOP_STEPS)

for epoch in range(1, N_EPOCHS):
if TARGET_VAR == 'gender':
val_metrics = classifier_step(outer_split_num,
Expand All @@ -420,6 +418,12 @@ def get_array_data(data_fold, num_nodes=50):
model,
train_in_loader,
val_loader)
if sum([val_metrics['loss'] > loss for loss in last_losses_val]) == EARLY_STOP_STEPS:
print("EARLY STOPPING IT")
break

last_losses_val.append(val_metrics['loss'])

if val_metrics['loss'] < best_metrics_fold['loss']:
best_metrics_fold['loss'] = val_metrics['loss']
torch.save(model, model_names['loss'])
Expand Down
252 changes: 128 additions & 124 deletions outputs/calculate_means.py
Original file line number Diff line number Diff line change
@@ -1,124 +1,128 @@

import numpy as np

dict_results = {
'dummy to copy' : {'aucs': [0., 0., 0., 0., 0.],
'sens': [0., 0., 0., 0., 0.],
'spec': [0., 0., 0., 0., 0.]},

'concat without gcn' : {'aucs': [0.7595, 0.6832, 0.6945, 0.7485, 0.7289],
'sens': [0.9202, 0.5771, 0.6170, 0.7181, 0.7984],
'spec': [0.3820, 0.6800, 0.6497, 0.6197, 0.4973]},

'concat with 5% gcn' : {'aucs': [0.7513, 0.6629, 0.6863, 0.7523, 0.7124],
'sens': [0.6995, 0.6037, 0.7394, 0.7207, 0.6989],
'spec': [0.6207, 0.6400, 0.5455, 0.6649, 0.5802]},

'concat with 20% gcn' : {'aucs': [0.7520, 0.6502, 0.6827, 0.7638, 0.7037],
'sens': [0.7580, 0.5984, 0.5612, 0.7314, 0.6586],
'spec': [0.5995, 0.6213, 0.6952, 0.6516, 0.6471]},

'xgboost binarised 5' : {'aucs': [0.6600, 0.7217, 0.6706, 0.7114, 0.6757],
'sens': [0.6622, 0.7447, 0.6888, 0.6968, 0.7070],
'spec': [0.6578, 0.6987, 0.6524, 0.7261, 0.6444]},
'xgboost binarised 20' : {'aucs': [0.7370, 0.7243, 0.7067, 0.7247, 0.7064],
'sens': [0.7287, 0.7473, 0.6968, 0.6782, 0.7070],
'spec': [0.7454, 0.7013, 0.7166, 0.7713, 0.7059]},

'mean_TCN_GCN5' : {'aucs': [0.6920, 0.6349, 0.5968, 0.7117, 0.6074],
'sens': [1.0000, 1.0000, 0.0000, 0.8537, 1.0000],
'spec': [0.0000, 0.0027, 1.0000, 0.3777, 0.0000]},

'mean_TCN' : {'aucs': [0.6928, 0.6169, 0.6612, 0.7520, 0.6453],
'sens': [0.7287, 0.9973, 0.6064, 0.4601, 0.0000],
'spec': [0.5438, 0.0160, 0.6364, 0.8723, 1.0000]},

'mean_CNN_64split+' : {'aucs': [0.6483, 0.6442, 0.6534, 0.6842, 0.6233],
'sens': [0.6690, 0.6471, 0.5296, 0.6418, 0.6964],
'spec': [0.5463, 0.5519, 0.6791, 0.6250, 0.4898]},

'mean_CNN_64split' : {'aucs': [0.6426, 0.6404, 0.6394, 0.6885, 0.6206],
'sens': [0.6815, 0.6277, 0.5136, 0.6511, 0.6252],
'spec': [0.5167, 0.5712, 0.6749, 0.6219, 0.5500]},

'AUC xgboost 64plit' : {'aucs': [0.6971, 0.6947, 0.6877, 0.6814, 0.6873],
'sens': [0.6867, 0.6780, 0.6672, 0.6795, 0.6788],
'spec': [0.7075, 0.7114, 0.7081, 0.6832, 0.6959]},

'AUC xgboost4plit' : {'aucs': [0.7875, 0.7723, 0.7853, 0.7859, 0.7950],
'accs': [0.7875, 0.7723, 0.7853, 0.7859, 0.7949],
'sens': [0.7686, 0.7899, 0.7819, 0.7660, 0.8118],
'spec': [0.8064, 0.7547, 0.7888, 0.8059, 0.7781]},
############
'AUC diff_pool 5' : {'aucs': [0.6752, 0.6335, 0.6529, 0.6993, 0.6767],
'accs': [0.6016, 0.6005, 0.6147, 0.6184, 0.6327],
'f1s' : [0.6842, 0.6386, 0.6980, 0.6911, 0.6675]},

'AUC diff_pool 20' : {'aucs': [0.6576, 0.6453, 0.6744, 0.7378, 0.6735],
'accs': [0.6255, 0.6165, 0.6400, 0.6343, 0.6434],
'f1s' : [0.7044, 0.6697, 0.6438, 0.7046, 0.6463]},

'AUC mean 5' : {'aucs': [0.6782, 0.6404, 0.6872, 0.7488, 0.7032],
'accs': [0.5007, 0.5819, 0.5947, 0.6622, 0.6180],
'f1s' : [0.0000, 0.5527, 0.4967, 0.6947, 0.4956]},

'AUC mean 20' : {'aucs': [0.6787, 0.6404, 0.6873, 0.7490, 0.7021],
'accs': [0.5007, 0.5819, 0.5693, 0.6622, 0.6072],
'f1s' : [0.0000, 0.5527, 0.3501, 0.6947, 0.4564]},

'Loss diff_pool 5' : {'aucs': [0.5045, 0.5039, 0.6614, 0.5159, 0.6733],
'accs': [0.5007, 0.5007, 0.6227, 0.5000, 0.6206],
'f1s' : [0.0000, 0.6673, 0.6907, 0.6667, 0.6698]},

'Loss diff_pool 20' : {'aucs': [0.6915, 0.6444, 0.6722, 0.5247, 0.4919],
'accs': [0.6255, 0.6192, 0.6173, 0.5000, 0.4987],
'f1s' : [0.6853, 0.6324, 0.6530, 0.0000, 0.6655]},

'Loss mean 5' : {'aucs': [0.6895, 0.6388, 0.6871, 0.7488, 0.6807],
'accs': [0.6321, 0.5925, 0.6267, 0.6622, 0.6059],
'f1s' : [0.6126, 0.5785, 0.6143, 0.6947, 0.6142]},

'Loss mean 20' : {'aucs': [0.6895, 0.6388, 0.6871, 0.7490, 0.6807],
'accs': [0.6321, 0.5925, 0.6267, 0.6622, 0.6059],
'f1s' : [0.6126, 0.5785, 0.6143, 0.6947, 0.6202]},

############
'GCN AUC diff_pool 5' : {'aucs': [0.6097, 0.6304, 0.6719, 0.7066, 0.6454],
'accs': [0.4993, 0.6178, 0.5013, 0.6449, 0.6005],
'f1s' : [0., 0., 0., 0., 0.]},

'GCN AUC diff_pool 20' : {'aucs': [0.6367, 0.6383, 0.6703, 0.6999, 0.6736],
'accs': [0.4993, 0.6152, 0.6293, 0.6609, 0.6247],
'f1s' : [0., 0., 0., 0., 0.]},

'GCN AUC mean 5' : {'aucs': [0.7266, 0.6478, 0.6832, 0.7682, 0.6887],
'accs': [0.4993, 0.5007, 0.6187, 0.6742, 0.5697],
'f1s' : [0., 0., 0., 0., 0.]},

'GCN AUC mean 20' : {'aucs': [0.6778, 0.6506, 0.6706, 0.7282, 0.6904],
'accs': [0.6361, 0.6232, 0.5013, 0.5000, 0.6099],
'f1s' : [0., 0., 0., 0., 0.]},

'GCN Loss diff_pool 5' : {'aucs': [0.5055, 0.4938, 0.6230, 0.5902, 0.4557],
'accs': [0.4993, 0.4993, 0.5013, 0.5066, 0.4987],
'f1s' : [0., 0., 0., 0., 0.]},

'GCN Loss diff_pool 20' : {'aucs': [0.5159, 0.6256, 0.6658, 0.4174, 0.4065],
'accs': [0.5060, 0.5925, 0.6320, 0.4814, 0.5013],
'f1s' : [0., 0., 0., 0., 0.]},

'GCN Loss mean 5' : {'aucs': [0.7206, 0.6477, 0.6855, 0.7581, 0.6935],
'accs': [0.6640, 0.5925, 0.6280, 0.6862, 0.6287],
'f1s' : [0., 0., 0., 0., 0.]},

'GCN Loss mean 20' : {'aucs': [0.7216, 0.6460, 0.6840, 0.7588, 0.6897],
'accs': [0.6521, 0.5939, 0.6320, 0.6902, 0.6247],
'f1s' : [0., 0., 0., 0., 0.]}
}


for key, value in dict_results.items():
print(key, ":")
for metric, values in value.items():
print(metric, ":", round(np.mean(values), 3), "(", round(np.std(values), 3), ")")
print()

import numpy as np

dict_results = {
'dummy to copy' : {'aucs': [0., 0., 0., 0., 0.],
'sens': [0., 0., 0., 0., 0.],
'spec': [0., 0., 0., 0., 0.]},

'ukb xgboost' : {'aucs': [0.8835, 0.8755, 0.8850, 0.8911, 0.8795],
'sens': [0.8932, 0.8725, 0.8780, 0.8816, 0.8871],
'spec': [0.8737, 0.8786, 0.8920, 0.9005, 0.8719]},

'concat without gcn' : {'aucs': [0.7595, 0.6832, 0.6945, 0.7485, 0.7289],
'sens': [0.9202, 0.5771, 0.6170, 0.7181, 0.7984],
'spec': [0.3820, 0.6800, 0.6497, 0.6197, 0.4973]},

'concat with 5% gcn' : {'aucs': [0.7513, 0.6629, 0.6863, 0.7523, 0.7124],
'sens': [0.6995, 0.6037, 0.7394, 0.7207, 0.6989],
'spec': [0.6207, 0.6400, 0.5455, 0.6649, 0.5802]},

'concat with 20% gcn' : {'aucs': [0.7520, 0.6502, 0.6827, 0.7638, 0.7037],
'sens': [0.7580, 0.5984, 0.5612, 0.7314, 0.6586],
'spec': [0.5995, 0.6213, 0.6952, 0.6516, 0.6471]},

'xgboost binarised 5' : {'aucs': [0.6600, 0.7217, 0.6706, 0.7114, 0.6757],
'sens': [0.6622, 0.7447, 0.6888, 0.6968, 0.7070],
'spec': [0.6578, 0.6987, 0.6524, 0.7261, 0.6444]},
'xgboost binarised 20' : {'aucs': [0.7370, 0.7243, 0.7067, 0.7247, 0.7064],
'sens': [0.7287, 0.7473, 0.6968, 0.6782, 0.7070],
'spec': [0.7454, 0.7013, 0.7166, 0.7713, 0.7059]},

'mean_TCN_GCN5' : {'aucs': [0.6920, 0.6349, 0.5968, 0.7117, 0.6074],
'sens': [1.0000, 1.0000, 0.0000, 0.8537, 1.0000],
'spec': [0.0000, 0.0027, 1.0000, 0.3777, 0.0000]},

'mean_TCN' : {'aucs': [0.6928, 0.6169, 0.6612, 0.7520, 0.6453],
'sens': [0.7287, 0.9973, 0.6064, 0.4601, 0.0000],
'spec': [0.5438, 0.0160, 0.6364, 0.8723, 1.0000]},

'mean_CNN_64split+' : {'aucs': [0.6483, 0.6442, 0.6534, 0.6842, 0.6233],
'sens': [0.6690, 0.6471, 0.5296, 0.6418, 0.6964],
'spec': [0.5463, 0.5519, 0.6791, 0.6250, 0.4898]},

'mean_CNN_64split' : {'aucs': [0.6426, 0.6404, 0.6394, 0.6885, 0.6206],
'sens': [0.6815, 0.6277, 0.5136, 0.6511, 0.6252],
'spec': [0.5167, 0.5712, 0.6749, 0.6219, 0.5500]},

'AUC xgboost 64plit' : {'aucs': [0.6971, 0.6947, 0.6877, 0.6814, 0.6873],
'sens': [0.6867, 0.6780, 0.6672, 0.6795, 0.6788],
'spec': [0.7075, 0.7114, 0.7081, 0.6832, 0.6959]},

'AUC xgboost4plit' : {'aucs': [0.7875, 0.7723, 0.7853, 0.7859, 0.7950],
'accs': [0.7875, 0.7723, 0.7853, 0.7859, 0.7949],
'sens': [0.7686, 0.7899, 0.7819, 0.7660, 0.8118],
'spec': [0.8064, 0.7547, 0.7888, 0.8059, 0.7781]},
############
'AUC diff_pool 5' : {'aucs': [0.6752, 0.6335, 0.6529, 0.6993, 0.6767],
'accs': [0.6016, 0.6005, 0.6147, 0.6184, 0.6327],
'f1s' : [0.6842, 0.6386, 0.6980, 0.6911, 0.6675]},

'AUC diff_pool 20' : {'aucs': [0.6576, 0.6453, 0.6744, 0.7378, 0.6735],
'accs': [0.6255, 0.6165, 0.6400, 0.6343, 0.6434],
'f1s' : [0.7044, 0.6697, 0.6438, 0.7046, 0.6463]},

'AUC mean 5' : {'aucs': [0.6782, 0.6404, 0.6872, 0.7488, 0.7032],
'accs': [0.5007, 0.5819, 0.5947, 0.6622, 0.6180],
'f1s' : [0.0000, 0.5527, 0.4967, 0.6947, 0.4956]},

'AUC mean 20' : {'aucs': [0.6787, 0.6404, 0.6873, 0.7490, 0.7021],
'accs': [0.5007, 0.5819, 0.5693, 0.6622, 0.6072],
'f1s' : [0.0000, 0.5527, 0.3501, 0.6947, 0.4564]},

'Loss diff_pool 5' : {'aucs': [0.5045, 0.5039, 0.6614, 0.5159, 0.6733],
'accs': [0.5007, 0.5007, 0.6227, 0.5000, 0.6206],
'f1s' : [0.0000, 0.6673, 0.6907, 0.6667, 0.6698]},

'Loss diff_pool 20' : {'aucs': [0.6915, 0.6444, 0.6722, 0.5247, 0.4919],
'accs': [0.6255, 0.6192, 0.6173, 0.5000, 0.4987],
'f1s' : [0.6853, 0.6324, 0.6530, 0.0000, 0.6655]},

'Loss mean 5' : {'aucs': [0.6895, 0.6388, 0.6871, 0.7488, 0.6807],
'accs': [0.6321, 0.5925, 0.6267, 0.6622, 0.6059],
'f1s' : [0.6126, 0.5785, 0.6143, 0.6947, 0.6142]},

'Loss mean 20' : {'aucs': [0.6895, 0.6388, 0.6871, 0.7490, 0.6807],
'accs': [0.6321, 0.5925, 0.6267, 0.6622, 0.6059],
'f1s' : [0.6126, 0.5785, 0.6143, 0.6947, 0.6202]},

############
'GCN AUC diff_pool 5' : {'aucs': [0.6097, 0.6304, 0.6719, 0.7066, 0.6454],
'accs': [0.4993, 0.6178, 0.5013, 0.6449, 0.6005],
'f1s' : [0., 0., 0., 0., 0.]},

'GCN AUC diff_pool 20' : {'aucs': [0.6367, 0.6383, 0.6703, 0.6999, 0.6736],
'accs': [0.4993, 0.6152, 0.6293, 0.6609, 0.6247],
'f1s' : [0., 0., 0., 0., 0.]},

'GCN AUC mean 5' : {'aucs': [0.7266, 0.6478, 0.6832, 0.7682, 0.6887],
'accs': [0.4993, 0.5007, 0.6187, 0.6742, 0.5697],
'f1s' : [0., 0., 0., 0., 0.]},

'GCN AUC mean 20' : {'aucs': [0.6778, 0.6506, 0.6706, 0.7282, 0.6904],
'accs': [0.6361, 0.6232, 0.5013, 0.5000, 0.6099],
'f1s' : [0., 0., 0., 0., 0.]},

'GCN Loss diff_pool 5' : {'aucs': [0.5055, 0.4938, 0.6230, 0.5902, 0.4557],
'accs': [0.4993, 0.4993, 0.5013, 0.5066, 0.4987],
'f1s' : [0., 0., 0., 0., 0.]},

'GCN Loss diff_pool 20' : {'aucs': [0.5159, 0.6256, 0.6658, 0.4174, 0.4065],
'accs': [0.5060, 0.5925, 0.6320, 0.4814, 0.5013],
'f1s' : [0., 0., 0., 0., 0.]},

'GCN Loss mean 5' : {'aucs': [0.7206, 0.6477, 0.6855, 0.7581, 0.6935],
'accs': [0.6640, 0.5925, 0.6280, 0.6862, 0.6287],
'f1s' : [0., 0., 0., 0., 0.]},

'GCN Loss mean 20' : {'aucs': [0.7216, 0.6460, 0.6840, 0.7588, 0.6897],
'accs': [0.6521, 0.5939, 0.6320, 0.6902, 0.6247],
'f1s' : [0., 0., 0., 0., 0.]}
}


for key, value in dict_results.items():
print(key, ":")
for metric, values in value.items():
print(metric, ":", round(np.mean(values), 3), "(", round(np.std(values), 3), ")")
print()

0 comments on commit 7a17633

Please sign in to comment.