-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLeaveOutAlloyCV.py
55 lines (43 loc) · 2.01 KB
/
LeaveOutAlloyCV.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import matplotlib.pyplot as plt
import data_parser
import numpy as np
from sklearn.kernel_ridge import KernelRidge
from sklearn.metrics import mean_squared_error
def loacv(model=KernelRidge(alpha=.00518, coef0=1, degree=3, gamma=.518, kernel='laplacian', kernel_params=None),
datapath="../../DBTT_Data.csv", savepath='../../{}.png',
X=["N(Cu)", "N(Ni)", "N(Mn)", "N(P)","N(Si)", "N( C )", "N(log(fluence)", "N(log(flux)", "N(Temp)"],
Y="delta sigma"):
data = data_parser.parse(datapath)
data.set_x_features(X)
data.set_y_feature(Y)
rms_list = []
alloy_list = []
for alloy in range(1, 60):
model = model # creates a new model
# fit model to all alloys except the one to be removed
data.remove_all_filters()
data.add_exclusive_filter("Alloy", '=', alloy)
model.fit(data.get_x_data(), np.asarray(data.get_y_data()).ravel())
# predict removed alloy
data.remove_all_filters()
data.add_inclusive_filter("Alloy", '=', alloy)
if len(data.get_x_data()) == 0: continue # if alloy doesn't exist(x data is empty), then continue
Ypredict = model.predict(data.get_x_data())
rms = np.sqrt(mean_squared_error(Ypredict, np.asarray(data.get_y_data()).ravel()))
rms_list.append(rms)
alloy_list.append(alloy)
print('Mean RMSE: ', np.mean(rms_list))
# graph rmse vs alloy
fig, ax = plt.subplots(figsize=(10, 4))
plt.xticks(np.arange(0, max(alloy_list) + 1, 5))
ax.scatter(alloy_list, rms_list, color='black', s=10)
ax.plot((0, 59), (0, 0), ls="--", c=".3")
ax.set_xlabel('Alloy Number')
ax.set_ylabel('RMSE (Mpa)')
ax.set_title('Leave out Alloy')
ax.text(.05, .88, 'Mean RMSE: {:.2f}'.format(np.mean(rms_list)), fontsize=14, transform=ax.transAxes)
for x in np.argsort(rms_list)[-5:]:
ax.annotate(s = alloy_list[x],xy = (alloy_list[x], rms_list[x]))
fig.savefig(savepath.format(ax.get_title()), dpi=200, bbox_inches='tight')
fig.clf()
plt.close()