-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcausalMR.R
56 lines (47 loc) · 1.38 KB
/
causalMR.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
## Written by Ruilin Li
## ruilinli@stanford.edu
## Rivas and Tibshirani Labs
## 02.12.2020
soft_thresh = function(x, thresh){
sign(x)*(pmax(0, abs(x) - thresh))
}
converged = function(x, y, x_prev, y_prev, eps){
x_converge = max(abs(x-x_prev)) < eps
y_converge = max(abs(y - y_prev)) < eps
return(x_converge & y_converge)
}
# multi-variate case
prox_grad = function(a, b, C,lambda, t=0.01, x=NULL, y=NULL, eps=0.0001){
J = length(a)
# Some preprocessing
if (is.null(x)){x = rep(0, J)}
if (is.null(y)){y = rep(0,dim(C)[2])}
# pre-compute C^TAC
cac = sweep(C, 1, sqrt(a), FUN="*")
cac = crossprod(cac)
while(TRUE){
x_prev = x
y_prev = y
dy = 2*cac %*% y - 2* t(C) %*% (a*(b-x))
dx = 2*a*x - a*(b - C%*%y)
x = soft_thresh(x - t*dx, t*lambda)
y = y - t*dy
if(converged(x, y, x_prev, y_prev, eps)){break}
}
return(list(theta0=x, thetaL=y))
}
wrapper = function(se, beta_Y, beta_X, lambda, t=0.01, x=NULL, y=NULL, eps=1e-7, standardizese = FALSE){
if(standardizese){
prox_grad(1/((se/min(se))^2), beta_Y, as.matrix(beta_X),lambda, t, x, y, eps)
}else{
prox_grad(1/(se^2), beta_Y, as.matrix(beta_X),lambda, t, x, y, eps)
}
}
# simple test case just to see the algorithm converges
# J = 20
# p = 2
# se = rnorm(J)^2+1
# beta_Y = rnorm(J)
# beta_X = matrix(rnorm(J*p), J, p)
# lambda = 0.1
# wrapper(se, beta_Y, beta_X, lambda)