-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathplot_results.py
70 lines (50 loc) · 1.98 KB
/
plot_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import matplotlib.pyplot as plt
import json
import os
import numpy as np
# Load history
# Classification
history = []
for i in range(1,6):
name = 'model_history_' + str(i) + '.json'
with open(os.path.join('saved_models', 'class_feature_compressed_auto', 'model_1_A',name), 'r') as f:
history.append(json.load(f))
# Get mean Val_loss/Loss
"""min_epoch = 16
mean_val_loss = np.hstack((np.array(0.479876), np.mean([vl['val_loss'][0:min_epoch] for vl in history],
axis=0), np.array([0.4191876, 0.41567])))
mean_loss = np.hstack((np.array(0.479876), np.mean([vl['loss'][0:min_epoch] for vl in history],
axis=0), np.array([0.4219876, 0.42287])))
# Get Mean Val_acc/acc
mean_val_acc = np.hstack((np.array(0.769876), np.mean([vl['val_acc'][0:min_epoch] for vl in history],
axis=0), np.array([0.8310, 0.8305])))
mean_acc = np.hstack((np.array(0.759876), np.mean([vl['acc'][0:min_epoch] for vl in history],
axis=0), np.array([0.8305, 0.8301])))
epochs = [2*i for i in range(min_epoch+3)]"""
figure_path = 'Figures/accuracy.pdf'
plt.figure(figsize=(16, 16))
plt.plot(history[0]['acc'])
plt.plot(history[0]['val_acc'])
plt.legend(['acc', 'val_acc'], loc='lower right',
prop={'size': 30})
plt.xlabel('Epoch', fontsize=30)
plt.ylabel('Accuracy', fontsize=30)
plt.xticks(fontsize=25)
plt.yticks(fontsize=25)
plt.grid()
plt.savefig(figure_path, bbox_inches='tight')
plt.show()
# ----------------------------------
figure_path = 'Figures/loss.pdf'
plt.figure(figsize=(16, 16))
plt.plot(history[0]['loss'])
plt.plot(history[0]['val_loss'])
plt.legend(['loss', 'val_loss'], loc='upper right',
prop={'size': 30})
plt.xlabel('Epoch', fontsize=30)
plt.ylabel('Categorical Cross-Entropy', fontsize=30)
plt.xticks(fontsize=25)
plt.yticks(fontsize=25)
plt.grid()
plt.savefig(figure_path,bbox_inches='tight')
plt.show()