-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlinear.py
158 lines (112 loc) · 4.56 KB
/
linear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import ot
import numpy as np
import scipy.linalg as slin
import scipy.optimize as sopt
from config import get_data
from utils.eval import evaluate, write_result
from tqdm import tqdm
from ot_gradient import auto_ot
def otm(X_init, lambda1, max_iter=100, h_tol=1e-8, rho_max=1e+16, eta=0.01):
class Sup:
def __init__(self, X_init):
mask = np.isnan(X_init)
X_init[mask] = 0
self.mask = mask.astype('float')
self.X_init = X_init
self.w = None # 2 * d * d
self.imps = None # n,d
global supp
supp = Sup(X_init)
def _loss(R):
"""Evaluate value and gradient of loss."""
loss = 0.5 / R.shape[0] * (R ** 2).sum()
return loss
def _h(W):
"""Evaluate value and gradient of acyclicity constraint."""
E = slin.expm(W * W)
h = np.trace(E) - d
G_h = E.T * W * 2
return h, G_h
def _adj(w):
"""Convert doubled variables ([2 d^2] array) back to original variables ([d, d] matrix)."""
return (w[:d * d] - w[d * d:]).reshape([d, d])
def _wfunc(w):
"""Evaluate value and gradient of augmented Lagrangian for doubled variables ([2 d^2] array)."""
imps = supp.imps.reshape(d,d)
X = supp.X_init * (1 - supp.mask) + (supp.X_init @ imps) * supp.mask
X = X - np.mean(X, axis=0, keepdims=True) # for l2 only
W = _adj(w)
M = X @ W
R = X - M
loss = _loss(R)
h, G_h = _h(W)
ot_loss, G = auto_ot(X, W, 'w')
# Objective function
obj = loss + alpha * h + lambda1 * w.sum() + 0.5 * rho * h * h + eta * ot_loss
# Calculating gradient for W
G_W = - 1.0 / X.shape[0] * X.T @ R
G_smooth = G_W + (rho * h + alpha) * G_h + eta * G
g_obj = np.concatenate((G_smooth + lambda1, - G_smooth + lambda1), axis=None)
return obj , g_obj
def _xfunc(imps):
imps = imps.reshape(d, d)
X = supp.X_init * (1 - supp.mask) + (supp.X_init @ imps) * supp.mask
X = X - np.mean(X, axis=0, keepdims=True) # for l2 only
W = _adj(supp.w)
M = X @ W
R = X - M
ot_loss, G = auto_ot(X, W, 'x')
obj = _loss(R) + eta * ot_loss
# Calculating gradient
I = np.eye(d,d)
g_obj = 1.0 / X.shape[0] * (R @ (I - W.T)) * supp.mask
g_obj = (supp.X_init.T @ g_obj) + eta * (supp.X_init.T @ G)
g_obj = g_obj.reshape(-1)
return obj , g_obj
n, d = X_init.shape
w_est = np.zeros(2 * d * d)
rho, alpha, h = 1.0, 0.0, np.inf
imps = np.ones(d * d)
wbnds = [(0, 0) if i == j else (0, None) for _ in range(2) for i in range(d) for j in range(d)]
ibnds = [(None, None)] * imps.shape[0]
supp.w = w_est
supp.imps = imps
for i in range(max_iter):
print(f'Iteration {i} ...')
params_new, w_new, h_new = None, None, None
while rho < rho_max:
sol = sopt.minimize(_xfunc, imps, method='L-BFGS-B', jac=True, bounds=ibnds)
imps_new = supp.imps = sol.x
sol = sopt.minimize(_wfunc, w_est, method='L-BFGS-B', jac=True, bounds=wbnds)
w_new = supp.w = sol.x
print(imps_new.max().round(5), imps_new.min().round(5), w_new.sum().round(5))
h_new, _ = _h(_adj(w_new))
if h_new > 0.25 * h:
rho *= 2
else:
break
imps, w_est, h = imps_new, w_new, h_new
params = params_new
print(f'Current h={h}')
alpha += rho * h
if h <= h_tol or rho >= rho_max:
print(f'Stopping at h={h} and rho={rho}')
break
W_est = _adj(w_est)
imps = imps.reshape(d,d)
X_filled = supp.X_init * (1 - supp.mask) + (supp.X_init @ imps) * supp.mask
return W_est, X_filled, supp.mask
if __name__ == '__main__':
import sys
from dag_methods import Notears
config_id = int(sys.argv[1])
graph_type = sys.argv[2]
sem_type = 'linear'
dataset, config = get_data(config_id, graph_type, sem_type)
n,d = dataset.X.shape
W_est, X_filled, mask = otm(dataset.X, lambda1=0.1,
max_iter=10, h_tol=1e-8, rho_max=1e+16, eta=0.01)
raw_result = evaluate(dataset.B_bin, W_est, threshold = 0.3)
# =============== WRITE GRAPH ===============
saved_path = f'output/otm_linear.txt'
write_result(raw_result, config['code'], saved_path)