-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path7_MLP_L2.py
91 lines (78 loc) · 3.51 KB
/
7_MLP_L2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import pickle
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from scipy.stats import wilcoxon
import tensorflow as tf
from keras.regularizers import L2
from keras.layers import *
with open('raw/indexes_std.pkl', 'rb') as f:
# with open('raw/lorenz96.pkl', 'rb') as f:
x_train, y_train, x_test, y_test = pickle.load(f)
n_nodes = x_train.shape[2]
z_val = np.full((n_nodes, n_nodes), np.nan)
p_val = np.full((n_nodes, n_nodes), np.nan)
rng = np.random.RandomState(seed=23846)
stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=60)
tf.random.set_seed(seed=748697)
for i in range(n_nodes):
x_ur_train = np.swapaxes(x_train, 1, 2).reshape(x_train.shape[0], -1)
x_ur_test = np.swapaxes(x_test, 1, 2).reshape(x_test.shape[0], -1)
x_r_train = x_train.copy()
rng.shuffle(x_r_train[:, :, i])
x_r_train = np.swapaxes(x_r_train, 1, 2).reshape(x_train.shape[0], -1)
x_r_test = x_test.copy()
rng.shuffle(x_r_test[:, :, i])
x_r_test = np.swapaxes(x_r_test, 1, 2).reshape(x_test.shape[0], -1)
for j in range(n_nodes):
if i == j:
continue
l0 = Input(shape=x_ur_train.shape[1])
l1 = Dense(256, activation='relu', kernel_regularizer=L2(0.1))(l0)
l2 = Dense(256, activation='relu', kernel_regularizer=L2(0.1))(l1)
l3 = Dropout(0.1)(l2)
l4 = Dense(64, activation='relu')(l3)
l5 = Dense(1)(l4)
r = tf.keras.Model(l0, l5)
ur = tf.keras.models.clone_model(r)
r.compile(optimizer='adam', loss='mse')
r.fit(
x_r_train, y_train[:, j], validation_data=(x_r_test, y_test[:, j]),
epochs=5000, batch_size=1000, callbacks=[stop_early], verbose=0
) # restricted model
y_r_test_hat = r.predict(x_r_test, batch_size=2000)
ur.compile(optimizer='adam', loss='mse')
ur.fit(
x_ur_train, y_train[:, j], validation_data=(x_ur_test, y_test[:, j]),
epochs=5000, batch_size=1000, callbacks=[stop_early], verbose=0
) # unrestricted model
y_ur_test_hat = ur.predict(x_ur_test, batch_size=2000)
err_ur = (y_ur_test_hat.flatten() - y_test[:, j]) ** 2
err_r = (y_r_test_hat.flatten() - y_test[:, j]) ** 2
wilcoxon_results = wilcoxon(x=err_r, y=err_ur, method='approx')
w, p, z = wilcoxon_results.statistic, wilcoxon_results.pvalue, wilcoxon_results.zstatistic
p_val[i, j] = p
z_val[i, j] = z
print('i =', i, 'j =', j, 'p =', p)
# %% Export results
with open('raw/indexes_conditional_MLP_L2_0.1_wilcoxon_stats.pkl', 'wb') as f:
# with open('raw/lorenz96_conditional_MLP_L2_0.1_wilcoxon_stats.pkl', 'wb') as f:
pickle.dump({'z_val': z_val, 'p_val': p_val}, f)
with open('raw/indexes_names.pkl', 'rb') as f:
col_names = pickle.load(f)
# %% Causality heatmap
fig, ax = plt.subplots(figsize=(10, 8))
mask = np.zeros_like(p_val, dtype=bool)
mask[np.diag_indices_from(mask)] = True
heatmap = sns.heatmap(p_val, mask=mask, square=True, linewidths=.5, cmap='coolwarm',
vmin=0, vmax=0.1, annot=True, fmt='.2f')
# add the column names as labels
ax.set_yticklabels(col_names, rotation=0)
ax.set_xticklabels(col_names, rotation=90)
ax.set_ylabel('Cause')
ax.set_xlabel('Effect')
fig.subplots_adjust(bottom=0.15, top=0.95)
sns.set_style({'xtick.bottom': True}, {'ytick.left': True})
heatmap.get_figure().savefig('results/indexes_conditional_MLP_L2_0.1_wilcoxon.eps')
# heatmap.get_figure().savefig('results/lorenz96_conditional_MLP_L2_0.1_wilcoxon.eps')
plt.close(fig)