Skip to content

Commit

Permalink
Merge pull request #778 from stan-dev/loo-moment-match
Browse files Browse the repository at this point in the history
Add moment-matching support to `$loo()` method
  • Loading branch information
andrjohns authored Jun 27, 2023
2 parents 83bbc35 + a6eb030 commit dbf41cd
Show file tree
Hide file tree
Showing 5 changed files with 363 additions and 5 deletions.
41 changes: 38 additions & 3 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -1398,8 +1398,11 @@ CmdStanMCMC <- R6::R6Class(
#' but will result in a warning from the \pkg{loo} package.
#' * If `r_eff` is anything else, that object will be passed as the `r_eff`
#' argument to [loo::loo.array()].
#' @param moment_match (boolean) Whether to use a moment-matching correction for
#' for problematic observations.
#' @param ... Other arguments (e.g., `cores`, `save_psis`, etc.) passed to
#' [loo::loo.array()].
#' [loo::loo.array()] or [loo::loo_moment_match.default()]
#' (if `moment_match` = `TRUE` is set).
#'
#' @return The object returned by [loo::loo.array()].
#'
Expand All @@ -1416,7 +1419,7 @@ CmdStanMCMC <- R6::R6Class(
#' print(loo_result)
#' }
#'
loo <- function(variables = "log_lik", r_eff = TRUE, ...) {
loo <- function(variables = "log_lik", r_eff = TRUE, moment_match = FALSE, ...) {
require_suggested_package("loo")
LLarray <- self$draws(variables, format = "draws_array")
if (is.logical(r_eff)) {
Expand All @@ -1427,7 +1430,39 @@ loo <- function(variables = "log_lik", r_eff = TRUE, ...) {
r_eff <- NULL
}
}
loo::loo.array(LLarray, r_eff = r_eff, ...)

if (moment_match == TRUE) {
# Moment-matching requires log-prob, constrain, and unconstrain methods
if (is.null(private$model_methods_env_$model_ptr)) {
self$init_model_methods()
}

suppressWarnings(loo_result <- loo::loo.array(LLarray, r_eff = r_eff, ...))

log_lik_i <- function(x, i, parameter_name = "log_lik", ...) {
ll_array <- x$draws(variables = parameter_name, format = "draws_array")[,,i]
# draws_array types don't drop the last dimension when it's 1, so we do this manually
attr(ll_array, "dim") <- attributes(ll_array)$dim[1:2]
ll_array
}

log_lik_i_upars <- function(x, upars, i, parameter_name = "log_lik", ...) {
apply(upars, 1, \(up_i) { x$constrain_variables(up_i)[[parameter_name]][i] })
}

loo::loo_moment_match.default(
x = self,
loo = loo_result,
post_draws = \(x, ...) { x$draws(format = "draws_matrix") },
log_lik_i = log_lik_i,
unconstrain_pars = \(x, pars, ...) { do.call(rbind, lapply(x$unconstrain_draws(), \(chain) { do.call(rbind, chain) })) },
log_prob_upars = \(x, upars, ...) { apply(upars, 1, x$log_prob) },
log_lik_i_upars = log_lik_i_upars,
...
)
} else {
loo::loo.array(LLarray, r_eff = r_eff, ...)
}
}
CmdStanMCMC$set("public", name = "loo", value = loo)

Expand Down
8 changes: 6 additions & 2 deletions man/fit-method-loo.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

272 changes: 272 additions & 0 deletions tests/testthat/resources/data/loo_moment_match.data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
{
"N": 262,
"K": 3,
"x": [
[17.549928774784245, 1, 0],
[18.200274723201296, 1, 0],
[1.2922847983320085, 1, 0],
[1.7320508075688772, 1, 0],
[1.4142135623730951, 1, 0],
[0, 1, 0],
[8.3666002653407556, 1, 0],
[8.0349237706402672, 1, 0],
[1, 0, 0],
[3.7416573867739413, 0, 0],
[11.757976016304847, 0, 0],
[4, 0, 0],
[9.8488578017961039, 0, 0],
[9.8994949366116654, 0, 0],
[6.6332495807107996, 0, 0],
[21.213203435596427, 0, 0],
[6.0555759428810738, 0, 0],
[8.6602540378443873, 0, 0],
[1.4142135623730951, 0, 0],
[18.506755523321747, 0, 0],
[4.7434164902525691, 0, 0],
[2, 0, 0],
[9.6953597148326587, 0, 0],
[11.494781424629178, 0, 0],
[8.9442719099991592, 0, 0],
[9.7467943448089631, 0, 0],
[5.8420886675914119, 0, 0],
[4.358898943540674, 0, 0],
[5, 0, 0],
[7.5828754440515507, 0, 0],
[1.7635192088548397, 0, 0],
[3.2403703492039302, 0, 0],
[6.8556546004010439, 0, 0],
[1.6217274740226855, 0, 0],
[14.282856857085701, 0, 0],
[0, 0, 0],
[0, 0, 0],
[15.692354826475215, 0, 0],
[8.730406634286858, 0, 0],
[3.872983346207417, 0, 0],
[3.8574603043971822, 0, 0],
[7.1414284285428504, 0, 0],
[3.872983346207417, 0, 0],
[7.245688373094719, 0, 0],
[3.7080992435478315, 0, 0],
[1.0816653826391966, 0, 0],
[1, 0, 0],
[1.4142135623730951, 0, 0],
[0, 0, 0],
[1.4142135623730951, 0, 0],
[1.4142135623730951, 0, 0],
[5.4772255750516612, 0, 0],
[2.6457513110645907, 0, 0],
[4.2426406871192848, 0, 0],
[1.4142135623730951, 0, 0],
[16.30950643030009, 0, 0],
[13.19090595827292, 1, 0],
[15.874507866387544, 1, 0],
[0.93808315196468595, 1, 0],
[19.300259065618782, 1, 0],
[6.4807406984078604, 1, 0],
[4.358898943540674, 1, 0],
[16.217274740226856, 1, 0],
[5.8309518948453007, 1, 0],
[1.3228756555322954, 1, 0],
[14.611639196202457, 1, 0],
[3.6235341863986879, 1, 0],
[12.409673645990857, 1, 0],
[4.5825756949558398, 1, 0],
[14.832396974191326, 1, 0],
[6.1903150162168643, 1, 0],
[18.774983355518586, 1, 0],
[2.6457513110645907, 1, 0],
[2, 1, 0],
[7.713624310270756, 1, 0],
[2.7386127875258306, 1, 0],
[10.62449998823474, 1, 0],
[13.114877048604001, 1, 0],
[3.6055512754639891, 1, 0],
[4.2426406871192848, 1, 0],
[0, 1, 0],
[5.196152422706632, 1, 0],
[12.165525060596439, 1, 0],
[5.6568542494923806, 1, 0],
[5.2915026221291814, 1, 0],
[0, 1, 0],
[5.2915026221291814, 1, 0],
[0, 1, 1],
[3.7416573867739413, 1, 1],
[2.2360679774997898, 1, 1],
[0, 1, 1],
[10.198039027185569, 1, 1],
[5.196152422706632, 1, 1],
[11.489125293076057, 1, 1],
[16.06237840420901, 1, 1],
[1, 1, 1],
[1.4142135623730951, 1, 1],
[2.4494897427831779, 1, 1],
[1.7320508075688772, 1, 1],
[1.7320508075688772, 1, 1],
[0, 1, 1],
[0, 1, 1],
[0, 1, 1],
[1.1180339887498949, 1, 1],
[0, 1, 1],
[4, 1, 1],
[8.2462112512353212, 1, 1],
[1, 1, 1],
[4.2426406871192848, 1, 1],
[11.120701416727274, 1, 1],
[1.4142135623730951, 1, 1],
[9.0829510622924747, 1, 1],
[0, 1, 1],
[1.1180339887498949, 1, 1],
[13.076696830622021, 1, 1],
[9.5854055730574075, 1, 1],
[2.2360679774997898, 1, 1],
[2.6457513110645907, 1, 1],
[0, 1, 1],
[2, 1, 1],
[7.3314391493075899, 1, 1],
[11.749893616539683, 1, 1],
[1, 1, 1],
[1.1180339887498949, 1, 1],
[5.2915026221291814, 1, 1],
[3.872983346207417, 1, 1],
[0.93808315196468595, 1, 1],
[0, 1, 1],
[5.7445626465380286, 1, 1],
[11.672617529928752, 1, 1],
[11.291589790636214, 1, 1],
[1.4142135623730951, 1, 1],
[1.4142135623730951, 1, 1],
[0, 1, 1],
[1.7320508075688772, 1, 1],
[6.7823299831252681, 1, 1],
[8.2462112512353212, 1, 1],
[0, 1, 1],
[7, 1, 1],
[5.2086466572421672, 1, 1],
[6.7453687816160208, 1, 0],
[0, 1, 0],
[2, 1, 0],
[0, 1, 0],
[1, 1, 0],
[0, 1, 0],
[0, 1, 0],
[3.2403703492039302, 1, 0],
[0, 1, 0],
[1.7320508075688772, 1, 0],
[0, 1, 0],
[0, 1, 0],
[0, 1, 0],
[1.9364916731037085, 1, 0],
[0, 1, 0],
[0, 1, 0],
[0, 0, 0],
[0.93808315196468595, 0, 0],
[1.4142135623730951, 1, 0],
[9, 1, 0],
[5.5677643628300215, 1, 0],
[3.1622776601683795, 1, 0],
[3.1984371183438953, 1, 0],
[0, 0, 0],
[3.6055512754639891, 1, 0],
[1, 1, 0],
[5.7662812973353983, 1, 0],
[0, 0, 0],
[7.2801098892805181, 0, 0],
[2.2360679774997898, 1, 0],
[12.529964086141668, 1, 0],
[4.8301138702933288, 1, 0],
[2.4494897427831779, 1, 0],
[3.1622776601683795, 1, 0],
[10, 1, 0],
[7.416198487095663, 1, 0],
[0, 1, 0],
[4.0311288741492746, 1, 0],
[2.7892651361962706, 1, 0],
[7.2801098892805181, 1, 0],
[1.4142135623730951, 1, 0],
[8.5440037453175304, 1, 0],
[0, 1, 0],
[0, 1, 0],
[0, 1, 0],
[1.7832554500127009, 1, 0],
[2.2360679774997898, 0, 0],
[1.7320508075688772, 0, 0],
[0, 0, 0],
[3.1622776601683795, 0, 0],
[1, 0, 0],
[0, 1, 0],
[1, 1, 0],
[3.4641016151377544, 1, 0],
[4.1231056256176606, 1, 0],
[4, 1, 0],
[1.5811388300841898, 1, 0],
[0, 1, 0],
[4.6776062254105994, 1, 0],
[13.152946437965905, 1, 0],
[10.535653752852738, 1, 0],
[5.9160797830996161, 1, 0],
[0, 1, 0],
[1.7320508075688772, 1, 0],
[1.4491376746189439, 1, 0],
[0, 1, 0],
[7, 1, 0],
[1.1180339887498949, 1, 0],
[1, 1, 0],
[1, 1, 0],
[0, 1, 0],
[0, 1, 0],
[7.3484692283495345, 1, 0],
[2.0493901531919199, 1, 0],
[7.1589105316381767, 1, 0],
[5.4772255750516612, 1, 0],
[14, 0, 0],
[0, 0, 0],
[1.4142135623730951, 0, 0],
[1.2247448713915889, 0, 0],
[9.8107084351742913, 0, 0],
[15.540270267920054, 0, 0],
[11.832159566199232, 0, 0],
[4.2426406871192848, 0, 0],
[1.7320508075688772, 0, 0],
[0.73484692283495345, 0, 0],
[9.0553851381374173, 0, 0],
[4.358898943540674, 0, 0],
[4.3301270189221936, 0, 0],
[7.1414284285428504, 0, 0],
[0, 0, 0],
[1, 0, 0],
[0, 0, 0],
[2.3323807579381204, 0, 0],
[0, 0, 0],
[1.7320508075688772, 0, 1],
[0, 0, 1],
[0, 0, 1],
[0, 0, 1],
[0, 0, 1],
[5.3619026473818039, 0, 1],
[0, 0, 1],
[1.7320508075688772, 0, 1],
[1.4142135623730951, 0, 1],
[11.61895003862225, 0, 1],
[0, 0, 1],
[8.2613558209291522, 0, 1],
[1, 0, 1],
[0, 0, 1],
[0, 0, 1],
[0, 0, 1],
[1, 0, 1],
[1, 0, 1],
[1.5811388300841898, 0, 1],
[7.1589105316381767, 0, 1],
[3.6235341863986879, 0, 1],
[0, 0, 1],
[0, 0, 1],
[0, 0, 1],
[0, 0, 1],
[0, 0, 1],
[0, 0, 1]
],
"y": [153, 127, 7, 7, 0, 0, 73, 24, 2, 2, 0, 21, 0, 179, 136, 104, 2, 5, 1, 203, 32, 1, 135, 59, 29, 120, 44, 1, 2, 193, 13, 37, 2, 0, 3, 0, 0, 15, 11, 19, 0, 19, 4, 122, 48, 0, 0, 3, 0, 9, 0, 0, 0, 12, 0, 357, 11, 60, 0, 159, 50, 48, 178, 4, 6, 0, 33, 127, 4, 63, 88, 5, 0, 0, 62, 4, 150, 38, 0, 3, 1, 14, 77, 42, 21, 1, 45, 0, 0, 0, 0, 0, 183, 28, 49, 1, 0, 0, 3, 0, 0, 0, 0, 18, 0, 0, 5, 0, 19, 5, 0, 27, 0, 0, 77, 1, 3, 2, 0, 0, 22, 102, 0, 0, 0, 0, 0, 0, 0, 4, 12, 2, 0, 0, 1, 0, 40, 0, 1, 2, 27, 0, 2, 0, 0, 0, 0, 3, 1, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 53, 69, 15, 0, 2, 4, 6, 8, 0, 0, 0, 18, 38, 0, 2, 18, 34, 1, 109, 5, 15, 0, 64, 0, 1, 0, 1, 3, 5, 7, 18, 1, 0, 0, 3, 3, 0, 19, 0, 8, 26, 50, 15, 0, 19, 5, 17, 121, 1, 0, 0, 0, 0, 4, 1, 14, 1, 25, 0, 14, 0, 59, 243, 80, 69, 14, 9, 38, 37, 48, 293, 7, 10, 19, 24, 91, 1, 0, 0, 0, 0, 148, 3, 26, 12, 77, 0, 7, 0, 1, 0, 17, 0, 7, 11, 6, 50, 1, 0, 0, 0, 171, 8],
"outcome_offset": [-0.22314355131420971, -0.51082562376599072, 0, 0, 0.13353139262452005, 0, -0.22314355131420971, 0.13353139262452005, 0, 0.13353139262452005, 0, 0, 0, -0.22314355131420971, 0, -0.22314355131420971, -0.22314355131420971, 0, 0, 0, 0, 0, 0, -0.1541506798272585, 0, 0, -0.22314355131420971, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -0.22314355131420971, 0.45198512374305638, 0.35667494393873334, 0, 0, 0, -0.22314355131420971, 0, 0, 0, 0, 0, 0, 0.13353139262452005, 0.13353139262452005, 0, 0, -0.22314355131420971, 0, -0.089612158689687402, 0, 0.25131442828090944, 0, 0, -0.25951119548508511, 0, 0.35667494393873334, 0, -0.25951119548508511, 0, 0, 0, 0, 0, -0.51082562376599072, 0, 0, 0, 0, 0.25131442828090944, 0, 0, 0, 0, 0, 0, 0.13353139262452005, 0, 0, 0, -0.22314355131420971, -0.22314355131420971, 0, 0, -0.37729423114146754, 0, 0, 0, 0, 0, 0, -0.22314355131420971, -0.1541506798272585, 0, 0, 0, 0, 0, 0, -0.916290731874155, -0.22314355131420971, 0.13353139262452005, 0, -0.22314355131420971, 0, -1.6094379124341003, 0, 0, 0, 0, 0, 0, 0, 0, 0.13353139262452005, 0.13353139262452005, 0, -0.22314355131420971, 0, -0.22314355131420971, 0, 0, 1.4552872326068431, -0.22314355131420971, -0.22314355131420971, -0.22314355131420971, 0, 0, 0.25131442828090944, 0, 0, 0, 0, 0, -0.7827593392496327, 0, 0.88730319500090338, 0.35667494393873334, 0, 0, 0.13353139262452005, 0, 0.13353139262452005, 0, 0, 0, -0.22314355131420971, 0.13353139262452005, 0, -0.22314355131420971, 0.45198512374305638, 0, 0.13353139262452005, 0.8266785731844698, 0, -0.55961578793542355, -0.22314355131420971, -0.1541506798272585, -0.22314355131420971, -0.1541506798272585, 0.8266785731844698, -0.22314355131420971, 0, -0.1541506798272585, 0, 0.028170876966697733, -0.51082562376599072, -0.1541506798272585, 0, -0.22314355131420971, 0, 0.39589565709201657, 0, 0, 0, -0.22314355131420971, -0.22314355131420971, 0, 0, -0.1541506798272585, -0.22314355131420971, 0, -0.22314355131420971, 0, 0, -0.22314355131420971, 0, 0, 0, 0.35667494393873334, 0.53899650073268446, -0.25951119548508511, -0.22314355131420971, 0, 0.61903920840622506, 0, 0, 0, 0, 0, 0, -0.22314355131420971, 0, 0, 0.8266785731844698, -0.22314355131420971, 0.35667494393873334, 0.028170876966697733, 0.53899650073268446, 0.13353139262452005, 0.13353139262452005, 0, 0, 0.13353139262452005, -0.22314355131420971, 0.35667494393873334, -0.089612158689687402, 0.13353139262452005, 0.25131442828090944, -0.51082562376599072, 0, -0.22314355131420971, 0.13353139262452005, 0, 0.35667494393873334, 0, 0, 0, 0, 0, 0, 0, 0, 0.13353139262452005, -0.1541506798272585, 0, 0, 0, 0, 0, 0, 0, -0.51082562376599072, 0.028170876966697733, 0, 0, -0.37729423114146754, -0.22314355131420971, 0, -0.22314355131420971, 0.39589565709201657, 0, 0, 0, 0],
"beta_prior_scale": 2.5,
"alpha_prior_scale": 5
}
24 changes: 24 additions & 0 deletions tests/testthat/resources/stan/loo_moment_match.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
data {
int<lower=1> K;
int<lower=1> N;
matrix[N,K] x;
array[N] int y;
vector[N] outcome_offset;

real beta_prior_scale;
real alpha_prior_scale;
}
parameters {
vector[K] beta;
real intercept;
}
model {
y ~ poisson(exp(x * beta + intercept + outcome_offset));
beta ~ normal(0,beta_prior_scale);
intercept ~ normal(0,alpha_prior_scale);
}
generated quantities {
vector[N] log_lik;
for (n in 1:N)
log_lik[n] = poisson_lpmf(y[n] | exp(x[n] * beta + intercept + outcome_offset[n]));
}
Loading

0 comments on commit dbf41cd

Please sign in to comment.