forked from aj-grant/mvmr-measurement-error
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmrest_me.R
126 lines (112 loc) · 4.52 KB
/
mrest_me.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
################################################################################
#Functions to implement the MLE approach for MVMR with measurement error
################################################################################
#Required inputs are:
#mrob: an object created using the mr_mvinput function in the MendelianRandomization package
#corX: the estimated exposure correlation matrix. If unknown, use the function mrest_me.
mrest_me_cor = function(mrob, corX, max_iter = 100, no_ini = 1){
bxhat = mrob@betaX
byhat = mrob@betaY
sebx = mrob@betaXse
seby = mrob@betaYse
p = length(byhat)
K = dim(bxhat)[2]
S = diag(seby^-2)
SigX = lapply(1:p, function(j){
S1 = diag(sebx[j, ])
S1 %*% corX %*% S1
})
l = matrix(nrow = max_iter, ncol = no_ini)
thest = matrix(nrow = K, ncol = no_ini)
for (k in 1:no_ini){
bxtilde = t(sapply(1:p, function(j){mv_norm(1, bxhat[j, ], SigX[[j]])}))
for (i in 1:100){
thest[, k] = solve(t(bxtilde) %*% S %*% bxtilde, t(bxtilde) %*% S %*% byhat)
l[i, k] = -0.5 * sum(sapply(1:p, function(j){
(byhat[j] - t(bxhat[j, ]) %*% thest[, k])^2 / (seby[j]^2 + t(thest[, k]) %*% SigX[[j]] %*% thest[, k])
}))
for (j in 1:p){
bxtilde[j, ] = t(solve(thest[, k] %*% t(thest[, k]) / seby[j]^2 + solve(SigX[[j]]), byhat[j] * thest[, k] / seby[j]^2 + solve(SigX[[j]], bxhat[j, ])))
}
if (i > 1){
if (abs(l[i] - l[(i-1)]) < 1e-4) {break}
}
}
}
k0 = which.max(apply(as.matrix(l[is.na(l[, 1]) == F, ]), 2, max))
th = thest[, k0]
v = sapply(1:p, function(j){seby[j]^2 + t(th) %*% SigX[[j]] %*% th})
e = sapply(1:p, function(j){byhat[j] - bxhat[j, ] %*% th})
t = sapply(1:p, function(j){e[j] / sqrt(v[j])})
dt = sapply(1:p, function(j){(-v[j] * bxhat[j, ] - e[j] * SigX[[j]] %*% th) / v[j]^(3/2)})
B = dt %*% t(dt)
dt2 = vector(length = p, mode = "list")
for (j in 1:p){
dt2[[j]] = matrix(nrow = K, ncol = K)
S = SigX[[j]] %*% th
for (k in 1:K){
for (l in 1:K){
dt2[[j]][k, l] = v[j]^(-3/2) * (-2 * S[l] * bxhat[j, k] + S[k] * bxhat[j, l] - e[j] * SigX[[j]][k, l]) +
3 * v[j]^(-5/2) *(v[j] * bxhat[j, k] + e[j] * S[k]) * S[l]
}
}
}
a = Reduce('+', lapply(1:p, function(j){c(t[j]) * dt2[[j]]}))
A = (dt %*% t(dt) + a)
Var = solve(A, B) %*% t(solve(A))
return(list("thest" = th, "l" = l, "Var" = Var))
}
mrest_me = function(mrob, max_iter = 100, no_ini = 1){
bxhat = as.matrix(mrob@betaX)
byhat = mrob@betaY
sebx = as.matrix(mrob@betaXse)
seby = mrob@betaYse
p = length(byhat)
K = dim(bxhat)[2]
S = diag(seby^-2)
SigX = lapply(1:p, function(j){diag(sebx[j, ]^2, length(sebx[j, ]))})
l = matrix(nrow = max_iter, ncol = no_ini)
thest = matrix(nrow = K, ncol = no_ini)
for (k in 1:no_ini){
bxtilde = t(matrix(sapply(1:p, function(j){mv_norm(1, bxhat[j, ], SigX[[j]])}), ncol = p))
for (i in 1:100){
thest[, k] = solve(t(bxtilde) %*% S %*% bxtilde, t(bxtilde) %*% S %*% byhat)
l[i, k] = -0.5 * sum(sapply(1:p, function(j){
(byhat[j] - t(bxhat[j, ]) %*% thest[, k])^2 / (seby[j]^2 + t(thest[, k]) %*% SigX[[j]] %*% thest[, k])
}))
for (j in 1:p){
bxtilde[j, ] = t(solve(thest[, k] %*% t(thest[, k]) / seby[j]^2 + solve(SigX[[j]]), byhat[j] * thest[, k] / seby[j]^2 + solve(SigX[[j]], bxhat[j, ])))
}
if (i > 1){
if (abs(l[i] - l[(i-1)]) < 1e-4) {break}
}
}
}
k0 = which.max(apply(as.matrix(l[is.na(l[, 1]) == F, ]), 2, max))
th = thest[, k0]
v = sapply(1:p, function(j){seby[j]^2 + t(th) %*% SigX[[j]] %*% th})
e = sapply(1:p, function(j){byhat[j] - bxhat[j, ] %*% th})
t = sapply(1:p, function(j){e[j] / sqrt(v[j])})
dt = matrix(sapply(1:p, function(j){(-v[j] * bxhat[j, ] - e[j] * SigX[[j]] %*% th) / v[j]^(3/2)}), ncol = p)
B = dt %*% t(dt)
dt2 = vector(length = p, mode = "list")
for (j in 1:p){
dt2[[j]] = matrix(nrow = K, ncol = K)
S = SigX[[j]] %*% th
for (k in 1:K){
for (l in 1:K){
dt2[[j]][k, l] = v[j]^(-3/2) * (-2 * S[l] * bxhat[j, k] + S[k] * bxhat[j, l] - e[j] * SigX[[j]][k, l]) +
3 * v[j]^(-5/2) *(v[j] * bxhat[j, k] + e[j] * S[k]) * S[l]
}
}
}
a = Reduce('+', lapply(1:p, function(j){c(t[j]) * dt2[[j]]}))
A = (dt %*% t(dt) + a)
Var = solve(A, B) %*% t(solve(A))
return(list("thest" = th, "l" = l, "Var" = Var))
}
ll_me = function(theta, byhat, bxhat, seby, SigX){
0.5 * sum(sapply(1:length(byhat), function(j){
(byhat[j] - (bxhat[j, ]) %*% theta)^2 / (seby[j]^2 + t(theta) %*% SigX[[j]] %*% theta)
}))
}